diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1b8be03 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.kingecg.top/kingecg/network + +go 1.23.1 diff --git a/network.go b/network.go new file mode 100644 index 0000000..db058eb --- /dev/null +++ b/network.go @@ -0,0 +1,379 @@ +package network + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "net" + "sync" + "time" +) + +// ========== 框架核心接口 ========== + +// Packet 用户自定义数据包必须实现的接口 +type Packet interface { + // 编码为字节流 + Encode() ([]byte, error) + // 从字节流解码 + Decode([]byte) error +} + +// Handler 用户自定义的数据处理函数类型 +type Handler func(conn net.Conn, p Packet) (Packet, error) + +// ========== 框架核心结构 ========== + +// ServerConfig 服务器配置 +type ServerConfig struct { + Network string // tcp, tcp4, tcp6, udp, udp4, udp6 + Address string // 监听地址 + MaxConn int // 最大连接数 (TCP only) + WorkerNum int // 工作协程数量 + QueueSize int // 任务队列大小 + ReadTimeout time.Duration // 读取超时 + IdleTimeout time.Duration // 空闲超时 +} + +// connectionShard 连接分片 +type connectionShard struct { + conns sync.Map // map[net.Conn]*Connection + lock sync.RWMutex + lastUsed time.Time +} + +// Connection 连接封装 +type Connection struct { + net.Conn + readBuffer *bytes.Buffer + writeBuffer *bytes.Buffer + writeChan chan []byte + closeChan chan struct{} + server *HighConcurrentServer + lastActive time.Time +} + +// WorkerPool 工作协程池 +type WorkerPool struct { + taskQueue chan func() + size int +} + +// ========== 框架实现 ========== + +// readLoop 连接读循环 +func (c *Connection) readLoop() { + reader := bufio.NewReader(c.Conn) + var header [4]byte + + for { + select { + case <-c.closeChan: + return + default: + } + + // 设置读超时 + if c.server.config.ReadTimeout > 0 { + c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout)) + } + + // 读取包头 (4字节长度) + _, err := io.ReadFull(reader, header[:]) + if err != nil { + if err != io.EOF { + // 处理错误 + } + close(c.closeChan) + return + } + + // 解析包长度 + length := binary.BigEndian.Uint32(header[:]) + + // 检查包长度是否合理 + if length > 10*1024*1024 { // 10MB + close(c.closeChan) + return + } + + // 读取包体 + packetData := make([]byte, length) + _, err = io.ReadFull(reader, packetData) + if err != nil { + close(c.closeChan) + return + } + + c.lastActive = time.Now() + + // 提交给工作池处理 + c.server.workerPool.Submit(func() { + c.processPacket(packetData) + }) + } +} + +// processPacket 处理数据包 +func (c *Connection) processPacket(data []byte) { + // 创建新的数据包实例 + packet := c.server.packetType() + + // 解码数据 + if err := packet.Decode(data); err != nil { + // 处理解码错误 + return + } + + // 调用用户处理函数 + response, err := c.server.handler(c.Conn, packet) + if err != nil { + // 处理错误 + return + } + + // 如果有响应,发送回去 + if response != nil { + c.Send(response) + } +} + +// Send 发送数据包 +func (c *Connection) Send(p Packet) error { + data, err := p.Encode() + if err != nil { + return err + } + + select { + case c.writeChan <- data: + return nil + default: + return errors.New("write channel full") + } +} + +// writeLoop 连接写循环 +func (c *Connection) writeLoop() { + // 批量写入优化 + var batch [][]byte + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case data := <-c.writeChan: + batch = append(batch, data) + // 尝试批量处理 + for len(c.writeChan) > 0 && len(batch) < 32 { + batch = append(batch, <-c.writeChan) + } + + // 合并写入 + c.writeBuffer.Reset() + for _, d := range batch { + // 写入长度头 + header := make([]byte, 4) + binary.BigEndian.PutUint32(header, uint32(len(d))) + c.writeBuffer.Write(header) + c.writeBuffer.Write(d) + } + + // 发送数据 + if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil { + close(c.closeChan) + return + } + batch = batch[:0] + + case <-ticker.C: + // 定时刷新 + if len(batch) > 0 { + c.writeBuffer.Reset() + for _, d := range batch { + header := make([]byte, 4) + binary.BigEndian.PutUint32(header, uint32(len(d))) + c.writeBuffer.Write(header) + c.writeBuffer.Write(d) + } + if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil { + close(c.closeChan) + return + } + batch = batch[:0] + } + + // 检查空闲超时 + if c.server.config.IdleTimeout > 0 && + time.Since(c.lastActive) > c.server.config.IdleTimeout { + close(c.closeChan) + return + } + + case <-c.closeChan: + return + } + } +} + +// ========== 辅助结构和方法 ========== + +// NewWorkerPool 创建工作池 +func NewWorkerPool(size, queueSize int) *WorkerPool { + pool := &WorkerPool{ + taskQueue: make(chan func(), queueSize), + size: size, + } + + for i := 0; i < size; i++ { + go pool.worker() + } + + return pool +} + +func (p *WorkerPool) worker() { + for task := range p.taskQueue { + task() + } +} + +func (p *WorkerPool) Submit(task func()) { + select { + case p.taskQueue <- task: + default: + // 任务队列满时的处理 + } +} + +func (p *WorkerPool) Stop() { + close(p.taskQueue) +} + +// udpConn UDP虚拟连接 +type udpConn struct { + net.PacketConn + addr net.Addr +} + +func (c *udpConn) Read(b []byte) (int, error) { + return 0, errors.New("not implemented") +} + +func (c *udpConn) Write(b []byte) (int, error) { + return c.PacketConn.WriteTo(b, c.addr) +} + +func (c *udpConn) Close() error { + return nil +} + +func (c *udpConn) LocalAddr() net.Addr { + return c.PacketConn.LocalAddr() +} + +func (c *udpConn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *udpConn) SetDeadline(t time.Time) error { + return c.PacketConn.SetDeadline(t) +} + +func (c *udpConn) SetReadDeadline(t time.Time) error { + return c.PacketConn.SetReadDeadline(t) +} + +func (c *udpConn) SetWriteDeadline(t time.Time) error { + return c.PacketConn.SetWriteDeadline(t) +} + +func (s *connectionShard) addConn(conn *Connection) { + s.lock.Lock() + defer s.lock.Unlock() + s.conns.Store(conn.Conn, conn) + s.lastUsed = time.Now() +} + +func (s *connectionShard) removeConn(conn *Connection) { + s.lock.Lock() + defer s.lock.Unlock() + s.conns.Delete(conn.Conn) + s.lastUsed = time.Now() +} + +// ========== 用户使用示例 ========== + +/* +// 用户自定义数据包 +type MyPacket struct { + ID uint32 + Payload string +} + +func (p *MyPacket) Encode() ([]byte, error) { + buf := bytes.NewBuffer(nil) + binary.Write(buf, binary.BigEndian, p.ID) + binary.Write(buf, binary.BigEndian, uint32(len(p.Payload))) + buf.WriteString(p.Payload) + return buf.Bytes(), nil +} + +func (p *MyPacket) Decode(data []byte) error { + buf := bytes.NewBuffer(data) + binary.Read(buf, binary.BigEndian, &p.ID) + var length uint32 + binary.Read(buf, binary.BigEndian, &length) + p.Payload = string(buf.Next(int(length))) + return nil +} + +// 用户自定义处理函数 +func myHandler(conn net.Conn, p Packet) (Packet, error) { + myPacket, ok := p.(*MyPacket) + if !ok { + return nil, errors.New("invalid packet type") + } + + // 处理业务逻辑 + response := &MyPacket{ + ID: myPacket.ID + 1, + Payload: "Response: " + myPacket.Payload, + } + + return response, nil +} + +func main() { + // 创建服务器配置 + config := network.ServerConfig{ + Network: "tcp", + Address: ":8080", + MaxConn: 10000, + WorkerNum: 100, + QueueSize: 10000, + ReadTimeout: 30 * time.Second, + IdleTimeout: 300 * time.Second, + } + + // 创建服务器实例 + server := network.NewServer(config, myHandler, func() network.Packet { + return &MyPacket{} + }) + + // 启动服务器 + if err := server.Start(); err != nil { + panic(err) + } + + // 等待关闭信号 + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + + // 优雅关闭 + server.Stop() +} +*/ diff --git a/server.go b/server.go new file mode 100644 index 0000000..1ac91ca --- /dev/null +++ b/server.go @@ -0,0 +1,306 @@ +package network + +import ( + "bytes" + "fmt" + "net" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// HighConcurrentServer 高并发服务器 +type HighConcurrentServer struct { + config ServerConfig + listener net.Listener + packetConn net.PacketConn + workerPool *WorkerPool + shards []*connectionShard + shardCount int + handler Handler + packetType func() Packet // 用于创建新Packet实例的函数 + activeConns int64 + shutdown chan struct{} + wg sync.WaitGroup +} + +// NewServer 创建新的高并发服务器 +func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer { + // 设置默认值 + if config.WorkerNum <= 0 { + config.WorkerNum = runtime.NumCPU() * 2 + } + if config.QueueSize <= 0 { + config.QueueSize = 1024 + } + if config.MaxConn <= 0 { + config.MaxConn = 100000 + } + + // 计算合适的分片数量 + shardCount := 32 + for shardCount < runtime.NumCPU() { + shardCount *= 2 + } + + server := &HighConcurrentServer{ + config: config, + shardCount: shardCount, + handler: handler, + packetType: packetType, + shutdown: make(chan struct{}), + shards: make([]*connectionShard, shardCount), + } + + // 初始化分片 + for i := 0; i < shardCount; i++ { + server.shards[i] = &connectionShard{} + } + + // 初始化工作池 + server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize) + + return server +} + +// Start 启动服务器 +func (s *HighConcurrentServer) Start() error { + switch s.config.Network { + case "tcp", "tcp4", "tcp6": + return s.startTCP() + case "udp", "udp4", "udp6": + return s.startUDP() + default: + return fmt.Errorf("unsupported network type: %s", s.config.Network) + } +} + +// startTCP 启动TCP服务器 +func (s *HighConcurrentServer) startTCP() error { + ln, err := net.Listen(s.config.Network, s.config.Address) + if err != nil { + return err + } + s.listener = ln + + s.wg.Add(1) + go s.acceptLoop() + + // 启动连接管理 + s.wg.Add(1) + go s.manageConnections() + + return nil +} + +// startUDP 启动UDP服务器 +func (s *HighConcurrentServer) startUDP() error { + pc, err := net.ListenPacket(s.config.Network, s.config.Address) + if err != nil { + return err + } + s.packetConn = pc + + s.wg.Add(1) + go s.udpReadLoop() + + return nil +} + +// acceptLoop TCP接受连接循环 +func (s *HighConcurrentServer) acceptLoop() { + defer s.wg.Done() + + for { + select { + case <-s.shutdown: + return + default: + } + + conn, err := s.listener.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + time.Sleep(100 * time.Millisecond) + continue + } + return + } + + // 检查最大连接数 + if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) { + conn.Close() + continue + } + + atomic.AddInt64(&s.activeConns, 1) + s.wg.Add(1) + go s.handleNewConnection(conn) + } +} + +// handleNewConnection 处理新连接 +func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) { + defer s.wg.Done() + defer atomic.AddInt64(&s.activeConns, -1) + defer conn.Close() + + // 获取分片 + shard := s.getShard(conn) + c := &Connection{ + Conn: conn, + readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), + writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), + writeChan: make(chan []byte, 32), + closeChan: make(chan struct{}), + server: s, + lastActive: time.Now(), + } + + // 添加到连接管理 + shard.addConn(c) + + // 启动读写协程 + go c.readLoop() + c.writeLoop() + + // 等待关闭 + <-c.closeChan + shard.removeConn(c) +} + +// getShard 获取连接对应的分片 +func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard { + // 使用连接的远程地址作为分片键 + addr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + return s.shards[0] + } + + // 简单哈希算法 + hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3] + return s.shards[int(hash)%s.shardCount] +} + +// udpReadLoop UDP读取循环 +func (s *HighConcurrentServer) udpReadLoop() { + defer s.wg.Done() + defer s.packetConn.Close() + + bufPool := sync.Pool{ + New: func() interface{} { return make([]byte, 65536) }, + } + + for { + select { + case <-s.shutdown: + return + default: + } + + buf := bufPool.Get().([]byte) + n, addr, err := s.packetConn.ReadFrom(buf) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + time.Sleep(100 * time.Millisecond) + continue + } + return + } + + s.wg.Add(1) + go func(data []byte, addr net.Addr) { + defer s.wg.Done() + defer bufPool.Put(data) + + // 处理UDP数据包 + packet := s.packetType() + if err := packet.Decode(data[:n]); err != nil { + return + } + + // 创建虚拟连接 + conn := &udpConn{PacketConn: s.packetConn, addr: addr} + + // 调用用户处理函数 + response, err := s.handler(conn, packet) + if err != nil { + return + } + + // 发送响应 + if response != nil { + respData, err := response.Encode() + if err == nil { + s.packetConn.WriteTo(respData, addr) + } + } + }(buf, addr) + } +} + +// manageConnections 管理连接(关闭空闲连接) +func (s *HighConcurrentServer) manageConnections() { + defer s.wg.Done() + + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.shutdown: + return + case <-ticker.C: + if s.config.IdleTimeout <= 0 { + continue + } + + now := time.Now() + for _, shard := range s.shards { + shard.lock.RLock() + shard.conns.Range(func(key, value interface{}) bool { + conn := value.(*Connection) + if now.Sub(conn.lastActive) > s.config.IdleTimeout { + select { + case conn.closeChan <- struct{}{}: + default: + } + } + return true + }) + shard.lock.RUnlock() + } + } + } +} + +// Stop 停止服务器 +func (s *HighConcurrentServer) Stop() { + close(s.shutdown) + + if s.listener != nil { + s.listener.Close() + } + + if s.packetConn != nil { + s.packetConn.Close() + } + + // 关闭所有连接 + for _, shard := range s.shards { + shard.lock.RLock() + shard.conns.Range(func(key, value interface{}) bool { + conn := value.(*Connection) + select { + case conn.closeChan <- struct{}{}: + default: + } + return true + }) + shard.lock.RUnlock() + } + + s.wg.Wait() + s.workerPool.Stop() +}