395 lines
9.0 KiB
Go
395 lines
9.0 KiB
Go
// Package network 提供了一个高性能的网络通信框架,支持TCP和UDP协议。
|
||
// 该框架采用工作池和连接分片管理来实现高并发,支持自定义数据包格式和处理函数。
|
||
// 主要特点包括:高并发连接处理、连接生命周期管理、优雅关闭机制和缓冲区复用。
|
||
package network
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/binary"
|
||
"errors"
|
||
"io"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// ========== 框架核心接口 ==========
|
||
|
||
// Packet 用户自定义数据包必须实现的接口
|
||
// 用户需要实现这个接口来定义自己的数据包格式,包括编码和解码方法
|
||
type Packet interface {
|
||
// Encode 将数据包编码为字节流
|
||
// 返回编码后的字节切片和可能的错误
|
||
Encode() ([]byte, error)
|
||
|
||
// Decode 从字节流解码为数据包
|
||
// 参数data包含要解码的字节数据
|
||
// 返回可能的解码错误
|
||
Decode([]byte) error
|
||
}
|
||
|
||
// Handler 用户自定义的数据处理函数类型
|
||
// 当收到数据包时,框架会调用这个函数来处理数据
|
||
// 参数conn是产生数据的网络连接,p是解码后的数据包
|
||
// 返回响应数据包和可能的错误
|
||
type Handler func(conn net.Conn, p Packet) (Packet, error)
|
||
|
||
// ========== 框架核心结构 ==========
|
||
|
||
// ServerConfig 服务器配置
|
||
// 包含服务器运行所需的各种参数设置
|
||
type ServerConfig struct {
|
||
Network string // 网络类型: tcp, tcp4, tcp6, udp, udp4, udp6
|
||
Address string // 监听地址,格式为 "ip:port",如 ":8080"
|
||
MaxConn int // 最大连接数 (仅TCP有效)
|
||
WorkerNum int // 工作协程数量,默认为CPU核心数的2倍
|
||
QueueSize int // 任务队列大小,默认为1024
|
||
ReadTimeout time.Duration // 读取超时时间,0表示不设置超时
|
||
IdleTimeout time.Duration // 空闲连接超时时间,0表示不清理空闲连接
|
||
}
|
||
|
||
// connectionShard 连接分片
|
||
// 用于高效管理大量连接,减少锁竞争
|
||
type connectionShard struct {
|
||
conns sync.Map // 存储连接的并发安全map,键为net.Conn,值为*Connection
|
||
lock sync.RWMutex // 读写锁,保护分片操作
|
||
lastUsed time.Time // 最后使用时间,用于分片管理
|
||
}
|
||
|
||
// Connection 连接封装
|
||
// 封装了底层网络连接,提供了缓冲区和通道管理
|
||
type Connection struct {
|
||
net.Conn // 嵌入底层网络连接
|
||
readBuffer *bytes.Buffer // 读缓冲区
|
||
writeBuffer *bytes.Buffer // 写缓冲区
|
||
writeChan chan []byte // 写入通道,用于异步写入
|
||
closeChan chan struct{} // 关闭通道,用于通知协程退出
|
||
server *HighConcurrentServer // 所属服务器
|
||
lastActive time.Time // 最后活动时间,用于空闲检测
|
||
}
|
||
|
||
// WorkerPool 工作协程池
|
||
// 用于高效处理网络请求,避免为每个连接创建协程
|
||
type WorkerPool struct {
|
||
taskQueue chan func() // 任务队列,存储待执行的函数
|
||
size int // 工作协程数量
|
||
}
|
||
|
||
// ========== 框架实现 ==========
|
||
|
||
// readLoop 连接读循环
|
||
func (c *Connection) readLoop() {
|
||
reader := bufio.NewReader(c.Conn)
|
||
var header [4]byte
|
||
|
||
for {
|
||
select {
|
||
case <-c.closeChan:
|
||
return
|
||
default:
|
||
}
|
||
|
||
// 设置读超时
|
||
if c.server.config.ReadTimeout > 0 {
|
||
c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
|
||
}
|
||
|
||
// 读取包头 (4字节长度)
|
||
_, err := io.ReadFull(reader, header[:])
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
// 处理错误
|
||
}
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
|
||
// 解析包长度
|
||
length := binary.BigEndian.Uint32(header[:])
|
||
|
||
// 检查包长度是否合理
|
||
if length > 10*1024*1024 { // 10MB
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
|
||
// 读取包体
|
||
packetData := make([]byte, length)
|
||
_, err = io.ReadFull(reader, packetData)
|
||
if err != nil {
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
|
||
c.lastActive = time.Now()
|
||
|
||
// 提交给工作池处理
|
||
c.server.workerPool.Submit(func() {
|
||
c.processPacket(packetData)
|
||
})
|
||
}
|
||
}
|
||
|
||
// processPacket 处理数据包
|
||
func (c *Connection) processPacket(data []byte) {
|
||
// 创建新的数据包实例
|
||
packet := c.server.packetType()
|
||
|
||
// 解码数据
|
||
if err := packet.Decode(data); err != nil {
|
||
// 处理解码错误
|
||
return
|
||
}
|
||
|
||
// 调用用户处理函数
|
||
response, err := c.server.handler(c.Conn, packet)
|
||
if err != nil {
|
||
// 处理错误
|
||
return
|
||
}
|
||
|
||
// 如果有响应,发送回去
|
||
if response != nil {
|
||
c.Send(response)
|
||
}
|
||
}
|
||
|
||
// Send 发送数据包
|
||
func (c *Connection) Send(p Packet) error {
|
||
data, err := p.Encode()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
select {
|
||
case c.writeChan <- data:
|
||
return nil
|
||
default:
|
||
return errors.New("write channel full")
|
||
}
|
||
}
|
||
|
||
// writeLoop 连接写循环
|
||
func (c *Connection) writeLoop() {
|
||
// 批量写入优化
|
||
var batch [][]byte
|
||
ticker := time.NewTicker(10 * time.Millisecond)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case data := <-c.writeChan:
|
||
batch = append(batch, data)
|
||
// 尝试批量处理
|
||
for len(c.writeChan) > 0 && len(batch) < 32 {
|
||
batch = append(batch, <-c.writeChan)
|
||
}
|
||
|
||
// 合并写入
|
||
c.writeBuffer.Reset()
|
||
for _, d := range batch {
|
||
// 写入长度头
|
||
header := make([]byte, 4)
|
||
binary.BigEndian.PutUint32(header, uint32(len(d)))
|
||
c.writeBuffer.Write(header)
|
||
c.writeBuffer.Write(d)
|
||
}
|
||
|
||
// 发送数据
|
||
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
batch = batch[:0]
|
||
|
||
case <-ticker.C:
|
||
// 定时刷新
|
||
if len(batch) > 0 {
|
||
c.writeBuffer.Reset()
|
||
for _, d := range batch {
|
||
header := make([]byte, 4)
|
||
binary.BigEndian.PutUint32(header, uint32(len(d)))
|
||
c.writeBuffer.Write(header)
|
||
c.writeBuffer.Write(d)
|
||
}
|
||
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
batch = batch[:0]
|
||
}
|
||
|
||
// 检查空闲超时
|
||
if c.server.config.IdleTimeout > 0 &&
|
||
time.Since(c.lastActive) > c.server.config.IdleTimeout {
|
||
close(c.closeChan)
|
||
return
|
||
}
|
||
|
||
case <-c.closeChan:
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// ========== 辅助结构和方法 ==========
|
||
|
||
// NewWorkerPool 创建工作池
|
||
func NewWorkerPool(size, queueSize int) *WorkerPool {
|
||
pool := &WorkerPool{
|
||
taskQueue: make(chan func(), queueSize),
|
||
size: size,
|
||
}
|
||
|
||
for i := 0; i < size; i++ {
|
||
go pool.worker()
|
||
}
|
||
|
||
return pool
|
||
}
|
||
|
||
func (p *WorkerPool) worker() {
|
||
for task := range p.taskQueue {
|
||
task()
|
||
}
|
||
}
|
||
|
||
func (p *WorkerPool) Submit(task func()) {
|
||
select {
|
||
case p.taskQueue <- task:
|
||
default:
|
||
// 任务队列满时的处理
|
||
}
|
||
}
|
||
|
||
func (p *WorkerPool) Stop() {
|
||
close(p.taskQueue)
|
||
}
|
||
|
||
// udpConn UDP虚拟连接
|
||
type udpConn struct {
|
||
net.PacketConn
|
||
addr net.Addr
|
||
}
|
||
|
||
func (c *udpConn) Read(b []byte) (int, error) {
|
||
return 0, errors.New("not implemented")
|
||
}
|
||
|
||
func (c *udpConn) Write(b []byte) (int, error) {
|
||
return c.PacketConn.WriteTo(b, c.addr)
|
||
}
|
||
|
||
func (c *udpConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
func (c *udpConn) LocalAddr() net.Addr {
|
||
return c.PacketConn.LocalAddr()
|
||
}
|
||
|
||
func (c *udpConn) RemoteAddr() net.Addr {
|
||
return c.addr
|
||
}
|
||
|
||
func (c *udpConn) SetDeadline(t time.Time) error {
|
||
return c.PacketConn.SetDeadline(t)
|
||
}
|
||
|
||
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||
return c.PacketConn.SetReadDeadline(t)
|
||
}
|
||
|
||
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||
return c.PacketConn.SetWriteDeadline(t)
|
||
}
|
||
|
||
func (s *connectionShard) addConn(conn *Connection) {
|
||
s.lock.Lock()
|
||
defer s.lock.Unlock()
|
||
s.conns.Store(conn.Conn, conn)
|
||
s.lastUsed = time.Now()
|
||
}
|
||
|
||
func (s *connectionShard) removeConn(conn *Connection) {
|
||
s.lock.Lock()
|
||
defer s.lock.Unlock()
|
||
s.conns.Delete(conn.Conn)
|
||
s.lastUsed = time.Now()
|
||
}
|
||
|
||
// ========== 用户使用示例 ==========
|
||
|
||
/*
|
||
// 用户自定义数据包
|
||
type MyPacket struct {
|
||
ID uint32
|
||
Payload string
|
||
}
|
||
|
||
func (p *MyPacket) Encode() ([]byte, error) {
|
||
buf := bytes.NewBuffer(nil)
|
||
binary.Write(buf, binary.BigEndian, p.ID)
|
||
binary.Write(buf, binary.BigEndian, uint32(len(p.Payload)))
|
||
buf.WriteString(p.Payload)
|
||
return buf.Bytes(), nil
|
||
}
|
||
|
||
func (p *MyPacket) Decode(data []byte) error {
|
||
buf := bytes.NewBuffer(data)
|
||
binary.Read(buf, binary.BigEndian, &p.ID)
|
||
var length uint32
|
||
binary.Read(buf, binary.BigEndian, &length)
|
||
p.Payload = string(buf.Next(int(length)))
|
||
return nil
|
||
}
|
||
|
||
// 用户自定义处理函数
|
||
func myHandler(conn net.Conn, p Packet) (Packet, error) {
|
||
myPacket, ok := p.(*MyPacket)
|
||
if !ok {
|
||
return nil, errors.New("invalid packet type")
|
||
}
|
||
|
||
// 处理业务逻辑
|
||
response := &MyPacket{
|
||
ID: myPacket.ID + 1,
|
||
Payload: "Response: " + myPacket.Payload,
|
||
}
|
||
|
||
return response, nil
|
||
}
|
||
|
||
func main() {
|
||
// 创建服务器配置
|
||
config := network.ServerConfig{
|
||
Network: "tcp",
|
||
Address: ":8080",
|
||
MaxConn: 10000,
|
||
WorkerNum: 100,
|
||
QueueSize: 10000,
|
||
ReadTimeout: 30 * time.Second,
|
||
IdleTimeout: 300 * time.Second,
|
||
}
|
||
|
||
// 创建服务器实例
|
||
server := network.NewServer(config, myHandler, func() network.Packet {
|
||
return &MyPacket{}
|
||
})
|
||
|
||
// 启动服务器
|
||
if err := server.Start(); err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 等待关闭信号
|
||
sig := make(chan os.Signal, 1)
|
||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||
<-sig
|
||
|
||
// 优雅关闭
|
||
server.Stop()
|
||
}
|
||
*/
|