Compare commits
No commits in common. "a73a0514ba1b3d50fa450d110c8bda5a6aad0dad" and "f27002ddff260496ea8a6687d0aa13e4053607ca" have entirely different histories.
a73a0514ba
...
f27002ddff
562
forwarder.go
562
forwarder.go
|
|
@ -8,108 +8,23 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// NATType 定义NAT类型
|
||||
type NATType string
|
||||
|
||||
const (
|
||||
SNAT NATType = "SNAT" // 源地址转换
|
||||
DNAT NATType = "DNAT" // 目标地址转换
|
||||
BINAT NATType = "BINAT" // 双向地址转换
|
||||
)
|
||||
|
||||
// ConnectionKey 定义连接跟踪的键
|
||||
type ConnectionKey struct {
|
||||
SrcIP string
|
||||
DstIP string
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
Protocol uint8
|
||||
}
|
||||
|
||||
// ConnectionInfo 定义连接跟踪的信息
|
||||
type ConnectionInfo struct {
|
||||
OriginalSrcIP string
|
||||
OriginalDstIP string
|
||||
OriginalSrcPort uint16
|
||||
OriginalDstPort uint16
|
||||
TranslatedSrcIP string
|
||||
TranslatedDstIP string
|
||||
TranslatedSrcPort uint16
|
||||
TranslatedDstPort uint16
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Forwarder 流量转发器
|
||||
type Forwarder struct {
|
||||
enabled bool
|
||||
natRules []ForwardRule // NAT规则列表
|
||||
natTable map[string]string // 兼容旧版本的NAT映射表
|
||||
connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表
|
||||
connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁
|
||||
cleanupTicker *time.Ticker // 定期清理过期连接
|
||||
connTimeout time.Duration // 连接超时时间
|
||||
natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口
|
||||
}
|
||||
|
||||
// NewForwarder 创建新的流量转发器
|
||||
func NewForwarder() *Forwarder {
|
||||
f := &Forwarder{
|
||||
return &Forwarder{
|
||||
enabled: false,
|
||||
natRules: []ForwardRule{},
|
||||
natTable: make(map[string]string),
|
||||
connTrackTable: make(map[ConnectionKey]ConnectionInfo),
|
||||
connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟
|
||||
}
|
||||
|
||||
// 启动定期清理过期连接的定时器
|
||||
f.cleanupTicker = time.NewTicker(1 * time.Minute)
|
||||
go f.cleanupExpiredConnections()
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// cleanupExpiredConnections 清理过期的连接
|
||||
func (f *Forwarder) cleanupExpiredConnections() {
|
||||
for range f.cleanupTicker.C {
|
||||
now := time.Now()
|
||||
expiredKeys := []ConnectionKey{}
|
||||
|
||||
// 查找过期的连接
|
||||
f.connTrackMutex.RLock()
|
||||
for key, info := range f.connTrackTable {
|
||||
if now.Sub(info.LastSeen) > f.connTimeout {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
}
|
||||
f.connTrackMutex.RUnlock()
|
||||
|
||||
// 删除过期的连接
|
||||
if len(expiredKeys) > 0 {
|
||||
f.connTrackMutex.Lock()
|
||||
for _, key := range expiredKeys {
|
||||
delete(f.connTrackTable, key)
|
||||
log.Printf("Removed expired connection: %s:%d -> %s:%d",
|
||||
key.SrcIP, key.SrcPort, key.DstIP, key.DstPort)
|
||||
}
|
||||
f.connTrackMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enable 启用转发
|
||||
func (f *Forwarder) Enable() {
|
||||
f.enabled = true
|
||||
}
|
||||
|
||||
// Disable 禁用转发
|
||||
func (f *Forwarder) Disable() {
|
||||
f.enabled = false
|
||||
}
|
||||
|
||||
// Start 启动转发服务
|
||||
|
|
@ -125,62 +40,10 @@ func (f *Forwarder) Stop() {
|
|||
log.Println("Forwarding service stopped")
|
||||
}
|
||||
|
||||
// Close 关闭转发器,停止清理定时器
|
||||
func (f *Forwarder) Close() {
|
||||
if f.cleanupTicker != nil {
|
||||
f.cleanupTicker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionStats 获取连接统计信息
|
||||
func (f *Forwarder) GetConnectionStats() map[string]interface{} {
|
||||
f.connTrackMutex.RLock()
|
||||
defer f.connTrackMutex.RUnlock()
|
||||
|
||||
stats := make(map[string]interface{})
|
||||
stats["total_connections"] = len(f.connTrackTable)
|
||||
|
||||
// 按协议统计连接数
|
||||
protocolStats := make(map[uint8]int)
|
||||
for key := range f.connTrackTable {
|
||||
protocolStats[key.Protocol]++
|
||||
}
|
||||
stats["protocol_stats"] = protocolStats
|
||||
|
||||
// 获取最近活动的连接
|
||||
type connSummary struct {
|
||||
SrcIP string `json:"src_ip"`
|
||||
SrcPort uint16 `json:"src_port"`
|
||||
DstIP string `json:"dst_ip"`
|
||||
DstPort uint16 `json:"dst_port"`
|
||||
Protocol uint8 `json:"protocol"`
|
||||
LastSeen string `json:"last_seen"`
|
||||
}
|
||||
|
||||
recentConns := []connSummary{}
|
||||
count := 0
|
||||
for key, info := range f.connTrackTable {
|
||||
if count >= 10 {
|
||||
break
|
||||
}
|
||||
recentConns = append(recentConns, connSummary{
|
||||
SrcIP: key.SrcIP,
|
||||
SrcPort: key.SrcPort,
|
||||
DstIP: key.DstIP,
|
||||
DstPort: key.DstPort,
|
||||
Protocol: key.Protocol,
|
||||
LastSeen: info.LastSeen.Format(time.RFC3339),
|
||||
})
|
||||
count++
|
||||
}
|
||||
stats["recent_connections"] = recentConns
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||||
content, err := os.ReadFile(ruleFile)
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
}
|
||||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||
|
|
@ -197,257 +60,23 @@ func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
|||
|
||||
// ForwardRule 定义转发规则结构
|
||||
type ForwardRule struct {
|
||||
Type NATType `json:"type"` // NAT类型:SNAT, DNAT, BINAT
|
||||
SrcIP string `json:"src_ip"` // 源IP
|
||||
SrcPort int `json:"src_port"` // 源端口
|
||||
DstIP string `json:"dst_ip"` // 目标IP
|
||||
DstPort int `json:"dst_port"` // 目标端口
|
||||
NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT)
|
||||
NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT)
|
||||
NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT)
|
||||
NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT)
|
||||
ID string `json:"id"` // 规则ID,用于动态删除
|
||||
}
|
||||
|
||||
// LoadConfig 从配置文件加载转发规则
|
||||
func (f *Forwarder) LoadConfig(configFile string) error {
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var config struct {
|
||||
ForwardEnabled bool `json:"forward_enabled"`
|
||||
NATRules []ForwardRule `json:"nat_rules"`
|
||||
ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.enabled = config.ForwardEnabled
|
||||
|
||||
// 加载新版本NAT规则
|
||||
if len(config.NATRules) > 0 {
|
||||
f.natRules = config.NATRules
|
||||
log.Printf("Loaded %d NAT rules", len(f.natRules))
|
||||
for i, rule := range f.natRules {
|
||||
log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d",
|
||||
i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort)
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容旧版本的转发规则
|
||||
if len(config.ForwardRules) > 0 {
|
||||
f.natTable = config.ForwardRules
|
||||
log.Printf("Loaded %d legacy forward rules", len(f.natTable))
|
||||
|
||||
// 将旧版本规则转换为新版本规则
|
||||
for src, dst := range f.natTable {
|
||||
parts := strings.Split(src, ":")
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
srcIP := parts[0]
|
||||
srcPort, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
dstParts := strings.Split(dst, ":")
|
||||
if len(dstParts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
dstIP := dstParts[0]
|
||||
dstPort, err := strconv.Atoi(dstParts[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建DNAT规则
|
||||
rule := ForwardRule{
|
||||
Type: DNAT,
|
||||
SrcIP: "",
|
||||
SrcPort: 0,
|
||||
DstIP: srcIP,
|
||||
DstPort: srcPort,
|
||||
NewDstIP: dstIP,
|
||||
NewDstPort: dstPort,
|
||||
ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort),
|
||||
}
|
||||
|
||||
f.natRules = append(f.natRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
SrcIP string // 源IP
|
||||
SrcPort int // 源端口
|
||||
DstIP string // 目标IP
|
||||
DstPort int // 目标端口
|
||||
}
|
||||
|
||||
// AddForwardRule 添加转发规则
|
||||
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
|
||||
// 生成规则ID(如果没有提供)
|
||||
if rule.ID == "" {
|
||||
rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d",
|
||||
rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// 添加到规则列表
|
||||
f.natRules = append(f.natRules, rule)
|
||||
|
||||
// 兼容旧版本:如果是DNAT规则,也添加到natTable
|
||||
if rule.Type == DNAT {
|
||||
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||
value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort)
|
||||
f.natTable[key] = value
|
||||
} else if rule.Type == "" {
|
||||
// 兼容旧版本的规则格式
|
||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||
f.natTable[key] = value
|
||||
}
|
||||
|
||||
log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID)
|
||||
}
|
||||
|
||||
// RemoveForwardRule 移除转发规则
|
||||
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
|
||||
// 兼容旧版本
|
||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||
delete(f.natTable, key)
|
||||
|
||||
// 从natRules中删除
|
||||
if rule.ID != "" {
|
||||
for i, r := range f.natRules {
|
||||
if r.ID == rule.ID {
|
||||
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
|
||||
log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveForwardRuleByID 通过ID删除转发规则
|
||||
func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool {
|
||||
for i, rule := range f.natRules {
|
||||
if rule.ID == ruleID {
|
||||
// 从规则列表中删除
|
||||
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
|
||||
|
||||
// 如果是DNAT规则,也从natTable中删除
|
||||
if rule.Type == DNAT {
|
||||
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||
delete(f.natTable, key)
|
||||
}
|
||||
|
||||
log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ListNATRules 列出所有NAT规则
|
||||
func (f *Forwarder) ListNATRules() []ForwardRule {
|
||||
return f.natRules
|
||||
}
|
||||
|
||||
// getConnectionKey 从数据包中获取连接键
|
||||
func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey {
|
||||
var srcIP, dstIP string
|
||||
var srcPort, dstPort uint16
|
||||
var protocol uint8
|
||||
|
||||
srcIP = ipLayer.SrcIP.String()
|
||||
dstIP = ipLayer.DstIP.String()
|
||||
protocol = uint8(ipLayer.Protocol)
|
||||
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
srcPort = uint16(t.SrcPort)
|
||||
dstPort = uint16(t.DstPort)
|
||||
case *layers.UDP:
|
||||
srcPort = uint16(t.SrcPort)
|
||||
dstPort = uint16(t.DstPort)
|
||||
}
|
||||
|
||||
// 如果是回复包,交换源和目标
|
||||
if isReply {
|
||||
return ConnectionKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
Protocol: protocol,
|
||||
}
|
||||
}
|
||||
|
||||
return ConnectionKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Protocol: protocol,
|
||||
}
|
||||
}
|
||||
|
||||
// findMatchingRule 查找匹配的NAT规则
|
||||
func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule {
|
||||
for i, rule := range f.natRules {
|
||||
// 检查源IP匹配
|
||||
if rule.SrcIP != "" && rule.SrcIP != srcIP {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查源端口匹配
|
||||
if rule.SrcPort != 0 && rule.SrcPort != srcPort {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目标IP匹配
|
||||
if rule.DstIP != "" && rule.DstIP != dstIP {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目标端口匹配
|
||||
if rule.DstPort != 0 && rule.DstPort != dstPort {
|
||||
continue
|
||||
}
|
||||
|
||||
// 找到匹配的规则
|
||||
return &f.natRules[i]
|
||||
}
|
||||
|
||||
// 兼容旧版本:检查natTable
|
||||
key := fmt.Sprintf("%s:%d", dstIP, dstPort)
|
||||
if forwardAddr, exists := f.natTable[key]; exists {
|
||||
// 解析转发目标地址
|
||||
parts := strings.Split(forwardAddr, ":")
|
||||
if len(parts) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
newDstIP := parts[0]
|
||||
newDstPort, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建临时DNAT规则
|
||||
return &ForwardRule{
|
||||
Type: DNAT,
|
||||
DstIP: dstIP,
|
||||
DstPort: dstPort,
|
||||
NewDstIP: newDstIP,
|
||||
NewDstPort: newDstPort,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardPacket 转发数据包
|
||||
|
|
@ -458,175 +87,60 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.
|
|||
|
||||
// 获取源IP和端口
|
||||
srcIP := ipLayer.SrcIP.String()
|
||||
dstIP := ipLayer.DstIP.String()
|
||||
var srcPort, dstPort int
|
||||
var srcPort int
|
||||
|
||||
// 根据传输层协议获取端口
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
srcPort = int(t.SrcPort)
|
||||
dstPort = int(t.DstPort)
|
||||
// dstPort = int(t.DstPort)
|
||||
case *layers.UDP:
|
||||
srcPort = int(t.SrcPort)
|
||||
dstPort = int(t.DstPort)
|
||||
// dstPort = int(t.DstPort)
|
||||
default:
|
||||
// 不支持的传输层协议
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建连接键
|
||||
connKey := f.getConnectionKey(ipLayer, transportLayer, false)
|
||||
|
||||
// 检查是否是已建立连接的回复包
|
||||
f.connTrackMutex.RLock()
|
||||
connInfo, isReply := f.connTrackTable[connKey]
|
||||
|
||||
// 如果不是回复包,尝试查找反向连接
|
||||
if !isReply {
|
||||
reverseKey := f.getConnectionKey(ipLayer, transportLayer, true)
|
||||
if info, found := f.connTrackTable[reverseKey]; found {
|
||||
isReply = true
|
||||
connInfo = info
|
||||
}
|
||||
}
|
||||
f.connTrackMutex.RUnlock()
|
||||
|
||||
// 如果是已建立连接的回复包,应用已有的NAT转换
|
||||
if isReply {
|
||||
// 更新连接的最后活动时间
|
||||
f.connTrackMutex.Lock()
|
||||
connInfo.LastSeen = time.Now()
|
||||
f.connTrackTable[connKey] = connInfo
|
||||
f.connTrackMutex.Unlock()
|
||||
|
||||
// 应用已有的NAT转换
|
||||
if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP {
|
||||
ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP)
|
||||
// 查找转发规则
|
||||
key := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||
if forwardAddr, exists := f.natTable[key]; exists {
|
||||
// 解析转发目标地址
|
||||
addr, port, err := net.SplitHostPort(forwardAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP {
|
||||
ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP)
|
||||
// 更新IP层目标地址
|
||||
newDstIP := net.ParseIP(addr)
|
||||
if newDstIP == nil {
|
||||
return fmt.Errorf("invalid forward IP address: %s", addr)
|
||||
}
|
||||
ipLayer.DstIP = newDstIP
|
||||
|
||||
// 更新传输层目标端口
|
||||
newDstPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
|
||||
t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort)
|
||||
}
|
||||
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
|
||||
t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort)
|
||||
t.DstPort = layers.TCPPort(newDstPort)
|
||||
case *layers.UDP:
|
||||
t.DstPort = layers.UDPPort(newDstPort)
|
||||
}
|
||||
|
||||
// 重新计算校验和
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
case *layers.UDP:
|
||||
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
|
||||
t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort)
|
||||
}
|
||||
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
|
||||
t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort)
|
||||
}
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
}
|
||||
|
||||
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查找匹配的NAT规则
|
||||
rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort)
|
||||
if rule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建连接跟踪信息
|
||||
newConnInfo := ConnectionInfo{
|
||||
OriginalSrcIP: srcIP,
|
||||
OriginalDstIP: dstIP,
|
||||
OriginalSrcPort: uint16(srcPort),
|
||||
OriginalDstPort: uint16(dstPort),
|
||||
TranslatedSrcIP: srcIP,
|
||||
TranslatedDstIP: dstIP,
|
||||
TranslatedSrcPort: uint16(srcPort),
|
||||
TranslatedDstPort: uint16(dstPort),
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
// 应用NAT规则
|
||||
switch rule.Type {
|
||||
case SNAT:
|
||||
// 源地址转换
|
||||
if rule.NewSrcIP != "" {
|
||||
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
|
||||
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
|
||||
}
|
||||
if rule.NewSrcPort != 0 {
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
case *layers.UDP:
|
||||
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
}
|
||||
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
|
||||
}
|
||||
case DNAT:
|
||||
// 目标地址转换
|
||||
if rule.NewDstIP != "" {
|
||||
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
|
||||
newConnInfo.TranslatedDstIP = rule.NewDstIP
|
||||
}
|
||||
if rule.NewDstPort != 0 {
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.DstPort = layers.TCPPort(rule.NewDstPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
case *layers.UDP:
|
||||
t.DstPort = layers.UDPPort(rule.NewDstPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
}
|
||||
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
|
||||
}
|
||||
case BINAT:
|
||||
// 双向地址转换
|
||||
if rule.NewSrcIP != "" {
|
||||
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
|
||||
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
|
||||
}
|
||||
if rule.NewSrcPort != 0 {
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
|
||||
case *layers.UDP:
|
||||
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
|
||||
}
|
||||
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
|
||||
}
|
||||
if rule.NewDstIP != "" {
|
||||
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
|
||||
newConnInfo.TranslatedDstIP = rule.NewDstIP
|
||||
}
|
||||
if rule.NewDstPort != 0 {
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.DstPort = layers.TCPPort(rule.NewDstPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
case *layers.UDP:
|
||||
t.DstPort = layers.UDPPort(rule.NewDstPort)
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
}
|
||||
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新连接跟踪表
|
||||
f.connTrackMutex.Lock()
|
||||
f.connTrackTable[connKey] = newConnInfo
|
||||
f.connTrackMutex.Unlock()
|
||||
|
||||
log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d",
|
||||
rule.Type,
|
||||
srcIP, srcPort, dstIP, dstPort,
|
||||
newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort,
|
||||
newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in New Issue