新增高并发网络框架实现,支持TCP/UDP协议和自定义数据包处理
This commit is contained in:
parent
03361a62a2
commit
377e51bb40
|
|
@ -0,0 +1,379 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ========== 框架核心接口 ==========
|
||||||
|
|
||||||
|
// Packet 用户自定义数据包必须实现的接口
|
||||||
|
type Packet interface {
|
||||||
|
// 编码为字节流
|
||||||
|
Encode() ([]byte, error)
|
||||||
|
// 从字节流解码
|
||||||
|
Decode([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler 用户自定义的数据处理函数类型
|
||||||
|
type Handler func(conn net.Conn, p Packet) (Packet, error)
|
||||||
|
|
||||||
|
// ========== 框架核心结构 ==========
|
||||||
|
|
||||||
|
// ServerConfig 服务器配置
|
||||||
|
type ServerConfig struct {
|
||||||
|
Network string // tcp, tcp4, tcp6, udp, udp4, udp6
|
||||||
|
Address string // 监听地址
|
||||||
|
MaxConn int // 最大连接数 (TCP only)
|
||||||
|
WorkerNum int // 工作协程数量
|
||||||
|
QueueSize int // 任务队列大小
|
||||||
|
ReadTimeout time.Duration // 读取超时
|
||||||
|
IdleTimeout time.Duration // 空闲超时
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionShard 连接分片
|
||||||
|
type connectionShard struct {
|
||||||
|
conns sync.Map // map[net.Conn]*Connection
|
||||||
|
lock sync.RWMutex
|
||||||
|
lastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection 连接封装
|
||||||
|
type Connection struct {
|
||||||
|
net.Conn
|
||||||
|
readBuffer *bytes.Buffer
|
||||||
|
writeBuffer *bytes.Buffer
|
||||||
|
writeChan chan []byte
|
||||||
|
closeChan chan struct{}
|
||||||
|
server *HighConcurrentServer
|
||||||
|
lastActive time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerPool 工作协程池
|
||||||
|
type WorkerPool struct {
|
||||||
|
taskQueue chan func()
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== 框架实现 ==========
|
||||||
|
|
||||||
|
// readLoop 连接读循环
|
||||||
|
func (c *Connection) readLoop() {
|
||||||
|
reader := bufio.NewReader(c.Conn)
|
||||||
|
var header [4]byte
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.closeChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置读超时
|
||||||
|
if c.server.config.ReadTimeout > 0 {
|
||||||
|
c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取包头 (4字节长度)
|
||||||
|
_, err := io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
// 处理错误
|
||||||
|
}
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析包长度
|
||||||
|
length := binary.BigEndian.Uint32(header[:])
|
||||||
|
|
||||||
|
// 检查包长度是否合理
|
||||||
|
if length > 10*1024*1024 { // 10MB
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取包体
|
||||||
|
packetData := make([]byte, length)
|
||||||
|
_, err = io.ReadFull(reader, packetData)
|
||||||
|
if err != nil {
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lastActive = time.Now()
|
||||||
|
|
||||||
|
// 提交给工作池处理
|
||||||
|
c.server.workerPool.Submit(func() {
|
||||||
|
c.processPacket(packetData)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPacket 处理数据包
|
||||||
|
func (c *Connection) processPacket(data []byte) {
|
||||||
|
// 创建新的数据包实例
|
||||||
|
packet := c.server.packetType()
|
||||||
|
|
||||||
|
// 解码数据
|
||||||
|
if err := packet.Decode(data); err != nil {
|
||||||
|
// 处理解码错误
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用用户处理函数
|
||||||
|
response, err := c.server.handler(c.Conn, packet)
|
||||||
|
if err != nil {
|
||||||
|
// 处理错误
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有响应,发送回去
|
||||||
|
if response != nil {
|
||||||
|
c.Send(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 发送数据包
|
||||||
|
func (c *Connection) Send(p Packet) error {
|
||||||
|
data, err := p.Encode()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.writeChan <- data:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("write channel full")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeLoop 连接写循环
|
||||||
|
func (c *Connection) writeLoop() {
|
||||||
|
// 批量写入优化
|
||||||
|
var batch [][]byte
|
||||||
|
ticker := time.NewTicker(10 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case data := <-c.writeChan:
|
||||||
|
batch = append(batch, data)
|
||||||
|
// 尝试批量处理
|
||||||
|
for len(c.writeChan) > 0 && len(batch) < 32 {
|
||||||
|
batch = append(batch, <-c.writeChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并写入
|
||||||
|
c.writeBuffer.Reset()
|
||||||
|
for _, d := range batch {
|
||||||
|
// 写入长度头
|
||||||
|
header := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(header, uint32(len(d)))
|
||||||
|
c.writeBuffer.Write(header)
|
||||||
|
c.writeBuffer.Write(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送数据
|
||||||
|
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch = batch[:0]
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
// 定时刷新
|
||||||
|
if len(batch) > 0 {
|
||||||
|
c.writeBuffer.Reset()
|
||||||
|
for _, d := range batch {
|
||||||
|
header := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(header, uint32(len(d)))
|
||||||
|
c.writeBuffer.Write(header)
|
||||||
|
c.writeBuffer.Write(d)
|
||||||
|
}
|
||||||
|
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch = batch[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查空闲超时
|
||||||
|
if c.server.config.IdleTimeout > 0 &&
|
||||||
|
time.Since(c.lastActive) > c.server.config.IdleTimeout {
|
||||||
|
close(c.closeChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-c.closeChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== 辅助结构和方法 ==========
|
||||||
|
|
||||||
|
// NewWorkerPool 创建工作池
|
||||||
|
func NewWorkerPool(size, queueSize int) *WorkerPool {
|
||||||
|
pool := &WorkerPool{
|
||||||
|
taskQueue: make(chan func(), queueSize),
|
||||||
|
size: size,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
go pool.worker()
|
||||||
|
}
|
||||||
|
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WorkerPool) worker() {
|
||||||
|
for task := range p.taskQueue {
|
||||||
|
task()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WorkerPool) Submit(task func()) {
|
||||||
|
select {
|
||||||
|
case p.taskQueue <- task:
|
||||||
|
default:
|
||||||
|
// 任务队列满时的处理
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WorkerPool) Stop() {
|
||||||
|
close(p.taskQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpConn UDP虚拟连接
|
||||||
|
type udpConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Read(b []byte) (int, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Write(b []byte) (int, error) {
|
||||||
|
return c.PacketConn.WriteTo(b, c.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) LocalAddr() net.Addr {
|
||||||
|
return c.PacketConn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) RemoteAddr() net.Addr {
|
||||||
|
return c.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||||
|
return c.PacketConn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.PacketConn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.PacketConn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *connectionShard) addConn(conn *Connection) {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
s.conns.Store(conn.Conn, conn)
|
||||||
|
s.lastUsed = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *connectionShard) removeConn(conn *Connection) {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
s.conns.Delete(conn.Conn)
|
||||||
|
s.lastUsed = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== 用户使用示例 ==========
|
||||||
|
|
||||||
|
/*
|
||||||
|
// 用户自定义数据包
|
||||||
|
type MyPacket struct {
|
||||||
|
ID uint32
|
||||||
|
Payload string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MyPacket) Encode() ([]byte, error) {
|
||||||
|
buf := bytes.NewBuffer(nil)
|
||||||
|
binary.Write(buf, binary.BigEndian, p.ID)
|
||||||
|
binary.Write(buf, binary.BigEndian, uint32(len(p.Payload)))
|
||||||
|
buf.WriteString(p.Payload)
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MyPacket) Decode(data []byte) error {
|
||||||
|
buf := bytes.NewBuffer(data)
|
||||||
|
binary.Read(buf, binary.BigEndian, &p.ID)
|
||||||
|
var length uint32
|
||||||
|
binary.Read(buf, binary.BigEndian, &length)
|
||||||
|
p.Payload = string(buf.Next(int(length)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户自定义处理函数
|
||||||
|
func myHandler(conn net.Conn, p Packet) (Packet, error) {
|
||||||
|
myPacket, ok := p.(*MyPacket)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid packet type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理业务逻辑
|
||||||
|
response := &MyPacket{
|
||||||
|
ID: myPacket.ID + 1,
|
||||||
|
Payload: "Response: " + myPacket.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建服务器配置
|
||||||
|
config := network.ServerConfig{
|
||||||
|
Network: "tcp",
|
||||||
|
Address: ":8080",
|
||||||
|
MaxConn: 10000,
|
||||||
|
WorkerNum: 100,
|
||||||
|
QueueSize: 10000,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 300 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建服务器实例
|
||||||
|
server := network.NewServer(config, myHandler, func() network.Packet {
|
||||||
|
return &MyPacket{}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 启动服务器
|
||||||
|
if err := server.Start(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待关闭信号
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-sig
|
||||||
|
|
||||||
|
// 优雅关闭
|
||||||
|
server.Stop()
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
@ -0,0 +1,306 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HighConcurrentServer 高并发服务器
|
||||||
|
type HighConcurrentServer struct {
|
||||||
|
config ServerConfig
|
||||||
|
listener net.Listener
|
||||||
|
packetConn net.PacketConn
|
||||||
|
workerPool *WorkerPool
|
||||||
|
shards []*connectionShard
|
||||||
|
shardCount int
|
||||||
|
handler Handler
|
||||||
|
packetType func() Packet // 用于创建新Packet实例的函数
|
||||||
|
activeConns int64
|
||||||
|
shutdown chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer 创建新的高并发服务器
|
||||||
|
func NewServer(config ServerConfig, handler Handler, packetType func() Packet) *HighConcurrentServer {
|
||||||
|
// 设置默认值
|
||||||
|
if config.WorkerNum <= 0 {
|
||||||
|
config.WorkerNum = runtime.NumCPU() * 2
|
||||||
|
}
|
||||||
|
if config.QueueSize <= 0 {
|
||||||
|
config.QueueSize = 1024
|
||||||
|
}
|
||||||
|
if config.MaxConn <= 0 {
|
||||||
|
config.MaxConn = 100000
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算合适的分片数量
|
||||||
|
shardCount := 32
|
||||||
|
for shardCount < runtime.NumCPU() {
|
||||||
|
shardCount *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &HighConcurrentServer{
|
||||||
|
config: config,
|
||||||
|
shardCount: shardCount,
|
||||||
|
handler: handler,
|
||||||
|
packetType: packetType,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
shards: make([]*connectionShard, shardCount),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化分片
|
||||||
|
for i := 0; i < shardCount; i++ {
|
||||||
|
server.shards[i] = &connectionShard{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化工作池
|
||||||
|
server.workerPool = NewWorkerPool(config.WorkerNum, config.QueueSize)
|
||||||
|
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动服务器
|
||||||
|
func (s *HighConcurrentServer) Start() error {
|
||||||
|
switch s.config.Network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
return s.startTCP()
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
return s.startUDP()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported network type: %s", s.config.Network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startTCP 启动TCP服务器
|
||||||
|
func (s *HighConcurrentServer) startTCP() error {
|
||||||
|
ln, err := net.Listen(s.config.Network, s.config.Address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listener = ln
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.acceptLoop()
|
||||||
|
|
||||||
|
// 启动连接管理
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.manageConnections()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startUDP 启动UDP服务器
|
||||||
|
func (s *HighConcurrentServer) startUDP() error {
|
||||||
|
pc, err := net.ListenPacket(s.config.Network, s.config.Address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.packetConn = pc
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.udpReadLoop()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptLoop TCP接受连接循环
|
||||||
|
func (s *HighConcurrentServer) acceptLoop() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查最大连接数
|
||||||
|
if atomic.LoadInt64(&s.activeConns) >= int64(s.config.MaxConn) {
|
||||||
|
conn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt64(&s.activeConns, 1)
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.handleNewConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleNewConnection 处理新连接
|
||||||
|
func (s *HighConcurrentServer) handleNewConnection(conn net.Conn) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer atomic.AddInt64(&s.activeConns, -1)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// 获取分片
|
||||||
|
shard := s.getShard(conn)
|
||||||
|
c := &Connection{
|
||||||
|
Conn: conn,
|
||||||
|
readBuffer: bytes.NewBuffer(make([]byte, 0, 4096)),
|
||||||
|
writeBuffer: bytes.NewBuffer(make([]byte, 0, 4096)),
|
||||||
|
writeChan: make(chan []byte, 32),
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
server: s,
|
||||||
|
lastActive: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到连接管理
|
||||||
|
shard.addConn(c)
|
||||||
|
|
||||||
|
// 启动读写协程
|
||||||
|
go c.readLoop()
|
||||||
|
c.writeLoop()
|
||||||
|
|
||||||
|
// 等待关闭
|
||||||
|
<-c.closeChan
|
||||||
|
shard.removeConn(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getShard 获取连接对应的分片
|
||||||
|
func (s *HighConcurrentServer) getShard(conn net.Conn) *connectionShard {
|
||||||
|
// 使用连接的远程地址作为分片键
|
||||||
|
addr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
return s.shards[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单哈希算法
|
||||||
|
hash := addr.IP.To4()[0] + addr.IP.To4()[1] + addr.IP.To4()[2] + addr.IP.To4()[3]
|
||||||
|
return s.shards[int(hash)%s.shardCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpReadLoop UDP读取循环
|
||||||
|
func (s *HighConcurrentServer) udpReadLoop() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer s.packetConn.Close()
|
||||||
|
|
||||||
|
bufPool := sync.Pool{
|
||||||
|
New: func() interface{} { return make([]byte, 65536) },
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bufPool.Get().([]byte)
|
||||||
|
n, addr, err := s.packetConn.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func(data []byte, addr net.Addr) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer bufPool.Put(data)
|
||||||
|
|
||||||
|
// 处理UDP数据包
|
||||||
|
packet := s.packetType()
|
||||||
|
if err := packet.Decode(data[:n]); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建虚拟连接
|
||||||
|
conn := &udpConn{PacketConn: s.packetConn, addr: addr}
|
||||||
|
|
||||||
|
// 调用用户处理函数
|
||||||
|
response, err := s.handler(conn, packet)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送响应
|
||||||
|
if response != nil {
|
||||||
|
respData, err := response.Encode()
|
||||||
|
if err == nil {
|
||||||
|
s.packetConn.WriteTo(respData, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(buf, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// manageConnections 管理连接(关闭空闲连接)
|
||||||
|
func (s *HighConcurrentServer) manageConnections() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if s.config.IdleTimeout <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, shard := range s.shards {
|
||||||
|
shard.lock.RLock()
|
||||||
|
shard.conns.Range(func(key, value interface{}) bool {
|
||||||
|
conn := value.(*Connection)
|
||||||
|
if now.Sub(conn.lastActive) > s.config.IdleTimeout {
|
||||||
|
select {
|
||||||
|
case conn.closeChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
shard.lock.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止服务器
|
||||||
|
func (s *HighConcurrentServer) Stop() {
|
||||||
|
close(s.shutdown)
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
s.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.packetConn != nil {
|
||||||
|
s.packetConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭所有连接
|
||||||
|
for _, shard := range s.shards {
|
||||||
|
shard.lock.RLock()
|
||||||
|
shard.conns.Range(func(key, value interface{}) bool {
|
||||||
|
conn := value.(*Connection)
|
||||||
|
select {
|
||||||
|
case conn.closeChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
shard.lock.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.wg.Wait()
|
||||||
|
s.workerPool.Stop()
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue