feat(forwarder): 增强NAT功能,添加连接跟踪和多种NAT类型支持
This commit is contained in:
parent
7013bd61f1
commit
a73a0514ba
557
forwarder.go
557
forwarder.go
|
|
@ -8,23 +8,108 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NATType 定义NAT类型
|
||||||
|
type NATType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SNAT NATType = "SNAT" // 源地址转换
|
||||||
|
DNAT NATType = "DNAT" // 目标地址转换
|
||||||
|
BINAT NATType = "BINAT" // 双向地址转换
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionKey 定义连接跟踪的键
|
||||||
|
type ConnectionKey struct {
|
||||||
|
SrcIP string
|
||||||
|
DstIP string
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
Protocol uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionInfo 定义连接跟踪的信息
|
||||||
|
type ConnectionInfo struct {
|
||||||
|
OriginalSrcIP string
|
||||||
|
OriginalDstIP string
|
||||||
|
OriginalSrcPort uint16
|
||||||
|
OriginalDstPort uint16
|
||||||
|
TranslatedSrcIP string
|
||||||
|
TranslatedDstIP string
|
||||||
|
TranslatedSrcPort uint16
|
||||||
|
TranslatedDstPort uint16
|
||||||
|
LastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// Forwarder 流量转发器
|
// Forwarder 流量转发器
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
enabled bool
|
enabled bool
|
||||||
natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口
|
natRules []ForwardRule // NAT规则列表
|
||||||
|
natTable map[string]string // 兼容旧版本的NAT映射表
|
||||||
|
connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表
|
||||||
|
connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁
|
||||||
|
cleanupTicker *time.Ticker // 定期清理过期连接
|
||||||
|
connTimeout time.Duration // 连接超时时间
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewForwarder 创建新的流量转发器
|
// NewForwarder 创建新的流量转发器
|
||||||
func NewForwarder() *Forwarder {
|
func NewForwarder() *Forwarder {
|
||||||
return &Forwarder{
|
f := &Forwarder{
|
||||||
enabled: false,
|
enabled: false,
|
||||||
|
natRules: []ForwardRule{},
|
||||||
natTable: make(map[string]string),
|
natTable: make(map[string]string),
|
||||||
|
connTrackTable: make(map[ConnectionKey]ConnectionInfo),
|
||||||
|
connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 启动定期清理过期连接的定时器
|
||||||
|
f.cleanupTicker = time.NewTicker(1 * time.Minute)
|
||||||
|
go f.cleanupExpiredConnections()
|
||||||
|
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpiredConnections 清理过期的连接
|
||||||
|
func (f *Forwarder) cleanupExpiredConnections() {
|
||||||
|
for range f.cleanupTicker.C {
|
||||||
|
now := time.Now()
|
||||||
|
expiredKeys := []ConnectionKey{}
|
||||||
|
|
||||||
|
// 查找过期的连接
|
||||||
|
f.connTrackMutex.RLock()
|
||||||
|
for key, info := range f.connTrackTable {
|
||||||
|
if now.Sub(info.LastSeen) > f.connTimeout {
|
||||||
|
expiredKeys = append(expiredKeys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.connTrackMutex.RUnlock()
|
||||||
|
|
||||||
|
// 删除过期的连接
|
||||||
|
if len(expiredKeys) > 0 {
|
||||||
|
f.connTrackMutex.Lock()
|
||||||
|
for _, key := range expiredKeys {
|
||||||
|
delete(f.connTrackTable, key)
|
||||||
|
log.Printf("Removed expired connection: %s:%d -> %s:%d",
|
||||||
|
key.SrcIP, key.SrcPort, key.DstIP, key.DstPort)
|
||||||
|
}
|
||||||
|
f.connTrackMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 启用转发
|
||||||
|
func (f *Forwarder) Enable() {
|
||||||
|
f.enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable 禁用转发
|
||||||
|
func (f *Forwarder) Disable() {
|
||||||
|
f.enabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动转发服务
|
// Start 启动转发服务
|
||||||
|
|
@ -40,10 +125,62 @@ func (f *Forwarder) Stop() {
|
||||||
log.Println("Forwarding service stopped")
|
log.Println("Forwarding service stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close 关闭转发器,停止清理定时器
|
||||||
|
func (f *Forwarder) Close() {
|
||||||
|
if f.cleanupTicker != nil {
|
||||||
|
f.cleanupTicker.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionStats 获取连接统计信息
|
||||||
|
func (f *Forwarder) GetConnectionStats() map[string]interface{} {
|
||||||
|
f.connTrackMutex.RLock()
|
||||||
|
defer f.connTrackMutex.RUnlock()
|
||||||
|
|
||||||
|
stats := make(map[string]interface{})
|
||||||
|
stats["total_connections"] = len(f.connTrackTable)
|
||||||
|
|
||||||
|
// 按协议统计连接数
|
||||||
|
protocolStats := make(map[uint8]int)
|
||||||
|
for key := range f.connTrackTable {
|
||||||
|
protocolStats[key.Protocol]++
|
||||||
|
}
|
||||||
|
stats["protocol_stats"] = protocolStats
|
||||||
|
|
||||||
|
// 获取最近活动的连接
|
||||||
|
type connSummary struct {
|
||||||
|
SrcIP string `json:"src_ip"`
|
||||||
|
SrcPort uint16 `json:"src_port"`
|
||||||
|
DstIP string `json:"dst_ip"`
|
||||||
|
DstPort uint16 `json:"dst_port"`
|
||||||
|
Protocol uint8 `json:"protocol"`
|
||||||
|
LastSeen string `json:"last_seen"`
|
||||||
|
}
|
||||||
|
|
||||||
|
recentConns := []connSummary{}
|
||||||
|
count := 0
|
||||||
|
for key, info := range f.connTrackTable {
|
||||||
|
if count >= 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
recentConns = append(recentConns, connSummary{
|
||||||
|
SrcIP: key.SrcIP,
|
||||||
|
SrcPort: key.SrcPort,
|
||||||
|
DstIP: key.DstIP,
|
||||||
|
DstPort: key.DstPort,
|
||||||
|
Protocol: key.Protocol,
|
||||||
|
LastSeen: info.LastSeen.Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
stats["recent_connections"] = recentConns
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||||||
content, err := os.ReadFile(ruleFile)
|
content, err := os.ReadFile(ruleFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||||
|
|
@ -60,23 +197,257 @@ func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||||||
|
|
||||||
// ForwardRule 定义转发规则结构
|
// ForwardRule 定义转发规则结构
|
||||||
type ForwardRule struct {
|
type ForwardRule struct {
|
||||||
SrcIP string // 源IP
|
Type NATType `json:"type"` // NAT类型:SNAT, DNAT, BINAT
|
||||||
SrcPort int // 源端口
|
SrcIP string `json:"src_ip"` // 源IP
|
||||||
DstIP string // 目标IP
|
SrcPort int `json:"src_port"` // 源端口
|
||||||
DstPort int // 目标端口
|
DstIP string `json:"dst_ip"` // 目标IP
|
||||||
|
DstPort int `json:"dst_port"` // 目标端口
|
||||||
|
NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT)
|
||||||
|
NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT)
|
||||||
|
NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT)
|
||||||
|
NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT)
|
||||||
|
ID string `json:"id"` // 规则ID,用于动态删除
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig 从配置文件加载转发规则
|
||||||
|
func (f *Forwarder) LoadConfig(configFile string) error {
|
||||||
|
data, err := os.ReadFile(configFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config struct {
|
||||||
|
ForwardEnabled bool `json:"forward_enabled"`
|
||||||
|
NATRules []ForwardRule `json:"nat_rules"`
|
||||||
|
ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(data, &config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.enabled = config.ForwardEnabled
|
||||||
|
|
||||||
|
// 加载新版本NAT规则
|
||||||
|
if len(config.NATRules) > 0 {
|
||||||
|
f.natRules = config.NATRules
|
||||||
|
log.Printf("Loaded %d NAT rules", len(f.natRules))
|
||||||
|
for i, rule := range f.natRules {
|
||||||
|
log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d",
|
||||||
|
i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兼容旧版本的转发规则
|
||||||
|
if len(config.ForwardRules) > 0 {
|
||||||
|
f.natTable = config.ForwardRules
|
||||||
|
log.Printf("Loaded %d legacy forward rules", len(f.natTable))
|
||||||
|
|
||||||
|
// 将旧版本规则转换为新版本规则
|
||||||
|
for src, dst := range f.natTable {
|
||||||
|
parts := strings.Split(src, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := parts[0]
|
||||||
|
srcPort, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dstParts := strings.Split(dst, ":")
|
||||||
|
if len(dstParts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := dstParts[0]
|
||||||
|
dstPort, err := strconv.Atoi(dstParts[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建DNAT规则
|
||||||
|
rule := ForwardRule{
|
||||||
|
Type: DNAT,
|
||||||
|
SrcIP: "",
|
||||||
|
SrcPort: 0,
|
||||||
|
DstIP: srcIP,
|
||||||
|
DstPort: srcPort,
|
||||||
|
NewDstIP: dstIP,
|
||||||
|
NewDstPort: dstPort,
|
||||||
|
ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
f.natRules = append(f.natRules, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddForwardRule 添加转发规则
|
// AddForwardRule 添加转发规则
|
||||||
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
|
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
|
||||||
|
// 生成规则ID(如果没有提供)
|
||||||
|
if rule.ID == "" {
|
||||||
|
rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d",
|
||||||
|
rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到规则列表
|
||||||
|
f.natRules = append(f.natRules, rule)
|
||||||
|
|
||||||
|
// 兼容旧版本:如果是DNAT规则,也添加到natTable
|
||||||
|
if rule.Type == DNAT {
|
||||||
|
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||||
|
value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort)
|
||||||
|
f.natTable[key] = value
|
||||||
|
} else if rule.Type == "" {
|
||||||
|
// 兼容旧版本的规则格式
|
||||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||||
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||||
f.natTable[key] = value
|
f.natTable[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveForwardRule 移除转发规则
|
// RemoveForwardRule 移除转发规则
|
||||||
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
|
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
|
||||||
|
// 兼容旧版本
|
||||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||||
delete(f.natTable, key)
|
delete(f.natTable, key)
|
||||||
|
|
||||||
|
// 从natRules中删除
|
||||||
|
if rule.ID != "" {
|
||||||
|
for i, r := range f.natRules {
|
||||||
|
if r.ID == rule.ID {
|
||||||
|
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
|
||||||
|
log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveForwardRuleByID 通过ID删除转发规则
|
||||||
|
func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool {
|
||||||
|
for i, rule := range f.natRules {
|
||||||
|
if rule.ID == ruleID {
|
||||||
|
// 从规则列表中删除
|
||||||
|
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
|
||||||
|
|
||||||
|
// 如果是DNAT规则,也从natTable中删除
|
||||||
|
if rule.Type == DNAT {
|
||||||
|
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||||
|
delete(f.natTable, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNATRules 列出所有NAT规则
|
||||||
|
func (f *Forwarder) ListNATRules() []ForwardRule {
|
||||||
|
return f.natRules
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnectionKey 从数据包中获取连接键
|
||||||
|
func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey {
|
||||||
|
var srcIP, dstIP string
|
||||||
|
var srcPort, dstPort uint16
|
||||||
|
var protocol uint8
|
||||||
|
|
||||||
|
srcIP = ipLayer.SrcIP.String()
|
||||||
|
dstIP = ipLayer.DstIP.String()
|
||||||
|
protocol = uint8(ipLayer.Protocol)
|
||||||
|
|
||||||
|
switch t := transportLayer.(type) {
|
||||||
|
case *layers.TCP:
|
||||||
|
srcPort = uint16(t.SrcPort)
|
||||||
|
dstPort = uint16(t.DstPort)
|
||||||
|
case *layers.UDP:
|
||||||
|
srcPort = uint16(t.SrcPort)
|
||||||
|
dstPort = uint16(t.DstPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是回复包,交换源和目标
|
||||||
|
if isReply {
|
||||||
|
return ConnectionKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
Protocol: protocol,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ConnectionKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Protocol: protocol,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingRule 查找匹配的NAT规则
|
||||||
|
func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule {
|
||||||
|
for i, rule := range f.natRules {
|
||||||
|
// 检查源IP匹配
|
||||||
|
if rule.SrcIP != "" && rule.SrcIP != srcIP {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查源端口匹配
|
||||||
|
if rule.SrcPort != 0 && rule.SrcPort != srcPort {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标IP匹配
|
||||||
|
if rule.DstIP != "" && rule.DstIP != dstIP {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标端口匹配
|
||||||
|
if rule.DstPort != 0 && rule.DstPort != dstPort {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找到匹配的规则
|
||||||
|
return &f.natRules[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兼容旧版本:检查natTable
|
||||||
|
key := fmt.Sprintf("%s:%d", dstIP, dstPort)
|
||||||
|
if forwardAddr, exists := f.natTable[key]; exists {
|
||||||
|
// 解析转发目标地址
|
||||||
|
parts := strings.Split(forwardAddr, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newDstIP := parts[0]
|
||||||
|
newDstPort, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建临时DNAT规则
|
||||||
|
return &ForwardRule{
|
||||||
|
Type: DNAT,
|
||||||
|
DstIP: dstIP,
|
||||||
|
DstPort: dstPort,
|
||||||
|
NewDstIP: newDstIP,
|
||||||
|
NewDstPort: newDstPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardPacket 转发数据包
|
// ForwardPacket 转发数据包
|
||||||
|
|
@ -87,6 +458,7 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.
|
||||||
|
|
||||||
// 获取源IP和端口
|
// 获取源IP和端口
|
||||||
srcIP := ipLayer.SrcIP.String()
|
srcIP := ipLayer.SrcIP.String()
|
||||||
|
dstIP := ipLayer.DstIP.String()
|
||||||
var srcPort, dstPort int
|
var srcPort, dstPort int
|
||||||
|
|
||||||
// 根据传输层协议获取端口
|
// 根据传输层协议获取端口
|
||||||
|
|
@ -102,46 +474,159 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查找转发规则, 按照目标端口查找并转发。
|
// 创建连接键
|
||||||
// 相当于做了一个端口映射。
|
connKey := f.getConnectionKey(ipLayer, transportLayer, false)
|
||||||
key := fmt.Sprintf(":%d", dstPort) // srcIP, srcPort)
|
|
||||||
if forwardAddr, exists := f.natTable[key]; exists {
|
// 检查是否是已建立连接的回复包
|
||||||
// 解析转发目标地址
|
f.connTrackMutex.RLock()
|
||||||
addr, port, err := net.SplitHostPort(forwardAddr)
|
connInfo, isReply := f.connTrackTable[connKey]
|
||||||
if err != nil {
|
|
||||||
return err
|
// 如果不是回复包,尝试查找反向连接
|
||||||
|
if !isReply {
|
||||||
|
reverseKey := f.getConnectionKey(ipLayer, transportLayer, true)
|
||||||
|
if info, found := f.connTrackTable[reverseKey]; found {
|
||||||
|
isReply = true
|
||||||
|
connInfo = info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.connTrackMutex.RUnlock()
|
||||||
|
|
||||||
|
// 如果是已建立连接的回复包,应用已有的NAT转换
|
||||||
|
if isReply {
|
||||||
|
// 更新连接的最后活动时间
|
||||||
|
f.connTrackMutex.Lock()
|
||||||
|
connInfo.LastSeen = time.Now()
|
||||||
|
f.connTrackTable[connKey] = connInfo
|
||||||
|
f.connTrackMutex.Unlock()
|
||||||
|
|
||||||
|
// 应用已有的NAT转换
|
||||||
|
if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP {
|
||||||
|
ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新IP层目标地址
|
if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP {
|
||||||
newDstIP := net.ParseIP(addr)
|
ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP)
|
||||||
if newDstIP == nil {
|
|
||||||
return fmt.Errorf("invalid forward IP address: %s", addr)
|
|
||||||
}
|
|
||||||
ipLayer.DstIP = newDstIP
|
|
||||||
|
|
||||||
// 更新传输层目标端口
|
|
||||||
newDstPort, err := strconv.Atoi(port)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch t := transportLayer.(type) {
|
switch t := transportLayer.(type) {
|
||||||
case *layers.TCP:
|
case *layers.TCP:
|
||||||
t.DstPort = layers.TCPPort(newDstPort)
|
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
|
||||||
case *layers.UDP:
|
t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort)
|
||||||
t.DstPort = layers.UDPPort(newDstPort)
|
}
|
||||||
|
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
|
||||||
|
t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重新计算校验和
|
|
||||||
switch t := transportLayer.(type) {
|
|
||||||
case *layers.TCP:
|
|
||||||
t.SetNetworkLayerForChecksum(ipLayer)
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
case *layers.UDP:
|
case *layers.UDP:
|
||||||
t.SetNetworkLayerForChecksum(ipLayer)
|
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
|
||||||
|
t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort)
|
||||||
}
|
}
|
||||||
|
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
|
||||||
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
|
t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort)
|
||||||
|
}
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找匹配的NAT规则
|
||||||
|
rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
if rule == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建连接跟踪信息
|
||||||
|
newConnInfo := ConnectionInfo{
|
||||||
|
OriginalSrcIP: srcIP,
|
||||||
|
OriginalDstIP: dstIP,
|
||||||
|
OriginalSrcPort: uint16(srcPort),
|
||||||
|
OriginalDstPort: uint16(dstPort),
|
||||||
|
TranslatedSrcIP: srcIP,
|
||||||
|
TranslatedDstIP: dstIP,
|
||||||
|
TranslatedSrcPort: uint16(srcPort),
|
||||||
|
TranslatedDstPort: uint16(dstPort),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用NAT规则
|
||||||
|
switch rule.Type {
|
||||||
|
case SNAT:
|
||||||
|
// 源地址转换
|
||||||
|
if rule.NewSrcIP != "" {
|
||||||
|
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
|
||||||
|
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
|
||||||
|
}
|
||||||
|
if rule.NewSrcPort != 0 {
|
||||||
|
switch t := transportLayer.(type) {
|
||||||
|
case *layers.TCP:
|
||||||
|
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
case *layers.UDP:
|
||||||
|
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
}
|
||||||
|
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
|
||||||
|
}
|
||||||
|
case DNAT:
|
||||||
|
// 目标地址转换
|
||||||
|
if rule.NewDstIP != "" {
|
||||||
|
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
|
||||||
|
newConnInfo.TranslatedDstIP = rule.NewDstIP
|
||||||
|
}
|
||||||
|
if rule.NewDstPort != 0 {
|
||||||
|
switch t := transportLayer.(type) {
|
||||||
|
case *layers.TCP:
|
||||||
|
t.DstPort = layers.TCPPort(rule.NewDstPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
case *layers.UDP:
|
||||||
|
t.DstPort = layers.UDPPort(rule.NewDstPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
}
|
||||||
|
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
|
||||||
|
}
|
||||||
|
case BINAT:
|
||||||
|
// 双向地址转换
|
||||||
|
if rule.NewSrcIP != "" {
|
||||||
|
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
|
||||||
|
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
|
||||||
|
}
|
||||||
|
if rule.NewSrcPort != 0 {
|
||||||
|
switch t := transportLayer.(type) {
|
||||||
|
case *layers.TCP:
|
||||||
|
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
|
||||||
|
case *layers.UDP:
|
||||||
|
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
|
||||||
|
}
|
||||||
|
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
|
||||||
|
}
|
||||||
|
if rule.NewDstIP != "" {
|
||||||
|
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
|
||||||
|
newConnInfo.TranslatedDstIP = rule.NewDstIP
|
||||||
|
}
|
||||||
|
if rule.NewDstPort != 0 {
|
||||||
|
switch t := transportLayer.(type) {
|
||||||
|
case *layers.TCP:
|
||||||
|
t.DstPort = layers.TCPPort(rule.NewDstPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
case *layers.UDP:
|
||||||
|
t.DstPort = layers.UDPPort(rule.NewDstPort)
|
||||||
|
t.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
}
|
||||||
|
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新连接跟踪表
|
||||||
|
f.connTrackMutex.Lock()
|
||||||
|
f.connTrackTable[connKey] = newConnInfo
|
||||||
|
f.connTrackMutex.Unlock()
|
||||||
|
|
||||||
|
log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d",
|
||||||
|
rule.Type,
|
||||||
|
srcIP, srcPort, dstIP, dstPort,
|
||||||
|
newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort,
|
||||||
|
newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue