network/server.go

307 lines
6.0 KiB
Go

package network
import (
"bytes"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
// HighConcurrentServer 高并发服务器
type HighConcurrentServer struct {
config ServerConfig
listener net.Listener
packetConn net.PacketConn
workerPool *WorkerPool
shards []*connectionShard
shardCount int
handler Handler
packetType func() Packet // 用于创建新Packet实例的函数
activeConns int64
shutdown chan struct{}
wg sync.WaitGroup
}
// NewServer 创建新的高并发服务器
func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer {
// 设置默认值
if config.WorkerNum <= 0 {
config.WorkerNum = runtime.NumCPU() * 2
}
if config.QueueSize <= 0 {
config.QueueSize = 1024
}
if config.MaxConn <= 0 {
config.MaxConn = 100000
}
// 计算合适的分片数量
shardCount := 32
for shardCount < runtime.NumCPU() {
shardCount *= 2
}
server := &HighConcurrentServer{
config: config,
shardCount: shardCount,
handler: handler,
packetType: packetType,
shutdown: make(chan struct{}),
shards: make([]*connectionShard, shardCount),
}
// 初始化分片
for i := 0; i < shardCount; i++ {
server.shards[i] = &connectionShard{}
}
// 初始化工作池
server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize)
return server
}
// Start 启动服务器
func (s *HighConcurrentServer) Start() error {
switch s.config.Network {
case "tcp", "tcp4", "tcp6":
return s.startTCP()
case "udp", "udp4", "udp6":
return s.startUDP()
default:
return fmt.Errorf("unsupported network type: %s", s.config.Network)
}
}
// startTCP 启动TCP服务器
func (s *HighConcurrentServer) startTCP() error {
ln, err := net.Listen(s.config.Network, s.config.Address)
if err != nil {
return err
}
s.listener = ln
s.wg.Add(1)
go s.acceptLoop()
// 启动连接管理
s.wg.Add(1)
go s.manageConnections()
return nil
}
// startUDP 启动UDP服务器
func (s *HighConcurrentServer) startUDP() error {
pc, err := net.ListenPacket(s.config.Network, s.config.Address)
if err != nil {
return err
}
s.packetConn = pc
s.wg.Add(1)
go s.udpReadLoop()
return nil
}
// acceptLoop TCP接受连接循环
func (s *HighConcurrentServer) acceptLoop() {
defer s.wg.Done()
for {
select {
case <-s.shutdown:
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
time.Sleep(100 * time.Millisecond)
continue
}
return
}
// 检查最大连接数
if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) {
conn.Close()
continue
}
atomic.AddInt64(&s.activeConns, 1)
s.wg.Add(1)
go s.handleNewConnection(conn)
}
}
// handleNewConnection 处理新连接
func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) {
defer s.wg.Done()
defer atomic.AddInt64(&s.activeConns, -1)
defer conn.Close()
// 获取分片
shard := s.getShard(conn)
c := &Connection{
Conn: conn,
readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)),
writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)),
writeChan: make(chan []byte, 32),
closeChan: make(chan struct{}),
server: s,
lastActive: time.Now(),
}
// 添加到连接管理
shard.addConn(c)
// 启动读写协程
go c.readLoop()
c.writeLoop()
// 等待关闭
<-c.closeChan
shard.removeConn(c)
}
// getShard 获取连接对应的分片
func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard {
// 使用连接的远程地址作为分片键
addr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
return s.shards[0]
}
// 简单哈希算法
hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3]
return s.shards[int(hash)%s.shardCount]
}
// udpReadLoop UDP读取循环
func (s *HighConcurrentServer) udpReadLoop() {
defer s.wg.Done()
defer s.packetConn.Close()
bufPool := sync.Pool{
New: func() interface{} { return make([]byte, 65536) },
}
for {
select {
case <-s.shutdown:
return
default:
}
buf := bufPool.Get().([]byte)
n, addr, err := s.packetConn.ReadFrom(buf)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
time.Sleep(100 * time.Millisecond)
continue
}
return
}
s.wg.Add(1)
go func(data []byte, addr net.Addr) {
defer s.wg.Done()
defer bufPool.Put(data)
// 处理UDP数据包
packet := s.packetType()
if err := packet.Decode(data[:n]); err != nil {
return
}
// 创建虚拟连接
conn := &udpConn{PacketConn: s.packetConn, addr: addr}
// 调用用户处理函数
response, err := s.handler(conn, packet)
if err != nil {
return
}
// 发送响应
if response != nil {
respData, err := response.Encode()
if err == nil {
s.packetConn.WriteTo(respData, addr)
}
}
}(buf, addr)
}
}
// manageConnections 管理连接(关闭空闲连接)
func (s *HighConcurrentServer) manageConnections() {
defer s.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.shutdown:
return
case <-ticker.C:
if s.config.IdleTimeout <= 0 {
continue
}
now := time.Now()
for _, shard := range s.shards {
shard.lock.RLock()
shard.conns.Range(func(key, value interface{}) bool {
conn := value.(*Connection)
if now.Sub(conn.lastActive) > s.config.IdleTimeout {
select {
case conn.closeChan <- struct{}{}:
default:
}
}
return true
})
shard.lock.RUnlock()
}
}
}
}
// Stop 停止服务器
func (s *HighConcurrentServer) Stop() {
close(s.shutdown)
if s.listener != nil {
s.listener.Close()
}
if s.packetConn != nil {
s.packetConn.Close()
}
// 关闭所有连接
for _, shard := range s.shards {
shard.lock.RLock()
shard.conns.Range(func(key, value interface{}) bool {
conn := value.(*Connection)
select {
case conn.closeChan <- struct{}{}:
default:
}
return true
})
shard.lock.RUnlock()
}
s.wg.Wait()
s.workerPool.Stop()
}