package network import ( "bytes" "fmt" "net" "runtime" "sync" "sync/atomic" "time" ) // HighConcurrentServer 高并发服务器 type HighConcurrentServer struct { config ServerConfig listener net.Listener packetConn net.PacketConn workerPool *WorkerPool shards []*connectionShard shardCount int handler Handler packetType func() Packet // 用于创建新Packet实例的函数 activeConns int64 shutdown chan struct{} wg sync.WaitGroup } // NewServer 创建新的高并发服务器 func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer { // 设置默认值 if config.WorkerNum <= 0 { config.WorkerNum = runtime.NumCPU() * 2 } if config.QueueSize <= 0 { config.QueueSize = 1024 } if config.MaxConn <= 0 { config.MaxConn = 100000 } // 计算合适的分片数量 shardCount := 32 for shardCount < runtime.NumCPU() { shardCount *= 2 } server := &HighConcurrentServer{ config: config, shardCount: shardCount, handler: handler, packetType: packetType, shutdown: make(chan struct{}), shards: make([]*connectionShard, shardCount), } // 初始化分片 for i := 0; i < shardCount; i++ { server.shards[i] = &connectionShard{} } // 初始化工作池 server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize) return server } // Start 启动服务器 func (s *HighConcurrentServer) Start() error { switch s.config.Network { case "tcp", "tcp4", "tcp6": return s.startTCP() case "udp", "udp4", "udp6": return s.startUDP() default: return fmt.Errorf("unsupported network type: %s", s.config.Network) } } // startTCP 启动TCP服务器 func (s *HighConcurrentServer) startTCP() error { ln, err := net.Listen(s.config.Network, s.config.Address) if err != nil { return err } s.listener = ln s.wg.Add(1) go s.acceptLoop() // 启动连接管理 s.wg.Add(1) go s.manageConnections() return nil } // startUDP 启动UDP服务器 func (s *HighConcurrentServer) startUDP() error { pc, err := net.ListenPacket(s.config.Network, s.config.Address) if err != nil { return err } s.packetConn = pc s.wg.Add(1) go s.udpReadLoop() return nil } // acceptLoop TCP接受连接循环 func (s *HighConcurrentServer) acceptLoop() { defer s.wg.Done() for { select { case <-s.shutdown: return default: } conn, err := s.listener.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { time.Sleep(100 * time.Millisecond) continue } return } // 检查最大连接数 if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) { conn.Close() continue } atomic.AddInt64(&s.activeConns, 1) s.wg.Add(1) go s.handleNewConnection(conn) } } // handleNewConnection 处理新连接 func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) { defer s.wg.Done() defer atomic.AddInt64(&s.activeConns, -1) defer conn.Close() // 获取分片 shard := s.getShard(conn) c := &Connection{ Conn: conn, readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), writeChan: make(chan []byte, 32), closeChan: make(chan struct{}), server: s, lastActive: time.Now(), } // 添加到连接管理 shard.addConn(c) // 启动读写协程 go c.readLoop() c.writeLoop() // 等待关闭 <-c.closeChan shard.removeConn(c) } // getShard 获取连接对应的分片 func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard { // 使用连接的远程地址作为分片键 addr, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { return s.shards[0] } // 简单哈希算法 hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3] return s.shards[int(hash)%s.shardCount] } // udpReadLoop UDP读取循环 func (s *HighConcurrentServer) udpReadLoop() { defer s.wg.Done() defer s.packetConn.Close() bufPool := sync.Pool{ New: func() interface{} { return make([]byte, 65536) }, } for { select { case <-s.shutdown: return default: } buf := bufPool.Get().([]byte) n, addr, err := s.packetConn.ReadFrom(buf) if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { time.Sleep(100 * time.Millisecond) continue } return } s.wg.Add(1) go func(data []byte, addr net.Addr) { defer s.wg.Done() defer bufPool.Put(data) // 处理UDP数据包 packet := s.packetType() if err := packet.Decode(data[:n]); err != nil { return } // 创建虚拟连接 conn := &udpConn{PacketConn: s.packetConn, addr: addr} // 调用用户处理函数 response, err := s.handler(conn, packet) if err != nil { return } // 发送响应 if response != nil { respData, err := response.Encode() if err == nil { s.packetConn.WriteTo(respData, addr) } } }(buf, addr) } } // manageConnections 管理连接(关闭空闲连接) func (s *HighConcurrentServer) manageConnections() { defer s.wg.Done() ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-s.shutdown: return case <-ticker.C: if s.config.IdleTimeout <= 0 { continue } now := time.Now() for _, shard := range s.shards { shard.lock.RLock() shard.conns.Range(func(key, value interface{}) bool { conn := value.(*Connection) if now.Sub(conn.lastActive) > s.config.IdleTimeout { select { case conn.closeChan <- struct{}{}: default: } } return true }) shard.lock.RUnlock() } } } } // Stop 停止服务器 func (s *HighConcurrentServer) Stop() { close(s.shutdown) if s.listener != nil { s.listener.Close() } if s.packetConn != nil { s.packetConn.Close() } // 关闭所有连接 for _, shard := range s.shards { shard.lock.RLock() shard.conns.Range(func(key, value interface{}) bool { conn := value.(*Connection) select { case conn.closeChan <- struct{}{}: default: } return true }) shard.lock.RUnlock() } s.wg.Wait() s.workerPool.Stop() }