network/network.go

395 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package network 提供了一个高性能的网络通信框架支持TCP和UDP协议。
// 该框架采用工作池和连接分片管理来实现高并发,支持自定义数据包格式和处理函数。
// 主要特点包括:高并发连接处理、连接生命周期管理、优雅关闭机制和缓冲区复用。
package network
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
// ========== 框架核心接口 ==========
// Packet 用户自定义数据包必须实现的接口
// 用户需要实现这个接口来定义自己的数据包格式,包括编码和解码方法
type Packet interface {
// Encode 将数据包编码为字节流
// 返回编码后的字节切片和可能的错误
Encode() ([]byte, error)
// Decode 从字节流解码为数据包
// 参数data包含要解码的字节数据
// 返回可能的解码错误
Decode([]byte) error
}
// Handler 用户自定义的数据处理函数类型
// 当收到数据包时,框架会调用这个函数来处理数据
// 参数conn是产生数据的网络连接p是解码后的数据包
// 返回响应数据包和可能的错误
type Handler func(conn net.Conn, p Packet) (Packet, error)
// ========== 框架核心结构 ==========
// ServerConfig 服务器配置
// 包含服务器运行所需的各种参数设置
type ServerConfig struct {
Network string // 网络类型: tcp, tcp4, tcp6, udp, udp4, udp6
Address string // 监听地址,格式为 "ip:port",如 ":8080"
MaxConn int // 最大连接数 (仅TCP有效)
WorkerNum int // 工作协程数量默认为CPU核心数的2倍
QueueSize int // 任务队列大小默认为1024
ReadTimeout time.Duration // 读取超时时间0表示不设置超时
IdleTimeout time.Duration // 空闲连接超时时间0表示不清理空闲连接
}
// connectionShard 连接分片
// 用于高效管理大量连接,减少锁竞争
type connectionShard struct {
conns sync.Map // 存储连接的并发安全map键为net.Conn值为*Connection
lock sync.RWMutex // 读写锁,保护分片操作
lastUsed time.Time // 最后使用时间,用于分片管理
}
// Connection 连接封装
// 封装了底层网络连接,提供了缓冲区和通道管理
type Connection struct {
net.Conn // 嵌入底层网络连接
readBuffer *bytes.Buffer // 读缓冲区
writeBuffer *bytes.Buffer // 写缓冲区
writeChan chan []byte // 写入通道,用于异步写入
closeChan chan struct{} // 关闭通道,用于通知协程退出
server *HighConcurrentServer // 所属服务器
lastActive time.Time // 最后活动时间,用于空闲检测
}
// WorkerPool 工作协程池
// 用于高效处理网络请求,避免为每个连接创建协程
type WorkerPool struct {
taskQueue chan func() // 任务队列,存储待执行的函数
size int // 工作协程数量
}
// ========== 框架实现 ==========
// readLoop 连接读循环
func (c *Connection) readLoop() {
reader := bufio.NewReader(c.Conn)
var header [4]byte
for {
select {
case <-c.closeChan:
return
default:
}
// 设置读超时
if c.server.config.ReadTimeout > 0 {
c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
}
// 读取包头 (4字节长度)
_, err := io.ReadFull(reader, header[:])
if err != nil {
if err != io.EOF {
// 处理错误
}
close(c.closeChan)
return
}
// 解析包长度
length := binary.BigEndian.Uint32(header[:])
// 检查包长度是否合理
if length > 10*1024*1024 { // 10MB
close(c.closeChan)
return
}
// 读取包体
packetData := make([]byte, length)
_, err = io.ReadFull(reader, packetData)
if err != nil {
close(c.closeChan)
return
}
c.lastActive = time.Now()
// 提交给工作池处理
c.server.workerPool.Submit(func() {
c.processPacket(packetData)
})
}
}
// processPacket 处理数据包
func (c *Connection) processPacket(data []byte) {
// 创建新的数据包实例
packet := c.server.packetType()
// 解码数据
if err := packet.Decode(data); err != nil {
// 处理解码错误
return
}
// 调用用户处理函数
response, err := c.server.handler(c.Conn, packet)
if err != nil {
// 处理错误
return
}
// 如果有响应,发送回去
if response != nil {
c.Send(response)
}
}
// Send 发送数据包
func (c *Connection) Send(p Packet) error {
data, err := p.Encode()
if err != nil {
return err
}
select {
case c.writeChan <- data:
return nil
default:
return errors.New("write channel full")
}
}
// writeLoop 连接写循环
func (c *Connection) writeLoop() {
// 批量写入优化
var batch [][]byte
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case data := <-c.writeChan:
batch = append(batch, data)
// 尝试批量处理
for len(c.writeChan) > 0 && len(batch) < 32 {
batch = append(batch, <-c.writeChan)
}
// 合并写入
c.writeBuffer.Reset()
for _, d := range batch {
// 写入长度头
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(len(d)))
c.writeBuffer.Write(header)
c.writeBuffer.Write(d)
}
// 发送数据
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
close(c.closeChan)
return
}
batch = batch[:0]
case <-ticker.C:
// 定时刷新
if len(batch) > 0 {
c.writeBuffer.Reset()
for _, d := range batch {
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(len(d)))
c.writeBuffer.Write(header)
c.writeBuffer.Write(d)
}
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
close(c.closeChan)
return
}
batch = batch[:0]
}
// 检查空闲超时
if c.server.config.IdleTimeout > 0 &&
time.Since(c.lastActive) > c.server.config.IdleTimeout {
close(c.closeChan)
return
}
case <-c.closeChan:
return
}
}
}
// ========== 辅助结构和方法 ==========
// NewWorkerPool 创建工作池
func NewWorkerPool(size, queueSize int) *WorkerPool {
pool := &WorkerPool{
taskQueue: make(chan func(), queueSize),
size: size,
}
for i := 0; i < size; i++ {
go pool.worker()
}
return pool
}
func (p *WorkerPool) worker() {
for task := range p.taskQueue {
task()
}
}
func (p *WorkerPool) Submit(task func()) {
select {
case p.taskQueue <- task:
default:
// 任务队列满时的处理
}
}
func (p *WorkerPool) Stop() {
close(p.taskQueue)
}
// udpConn UDP虚拟连接
type udpConn struct {
net.PacketConn
addr net.Addr
}
func (c *udpConn) Read(b []byte) (int, error) {
return 0, errors.New("not implemented")
}
func (c *udpConn) Write(b []byte) (int, error) {
return c.PacketConn.WriteTo(b, c.addr)
}
func (c *udpConn) Close() error {
return nil
}
func (c *udpConn) LocalAddr() net.Addr {
return c.PacketConn.LocalAddr()
}
func (c *udpConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *udpConn) SetDeadline(t time.Time) error {
return c.PacketConn.SetDeadline(t)
}
func (c *udpConn) SetReadDeadline(t time.Time) error {
return c.PacketConn.SetReadDeadline(t)
}
func (c *udpConn) SetWriteDeadline(t time.Time) error {
return c.PacketConn.SetWriteDeadline(t)
}
func (s *connectionShard) addConn(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
s.conns.Store(conn.Conn, conn)
s.lastUsed = time.Now()
}
func (s *connectionShard) removeConn(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
s.conns.Delete(conn.Conn)
s.lastUsed = time.Now()
}
// ========== 用户使用示例 ==========
/*
// 用户自定义数据包
type MyPacket struct {
ID uint32
Payload string
}
func (p *MyPacket) Encode() ([]byte, error) {
buf := bytes.NewBuffer(nil)
binary.Write(buf, binary.BigEndian, p.ID)
binary.Write(buf, binary.BigEndian, uint32(len(p.Payload)))
buf.WriteString(p.Payload)
return buf.Bytes(), nil
}
func (p *MyPacket) Decode(data []byte) error {
buf := bytes.NewBuffer(data)
binary.Read(buf, binary.BigEndian, &p.ID)
var length uint32
binary.Read(buf, binary.BigEndian, &length)
p.Payload = string(buf.Next(int(length)))
return nil
}
// 用户自定义处理函数
func myHandler(conn net.Conn, p Packet) (Packet, error) {
myPacket, ok := p.(*MyPacket)
if !ok {
return nil, errors.New("invalid packet type")
}
// 处理业务逻辑
response := &MyPacket{
ID: myPacket.ID + 1,
Payload: "Response: " + myPacket.Payload,
}
return response, nil
}
func main() {
// 创建服务器配置
config := network.ServerConfig{
Network: "tcp",
Address: ":8080",
MaxConn: 10000,
WorkerNum: 100,
QueueSize: 10000,
ReadTimeout: 30 * time.Second,
IdleTimeout: 300 * time.Second,
}
// 创建服务器实例
server := network.NewServer(config, myHandler, func() network.Packet {
return &MyPacket{}
})
// 启动服务器
if err := server.Start(); err != nil {
panic(err)
}
// 等待关闭信号
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
// 优雅关闭
server.Stop()
}
*/