gofirewall/forwarder.go

148 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// Forwarder 流量转发器
type Forwarder struct {
enabled bool
natTable map[string]string // 简单的NAT映射表key: 原始地址:端口, value: 转发后地址:端口
}
// NewForwarder 创建新的流量转发器
func NewForwarder() *Forwarder {
return &Forwarder{
enabled: false,
natTable: make(map[string]string),
}
}
// Start 启动转发服务
func (f *Forwarder) Start() error {
f.enabled = true
log.Println("Forwarding service started")
return nil
}
// Stop 停止转发服务
func (f *Forwarder) Stop() {
f.enabled = false
log.Println("Forwarding service stopped")
}
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
content, err := os.ReadFile(ruleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule ForwardRule
err := decoder.Decode(&rule)
if err != nil {
return err
}
f.AddForwardRule(rule)
}
return nil
}
// ForwardRule 定义转发规则结构
type ForwardRule struct {
SrcIP string // 源IP
SrcPort int // 源端口
DstIP string // 目标IP
DstPort int // 目标端口
}
// AddForwardRule 添加转发规则
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
f.natTable[key] = value
}
// RemoveForwardRule 移除转发规则
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
delete(f.natTable, key)
}
// ForwardPacket 转发数据包
func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error {
if !f.enabled {
return nil
}
// 获取源IP和端口
srcIP := ipLayer.SrcIP.String()
var srcPort, dstPort int
// 根据传输层协议获取端口
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = int(t.SrcPort)
dstPort = int(t.DstPort)
case *layers.UDP:
srcPort = int(t.SrcPort)
dstPort = int(t.DstPort)
default:
// 不支持的传输层协议
return nil
}
// 查找转发规则, 按照目标端口查找并转发。
// 相当于做了一个端口映射。
key := fmt.Sprintf(":%d", dstPort) // srcIP, srcPort)
if forwardAddr, exists := f.natTable[key]; exists {
// 解析转发目标地址
addr, port, err := net.SplitHostPort(forwardAddr)
if err != nil {
return err
}
// 更新IP层目标地址
newDstIP := net.ParseIP(addr)
if newDstIP == nil {
return fmt.Errorf("invalid forward IP address: %s", addr)
}
ipLayer.DstIP = newDstIP
// 更新传输层目标端口
newDstPort, err := strconv.Atoi(port)
if err != nil {
return err
}
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(newDstPort)
case *layers.UDP:
t.DstPort = layers.UDPPort(newDstPort)
}
// 重新计算校验和
switch t := transportLayer.(type) {
case *layers.TCP:
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.SetNetworkLayerForChecksum(ipLayer)
}
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
}
return nil
}