package main import ( "encoding/json" "fmt" "log" "net" "os" "strconv" "strings" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) // NATType 定义NAT类型 type NATType string const ( SNAT NATType = "SNAT" // 源地址转换 DNAT NATType = "DNAT" // 目标地址转换 BINAT NATType = "BINAT" // 双向地址转换 ) // ConnectionKey 定义连接跟踪的键 type ConnectionKey struct { SrcIP string DstIP string SrcPort uint16 DstPort uint16 Protocol uint8 } // ConnectionInfo 定义连接跟踪的信息 type ConnectionInfo struct { OriginalSrcIP string OriginalDstIP string OriginalSrcPort uint16 OriginalDstPort uint16 TranslatedSrcIP string TranslatedDstIP string TranslatedSrcPort uint16 TranslatedDstPort uint16 LastSeen time.Time } // Forwarder 流量转发器 type Forwarder struct { enabled bool natRules []ForwardRule // NAT规则列表 natTable map[string]string // 兼容旧版本的NAT映射表 connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表 connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁 cleanupTicker *time.Ticker // 定期清理过期连接 connTimeout time.Duration // 连接超时时间 } // NewForwarder 创建新的流量转发器 func NewForwarder() *Forwarder { f := &Forwarder{ enabled: false, natRules: []ForwardRule{}, natTable: make(map[string]string), connTrackTable: make(map[ConnectionKey]ConnectionInfo), connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟 } // 启动定期清理过期连接的定时器 f.cleanupTicker = time.NewTicker(1 * time.Minute) go f.cleanupExpiredConnections() return f } // cleanupExpiredConnections 清理过期的连接 func (f *Forwarder) cleanupExpiredConnections() { for range f.cleanupTicker.C { now := time.Now() expiredKeys := []ConnectionKey{} // 查找过期的连接 f.connTrackMutex.RLock() for key, info := range f.connTrackTable { if now.Sub(info.LastSeen) > f.connTimeout { expiredKeys = append(expiredKeys, key) } } f.connTrackMutex.RUnlock() // 删除过期的连接 if len(expiredKeys) > 0 { f.connTrackMutex.Lock() for _, key := range expiredKeys { delete(f.connTrackTable, key) log.Printf("Removed expired connection: %s:%d -> %s:%d", key.SrcIP, key.SrcPort, key.DstIP, key.DstPort) } f.connTrackMutex.Unlock() } } } // Enable 启用转发 func (f *Forwarder) Enable() { f.enabled = true } // Disable 禁用转发 func (f *Forwarder) Disable() { f.enabled = false } // Start 启动转发服务 func (f *Forwarder) Start() error { f.enabled = true log.Println("Forwarding service started") return nil } // Stop 停止转发服务 func (f *Forwarder) Stop() { f.enabled = false log.Println("Forwarding service stopped") } // Close 关闭转发器,停止清理定时器 func (f *Forwarder) Close() { if f.cleanupTicker != nil { f.cleanupTicker.Stop() } } // GetConnectionStats 获取连接统计信息 func (f *Forwarder) GetConnectionStats() map[string]interface{} { f.connTrackMutex.RLock() defer f.connTrackMutex.RUnlock() stats := make(map[string]interface{}) stats["total_connections"] = len(f.connTrackTable) // 按协议统计连接数 protocolStats := make(map[uint8]int) for key := range f.connTrackTable { protocolStats[key.Protocol]++ } stats["protocol_stats"] = protocolStats // 获取最近活动的连接 type connSummary struct { SrcIP string `json:"src_ip"` SrcPort uint16 `json:"src_port"` DstIP string `json:"dst_ip"` DstPort uint16 `json:"dst_port"` Protocol uint8 `json:"protocol"` LastSeen string `json:"last_seen"` } recentConns := []connSummary{} count := 0 for key, info := range f.connTrackTable { if count >= 10 { break } recentConns = append(recentConns, connSummary{ SrcIP: key.SrcIP, SrcPort: key.SrcPort, DstIP: key.DstIP, DstPort: key.DstPort, Protocol: key.Protocol, LastSeen: info.LastSeen.Format(time.RFC3339), }) count++ } stats["recent_connections"] = recentConns return stats } func (f *Forwarder) LoadRulesFromFile(ruleFile string) error { content, err := os.ReadFile(ruleFile) if err != nil { return err } decoder := json.NewDecoder(strings.NewReader(string(content))) for decoder.More() { var rule ForwardRule err := decoder.Decode(&rule) if err != nil { return err } f.AddForwardRule(rule) } return nil } // ForwardRule 定义转发规则结构 type ForwardRule struct { Type NATType `json:"type"` // NAT类型:SNAT, DNAT, BINAT SrcIP string `json:"src_ip"` // 源IP SrcPort int `json:"src_port"` // 源端口 DstIP string `json:"dst_ip"` // 目标IP DstPort int `json:"dst_port"` // 目标端口 NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT) NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT) NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT) NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT) ID string `json:"id"` // 规则ID,用于动态删除 } // LoadConfig 从配置文件加载转发规则 func (f *Forwarder) LoadConfig(configFile string) error { data, err := os.ReadFile(configFile) if err != nil { return err } var config struct { ForwardEnabled bool `json:"forward_enabled"` NATRules []ForwardRule `json:"nat_rules"` ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本 } err = json.Unmarshal(data, &config) if err != nil { return err } f.enabled = config.ForwardEnabled // 加载新版本NAT规则 if len(config.NATRules) > 0 { f.natRules = config.NATRules log.Printf("Loaded %d NAT rules", len(f.natRules)) for i, rule := range f.natRules { log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d", i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort) } } // 兼容旧版本的转发规则 if len(config.ForwardRules) > 0 { f.natTable = config.ForwardRules log.Printf("Loaded %d legacy forward rules", len(f.natTable)) // 将旧版本规则转换为新版本规则 for src, dst := range f.natTable { parts := strings.Split(src, ":") if len(parts) != 2 { continue } srcIP := parts[0] srcPort, err := strconv.Atoi(parts[1]) if err != nil { continue } dstParts := strings.Split(dst, ":") if len(dstParts) != 2 { continue } dstIP := dstParts[0] dstPort, err := strconv.Atoi(dstParts[1]) if err != nil { continue } // 创建DNAT规则 rule := ForwardRule{ Type: DNAT, SrcIP: "", SrcPort: 0, DstIP: srcIP, DstPort: srcPort, NewDstIP: dstIP, NewDstPort: dstPort, ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort), } f.natRules = append(f.natRules, rule) } } return nil } // AddForwardRule 添加转发规则 func (f *Forwarder) AddForwardRule(rule ForwardRule) { // 生成规则ID(如果没有提供) if rule.ID == "" { rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d", rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano()) } // 添加到规则列表 f.natRules = append(f.natRules, rule) // 兼容旧版本:如果是DNAT规则,也添加到natTable if rule.Type == DNAT { key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort) f.natTable[key] = value } else if rule.Type == "" { // 兼容旧版本的规则格式 key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) f.natTable[key] = value } log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID) } // RemoveForwardRule 移除转发规则 func (f *Forwarder) RemoveForwardRule(rule ForwardRule) { // 兼容旧版本 key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) delete(f.natTable, key) // 从natRules中删除 if rule.ID != "" { for i, r := range f.natRules { if r.ID == rule.ID { f.natRules = append(f.natRules[:i], f.natRules[i+1:]...) log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID) break } } } } // RemoveForwardRuleByID 通过ID删除转发规则 func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool { for i, rule := range f.natRules { if rule.ID == ruleID { // 从规则列表中删除 f.natRules = append(f.natRules[:i], f.natRules[i+1:]...) // 如果是DNAT规则,也从natTable中删除 if rule.Type == DNAT { key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) delete(f.natTable, key) } log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID) return true } } return false } // ListNATRules 列出所有NAT规则 func (f *Forwarder) ListNATRules() []ForwardRule { return f.natRules } // getConnectionKey 从数据包中获取连接键 func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey { var srcIP, dstIP string var srcPort, dstPort uint16 var protocol uint8 srcIP = ipLayer.SrcIP.String() dstIP = ipLayer.DstIP.String() protocol = uint8(ipLayer.Protocol) switch t := transportLayer.(type) { case *layers.TCP: srcPort = uint16(t.SrcPort) dstPort = uint16(t.DstPort) case *layers.UDP: srcPort = uint16(t.SrcPort) dstPort = uint16(t.DstPort) } // 如果是回复包,交换源和目标 if isReply { return ConnectionKey{ SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort, Protocol: protocol, } } return ConnectionKey{ SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort, Protocol: protocol, } } // findMatchingRule 查找匹配的NAT规则 func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule { for i, rule := range f.natRules { // 检查源IP匹配 if rule.SrcIP != "" && rule.SrcIP != srcIP { continue } // 检查源端口匹配 if rule.SrcPort != 0 && rule.SrcPort != srcPort { continue } // 检查目标IP匹配 if rule.DstIP != "" && rule.DstIP != dstIP { continue } // 检查目标端口匹配 if rule.DstPort != 0 && rule.DstPort != dstPort { continue } // 找到匹配的规则 return &f.natRules[i] } // 兼容旧版本:检查natTable key := fmt.Sprintf("%s:%d", dstIP, dstPort) if forwardAddr, exists := f.natTable[key]; exists { // 解析转发目标地址 parts := strings.Split(forwardAddr, ":") if len(parts) != 2 { return nil } newDstIP := parts[0] newDstPort, err := strconv.Atoi(parts[1]) if err != nil { return nil } // 创建临时DNAT规则 return &ForwardRule{ Type: DNAT, DstIP: dstIP, DstPort: dstPort, NewDstIP: newDstIP, NewDstPort: newDstPort, } } return nil } // ForwardPacket 转发数据包 func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error { if !f.enabled { return nil } // 获取源IP和端口 srcIP := ipLayer.SrcIP.String() dstIP := ipLayer.DstIP.String() var srcPort, dstPort int // 根据传输层协议获取端口 switch t := transportLayer.(type) { case *layers.TCP: srcPort = int(t.SrcPort) dstPort = int(t.DstPort) case *layers.UDP: srcPort = int(t.SrcPort) dstPort = int(t.DstPort) default: // 不支持的传输层协议 return nil } // 创建连接键 connKey := f.getConnectionKey(ipLayer, transportLayer, false) // 检查是否是已建立连接的回复包 f.connTrackMutex.RLock() connInfo, isReply := f.connTrackTable[connKey] // 如果不是回复包,尝试查找反向连接 if !isReply { reverseKey := f.getConnectionKey(ipLayer, transportLayer, true) if info, found := f.connTrackTable[reverseKey]; found { isReply = true connInfo = info } } f.connTrackMutex.RUnlock() // 如果是已建立连接的回复包,应用已有的NAT转换 if isReply { // 更新连接的最后活动时间 f.connTrackMutex.Lock() connInfo.LastSeen = time.Now() f.connTrackTable[connKey] = connInfo f.connTrackMutex.Unlock() // 应用已有的NAT转换 if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP { ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP) } if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP { ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP) } switch t := transportLayer.(type) { case *layers.TCP: if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) { t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort) } if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) { t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort) } t.SetNetworkLayerForChecksum(ipLayer) case *layers.UDP: if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) { t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort) } if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) { t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort) } t.SetNetworkLayerForChecksum(ipLayer) } return nil } // 查找匹配的NAT规则 rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort) if rule == nil { return nil } // 创建连接跟踪信息 newConnInfo := ConnectionInfo{ OriginalSrcIP: srcIP, OriginalDstIP: dstIP, OriginalSrcPort: uint16(srcPort), OriginalDstPort: uint16(dstPort), TranslatedSrcIP: srcIP, TranslatedDstIP: dstIP, TranslatedSrcPort: uint16(srcPort), TranslatedDstPort: uint16(dstPort), LastSeen: time.Now(), } // 应用NAT规则 switch rule.Type { case SNAT: // 源地址转换 if rule.NewSrcIP != "" { ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP) newConnInfo.TranslatedSrcIP = rule.NewSrcIP } if rule.NewSrcPort != 0 { switch t := transportLayer.(type) { case *layers.TCP: t.SrcPort = layers.TCPPort(rule.NewSrcPort) t.SetNetworkLayerForChecksum(ipLayer) case *layers.UDP: t.SrcPort = layers.UDPPort(rule.NewSrcPort) t.SetNetworkLayerForChecksum(ipLayer) } newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort) } case DNAT: // 目标地址转换 if rule.NewDstIP != "" { ipLayer.DstIP = net.ParseIP(rule.NewDstIP) newConnInfo.TranslatedDstIP = rule.NewDstIP } if rule.NewDstPort != 0 { switch t := transportLayer.(type) { case *layers.TCP: t.DstPort = layers.TCPPort(rule.NewDstPort) t.SetNetworkLayerForChecksum(ipLayer) case *layers.UDP: t.DstPort = layers.UDPPort(rule.NewDstPort) t.SetNetworkLayerForChecksum(ipLayer) } newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort) } case BINAT: // 双向地址转换 if rule.NewSrcIP != "" { ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP) newConnInfo.TranslatedSrcIP = rule.NewSrcIP } if rule.NewSrcPort != 0 { switch t := transportLayer.(type) { case *layers.TCP: t.SrcPort = layers.TCPPort(rule.NewSrcPort) case *layers.UDP: t.SrcPort = layers.UDPPort(rule.NewSrcPort) } newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort) } if rule.NewDstIP != "" { ipLayer.DstIP = net.ParseIP(rule.NewDstIP) newConnInfo.TranslatedDstIP = rule.NewDstIP } if rule.NewDstPort != 0 { switch t := transportLayer.(type) { case *layers.TCP: t.DstPort = layers.TCPPort(rule.NewDstPort) t.SetNetworkLayerForChecksum(ipLayer) case *layers.UDP: t.DstPort = layers.UDPPort(rule.NewDstPort) t.SetNetworkLayerForChecksum(ipLayer) } newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort) } } // 更新连接跟踪表 f.connTrackMutex.Lock() f.connTrackTable[connKey] = newConnInfo f.connTrackMutex.Unlock() log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d", rule.Type, srcIP, srcPort, dstIP, dstPort, newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort, newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort) return nil }