network/network.go

380 lines
7.4 KiB
Go

package network
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
// ========== 框架核心接口 ==========
// Packet 用户自定义数据包必须实现的接口
type Packet interface {
// 编码为字节流
Encode() ([]byte, error)
// 从字节流解码
Decode([]byte) error
}
// Handler 用户自定义的数据处理函数类型
type Handler func(conn net.Conn, p Packet) (Packet, error)
// ========== 框架核心结构 ==========
// ServerConfig 服务器配置
type ServerConfig struct {
Network string // tcp, tcp4, tcp6, udp, udp4, udp6
Address string // 监听地址
MaxConn int // 最大连接数 (TCP only)
WorkerNum int // 工作协程数量
QueueSize int // 任务队列大小
ReadTimeout time.Duration // 读取超时
IdleTimeout time.Duration // 空闲超时
}
// connectionShard 连接分片
type connectionShard struct {
conns sync.Map // map[net.Conn]*Connection
lock sync.RWMutex
lastUsed time.Time
}
// Connection 连接封装
type Connection struct {
net.Conn
readBuffer *bytes.Buffer
writeBuffer *bytes.Buffer
writeChan chan []byte
closeChan chan struct{}
server *HighConcurrentServer
lastActive time.Time
}
// WorkerPool 工作协程池
type WorkerPool struct {
taskQueue chan func()
size int
}
// ========== 框架实现 ==========
// readLoop 连接读循环
func (c *Connection) readLoop() {
reader := bufio.NewReader(c.Conn)
var header [4]byte
for {
select {
case <-c.closeChan:
return
default:
}
// 设置读超时
if c.server.config.ReadTimeout > 0 {
c.SetReadDeadline(time.Now().Add(c.server.config.ReadTimeout))
}
// 读取包头 (4字节长度)
_, err := io.ReadFull(reader, header[:])
if err != nil {
if err != io.EOF {
// 处理错误
}
close(c.closeChan)
return
}
// 解析包长度
length := binary.BigEndian.Uint32(header[:])
// 检查包长度是否合理
if length > 10*1024*1024 { // 10MB
close(c.closeChan)
return
}
// 读取包体
packetData := make([]byte, length)
_, err = io.ReadFull(reader, packetData)
if err != nil {
close(c.closeChan)
return
}
c.lastActive = time.Now()
// 提交给工作池处理
c.server.workerPool.Submit(func() {
c.processPacket(packetData)
})
}
}
// processPacket 处理数据包
func (c *Connection) processPacket(data []byte) {
// 创建新的数据包实例
packet := c.server.packetType()
// 解码数据
if err := packet.Decode(data); err != nil {
// 处理解码错误
return
}
// 调用用户处理函数
response, err := c.server.handler(c.Conn, packet)
if err != nil {
// 处理错误
return
}
// 如果有响应,发送回去
if response != nil {
c.Send(response)
}
}
// Send 发送数据包
func (c *Connection) Send(p Packet) error {
data, err := p.Encode()
if err != nil {
return err
}
select {
case c.writeChan <- data:
return nil
default:
return errors.New("write channel full")
}
}
// writeLoop 连接写循环
func (c *Connection) writeLoop() {
// 批量写入优化
var batch [][]byte
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case data := <-c.writeChan:
batch = append(batch, data)
// 尝试批量处理
for len(c.writeChan) > 0 && len(batch) < 32 {
batch = append(batch, <-c.writeChan)
}
// 合并写入
c.writeBuffer.Reset()
for _, d := range batch {
// 写入长度头
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(len(d)))
c.writeBuffer.Write(header)
c.writeBuffer.Write(d)
}
// 发送数据
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
close(c.closeChan)
return
}
batch = batch[:0]
case <-ticker.C:
// 定时刷新
if len(batch) > 0 {
c.writeBuffer.Reset()
for _, d := range batch {
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(len(d)))
c.writeBuffer.Write(header)
c.writeBuffer.Write(d)
}
if _, err := c.Conn.Write(c.writeBuffer.Bytes()); err != nil {
close(c.closeChan)
return
}
batch = batch[:0]
}
// 检查空闲超时
if c.server.config.IdleTimeout > 0 &&
time.Since(c.lastActive) > c.server.config.IdleTimeout {
close(c.closeChan)
return
}
case <-c.closeChan:
return
}
}
}
// ========== 辅助结构和方法 ==========
// NewWorkerPool 创建工作池
func NewWorkerPool(size, queueSize int) *WorkerPool {
pool := &WorkerPool{
taskQueue: make(chan func(), queueSize),
size: size,
}
for i := 0; i < size; i++ {
go pool.worker()
}
return pool
}
func (p *WorkerPool) worker() {
for task := range p.taskQueue {
task()
}
}
func (p *WorkerPool) Submit(task func()) {
select {
case p.taskQueue <- task:
default:
// 任务队列满时的处理
}
}
func (p *WorkerPool) Stop() {
close(p.taskQueue)
}
// udpConn UDP虚拟连接
type udpConn struct {
net.PacketConn
addr net.Addr
}
func (c *udpConn) Read(b []byte) (int, error) {
return 0, errors.New("not implemented")
}
func (c *udpConn) Write(b []byte) (int, error) {
return c.PacketConn.WriteTo(b, c.addr)
}
func (c *udpConn) Close() error {
return nil
}
func (c *udpConn) LocalAddr() net.Addr {
return c.PacketConn.LocalAddr()
}
func (c *udpConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *udpConn) SetDeadline(t time.Time) error {
return c.PacketConn.SetDeadline(t)
}
func (c *udpConn) SetReadDeadline(t time.Time) error {
return c.PacketConn.SetReadDeadline(t)
}
func (c *udpConn) SetWriteDeadline(t time.Time) error {
return c.PacketConn.SetWriteDeadline(t)
}
func (s *connectionShard) addConn(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
s.conns.Store(conn.Conn, conn)
s.lastUsed = time.Now()
}
func (s *connectionShard) removeConn(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
s.conns.Delete(conn.Conn)
s.lastUsed = time.Now()
}
// ========== 用户使用示例 ==========
/*
// 用户自定义数据包
type MyPacket struct {
ID uint32
Payload string
}
func (p *MyPacket) Encode() ([]byte, error) {
buf := bytes.NewBuffer(nil)
binary.Write(buf, binary.BigEndian, p.ID)
binary.Write(buf, binary.BigEndian, uint32(len(p.Payload)))
buf.WriteString(p.Payload)
return buf.Bytes(), nil
}
func (p *MyPacket) Decode(data []byte) error {
buf := bytes.NewBuffer(data)
binary.Read(buf, binary.BigEndian, &p.ID)
var length uint32
binary.Read(buf, binary.BigEndian, &length)
p.Payload = string(buf.Next(int(length)))
return nil
}
// 用户自定义处理函数
func myHandler(conn net.Conn, p Packet) (Packet, error) {
myPacket, ok := p.(*MyPacket)
if !ok {
return nil, errors.New("invalid packet type")
}
// 处理业务逻辑
response := &MyPacket{
ID: myPacket.ID + 1,
Payload: "Response: " + myPacket.Payload,
}
return response, nil
}
func main() {
// 创建服务器配置
config := network.ServerConfig{
Network: "tcp",
Address: ":8080",
MaxConn: 10000,
WorkerNum: 100,
QueueSize: 10000,
ReadTimeout: 30 * time.Second,
IdleTimeout: 300 * time.Second,
}
// 创建服务器实例
server := network.NewServer(config, myHandler, func() network.Packet {
return &MyPacket{}
})
// 启动服务器
if err := server.Start(); err != nil {
panic(err)
}
// 等待关闭信号
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
// 优雅关闭
server.Stop()
}
*/