147 lines
3.2 KiB
Go
147 lines
3.2 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/google/gopacket"
|
||
"github.com/google/gopacket/layers"
|
||
)
|
||
|
||
// Forwarder 流量转发器
|
||
type Forwarder struct {
|
||
enabled bool
|
||
natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口
|
||
}
|
||
|
||
// NewForwarder 创建新的流量转发器
|
||
func NewForwarder() *Forwarder {
|
||
return &Forwarder{
|
||
enabled: false,
|
||
natTable: make(map[string]string),
|
||
}
|
||
}
|
||
|
||
// Start 启动转发服务
|
||
func (f *Forwarder) Start() error {
|
||
f.enabled = true
|
||
log.Println("Forwarding service started")
|
||
return nil
|
||
}
|
||
|
||
// Stop 停止转发服务
|
||
func (f *Forwarder) Stop() {
|
||
f.enabled = false
|
||
log.Println("Forwarding service stopped")
|
||
}
|
||
|
||
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||
content, err := os.ReadFile(ruleFile)
|
||
if err != nil {
|
||
|
||
return err
|
||
}
|
||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||
for decoder.More() {
|
||
var rule ForwardRule
|
||
err := decoder.Decode(&rule)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
f.AddForwardRule(rule)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ForwardRule 定义转发规则结构
|
||
type ForwardRule struct {
|
||
SrcIP string // 源IP
|
||
SrcPort int // 源端口
|
||
DstIP string // 目标IP
|
||
DstPort int // 目标端口
|
||
}
|
||
|
||
// AddForwardRule 添加转发规则
|
||
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
|
||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||
f.natTable[key] = value
|
||
}
|
||
|
||
// RemoveForwardRule 移除转发规则
|
||
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
|
||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||
delete(f.natTable, key)
|
||
}
|
||
|
||
// ForwardPacket 转发数据包
|
||
func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error {
|
||
if !f.enabled {
|
||
return nil
|
||
}
|
||
|
||
// 获取源IP和端口
|
||
srcIP := ipLayer.SrcIP.String()
|
||
var srcPort int
|
||
|
||
// 根据传输层协议获取端口
|
||
switch t := transportLayer.(type) {
|
||
case *layers.TCP:
|
||
srcPort = int(t.SrcPort)
|
||
// dstPort = int(t.DstPort)
|
||
case *layers.UDP:
|
||
srcPort = int(t.SrcPort)
|
||
// dstPort = int(t.DstPort)
|
||
default:
|
||
// 不支持的传输层协议
|
||
return nil
|
||
}
|
||
|
||
// 查找转发规则
|
||
key := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||
if forwardAddr, exists := f.natTable[key]; exists {
|
||
// 解析转发目标地址
|
||
addr, port, err := net.SplitHostPort(forwardAddr)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新IP层目标地址
|
||
newDstIP := net.ParseIP(addr)
|
||
if newDstIP == nil {
|
||
return fmt.Errorf("invalid forward IP address: %s", addr)
|
||
}
|
||
ipLayer.DstIP = newDstIP
|
||
|
||
// 更新传输层目标端口
|
||
newDstPort, err := strconv.Atoi(port)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
switch t := transportLayer.(type) {
|
||
case *layers.TCP:
|
||
t.DstPort = layers.TCPPort(newDstPort)
|
||
case *layers.UDP:
|
||
t.DstPort = layers.UDPPort(newDstPort)
|
||
}
|
||
|
||
// 重新计算校验和
|
||
switch t := transportLayer.(type) {
|
||
case *layers.TCP:
|
||
t.SetNetworkLayerForChecksum(ipLayer)
|
||
case *layers.UDP:
|
||
t.SetNetworkLayerForChecksum(ipLayer)
|
||
}
|
||
|
||
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
|
||
}
|
||
|
||
return nil
|
||
}
|