network/server.go

356 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package network 提供了一个高性能的网络通信框架支持TCP和UDP协议。
package network
import (
"bytes"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
// HighConcurrentServer 高并发服务器
// 实现了一个高性能、高并发的网络服务器支持TCP和UDP协议
// 主要特点包括:
// - 连接分片管理,减少锁竞争
// - 工作协程池,高效处理请求
// - 自动管理连接生命周期
// - 优雅关闭机制
type HighConcurrentServer struct {
config ServerConfig // 服务器配置
listener net.Listener // TCP监听器
packetConn net.PacketConn // UDP包连接
workerPool *WorkerPool // 工作协程池
shards []*connectionShard // 连接分片数组
shardCount int // 分片数量
handler Handler // 用户自定义处理函数
packetType func() Packet // 用于创建新Packet实例的工厂函数
activeConns int64 // 当前活动连接数
shutdown chan struct{} // 关闭信号通道
wg sync.WaitGroup // 等待组,用于优雅关闭
}
// NewServer 创建新的高并发服务器
// config: 服务器配置
// handler: 用户自定义的数据包处理函数
// packetType: 用于创建新的数据包实例的工厂函数
// 返回初始化好的服务器实例,但尚未启动
func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer {
// 设置默认值
if config.WorkerNum <= 0 {
config.WorkerNum = runtime.NumCPU() * 2 // 默认为CPU核心数的2倍
}
if config.QueueSize <= 0 {
config.QueueSize = 1024 // 默认队列大小为1024
}
if config.MaxConn <= 0 {
config.MaxConn = 100000 // 默认最大连接数为10万
}
// 计算合适的分片数量至少为32且不小于CPU核心数
// 分片数量为2的幂便于哈希分配
shardCount := 32
for shardCount < runtime.NumCPU() {
shardCount *= 2
}
server := &HighConcurrentServer{
config: config,
shardCount: shardCount,
handler: handler,
packetType: packetType,
shutdown: make(chan struct{}),
shards: make([]*connectionShard, shardCount),
}
// 初始化连接分片
for i := 0; i < shardCount; i++ {
server.shards[i] = &connectionShard{}
}
// 初始化工作协程池
server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize)
return server
}
// Start 启动服务器
// 根据配置的网络类型启动相应的服务器TCP或UDP
// 返回可能的启动错误
func (s *HighConcurrentServer) Start() error {
switch s.config.Network {
case "tcp", "tcp4", "tcp6":
return s.startTCP() // 启动TCP服务器
case "udp", "udp4", "udp6":
return s.startUDP() // 启动UDP服务器
default:
return fmt.Errorf("unsupported network type: %s", s.config.Network)
}
}
// startTCP 启动TCP服务器
// 创建监听器并启动接受连接循环和连接管理协程
// 返回可能的启动错误
func (s *HighConcurrentServer) startTCP() error {
// 创建TCP监听器
ln, err := net.Listen(s.config.Network, s.config.Address)
if err != nil {
return err
}
s.listener = ln
// 启动接受连接循环
s.wg.Add(1)
go s.acceptLoop()
// 启动连接管理(清理空闲连接)
s.wg.Add(1)
go s.manageConnections()
return nil
}
// startUDP 启动UDP服务器
// 创建UDP包连接并启动读取循环
// 返回可能的启动错误
func (s *HighConcurrentServer) startUDP() error {
// 创建UDP包连接
pc, err := net.ListenPacket(s.config.Network, s.config.Address)
if err != nil {
return err
}
s.packetConn = pc
// 启动UDP读取循环
s.wg.Add(1)
go s.udpReadLoop()
return nil
}
// acceptLoop TCP接受连接循环
// 持续接受新的TCP连接并为每个连接创建处理协程
// 当收到关闭信号或发生错误时退出
func (s *HighConcurrentServer) acceptLoop() {
defer s.wg.Done()
for {
// 检查是否收到关闭信号
select {
case <-s.shutdown:
return
default:
}
// 接受新连接
conn, err := s.listener.Accept()
if err != nil {
// 处理临时错误,如"too many open files"
if ne, ok := err.(net.Error); ok && ne.Temporary() {
time.Sleep(100 * time.Millisecond) // 短暂休眠后重试
continue
}
return // 非临时错误,退出循环
}
// 检查是否超过最大连接数限制
if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) {
conn.Close() // 超过限制,直接关闭连接
continue
}
// 增加活动连接计数并启动连接处理协程
atomic.AddInt64(&s.activeConns, 1)
s.wg.Add(1)
go s.handleNewConnection(conn)
}
}
// handleNewConnection 处理新连接
// conn: 新接受的TCP连接
// 为新连接创建Connection对象启动读写循环并管理连接的生命周期
func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) {
defer s.wg.Done()
defer atomic.AddInt64(&s.activeConns, -1) // 减少活动连接计数
defer conn.Close() // 确保连接关闭
// 获取连接对应的分片,用于减少锁竞争
shard := s.getShard(conn)
c := &Connection{
Conn: conn,
readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化读缓冲区容量为4KB
writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化写缓冲区容量为4KB
writeChan: make(chan []byte, 32), // 写入通道缓冲大小为32
closeChan: make(chan struct{}), // 关闭信号通道
server: s,
lastActive: time.Now(), // 记录初始活动时间
}
// 将连接添加到分片管理中
shard.addConn(c)
// 启动读写协程
go c.readLoop() // 读取循环在单独的协程中运行
c.writeLoop() // 写入循环在当前协程中运行
// 等待关闭信号
<-c.closeChan
shard.removeConn(c) // 从分片中移除连接
}
// getShard 获取连接对应的分片
// conn: 网络连接
// 返回该连接应该被分配到的分片
// 使用连接的远程IP地址进行哈希来确定分片以实现负载均衡
func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard {
// 尝试获取TCP地址
addr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
return s.shards[0] // 非TCP连接返回第一个分片
}
// 使用IP地址的四个字节进行简单哈希
// 将IP地址的四个字节相加作为哈希值
hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3]
return s.shards[int(hash)%s.shardCount] // 使用取模运算确定分片索引
}
// udpReadLoop UDP读取循环
// 持续从UDP连接读取数据包并处理
// 使用sync.Pool复用缓冲区以减少内存分配
func (s *HighConcurrentServer) udpReadLoop() {
defer s.wg.Done()
defer s.packetConn.Close()
// 创建缓冲区池每个缓冲区大小为64KBUDP最大包大小
bufPool := sync.Pool{
New: func() interface{} { return make([]byte, 65536) },
}
for {
// 检查是否收到关闭信号
select {
case <-s.shutdown:
return
default:
}
// 从池中获取缓冲区
buf := bufPool.Get().([]byte)
n, addr, err := s.packetConn.ReadFrom(buf)
if err != nil {
// 处理临时错误
if ne, ok := err.(net.Error); ok && ne.Temporary() {
time.Sleep(100 * time.Millisecond)
continue
}
return
}
// 启动新的协程处理数据包
s.wg.Add(1)
go func(data []byte, addr net.Addr) {
defer s.wg.Done()
defer bufPool.Put(data) // 将缓冲区放回池中
// 解码数据包
packet := s.packetType()
if err := packet.Decode(data[:n]); err != nil {
return
}
// 创建虚拟UDP连接
conn := &udpConn{PacketConn: s.packetConn, addr: addr}
// 调用用户处理函数
response, err := s.handler(conn, packet)
if err != nil {
return
}
// 发送响应(如果有)
if response != nil {
respData, err := response.Encode()
if err == nil {
s.packetConn.WriteTo(respData, addr)
}
}
}(buf, addr)
}
}
// manageConnections 管理连接(关闭空闲连接)
// 定期检查并关闭超过空闲超时时间的连接
// 当服务器关闭时退出
func (s *HighConcurrentServer) manageConnections() {
defer s.wg.Done()
// 创建定时器每30秒检查一次
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.shutdown:
return
case <-ticker.C:
// 如果未设置空闲超时,跳过检查
if s.config.IdleTimeout <= 0 {
continue
}
now := time.Now()
// 遍历所有分片
for _, shard := range s.shards {
shard.lock.RLock()
// 遍历分片中的所有连接
shard.conns.Range(func(key, value interface{}) bool {
conn := value.(*Connection)
// 检查连接是否超过空闲超时时间
if now.Sub(conn.lastActive) > s.config.IdleTimeout {
// 发送关闭信号
select {
case conn.closeChan <- struct{}{}:
default:
}
}
return true
})
shard.lock.RUnlock()
}
}
}
}
// Stop 停止服务器
func (s *HighConcurrentServer) Stop() {
close(s.shutdown)
if s.listener != nil {
s.listener.Close()
}
if s.packetConn != nil {
s.packetConn.Close()
}
// 关闭所有连接
for _, shard := range s.shards {
shard.lock.RLock()
shard.conns.Range(func(key, value interface{}) bool {
conn := value.(*Connection)
select {
case conn.closeChan <- struct{}{}:
default:
}
return true
})
shard.lock.RUnlock()
}
s.wg.Wait()
s.workerPool.Stop()
}