Compare commits

..

2 Commits

1 changed files with 542 additions and 56 deletions

View File

@ -8,23 +8,108 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// NATType 定义NAT类型
type NATType string
const (
SNAT NATType = "SNAT" // 源地址转换
DNAT NATType = "DNAT" // 目标地址转换
BINAT NATType = "BINAT" // 双向地址转换
)
// ConnectionKey 定义连接跟踪的键
type ConnectionKey struct {
SrcIP string
DstIP string
SrcPort uint16
DstPort uint16
Protocol uint8
}
// ConnectionInfo 定义连接跟踪的信息
type ConnectionInfo struct {
OriginalSrcIP string
OriginalDstIP string
OriginalSrcPort uint16
OriginalDstPort uint16
TranslatedSrcIP string
TranslatedDstIP string
TranslatedSrcPort uint16
TranslatedDstPort uint16
LastSeen time.Time
}
// Forwarder 流量转发器
type Forwarder struct {
enabled bool
natTable map[string]string // 简单的NAT映射表key: 原始地址:端口, value: 转发后地址:端口
enabled bool
natRules []ForwardRule // NAT规则列表
natTable map[string]string // 兼容旧版本的NAT映射表
connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表
connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁
cleanupTicker *time.Ticker // 定期清理过期连接
connTimeout time.Duration // 连接超时时间
}
// NewForwarder 创建新的流量转发器
func NewForwarder() *Forwarder {
return &Forwarder{
enabled: false,
natTable: make(map[string]string),
f := &Forwarder{
enabled: false,
natRules: []ForwardRule{},
natTable: make(map[string]string),
connTrackTable: make(map[ConnectionKey]ConnectionInfo),
connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟
}
// 启动定期清理过期连接的定时器
f.cleanupTicker = time.NewTicker(1 * time.Minute)
go f.cleanupExpiredConnections()
return f
}
// cleanupExpiredConnections 清理过期的连接
func (f *Forwarder) cleanupExpiredConnections() {
for range f.cleanupTicker.C {
now := time.Now()
expiredKeys := []ConnectionKey{}
// 查找过期的连接
f.connTrackMutex.RLock()
for key, info := range f.connTrackTable {
if now.Sub(info.LastSeen) > f.connTimeout {
expiredKeys = append(expiredKeys, key)
}
}
f.connTrackMutex.RUnlock()
// 删除过期的连接
if len(expiredKeys) > 0 {
f.connTrackMutex.Lock()
for _, key := range expiredKeys {
delete(f.connTrackTable, key)
log.Printf("Removed expired connection: %s:%d -> %s:%d",
key.SrcIP, key.SrcPort, key.DstIP, key.DstPort)
}
f.connTrackMutex.Unlock()
}
}
}
// Enable 启用转发
func (f *Forwarder) Enable() {
f.enabled = true
}
// Disable 禁用转发
func (f *Forwarder) Disable() {
f.enabled = false
}
// Start 启动转发服务
@ -40,10 +125,62 @@ func (f *Forwarder) Stop() {
log.Println("Forwarding service stopped")
}
// Close 关闭转发器,停止清理定时器
func (f *Forwarder) Close() {
if f.cleanupTicker != nil {
f.cleanupTicker.Stop()
}
}
// GetConnectionStats 获取连接统计信息
func (f *Forwarder) GetConnectionStats() map[string]interface{} {
f.connTrackMutex.RLock()
defer f.connTrackMutex.RUnlock()
stats := make(map[string]interface{})
stats["total_connections"] = len(f.connTrackTable)
// 按协议统计连接数
protocolStats := make(map[uint8]int)
for key := range f.connTrackTable {
protocolStats[key.Protocol]++
}
stats["protocol_stats"] = protocolStats
// 获取最近活动的连接
type connSummary struct {
SrcIP string `json:"src_ip"`
SrcPort uint16 `json:"src_port"`
DstIP string `json:"dst_ip"`
DstPort uint16 `json:"dst_port"`
Protocol uint8 `json:"protocol"`
LastSeen string `json:"last_seen"`
}
recentConns := []connSummary{}
count := 0
for key, info := range f.connTrackTable {
if count >= 10 {
break
}
recentConns = append(recentConns, connSummary{
SrcIP: key.SrcIP,
SrcPort: key.SrcPort,
DstIP: key.DstIP,
DstPort: key.DstPort,
Protocol: key.Protocol,
LastSeen: info.LastSeen.Format(time.RFC3339),
})
count++
}
stats["recent_connections"] = recentConns
return stats
}
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
content, err := os.ReadFile(ruleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
@ -60,23 +197,257 @@ func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
// ForwardRule 定义转发规则结构
type ForwardRule struct {
SrcIP string // 源IP
SrcPort int // 源端口
DstIP string // 目标IP
DstPort int // 目标端口
Type NATType `json:"type"` // NAT类型SNAT, DNAT, BINAT
SrcIP string `json:"src_ip"` // 源IP
SrcPort int `json:"src_port"` // 源端口
DstIP string `json:"dst_ip"` // 目标IP
DstPort int `json:"dst_port"` // 目标端口
NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT)
NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT)
NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT)
NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT)
ID string `json:"id"` // 规则ID用于动态删除
}
// LoadConfig 从配置文件加载转发规则
func (f *Forwarder) LoadConfig(configFile string) error {
data, err := os.ReadFile(configFile)
if err != nil {
return err
}
var config struct {
ForwardEnabled bool `json:"forward_enabled"`
NATRules []ForwardRule `json:"nat_rules"`
ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本
}
err = json.Unmarshal(data, &config)
if err != nil {
return err
}
f.enabled = config.ForwardEnabled
// 加载新版本NAT规则
if len(config.NATRules) > 0 {
f.natRules = config.NATRules
log.Printf("Loaded %d NAT rules", len(f.natRules))
for i, rule := range f.natRules {
log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d",
i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort)
}
}
// 兼容旧版本的转发规则
if len(config.ForwardRules) > 0 {
f.natTable = config.ForwardRules
log.Printf("Loaded %d legacy forward rules", len(f.natTable))
// 将旧版本规则转换为新版本规则
for src, dst := range f.natTable {
parts := strings.Split(src, ":")
if len(parts) != 2 {
continue
}
srcIP := parts[0]
srcPort, err := strconv.Atoi(parts[1])
if err != nil {
continue
}
dstParts := strings.Split(dst, ":")
if len(dstParts) != 2 {
continue
}
dstIP := dstParts[0]
dstPort, err := strconv.Atoi(dstParts[1])
if err != nil {
continue
}
// 创建DNAT规则
rule := ForwardRule{
Type: DNAT,
SrcIP: "",
SrcPort: 0,
DstIP: srcIP,
DstPort: srcPort,
NewDstIP: dstIP,
NewDstPort: dstPort,
ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort),
}
f.natRules = append(f.natRules, rule)
}
}
return nil
}
// AddForwardRule 添加转发规则
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
f.natTable[key] = value
// 生成规则ID如果没有提供
if rule.ID == "" {
rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d",
rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano())
}
// 添加到规则列表
f.natRules = append(f.natRules, rule)
// 兼容旧版本如果是DNAT规则也添加到natTable
if rule.Type == DNAT {
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort)
f.natTable[key] = value
} else if rule.Type == "" {
// 兼容旧版本的规则格式
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
f.natTable[key] = value
}
log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID)
}
// RemoveForwardRule 移除转发规则
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
// 兼容旧版本
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
delete(f.natTable, key)
// 从natRules中删除
if rule.ID != "" {
for i, r := range f.natRules {
if r.ID == rule.ID {
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID)
break
}
}
}
}
// RemoveForwardRuleByID 通过ID删除转发规则
func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool {
for i, rule := range f.natRules {
if rule.ID == ruleID {
// 从规则列表中删除
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
// 如果是DNAT规则也从natTable中删除
if rule.Type == DNAT {
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
delete(f.natTable, key)
}
log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID)
return true
}
}
return false
}
// ListNATRules 列出所有NAT规则
func (f *Forwarder) ListNATRules() []ForwardRule {
return f.natRules
}
// getConnectionKey 从数据包中获取连接键
func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey {
var srcIP, dstIP string
var srcPort, dstPort uint16
var protocol uint8
srcIP = ipLayer.SrcIP.String()
dstIP = ipLayer.DstIP.String()
protocol = uint8(ipLayer.Protocol)
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = uint16(t.SrcPort)
dstPort = uint16(t.DstPort)
case *layers.UDP:
srcPort = uint16(t.SrcPort)
dstPort = uint16(t.DstPort)
}
// 如果是回复包,交换源和目标
if isReply {
return ConnectionKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
Protocol: protocol,
}
}
return ConnectionKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
Protocol: protocol,
}
}
// findMatchingRule 查找匹配的NAT规则
func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule {
for i, rule := range f.natRules {
// 检查源IP匹配
if rule.SrcIP != "" && rule.SrcIP != srcIP {
continue
}
// 检查源端口匹配
if rule.SrcPort != 0 && rule.SrcPort != srcPort {
continue
}
// 检查目标IP匹配
if rule.DstIP != "" && rule.DstIP != dstIP {
continue
}
// 检查目标端口匹配
if rule.DstPort != 0 && rule.DstPort != dstPort {
continue
}
// 找到匹配的规则
return &f.natRules[i]
}
// 兼容旧版本检查natTable
key := fmt.Sprintf("%s:%d", dstIP, dstPort)
if forwardAddr, exists := f.natTable[key]; exists {
// 解析转发目标地址
parts := strings.Split(forwardAddr, ":")
if len(parts) != 2 {
return nil
}
newDstIP := parts[0]
newDstPort, err := strconv.Atoi(parts[1])
if err != nil {
return nil
}
// 创建临时DNAT规则
return &ForwardRule{
Type: DNAT,
DstIP: dstIP,
DstPort: dstPort,
NewDstIP: newDstIP,
NewDstPort: newDstPort,
}
}
return nil
}
// ForwardPacket 转发数据包
@ -87,60 +458,175 @@ func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.
// 获取源IP和端口
srcIP := ipLayer.SrcIP.String()
var srcPort int
dstIP := ipLayer.DstIP.String()
var srcPort, dstPort int
// 根据传输层协议获取端口
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = int(t.SrcPort)
// dstPort = int(t.DstPort)
dstPort = int(t.DstPort)
case *layers.UDP:
srcPort = int(t.SrcPort)
// dstPort = int(t.DstPort)
dstPort = int(t.DstPort)
default:
// 不支持的传输层协议
return nil
}
// 查找转发规则
key := fmt.Sprintf("%s:%d", srcIP, srcPort)
if forwardAddr, exists := f.natTable[key]; exists {
// 解析转发目标地址
addr, port, err := net.SplitHostPort(forwardAddr)
if err != nil {
return err
}
// 创建连接键
connKey := f.getConnectionKey(ipLayer, transportLayer, false)
// 更新IP层目标地址
newDstIP := net.ParseIP(addr)
if newDstIP == nil {
return fmt.Errorf("invalid forward IP address: %s", addr)
}
ipLayer.DstIP = newDstIP
// 检查是否是已建立连接的回复包
f.connTrackMutex.RLock()
connInfo, isReply := f.connTrackTable[connKey]
// 更新传输层目标端口
newDstPort, err := strconv.Atoi(port)
if err != nil {
return err
// 如果不是回复包,尝试查找反向连接
if !isReply {
reverseKey := f.getConnectionKey(ipLayer, transportLayer, true)
if info, found := f.connTrackTable[reverseKey]; found {
isReply = true
connInfo = info
}
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(newDstPort)
case *layers.UDP:
t.DstPort = layers.UDPPort(newDstPort)
}
// 重新计算校验和
switch t := transportLayer.(type) {
case *layers.TCP:
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.SetNetworkLayerForChecksum(ipLayer)
}
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
}
f.connTrackMutex.RUnlock()
// 如果是已建立连接的回复包应用已有的NAT转换
if isReply {
// 更新连接的最后活动时间
f.connTrackMutex.Lock()
connInfo.LastSeen = time.Now()
f.connTrackTable[connKey] = connInfo
f.connTrackMutex.Unlock()
// 应用已有的NAT转换
if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP {
ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP)
}
if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP {
ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP)
}
switch t := transportLayer.(type) {
case *layers.TCP:
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort)
}
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort)
}
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort)
}
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort)
}
t.SetNetworkLayerForChecksum(ipLayer)
}
return nil
}
// 查找匹配的NAT规则
rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort)
if rule == nil {
return nil
}
// 创建连接跟踪信息
newConnInfo := ConnectionInfo{
OriginalSrcIP: srcIP,
OriginalDstIP: dstIP,
OriginalSrcPort: uint16(srcPort),
OriginalDstPort: uint16(dstPort),
TranslatedSrcIP: srcIP,
TranslatedDstIP: dstIP,
TranslatedSrcPort: uint16(srcPort),
TranslatedDstPort: uint16(dstPort),
LastSeen: time.Now(),
}
// 应用NAT规则
switch rule.Type {
case SNAT:
// 源地址转换
if rule.NewSrcIP != "" {
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
}
if rule.NewSrcPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
}
case DNAT:
// 目标地址转换
if rule.NewDstIP != "" {
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
newConnInfo.TranslatedDstIP = rule.NewDstIP
}
if rule.NewDstPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.DstPort = layers.UDPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
}
case BINAT:
// 双向地址转换
if rule.NewSrcIP != "" {
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
}
if rule.NewSrcPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
case *layers.UDP:
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
}
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
}
if rule.NewDstIP != "" {
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
newConnInfo.TranslatedDstIP = rule.NewDstIP
}
if rule.NewDstPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.DstPort = layers.UDPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
}
}
// 更新连接跟踪表
f.connTrackMutex.Lock()
f.connTrackTable[connKey] = newConnInfo
f.connTrackMutex.Unlock()
log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d",
rule.Type,
srcIP, srcPort, dstIP, dstPort,
newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort,
newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort)
return nil
}