diff --git a/forwarder.go b/forwarder.go index 72433bd..95d7c6c 100644 --- a/forwarder.go +++ b/forwarder.go @@ -8,23 +8,108 @@ import ( "os" "strconv" "strings" + "sync" + "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) +// NATType 定义NAT类型 +type NATType string + +const ( + SNAT NATType = "SNAT" // 源地址转换 + DNAT NATType = "DNAT" // 目标地址转换 + BINAT NATType = "BINAT" // 双向地址转换 +) + +// ConnectionKey 定义连接跟踪的键 +type ConnectionKey struct { + SrcIP string + DstIP string + SrcPort uint16 + DstPort uint16 + Protocol uint8 +} + +// ConnectionInfo 定义连接跟踪的信息 +type ConnectionInfo struct { + OriginalSrcIP string + OriginalDstIP string + OriginalSrcPort uint16 + OriginalDstPort uint16 + TranslatedSrcIP string + TranslatedDstIP string + TranslatedSrcPort uint16 + TranslatedDstPort uint16 + LastSeen time.Time +} + // Forwarder 流量转发器 type Forwarder struct { - enabled bool - natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口 + enabled bool + natRules []ForwardRule // NAT规则列表 + natTable map[string]string // 兼容旧版本的NAT映射表 + connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表 + connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁 + cleanupTicker *time.Ticker // 定期清理过期连接 + connTimeout time.Duration // 连接超时时间 } // NewForwarder 创建新的流量转发器 func NewForwarder() *Forwarder { - return &Forwarder{ - enabled: false, - natTable: make(map[string]string), + f := &Forwarder{ + enabled: false, + natRules: []ForwardRule{}, + natTable: make(map[string]string), + connTrackTable: make(map[ConnectionKey]ConnectionInfo), + connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟 } + + // 启动定期清理过期连接的定时器 + f.cleanupTicker = time.NewTicker(1 * time.Minute) + go f.cleanupExpiredConnections() + + return f +} + +// cleanupExpiredConnections 清理过期的连接 +func (f *Forwarder) cleanupExpiredConnections() { + for range f.cleanupTicker.C { + now := time.Now() + expiredKeys := []ConnectionKey{} + + // 查找过期的连接 + f.connTrackMutex.RLock() + for key, info := range f.connTrackTable { + if now.Sub(info.LastSeen) > f.connTimeout { + expiredKeys = append(expiredKeys, key) + } + } + f.connTrackMutex.RUnlock() + + // 删除过期的连接 + if len(expiredKeys) > 0 { + f.connTrackMutex.Lock() + for _, key := range expiredKeys { + delete(f.connTrackTable, key) + log.Printf("Removed expired connection: %s:%d -> %s:%d", + key.SrcIP, key.SrcPort, key.DstIP, key.DstPort) + } + f.connTrackMutex.Unlock() + } + } +} + +// Enable 启用转发 +func (f *Forwarder) Enable() { + f.enabled = true +} + +// Disable 禁用转发 +func (f *Forwarder) Disable() { + f.enabled = false } // Start 启动转发服务 @@ -40,10 +125,62 @@ func (f *Forwarder) Stop() { log.Println("Forwarding service stopped") } +// Close 关闭转发器,停止清理定时器 +func (f *Forwarder) Close() { + if f.cleanupTicker != nil { + f.cleanupTicker.Stop() + } +} + +// GetConnectionStats 获取连接统计信息 +func (f *Forwarder) GetConnectionStats() map[string]interface{} { + f.connTrackMutex.RLock() + defer f.connTrackMutex.RUnlock() + + stats := make(map[string]interface{}) + stats["total_connections"] = len(f.connTrackTable) + + // 按协议统计连接数 + protocolStats := make(map[uint8]int) + for key := range f.connTrackTable { + protocolStats[key.Protocol]++ + } + stats["protocol_stats"] = protocolStats + + // 获取最近活动的连接 + type connSummary struct { + SrcIP string `json:"src_ip"` + SrcPort uint16 `json:"src_port"` + DstIP string `json:"dst_ip"` + DstPort uint16 `json:"dst_port"` + Protocol uint8 `json:"protocol"` + LastSeen string `json:"last_seen"` + } + + recentConns := []connSummary{} + count := 0 + for key, info := range f.connTrackTable { + if count >= 10 { + break + } + recentConns = append(recentConns, connSummary{ + SrcIP: key.SrcIP, + SrcPort: key.SrcPort, + DstIP: key.DstIP, + DstPort: key.DstPort, + Protocol: key.Protocol, + LastSeen: info.LastSeen.Format(time.RFC3339), + }) + count++ + } + stats["recent_connections"] = recentConns + + return stats +} + func (f *Forwarder) LoadRulesFromFile(ruleFile string) error { content, err := os.ReadFile(ruleFile) if err != nil { - return err } decoder := json.NewDecoder(strings.NewReader(string(content))) @@ -60,23 +197,257 @@ func (f *Forwarder) LoadRulesFromFile(ruleFile string) error { // ForwardRule 定义转发规则结构 type ForwardRule struct { - SrcIP string // 源IP - SrcPort int // 源端口 - DstIP string // 目标IP - DstPort int // 目标端口 + Type NATType `json:"type"` // NAT类型:SNAT, DNAT, BINAT + SrcIP string `json:"src_ip"` // 源IP + SrcPort int `json:"src_port"` // 源端口 + DstIP string `json:"dst_ip"` // 目标IP + DstPort int `json:"dst_port"` // 目标端口 + NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT) + NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT) + NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT) + NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT) + ID string `json:"id"` // 规则ID,用于动态删除 +} + +// LoadConfig 从配置文件加载转发规则 +func (f *Forwarder) LoadConfig(configFile string) error { + data, err := os.ReadFile(configFile) + if err != nil { + return err + } + + var config struct { + ForwardEnabled bool `json:"forward_enabled"` + NATRules []ForwardRule `json:"nat_rules"` + ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本 + } + + err = json.Unmarshal(data, &config) + if err != nil { + return err + } + + f.enabled = config.ForwardEnabled + + // 加载新版本NAT规则 + if len(config.NATRules) > 0 { + f.natRules = config.NATRules + log.Printf("Loaded %d NAT rules", len(f.natRules)) + for i, rule := range f.natRules { + log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d", + i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort) + } + } + + // 兼容旧版本的转发规则 + if len(config.ForwardRules) > 0 { + f.natTable = config.ForwardRules + log.Printf("Loaded %d legacy forward rules", len(f.natTable)) + + // 将旧版本规则转换为新版本规则 + for src, dst := range f.natTable { + parts := strings.Split(src, ":") + if len(parts) != 2 { + continue + } + + srcIP := parts[0] + srcPort, err := strconv.Atoi(parts[1]) + if err != nil { + continue + } + + dstParts := strings.Split(dst, ":") + if len(dstParts) != 2 { + continue + } + + dstIP := dstParts[0] + dstPort, err := strconv.Atoi(dstParts[1]) + if err != nil { + continue + } + + // 创建DNAT规则 + rule := ForwardRule{ + Type: DNAT, + SrcIP: "", + SrcPort: 0, + DstIP: srcIP, + DstPort: srcPort, + NewDstIP: dstIP, + NewDstPort: dstPort, + ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort), + } + + f.natRules = append(f.natRules, rule) + } + } + + return nil } // AddForwardRule 添加转发规则 func (f *Forwarder) AddForwardRule(rule ForwardRule) { - key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) - value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) - f.natTable[key] = value + // 生成规则ID(如果没有提供) + if rule.ID == "" { + rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d", + rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano()) + } + + // 添加到规则列表 + f.natRules = append(f.natRules, rule) + + // 兼容旧版本:如果是DNAT规则,也添加到natTable + if rule.Type == DNAT { + key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) + value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort) + f.natTable[key] = value + } else if rule.Type == "" { + // 兼容旧版本的规则格式 + key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) + value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) + f.natTable[key] = value + } + + log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID) } // RemoveForwardRule 移除转发规则 func (f *Forwarder) RemoveForwardRule(rule ForwardRule) { + // 兼容旧版本 key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) delete(f.natTable, key) + + // 从natRules中删除 + if rule.ID != "" { + for i, r := range f.natRules { + if r.ID == rule.ID { + f.natRules = append(f.natRules[:i], f.natRules[i+1:]...) + log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID) + break + } + } + } +} + +// RemoveForwardRuleByID 通过ID删除转发规则 +func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool { + for i, rule := range f.natRules { + if rule.ID == ruleID { + // 从规则列表中删除 + f.natRules = append(f.natRules[:i], f.natRules[i+1:]...) + + // 如果是DNAT规则,也从natTable中删除 + if rule.Type == DNAT { + key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) + delete(f.natTable, key) + } + + log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID) + return true + } + } + return false +} + +// ListNATRules 列出所有NAT规则 +func (f *Forwarder) ListNATRules() []ForwardRule { + return f.natRules +} + +// getConnectionKey 从数据包中获取连接键 +func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey { + var srcIP, dstIP string + var srcPort, dstPort uint16 + var protocol uint8 + + srcIP = ipLayer.SrcIP.String() + dstIP = ipLayer.DstIP.String() + protocol = uint8(ipLayer.Protocol) + + switch t := transportLayer.(type) { + case *layers.TCP: + srcPort = uint16(t.SrcPort) + dstPort = uint16(t.DstPort) + case *layers.UDP: + srcPort = uint16(t.SrcPort) + dstPort = uint16(t.DstPort) + } + + // 如果是回复包,交换源和目标 + if isReply { + return ConnectionKey{ + SrcIP: dstIP, + DstIP: srcIP, + SrcPort: dstPort, + DstPort: srcPort, + Protocol: protocol, + } + } + + return ConnectionKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + Protocol: protocol, + } +} + +// findMatchingRule 查找匹配的NAT规则 +func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule { + for i, rule := range f.natRules { + // 检查源IP匹配 + if rule.SrcIP != "" && rule.SrcIP != srcIP { + continue + } + + // 检查源端口匹配 + if rule.SrcPort != 0 && rule.SrcPort != srcPort { + continue + } + + // 检查目标IP匹配 + if rule.DstIP != "" && rule.DstIP != dstIP { + continue + } + + // 检查目标端口匹配 + if rule.DstPort != 0 && rule.DstPort != dstPort { + continue + } + + // 找到匹配的规则 + return &f.natRules[i] + } + + // 兼容旧版本:检查natTable + key := fmt.Sprintf("%s:%d", dstIP, dstPort) + if forwardAddr, exists := f.natTable[key]; exists { + // 解析转发目标地址 + parts := strings.Split(forwardAddr, ":") + if len(parts) != 2 { + return nil + } + + newDstIP := parts[0] + newDstPort, err := strconv.Atoi(parts[1]) + if err != nil { + return nil + } + + // 创建临时DNAT规则 + return &ForwardRule{ + Type: DNAT, + DstIP: dstIP, + DstPort: dstPort, + NewDstIP: newDstIP, + NewDstPort: newDstPort, + } + } + + return nil } // ForwardPacket 转发数据包 @@ -87,6 +458,7 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket. // 获取源IP和端口 srcIP := ipLayer.SrcIP.String() + dstIP := ipLayer.DstIP.String() var srcPort, dstPort int // 根据传输层协议获取端口 @@ -101,47 +473,160 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket. // 不支持的传输层协议 return nil } - - // 查找转发规则, 按照目标端口查找并转发。 - // 相当于做了一个端口映射。 - key := fmt.Sprintf(":%d", dstPort) // srcIP, srcPort) - if forwardAddr, exists := f.natTable[key]; exists { - // 解析转发目标地址 - addr, port, err := net.SplitHostPort(forwardAddr) - if err != nil { - return err + + // 创建连接键 + connKey := f.getConnectionKey(ipLayer, transportLayer, false) + + // 检查是否是已建立连接的回复包 + f.connTrackMutex.RLock() + connInfo, isReply := f.connTrackTable[connKey] + + // 如果不是回复包,尝试查找反向连接 + if !isReply { + reverseKey := f.getConnectionKey(ipLayer, transportLayer, true) + if info, found := f.connTrackTable[reverseKey]; found { + isReply = true + connInfo = info } - - // 更新IP层目标地址 - newDstIP := net.ParseIP(addr) - if newDstIP == nil { - return fmt.Errorf("invalid forward IP address: %s", addr) - } - ipLayer.DstIP = newDstIP - - // 更新传输层目标端口 - newDstPort, err := strconv.Atoi(port) - if err != nil { - return err - } - - switch t := transportLayer.(type) { - case *layers.TCP: - t.DstPort = layers.TCPPort(newDstPort) - case *layers.UDP: - t.DstPort = layers.UDPPort(newDstPort) - } - - // 重新计算校验和 - switch t := transportLayer.(type) { - case *layers.TCP: - t.SetNetworkLayerForChecksum(ipLayer) - case *layers.UDP: - t.SetNetworkLayerForChecksum(ipLayer) - } - - log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort) } - + f.connTrackMutex.RUnlock() + + // 如果是已建立连接的回复包,应用已有的NAT转换 + if isReply { + // 更新连接的最后活动时间 + f.connTrackMutex.Lock() + connInfo.LastSeen = time.Now() + f.connTrackTable[connKey] = connInfo + f.connTrackMutex.Unlock() + + // 应用已有的NAT转换 + if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP { + ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP) + } + + if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP { + ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP) + } + + switch t := transportLayer.(type) { + case *layers.TCP: + if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) { + t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort) + } + if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) { + t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort) + } + t.SetNetworkLayerForChecksum(ipLayer) + case *layers.UDP: + if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) { + t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort) + } + if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) { + t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort) + } + t.SetNetworkLayerForChecksum(ipLayer) + } + + return nil + } + + // 查找匹配的NAT规则 + rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort) + if rule == nil { + return nil + } + + // 创建连接跟踪信息 + newConnInfo := ConnectionInfo{ + OriginalSrcIP: srcIP, + OriginalDstIP: dstIP, + OriginalSrcPort: uint16(srcPort), + OriginalDstPort: uint16(dstPort), + TranslatedSrcIP: srcIP, + TranslatedDstIP: dstIP, + TranslatedSrcPort: uint16(srcPort), + TranslatedDstPort: uint16(dstPort), + LastSeen: time.Now(), + } + + // 应用NAT规则 + switch rule.Type { + case SNAT: + // 源地址转换 + if rule.NewSrcIP != "" { + ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP) + newConnInfo.TranslatedSrcIP = rule.NewSrcIP + } + if rule.NewSrcPort != 0 { + switch t := transportLayer.(type) { + case *layers.TCP: + t.SrcPort = layers.TCPPort(rule.NewSrcPort) + t.SetNetworkLayerForChecksum(ipLayer) + case *layers.UDP: + t.SrcPort = layers.UDPPort(rule.NewSrcPort) + t.SetNetworkLayerForChecksum(ipLayer) + } + newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort) + } + case DNAT: + // 目标地址转换 + if rule.NewDstIP != "" { + ipLayer.DstIP = net.ParseIP(rule.NewDstIP) + newConnInfo.TranslatedDstIP = rule.NewDstIP + } + if rule.NewDstPort != 0 { + switch t := transportLayer.(type) { + case *layers.TCP: + t.DstPort = layers.TCPPort(rule.NewDstPort) + t.SetNetworkLayerForChecksum(ipLayer) + case *layers.UDP: + t.DstPort = layers.UDPPort(rule.NewDstPort) + t.SetNetworkLayerForChecksum(ipLayer) + } + newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort) + } + case BINAT: + // 双向地址转换 + if rule.NewSrcIP != "" { + ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP) + newConnInfo.TranslatedSrcIP = rule.NewSrcIP + } + if rule.NewSrcPort != 0 { + switch t := transportLayer.(type) { + case *layers.TCP: + t.SrcPort = layers.TCPPort(rule.NewSrcPort) + case *layers.UDP: + t.SrcPort = layers.UDPPort(rule.NewSrcPort) + } + newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort) + } + if rule.NewDstIP != "" { + ipLayer.DstIP = net.ParseIP(rule.NewDstIP) + newConnInfo.TranslatedDstIP = rule.NewDstIP + } + if rule.NewDstPort != 0 { + switch t := transportLayer.(type) { + case *layers.TCP: + t.DstPort = layers.TCPPort(rule.NewDstPort) + t.SetNetworkLayerForChecksum(ipLayer) + case *layers.UDP: + t.DstPort = layers.UDPPort(rule.NewDstPort) + t.SetNetworkLayerForChecksum(ipLayer) + } + newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort) + } + } + + // 更新连接跟踪表 + f.connTrackMutex.Lock() + f.connTrackTable[connKey] = newConnInfo + f.connTrackMutex.Unlock() + + log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d", + rule.Type, + srcIP, srcPort, dstIP, dstPort, + newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort, + newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort) + return nil -} +} \ No newline at end of file