gofirewall/forwarder.go

632 lines
16 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// NATType 定义NAT类型
type NATType string
const (
SNAT NATType = "SNAT" // 源地址转换
DNAT NATType = "DNAT" // 目标地址转换
BINAT NATType = "BINAT" // 双向地址转换
)
// ConnectionKey 定义连接跟踪的键
type ConnectionKey struct {
SrcIP string
DstIP string
SrcPort uint16
DstPort uint16
Protocol uint8
}
// ConnectionInfo 定义连接跟踪的信息
type ConnectionInfo struct {
OriginalSrcIP string
OriginalDstIP string
OriginalSrcPort uint16
OriginalDstPort uint16
TranslatedSrcIP string
TranslatedDstIP string
TranslatedSrcPort uint16
TranslatedDstPort uint16
LastSeen time.Time
}
// Forwarder 流量转发器
type Forwarder struct {
enabled bool
natRules []ForwardRule // NAT规则列表
natTable map[string]string // 兼容旧版本的NAT映射表
connTrackTable map[ConnectionKey]ConnectionInfo // 连接跟踪表
connTrackMutex sync.RWMutex // 保护连接跟踪表的互斥锁
cleanupTicker *time.Ticker // 定期清理过期连接
connTimeout time.Duration // 连接超时时间
}
// NewForwarder 创建新的流量转发器
func NewForwarder() *Forwarder {
f := &Forwarder{
enabled: false,
natRules: []ForwardRule{},
natTable: make(map[string]string),
connTrackTable: make(map[ConnectionKey]ConnectionInfo),
connTimeout: 5 * time.Minute, // 默认连接超时时间为5分钟
}
// 启动定期清理过期连接的定时器
f.cleanupTicker = time.NewTicker(1 * time.Minute)
go f.cleanupExpiredConnections()
return f
}
// cleanupExpiredConnections 清理过期的连接
func (f *Forwarder) cleanupExpiredConnections() {
for range f.cleanupTicker.C {
now := time.Now()
expiredKeys := []ConnectionKey{}
// 查找过期的连接
f.connTrackMutex.RLock()
for key, info := range f.connTrackTable {
if now.Sub(info.LastSeen) > f.connTimeout {
expiredKeys = append(expiredKeys, key)
}
}
f.connTrackMutex.RUnlock()
// 删除过期的连接
if len(expiredKeys) > 0 {
f.connTrackMutex.Lock()
for _, key := range expiredKeys {
delete(f.connTrackTable, key)
log.Printf("Removed expired connection: %s:%d -> %s:%d",
key.SrcIP, key.SrcPort, key.DstIP, key.DstPort)
}
f.connTrackMutex.Unlock()
}
}
}
// Enable 启用转发
func (f *Forwarder) Enable() {
f.enabled = true
}
// Disable 禁用转发
func (f *Forwarder) Disable() {
f.enabled = false
}
// Start 启动转发服务
func (f *Forwarder) Start() error {
f.enabled = true
log.Println("Forwarding service started")
return nil
}
// Stop 停止转发服务
func (f *Forwarder) Stop() {
f.enabled = false
log.Println("Forwarding service stopped")
}
// Close 关闭转发器,停止清理定时器
func (f *Forwarder) Close() {
if f.cleanupTicker != nil {
f.cleanupTicker.Stop()
}
}
// GetConnectionStats 获取连接统计信息
func (f *Forwarder) GetConnectionStats() map[string]interface{} {
f.connTrackMutex.RLock()
defer f.connTrackMutex.RUnlock()
stats := make(map[string]interface{})
stats["total_connections"] = len(f.connTrackTable)
// 按协议统计连接数
protocolStats := make(map[uint8]int)
for key := range f.connTrackTable {
protocolStats[key.Protocol]++
}
stats["protocol_stats"] = protocolStats
// 获取最近活动的连接
type connSummary struct {
SrcIP string `json:"src_ip"`
SrcPort uint16 `json:"src_port"`
DstIP string `json:"dst_ip"`
DstPort uint16 `json:"dst_port"`
Protocol uint8 `json:"protocol"`
LastSeen string `json:"last_seen"`
}
recentConns := []connSummary{}
count := 0
for key, info := range f.connTrackTable {
if count >= 10 {
break
}
recentConns = append(recentConns, connSummary{
SrcIP: key.SrcIP,
SrcPort: key.SrcPort,
DstIP: key.DstIP,
DstPort: key.DstPort,
Protocol: key.Protocol,
LastSeen: info.LastSeen.Format(time.RFC3339),
})
count++
}
stats["recent_connections"] = recentConns
return stats
}
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
content, err := os.ReadFile(ruleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule ForwardRule
err := decoder.Decode(&rule)
if err != nil {
return err
}
f.AddForwardRule(rule)
}
return nil
}
// ForwardRule 定义转发规则结构
type ForwardRule struct {
Type NATType `json:"type"` // NAT类型SNAT, DNAT, BINAT
SrcIP string `json:"src_ip"` // 源IP
SrcPort int `json:"src_port"` // 源端口
DstIP string `json:"dst_ip"` // 目标IP
DstPort int `json:"dst_port"` // 目标端口
NewSrcIP string `json:"new_src_ip"` // 新的源IP (用于SNAT和BINAT)
NewSrcPort int `json:"new_src_port"` // 新的源端口 (用于SNAT和BINAT)
NewDstIP string `json:"new_dst_ip"` // 新的目标IP (用于DNAT和BINAT)
NewDstPort int `json:"new_dst_port"` // 新的目标端口 (用于DNAT和BINAT)
ID string `json:"id"` // 规则ID用于动态删除
}
// LoadConfig 从配置文件加载转发规则
func (f *Forwarder) LoadConfig(configFile string) error {
data, err := os.ReadFile(configFile)
if err != nil {
return err
}
var config struct {
ForwardEnabled bool `json:"forward_enabled"`
NATRules []ForwardRule `json:"nat_rules"`
ForwardRules map[string]string `json:"forward_rules"` // 兼容旧版本
}
err = json.Unmarshal(data, &config)
if err != nil {
return err
}
f.enabled = config.ForwardEnabled
// 加载新版本NAT规则
if len(config.NATRules) > 0 {
f.natRules = config.NATRules
log.Printf("Loaded %d NAT rules", len(f.natRules))
for i, rule := range f.natRules {
log.Printf("Rule %d: Type=%s, SrcIP=%s, SrcPort=%d, DstIP=%s, DstPort=%d",
i+1, rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort)
}
}
// 兼容旧版本的转发规则
if len(config.ForwardRules) > 0 {
f.natTable = config.ForwardRules
log.Printf("Loaded %d legacy forward rules", len(f.natTable))
// 将旧版本规则转换为新版本规则
for src, dst := range f.natTable {
parts := strings.Split(src, ":")
if len(parts) != 2 {
continue
}
srcIP := parts[0]
srcPort, err := strconv.Atoi(parts[1])
if err != nil {
continue
}
dstParts := strings.Split(dst, ":")
if len(dstParts) != 2 {
continue
}
dstIP := dstParts[0]
dstPort, err := strconv.Atoi(dstParts[1])
if err != nil {
continue
}
// 创建DNAT规则
rule := ForwardRule{
Type: DNAT,
SrcIP: "",
SrcPort: 0,
DstIP: srcIP,
DstPort: srcPort,
NewDstIP: dstIP,
NewDstPort: dstPort,
ID: fmt.Sprintf("legacy-%s:%d", srcIP, srcPort),
}
f.natRules = append(f.natRules, rule)
}
}
return nil
}
// AddForwardRule 添加转发规则
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
// 生成规则ID如果没有提供
if rule.ID == "" {
rule.ID = fmt.Sprintf("%s-%s:%d-%s:%d-%d",
rule.Type, rule.SrcIP, rule.SrcPort, rule.DstIP, rule.DstPort, time.Now().UnixNano())
}
// 添加到规则列表
f.natRules = append(f.natRules, rule)
// 兼容旧版本如果是DNAT规则也添加到natTable
if rule.Type == DNAT {
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
value := fmt.Sprintf("%s:%d", rule.NewDstIP, rule.NewDstPort)
f.natTable[key] = value
} else if rule.Type == "" {
// 兼容旧版本的规则格式
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
f.natTable[key] = value
}
log.Printf("Added NAT rule: %s, ID=%s", rule.Type, rule.ID)
}
// RemoveForwardRule 移除转发规则
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
// 兼容旧版本
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
delete(f.natTable, key)
// 从natRules中删除
if rule.ID != "" {
for i, r := range f.natRules {
if r.ID == rule.ID {
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
log.Printf("Removed NAT rule: %s, ID=%s", r.Type, r.ID)
break
}
}
}
}
// RemoveForwardRuleByID 通过ID删除转发规则
func (f *Forwarder) RemoveForwardRuleByID(ruleID string) bool {
for i, rule := range f.natRules {
if rule.ID == ruleID {
// 从规则列表中删除
f.natRules = append(f.natRules[:i], f.natRules[i+1:]...)
// 如果是DNAT规则也从natTable中删除
if rule.Type == DNAT {
key := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
delete(f.natTable, key)
}
log.Printf("Removed NAT rule: %s, ID=%s", rule.Type, rule.ID)
return true
}
}
return false
}
// ListNATRules 列出所有NAT规则
func (f *Forwarder) ListNATRules() []ForwardRule {
return f.natRules
}
// getConnectionKey 从数据包中获取连接键
func (f *Forwarder) getConnectionKey(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, isReply bool) ConnectionKey {
var srcIP, dstIP string
var srcPort, dstPort uint16
var protocol uint8
srcIP = ipLayer.SrcIP.String()
dstIP = ipLayer.DstIP.String()
protocol = uint8(ipLayer.Protocol)
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = uint16(t.SrcPort)
dstPort = uint16(t.DstPort)
case *layers.UDP:
srcPort = uint16(t.SrcPort)
dstPort = uint16(t.DstPort)
}
// 如果是回复包,交换源和目标
if isReply {
return ConnectionKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
Protocol: protocol,
}
}
return ConnectionKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
Protocol: protocol,
}
}
// findMatchingRule 查找匹配的NAT规则
func (f *Forwarder) findMatchingRule(srcIP string, srcPort int, dstIP string, dstPort int) *ForwardRule {
for i, rule := range f.natRules {
// 检查源IP匹配
if rule.SrcIP != "" && rule.SrcIP != srcIP {
continue
}
// 检查源端口匹配
if rule.SrcPort != 0 && rule.SrcPort != srcPort {
continue
}
// 检查目标IP匹配
if rule.DstIP != "" && rule.DstIP != dstIP {
continue
}
// 检查目标端口匹配
if rule.DstPort != 0 && rule.DstPort != dstPort {
continue
}
// 找到匹配的规则
return &f.natRules[i]
}
// 兼容旧版本检查natTable
key := fmt.Sprintf("%s:%d", dstIP, dstPort)
if forwardAddr, exists := f.natTable[key]; exists {
// 解析转发目标地址
parts := strings.Split(forwardAddr, ":")
if len(parts) != 2 {
return nil
}
newDstIP := parts[0]
newDstPort, err := strconv.Atoi(parts[1])
if err != nil {
return nil
}
// 创建临时DNAT规则
return &ForwardRule{
Type: DNAT,
DstIP: dstIP,
DstPort: dstPort,
NewDstIP: newDstIP,
NewDstPort: newDstPort,
}
}
return nil
}
// ForwardPacket 转发数据包
func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error {
if !f.enabled {
return nil
}
// 获取源IP和端口
srcIP := ipLayer.SrcIP.String()
dstIP := ipLayer.DstIP.String()
var srcPort, dstPort int
// 根据传输层协议获取端口
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = int(t.SrcPort)
dstPort = int(t.DstPort)
case *layers.UDP:
srcPort = int(t.SrcPort)
dstPort = int(t.DstPort)
default:
// 不支持的传输层协议
return nil
}
// 创建连接键
connKey := f.getConnectionKey(ipLayer, transportLayer, false)
// 检查是否是已建立连接的回复包
f.connTrackMutex.RLock()
connInfo, isReply := f.connTrackTable[connKey]
// 如果不是回复包,尝试查找反向连接
if !isReply {
reverseKey := f.getConnectionKey(ipLayer, transportLayer, true)
if info, found := f.connTrackTable[reverseKey]; found {
isReply = true
connInfo = info
}
}
f.connTrackMutex.RUnlock()
// 如果是已建立连接的回复包应用已有的NAT转换
if isReply {
// 更新连接的最后活动时间
f.connTrackMutex.Lock()
connInfo.LastSeen = time.Now()
f.connTrackTable[connKey] = connInfo
f.connTrackMutex.Unlock()
// 应用已有的NAT转换
if connInfo.TranslatedSrcIP != "" && connInfo.TranslatedSrcIP != srcIP {
ipLayer.SrcIP = net.ParseIP(connInfo.TranslatedSrcIP)
}
if connInfo.TranslatedDstIP != "" && connInfo.TranslatedDstIP != dstIP {
ipLayer.DstIP = net.ParseIP(connInfo.TranslatedDstIP)
}
switch t := transportLayer.(type) {
case *layers.TCP:
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
t.SrcPort = layers.TCPPort(connInfo.TranslatedSrcPort)
}
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
t.DstPort = layers.TCPPort(connInfo.TranslatedDstPort)
}
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
if connInfo.TranslatedSrcPort != 0 && connInfo.TranslatedSrcPort != uint16(t.SrcPort) {
t.SrcPort = layers.UDPPort(connInfo.TranslatedSrcPort)
}
if connInfo.TranslatedDstPort != 0 && connInfo.TranslatedDstPort != uint16(t.DstPort) {
t.DstPort = layers.UDPPort(connInfo.TranslatedDstPort)
}
t.SetNetworkLayerForChecksum(ipLayer)
}
return nil
}
// 查找匹配的NAT规则
rule := f.findMatchingRule(srcIP, srcPort, dstIP, dstPort)
if rule == nil {
return nil
}
// 创建连接跟踪信息
newConnInfo := ConnectionInfo{
OriginalSrcIP: srcIP,
OriginalDstIP: dstIP,
OriginalSrcPort: uint16(srcPort),
OriginalDstPort: uint16(dstPort),
TranslatedSrcIP: srcIP,
TranslatedDstIP: dstIP,
TranslatedSrcPort: uint16(srcPort),
TranslatedDstPort: uint16(dstPort),
LastSeen: time.Now(),
}
// 应用NAT规则
switch rule.Type {
case SNAT:
// 源地址转换
if rule.NewSrcIP != "" {
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
}
if rule.NewSrcPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
}
case DNAT:
// 目标地址转换
if rule.NewDstIP != "" {
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
newConnInfo.TranslatedDstIP = rule.NewDstIP
}
if rule.NewDstPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.DstPort = layers.UDPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
}
case BINAT:
// 双向地址转换
if rule.NewSrcIP != "" {
ipLayer.SrcIP = net.ParseIP(rule.NewSrcIP)
newConnInfo.TranslatedSrcIP = rule.NewSrcIP
}
if rule.NewSrcPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.SrcPort = layers.TCPPort(rule.NewSrcPort)
case *layers.UDP:
t.SrcPort = layers.UDPPort(rule.NewSrcPort)
}
newConnInfo.TranslatedSrcPort = uint16(rule.NewSrcPort)
}
if rule.NewDstIP != "" {
ipLayer.DstIP = net.ParseIP(rule.NewDstIP)
newConnInfo.TranslatedDstIP = rule.NewDstIP
}
if rule.NewDstPort != 0 {
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.DstPort = layers.UDPPort(rule.NewDstPort)
t.SetNetworkLayerForChecksum(ipLayer)
}
newConnInfo.TranslatedDstPort = uint16(rule.NewDstPort)
}
}
// 更新连接跟踪表
f.connTrackMutex.Lock()
f.connTrackTable[connKey] = newConnInfo
f.connTrackMutex.Unlock()
log.Printf("Applied NAT rule: %s, %s:%d -> %s:%d => %s:%d -> %s:%d",
rule.Type,
srcIP, srcPort, dstIP, dstPort,
newConnInfo.TranslatedSrcIP, newConnInfo.TranslatedSrcPort,
newConnInfo.TranslatedDstIP, newConnInfo.TranslatedDstPort)
return nil
}