356 lines
9.3 KiB
Go
356 lines
9.3 KiB
Go
// Package network 提供了一个高性能的网络通信框架,支持TCP和UDP协议。
|
||
package network
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"net"
|
||
"runtime"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
)
|
||
|
||
// HighConcurrentServer 高并发服务器
|
||
// 实现了一个高性能、高并发的网络服务器,支持TCP和UDP协议
|
||
// 主要特点包括:
|
||
// - 连接分片管理,减少锁竞争
|
||
// - 工作协程池,高效处理请求
|
||
// - 自动管理连接生命周期
|
||
// - 优雅关闭机制
|
||
type HighConcurrentServer struct {
|
||
config ServerConfig // 服务器配置
|
||
listener net.Listener // TCP监听器
|
||
packetConn net.PacketConn // UDP包连接
|
||
workerPool *WorkerPool // 工作协程池
|
||
shards []*connectionShard // 连接分片数组
|
||
shardCount int // 分片数量
|
||
handler Handler // 用户自定义处理函数
|
||
packetType func() Packet // 用于创建新Packet实例的工厂函数
|
||
activeConns int64 // 当前活动连接数
|
||
shutdown chan struct{} // 关闭信号通道
|
||
wg sync.WaitGroup // 等待组,用于优雅关闭
|
||
}
|
||
|
||
// NewServer 创建新的高并发服务器
|
||
// config: 服务器配置
|
||
// handler: 用户自定义的数据包处理函数
|
||
// packetType: 用于创建新的数据包实例的工厂函数
|
||
// 返回初始化好的服务器实例,但尚未启动
|
||
func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer {
|
||
// 设置默认值
|
||
if config.WorkerNum <= 0 {
|
||
config.WorkerNum = runtime.NumCPU() * 2 // 默认为CPU核心数的2倍
|
||
}
|
||
if config.QueueSize <= 0 {
|
||
config.QueueSize = 1024 // 默认队列大小为1024
|
||
}
|
||
if config.MaxConn <= 0 {
|
||
config.MaxConn = 100000 // 默认最大连接数为10万
|
||
}
|
||
|
||
// 计算合适的分片数量,至少为32,且不小于CPU核心数
|
||
// 分片数量为2的幂,便于哈希分配
|
||
shardCount := 32
|
||
for shardCount < runtime.NumCPU() {
|
||
shardCount *= 2
|
||
}
|
||
|
||
server := &HighConcurrentServer{
|
||
config: config,
|
||
shardCount: shardCount,
|
||
handler: handler,
|
||
packetType: packetType,
|
||
shutdown: make(chan struct{}),
|
||
shards: make([]*connectionShard, shardCount),
|
||
}
|
||
|
||
// 初始化连接分片
|
||
for i := 0; i < shardCount; i++ {
|
||
server.shards[i] = &connectionShard{}
|
||
}
|
||
|
||
// 初始化工作协程池
|
||
server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize)
|
||
|
||
return server
|
||
}
|
||
|
||
// Start 启动服务器
|
||
// 根据配置的网络类型启动相应的服务器(TCP或UDP)
|
||
// 返回可能的启动错误
|
||
func (s *HighConcurrentServer) Start() error {
|
||
switch s.config.Network {
|
||
case "tcp", "tcp4", "tcp6":
|
||
return s.startTCP() // 启动TCP服务器
|
||
case "udp", "udp4", "udp6":
|
||
return s.startUDP() // 启动UDP服务器
|
||
default:
|
||
return fmt.Errorf("unsupported network type: %s", s.config.Network)
|
||
}
|
||
}
|
||
|
||
// startTCP 启动TCP服务器
|
||
// 创建监听器并启动接受连接循环和连接管理协程
|
||
// 返回可能的启动错误
|
||
func (s *HighConcurrentServer) startTCP() error {
|
||
// 创建TCP监听器
|
||
ln, err := net.Listen(s.config.Network, s.config.Address)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
s.listener = ln
|
||
|
||
// 启动接受连接循环
|
||
s.wg.Add(1)
|
||
go s.acceptLoop()
|
||
|
||
// 启动连接管理(清理空闲连接)
|
||
s.wg.Add(1)
|
||
go s.manageConnections()
|
||
|
||
return nil
|
||
}
|
||
|
||
// startUDP 启动UDP服务器
|
||
// 创建UDP包连接并启动读取循环
|
||
// 返回可能的启动错误
|
||
func (s *HighConcurrentServer) startUDP() error {
|
||
// 创建UDP包连接
|
||
pc, err := net.ListenPacket(s.config.Network, s.config.Address)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
s.packetConn = pc
|
||
|
||
// 启动UDP读取循环
|
||
s.wg.Add(1)
|
||
go s.udpReadLoop()
|
||
|
||
return nil
|
||
}
|
||
|
||
// acceptLoop TCP接受连接循环
|
||
// 持续接受新的TCP连接并为每个连接创建处理协程
|
||
// 当收到关闭信号或发生错误时退出
|
||
func (s *HighConcurrentServer) acceptLoop() {
|
||
defer s.wg.Done()
|
||
|
||
for {
|
||
// 检查是否收到关闭信号
|
||
select {
|
||
case <-s.shutdown:
|
||
return
|
||
default:
|
||
}
|
||
|
||
// 接受新连接
|
||
conn, err := s.listener.Accept()
|
||
if err != nil {
|
||
// 处理临时错误,如"too many open files"
|
||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||
time.Sleep(100 * time.Millisecond) // 短暂休眠后重试
|
||
continue
|
||
}
|
||
return // 非临时错误,退出循环
|
||
}
|
||
|
||
// 检查是否超过最大连接数限制
|
||
if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) {
|
||
conn.Close() // 超过限制,直接关闭连接
|
||
continue
|
||
}
|
||
|
||
// 增加活动连接计数并启动连接处理协程
|
||
atomic.AddInt64(&s.activeConns, 1)
|
||
s.wg.Add(1)
|
||
go s.handleNewConnection(conn)
|
||
}
|
||
}
|
||
|
||
// handleNewConnection 处理新连接
|
||
// conn: 新接受的TCP连接
|
||
// 为新连接创建Connection对象,启动读写循环,并管理连接的生命周期
|
||
func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) {
|
||
defer s.wg.Done()
|
||
defer atomic.AddInt64(&s.activeConns, -1) // 减少活动连接计数
|
||
defer conn.Close() // 确保连接关闭
|
||
|
||
// 获取连接对应的分片,用于减少锁竞争
|
||
shard := s.getShard(conn)
|
||
c := &Connection{
|
||
Conn: conn,
|
||
readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化读缓冲区,容量为4KB
|
||
writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)), // 初始化写缓冲区,容量为4KB
|
||
writeChan: make(chan []byte, 32), // 写入通道,缓冲大小为32
|
||
closeChan: make(chan struct{}), // 关闭信号通道
|
||
server: s,
|
||
lastActive: time.Now(), // 记录初始活动时间
|
||
}
|
||
|
||
// 将连接添加到分片管理中
|
||
shard.addConn(c)
|
||
|
||
// 启动读写协程
|
||
go c.readLoop() // 读取循环在单独的协程中运行
|
||
c.writeLoop() // 写入循环在当前协程中运行
|
||
|
||
// 等待关闭信号
|
||
<-c.closeChan
|
||
shard.removeConn(c) // 从分片中移除连接
|
||
}
|
||
|
||
// getShard 获取连接对应的分片
|
||
// conn: 网络连接
|
||
// 返回该连接应该被分配到的分片
|
||
// 使用连接的远程IP地址进行哈希来确定分片,以实现负载均衡
|
||
func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard {
|
||
// 尝试获取TCP地址
|
||
addr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||
if !ok {
|
||
return s.shards[0] // 非TCP连接返回第一个分片
|
||
}
|
||
|
||
// 使用IP地址的四个字节进行简单哈希
|
||
// 将IP地址的四个字节相加作为哈希值
|
||
hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3]
|
||
return s.shards[int(hash)%s.shardCount] // 使用取模运算确定分片索引
|
||
}
|
||
|
||
// udpReadLoop UDP读取循环
|
||
// 持续从UDP连接读取数据包并处理
|
||
// 使用sync.Pool复用缓冲区以减少内存分配
|
||
func (s *HighConcurrentServer) udpReadLoop() {
|
||
defer s.wg.Done()
|
||
defer s.packetConn.Close()
|
||
|
||
// 创建缓冲区池,每个缓冲区大小为64KB(UDP最大包大小)
|
||
bufPool := sync.Pool{
|
||
New: func() interface{} { return make([]byte, 65536) },
|
||
}
|
||
|
||
for {
|
||
// 检查是否收到关闭信号
|
||
select {
|
||
case <-s.shutdown:
|
||
return
|
||
default:
|
||
}
|
||
|
||
// 从池中获取缓冲区
|
||
buf := bufPool.Get().([]byte)
|
||
n, addr, err := s.packetConn.ReadFrom(buf)
|
||
if err != nil {
|
||
// 处理临时错误
|
||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||
time.Sleep(100 * time.Millisecond)
|
||
continue
|
||
}
|
||
return
|
||
}
|
||
|
||
// 启动新的协程处理数据包
|
||
s.wg.Add(1)
|
||
go func(data []byte, addr net.Addr) {
|
||
defer s.wg.Done()
|
||
defer bufPool.Put(data) // 将缓冲区放回池中
|
||
|
||
// 解码数据包
|
||
packet := s.packetType()
|
||
if err := packet.Decode(data[:n]); err != nil {
|
||
return
|
||
}
|
||
|
||
// 创建虚拟UDP连接
|
||
conn := &udpConn{PacketConn: s.packetConn, addr: addr}
|
||
|
||
// 调用用户处理函数
|
||
response, err := s.handler(conn, packet)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 发送响应(如果有)
|
||
if response != nil {
|
||
respData, err := response.Encode()
|
||
if err == nil {
|
||
s.packetConn.WriteTo(respData, addr)
|
||
}
|
||
}
|
||
}(buf, addr)
|
||
}
|
||
}
|
||
|
||
// manageConnections 管理连接(关闭空闲连接)
|
||
// 定期检查并关闭超过空闲超时时间的连接
|
||
// 当服务器关闭时退出
|
||
func (s *HighConcurrentServer) manageConnections() {
|
||
defer s.wg.Done()
|
||
|
||
// 创建定时器,每30秒检查一次
|
||
ticker := time.NewTicker(30 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-s.shutdown:
|
||
return
|
||
case <-ticker.C:
|
||
// 如果未设置空闲超时,跳过检查
|
||
if s.config.IdleTimeout <= 0 {
|
||
continue
|
||
}
|
||
|
||
now := time.Now()
|
||
// 遍历所有分片
|
||
for _, shard := range s.shards {
|
||
shard.lock.RLock()
|
||
// 遍历分片中的所有连接
|
||
shard.conns.Range(func(key, value interface{}) bool {
|
||
conn := value.(*Connection)
|
||
// 检查连接是否超过空闲超时时间
|
||
if now.Sub(conn.lastActive) > s.config.IdleTimeout {
|
||
// 发送关闭信号
|
||
select {
|
||
case conn.closeChan <- struct{}{}:
|
||
default:
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
shard.lock.RUnlock()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Stop 停止服务器
|
||
func (s *HighConcurrentServer) Stop() {
|
||
close(s.shutdown)
|
||
|
||
if s.listener != nil {
|
||
s.listener.Close()
|
||
}
|
||
|
||
if s.packetConn != nil {
|
||
s.packetConn.Close()
|
||
}
|
||
|
||
// 关闭所有连接
|
||
for _, shard := range s.shards {
|
||
shard.lock.RLock()
|
||
shard.conns.Range(func(key, value interface{}) bool {
|
||
conn := value.(*Connection)
|
||
select {
|
||
case conn.closeChan <- struct{}{}:
|
||
default:
|
||
}
|
||
return true
|
||
})
|
||
shard.lock.RUnlock()
|
||
}
|
||
|
||
s.wg.Wait()
|
||
s.workerPool.Stop()
|
||
}
|