// Package network 提供了一个高性能的网络通信框架,支持TCP和UDP协议。 package network import ( "bytes" "fmt" "net" "runtime" "sync" "sync/atomic" "time" ) // HighConcurrentServer 高并发服务器 // 实现了一个高性能、高并发的网络服务器,支持TCP和UDP协议 // 主要特点包括: // - 连接分片管理,减少锁竞争 // - 工作协程池,高效处理请求 // - 自动管理连接生命周期 // - 优雅关闭机制 type HighConcurrentServer struct { config ServerConfig // 服务器配置 listener net.Listener // TCP监听器 packetConn net.PacketConn // UDP包连接 workerPool *WorkerPool // 工作协程池 shards []*connectionShard // 连接分片数组 shardCount int // 分片数量 handler Handler // 用户自定义处理函数 packetType func() Packet // 用于创建新Packet实例的工厂函数 activeConns int64 // 当前活动连接数 shutdown chan struct{} // 关闭信号通道 wg sync.WaitGroup // 等待组,用于优雅关闭 } // NewServer 创建新的高并发服务器 // config: 服务器配置 // handler: 用户自定义的数据包处理函数 // packetType: 用于创建新的数据包实例的工厂函数 // 返回初始化好的服务器实例,但尚未启动 func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer { // 设置默认值 if config.WorkerNum <= 0 { config.WorkerNum = runtime.NumCPU() * 2 // 默认为CPU核心数的2倍 } if config.QueueSize <= 0 { config.QueueSize = 1024 // 默认队列大小为1024 } if config.MaxConn <= 0 { config.MaxConn = 100000 // 默认最大连接数为10万 } // 计算合适的分片数量,至少为32,且不小于CPU核心数 // 分片数量为2的幂,便于哈希分配 shardCount := 32 for shardCount < runtime.NumCPU() { shardCount *= 2 } server := &HighConcurrentServer{ config: config, shardCount: shardCount, handler: handler, packetType: packetType, shutdown: make(chan struct{}), shards: make([]*connectionShard, shardCount), } // 初始化连接分片 for i := 0; i < shardCount; i++ { server.shards[i] = &connectionShard{} } // 初始化工作协程池 server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize) return server } // Start 启动服务器 // 根据配置的网络类型启动相应的服务器(TCP或UDP) // 返回可能的启动错误 func (s *HighConcurrentServer) Start() error { switch s.config.Network { case "tcp", "tcp4", "tcp6": return s.startTCP() // 启动TCP服务器 case "udp", "udp4", "udp6": return s.startUDP() // 启动UDP服务器 default: return fmt.Errorf("unsupported network type: %s", s.config.Network) } } // startTCP 启动TCP服务器 // 创建监听器并启动接受连接循环和连接管理协程 // 返回可能的启动错误 func (s *HighConcurrentServer) startTCP() error { // 创建TCP监听器 ln, err := net.Listen(s.config.Network, s.config.Address) if err != nil { return err } s.listener = ln // 启动接受连接循环 s.wg.Add(1) go s.acceptLoop() // 启动连接管理(清理空闲连接) s.wg.Add(1) go s.manageConnections() return nil } // startUDP 启动UDP服务器 // 创建UDP包连接并启动读取循环 // 返回可能的启动错误 func (s *HighConcurrentServer) startUDP() error { // 创建UDP包连接 pc, err := net.ListenPacket(s.config.Network, s.config.Address) if err != nil { return err } s.packetConn = pc // 启动UDP读取循环 s.wg.Add(1) go s.udpReadLoop() return nil } // acceptLoop TCP接受连接循环 // 持续接受新的TCP连接并为每个连接创建处理协程 // 当收到关闭信号或发生错误时退出 func (s *HighConcurrentServer) acceptLoop() { defer s.wg.Done() for { // 检查是否收到关闭信号 select { case <-s.shutdown: return default: } // 接受新连接 conn, err := s.listener.Accept() if err != nil { // 处理临时错误,如"too many open files" if ne, ok := err.(net.Error); ok && ne.Temporary() { time.Sleep(100 * time.Millisecond) // 短暂休眠后重试 continue } return // 非临时错误,退出循环 } // 检查是否超过最大连接数限制 if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) { conn.Close() // 超过限制,直接关闭连接 continue } // 增加活动连接计数并启动连接处理协程 atomic.AddInt64(&s.activeConns, 1) s.wg.Add(1) go s.handleNewConnection(conn) } } // handleNewConnection 处理新连接 // conn: 新接受的TCP连接 // 为新连接创建Connection对象,启动读写循环,并管理连接的生命周期 func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) { defer s.wg.Done() defer atomic.AddInt64(&s.activeConns, -1) // 减少活动连接计数 defer conn.Close() // 确保连接关闭 // 获取连接对应的分片,用于减少锁竞争 shard := s.getShard(conn) c := &Connection{ Conn: conn, readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化读缓冲区,容量为4KB writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化写缓冲区,容量为4KB writeChan: make(chan []byte, 32), // 写入通道,缓冲大小为32 closeChan: make(chan struct{}), // 关闭信号通道 server: s, lastActive: time.Now(), // 记录初始活动时间 } // 将连接添加到分片管理中 shard.addConn(c) // 启动读写协程 go c.readLoop() // 读取循环在单独的协程中运行 c.writeLoop() // 写入循环在当前协程中运行 // 等待关闭信号 <-c.closeChan shard.removeConn(c) // 从分片中移除连接 } // getShard 获取连接对应的分片 // conn: 网络连接 // 返回该连接应该被分配到的分片 // 使用连接的远程IP地址进行哈希来确定分片,以实现负载均衡 func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard { // 尝试获取TCP地址 addr, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { return s.shards[0] // 非TCP连接返回第一个分片 } // 使用IP地址的四个字节进行简单哈希 // 将IP地址的四个字节相加作为哈希值 hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3] return s.shards[int(hash)%s.shardCount] // 使用取模运算确定分片索引 } // udpReadLoop UDP读取循环 // 持续从UDP连接读取数据包并处理 // 使用sync.Pool复用缓冲区以减少内存分配 func (s *HighConcurrentServer) udpReadLoop() { defer s.wg.Done() defer s.packetConn.Close() // 创建缓冲区池,每个缓冲区大小为64KB(UDP最大包大小) bufPool := sync.Pool{ New: func() interface{} { return make([]byte, 65536) }, } for { // 检查是否收到关闭信号 select { case <-s.shutdown: return default: } // 从池中获取缓冲区 buf := bufPool.Get().([]byte) n, addr, err := s.packetConn.ReadFrom(buf) if err != nil { // 处理临时错误 if ne, ok := err.(net.Error); ok && ne.Temporary() { time.Sleep(100 * time.Millisecond) continue } return } // 启动新的协程处理数据包 s.wg.Add(1) go func(data []byte, addr net.Addr) { defer s.wg.Done() defer bufPool.Put(data) // 将缓冲区放回池中 // 解码数据包 packet := s.packetType() if err := packet.Decode(data[:n]); err != nil { return } // 创建虚拟UDP连接 conn := &udpConn{PacketConn: s.packetConn, addr: addr} // 调用用户处理函数 response, err := s.handler(conn, packet) if err != nil { return } // 发送响应(如果有) if response != nil { respData, err := response.Encode() if err == nil { s.packetConn.WriteTo(respData, addr) } } }(buf, addr) } } // manageConnections 管理连接(关闭空闲连接) // 定期检查并关闭超过空闲超时时间的连接 // 当服务器关闭时退出 func (s *HighConcurrentServer) manageConnections() { defer s.wg.Done() // 创建定时器,每30秒检查一次 ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-s.shutdown: return case <-ticker.C: // 如果未设置空闲超时,跳过检查 if s.config.IdleTimeout <= 0 { continue } now := time.Now() // 遍历所有分片 for _, shard := range s.shards { shard.lock.RLock() // 遍历分片中的所有连接 shard.conns.Range(func(key, value interface{}) bool { conn := value.(*Connection) // 检查连接是否超过空闲超时时间 if now.Sub(conn.lastActive) > s.config.IdleTimeout { // 发送关闭信号 select { case conn.closeChan <- struct{}{}: default: } } return true }) shard.lock.RUnlock() } } } } // Stop 停止服务器 func (s *HighConcurrentServer) Stop() { close(s.shutdown) if s.listener != nil { s.listener.Close() } if s.packetConn != nil { s.packetConn.Close() } // 关闭所有连接 for _, shard := range s.shards { shard.lock.RLock() shard.conns.Range(func(key, value interface{}) bool { conn := value.(*Connection) select { case conn.closeChan <- struct{}{}: default: } return true }) shard.lock.RUnlock() } s.wg.Wait() s.workerPool.Stop() }