// Package network 提供了一个高性能的网络通信框架,支持TCP和UDP协议。 // 该框架采用工作池和连接分片管理来实现高并发,支持自定义数据包格式和处理函数。 // 主要特点包括:高并发连接处理、连接生命周期管理、优雅关闭机制和缓冲区复用。 package network import ( "bufio" "bytes" "encoding/binary" "errors" "io" "net" "sync" "time" ) // ========== 框架核心接口 ========== // Packet 用户自定义数据包必须实现的接口 // 用户需要实现这个接口来定义自己的数据包格式,包括编码和解码方法 type Packet interface { // Encode 将数据包编码为字节流 // 返回编码后的字节切片和可能的错误 Encode() ([]byte, error) // Decode 从字节流解码为数据包 // 参数data包含要解码的字节数据 // 返回可能的解码错误 Decode([]byte) error } // Handler 用户自定义的数据处理函数类型 // 当收到数据包时,框架会调用这个函数来处理数据 // 参数conn是产生数据的网络连接,p是解码后的数据包 // 返回响应数据包和可能的错误 type Handler func(conn net.Conn, p Packet) (Packet, error) // ========== 框架核心结构 ========== // ServerConfig 服务器配置 // 包含服务器运行所需的各种参数设置 type ServerConfig struct { Network string // 网络类型: tcp, tcp4, tcp6, udp, udp4, udp6 Address string // 监听地址,格式为 "ip:port",如 ":8080" MaxConn int // 最大连接数 (仅TCP有效) WorkerNum int // 工作协程数量,默认为CPU核心数的2倍 QueueSize int // 任务队列大小,默认为1024 ReadTimeout time.Duration // 读取超时时间,0表示不设置超时 IdleTimeout time.Duration // 空闲连接超时时间,0表示不清理空闲连接 } // connectionShard 连接分片 // 用于高效管理大量连接,减少锁竞争 type connectionShard struct { conns sync.Map // 存储连接的并发安全map,键为net.Conn,值为*Connection lock sync.RWMutex // 读写锁,保护分片操作 lastUsed time.Time // 最后使用时间,用于分片管理 } // Connection 连接封装 // 封装了底层网络连接,提供了缓冲区和通道管理 type Connection struct { net.Conn // 嵌入底层网络连接 readBuffer *bytes.Buffer // 读缓冲区 writeBuffer *bytes.Buffer // 写缓冲区 writeChan chan []byte // 写入通道,用于异步写入 closeChan chan struct{} // 关闭通道,用于通知协程退出 server *HighConcurrentServer // 所属服务器 lastActive time.Time // 最后活动时间,用于空闲检测 } // WorkerPool 工作协程池 // 用于高效处理网络请求,避免为每个连接创建协程 type WorkerPool struct { taskQueue chan func() // 任务队列,存储待执行的函数 size int // 工作协程数量 } // ========== 框架实现 ========== // readLoop 连接读循环 func (c *Connection) readLoop() { reader := bufio.NewReader(c.Conn) var header [4]byte for { select { case <-c.closeChan: return default: } // 设置读超时 if c.server.config.ReadTimeout > 0 { c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout)) } // 读取包头 (4字节长度) _, err := io.ReadFull(reader, header[:]) if err != nil { if err != io.EOF { // 处理错误 } close(c.closeChan) return } // 解析包长度 length := binary.BigEndian.Uint32(header[:]) // 检查包长度是否合理 if length > 10*1024*1024 { // 10MB close(c.closeChan) return } // 读取包体 packetData := make([]byte, length) _, err = io.ReadFull(reader, packetData) if err != nil { close(c.closeChan) return } c.lastActive = time.Now() // 提交给工作池处理 c.server.workerPool.Submit(func() { c.processPacket(packetData) }) } } // processPacket 处理数据包 func (c *Connection) processPacket(data []byte) { // 创建新的数据包实例 packet := c.server.packetType() // 解码数据 if err := packet.Decode(data); err != nil { // 处理解码错误 return } // 调用用户处理函数 response, err := c.server.handler(c.Conn, packet) if err != nil { // 处理错误 return } // 如果有响应,发送回去 if response != nil { c.Send(response) } } // Send 发送数据包 func (c *Connection) Send(p Packet) error { data, err := p.Encode() if err != nil { return err } select { case c.writeChan <- data: return nil default: return errors.New("write channel full") } } // writeLoop 连接写循环 func (c *Connection) writeLoop() { // 批量写入优化 var batch [][]byte ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case data := <-c.writeChan: batch = append(batch, data) // 尝试批量处理 for len(c.writeChan) > 0 && len(batch) < 32 { batch = append(batch, <-c.writeChan) } // 合并写入 c.writeBuffer.Reset() for _, d := range batch { // 写入长度头 header := make([]byte, 4) binary.BigEndian.PutUint32(header, uint32(len(d))) c.writeBuffer.Write(header) c.writeBuffer.Write(d) } // 发送数据 if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil { close(c.closeChan) return } batch = batch[:0] case <-ticker.C: // 定时刷新 if len(batch) > 0 { c.writeBuffer.Reset() for _, d := range batch { header := make([]byte, 4) binary.BigEndian.PutUint32(header, uint32(len(d))) c.writeBuffer.Write(header) c.writeBuffer.Write(d) } if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil { close(c.closeChan) return } batch = batch[:0] } // 检查空闲超时 if c.server.config.IdleTimeout > 0 && time.Since(c.lastActive) > c.server.config.IdleTimeout { close(c.closeChan) return } case <-c.closeChan: return } } } // ========== 辅助结构和方法 ========== // NewWorkerPool 创建工作池 func NewWorkerPool(size, queueSize int) *WorkerPool { pool := &WorkerPool{ taskQueue: make(chan func(), queueSize), size: size, } for i := 0; i < size; i++ { go pool.worker() } return pool } func (p *WorkerPool) worker() { for task := range p.taskQueue { task() } } func (p *WorkerPool) Submit(task func()) { select { case p.taskQueue <- task: default: // 任务队列满时的处理 } } func (p *WorkerPool) Stop() { close(p.taskQueue) } // udpConn UDP虚拟连接 type udpConn struct { net.PacketConn addr net.Addr } func (c *udpConn) Read(b []byte) (int, error) { return 0, errors.New("not implemented") } func (c *udpConn) Write(b []byte) (int, error) { return c.PacketConn.WriteTo(b, c.addr) } func (c *udpConn) Close() error { return nil } func (c *udpConn) LocalAddr() net.Addr { return c.PacketConn.LocalAddr() } func (c *udpConn) RemoteAddr() net.Addr { return c.addr } func (c *udpConn) SetDeadline(t time.Time) error { return c.PacketConn.SetDeadline(t) } func (c *udpConn) SetReadDeadline(t time.Time) error { return c.PacketConn.SetReadDeadline(t) } func (c *udpConn) SetWriteDeadline(t time.Time) error { return c.PacketConn.SetWriteDeadline(t) } func (s *connectionShard) addConn(conn *Connection) { s.lock.Lock() defer s.lock.Unlock() s.conns.Store(conn.Conn, conn) s.lastUsed = time.Now() } func (s *connectionShard) removeConn(conn *Connection) { s.lock.Lock() defer s.lock.Unlock() s.conns.Delete(conn.Conn) s.lastUsed = time.Now() } // ========== 用户使用示例 ========== /* // 用户自定义数据包 type MyPacket struct { ID uint32 Payload string } func (p *MyPacket) Encode() ([]byte, error) { buf := bytes.NewBuffer(nil) binary.Write(buf, binary.BigEndian, p.ID) binary.Write(buf, binary.BigEndian, uint32(len(p.Payload))) buf.WriteString(p.Payload) return buf.Bytes(), nil } func (p *MyPacket) Decode(data []byte) error { buf := bytes.NewBuffer(data) binary.Read(buf, binary.BigEndian, &p.ID) var length uint32 binary.Read(buf, binary.BigEndian, &length) p.Payload = string(buf.Next(int(length))) return nil } // 用户自定义处理函数 func myHandler(conn net.Conn, p Packet) (Packet, error) { myPacket, ok := p.(*MyPacket) if !ok { return nil, errors.New("invalid packet type") } // 处理业务逻辑 response := &MyPacket{ ID: myPacket.ID + 1, Payload: "Response: " + myPacket.Payload, } return response, nil } func main() { // 创建服务器配置 config := network.ServerConfig{ Network: "tcp", Address: ":8080", MaxConn: 10000, WorkerNum: 100, QueueSize: 10000, ReadTimeout: 30 * time.Second, IdleTimeout: 300 * time.Second, } // 创建服务器实例 server := network.NewServer(config, myHandler, func() network.Packet { return &MyPacket{} }) // 启动服务器 if err := server.Start(); err != nil { panic(err) } // 等待关闭信号 sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) <-sig // 优雅关闭 server.Stop() } */