init code

This commit is contained in:
kingecg 2025-07-03 11:23:49 +08:00
commit 61d8c20d09
14 changed files with 1364 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
vendor/
target/

63
README.md Normal file
View File

@ -0,0 +1,63 @@
# GoFirewall
基于Go语言开发的高性能防火墙系统支持自定义规则、流量过滤、日志记录和网络转发功能。
## 功能特性
- 支持TCP/UDP/ICMP协议过滤
- 自定义防火墙规则管理
- 实时网络流量监控
- 数据包转发(NAT)功能
- 多级别日志记录
- 配置文件管理
## 安装指南
1. 确保已安装Go 1.16+环境
2. 克隆项目仓库:
```
git clone https://github.com/yourusername/gofirewall.git
```
3. 安装依赖:
```
go mod download
```
4. 编译项目:
```
go build
```
## 快速开始
1. 配置防火墙规则(编辑`firewall.json`
2. 启动防火墙:
```
./gofirewall
```
3. 查看日志:
```
tail -f firewall.log
```
## 配置文件
配置文件示例(`firewall.json`)
```json
{
"log_level": "info",
"capture_interface": "eth0",
"forward_enabled": false,
"max_packet_size": 65536
}
```
## 开发指南
运行测试:
```
go test ./...
```
## 许可证
MIT License

103
USAGE.md Normal file
View File

@ -0,0 +1,103 @@
# GoFirewall 使用文档
## 目录
1. [基本概念](#基本概念)
2. [配置文件](#配置文件)
3. [规则管理](#规则管理)
4. [日志系统](#日志系统)
5. [流量转发](#流量转发)
6. [API参考](#api参考)
## 基本概念
GoFirewall 是一个基于规则的网络防火墙,主要功能包括:
- **流量过滤**:根据规则允许或阻止网络流量
- **日志记录**:记录匹配的流量和系统事件
- **流量转发**支持NAT规则转发数据包
## 配置文件
### 配置参数
| 参数 | 类型 | 默认值 | 描述 |
|------|------|--------|------|
| log_level | string | "info" | 日志级别 (debug/info/warn/error) |
| capture_interface | string | "" | 监听的网络接口 |
| forward_enabled | bool | false | 是否启用流量转发 |
| max_packet_size | int | 65536 | 最大数据包大小 |
## 规则管理
### 规则格式
```json
{
"id": "rule-1",
"name": "Allow SSH",
"protocol": "tcp",
"src_ip": "192.168.1.1",
"src_port": "",
"dst_ip": "",
"dst_port": "22",
"action": "allow",
"description": "Allow SSH access",
"enabled": true
}
```
### 规则字段
| 字段 | 必填 | 描述 |
|------|------|------|
| protocol | 是 | 协议类型 (tcp/udp/icmp/all) |
| src_ip | 否 | 源IP地址支持通配符* |
| dst_ip | 否 | 目标IP地址支持通配符* |
| src_port | 否 | 源端口,支持范围(如8000-9000) |
| dst_port | 否 | 目标端口,支持范围 |
| action | 是 | 动作 (allow/deny) |
| enabled | 是 | 是否启用规则 |
## 日志系统
日志文件默认输出到`firewall.log`,包含以下信息:
- 匹配的规则ID
- 数据包源/目标信息
- 执行动作
- 时间戳
## 流量转发
### 转发规则示例
```go
forwarder.AddForwardRule(ForwardRule{
SrcIP: "192.168.1.100",
SrcPort: 8080,
DstIP: "10.0.0.2",
DstPort: 80,
})
```
## API参考
### Firewall 接口
```go
type Firewall interface {
Start() error
Stop()
AddRule(rule *Rule)
RemoveRule(ruleID string) bool
}
```
### Logger 接口
```go
type Logger interface {
Info(v ...interface{})
Warn(v ...interface{})
Error(v ...interface{})
Debug(v ...interface{})
}
```

165
capture.go Normal file
View File

@ -0,0 +1,165 @@
package main
import (
"fmt"
"log"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
)
// PacketCapture 数据包捕获器
type PacketCapture struct {
device string
handle *pcap.Handle
ruleManager *RuleManager
logger *Logger
forwarder *Forwarder
stopChan chan struct{}
}
// NewPacketCapture 创建新的数据包捕获器
func NewPacketCapture(device string, rm *RuleManager, logger *Logger, forwarder *Forwarder) *PacketCapture {
return &PacketCapture{
device: device,
ruleManager: rm,
logger: logger,
forwarder: forwarder,
stopChan: make(chan struct{}),
}
}
// Start 启动数据包捕获
func (p *PacketCapture) Start() error {
// 打开网络设备
handle, err := pcap.OpenLive(p.device, 65536, true, pcap.BlockForever)
if err != nil {
return fmt.Errorf("无法打开网络设备: %v", err)
}
p.handle = handle
// 设置BPF过滤器捕获TCP、UDP和ICMP流量
filter := "tcp or udp or icmp"
if err := handle.SetBPFFilter(filter); err != nil {
return fmt.Errorf("无法设置过滤器: %v", err)
}
// 开始捕获数据包
go p.capturePackets()
log.Printf("开始在设备 %s 上捕获数据包", p.device)
return nil
}
// Stop 停止数据包捕获
func (p *PacketCapture) Stop() {
close(p.stopChan)
if p.handle != nil {
p.handle.Close()
}
log.Println("停止数据包捕获")
}
// 捕获并处理数据包
func (p *PacketCapture) capturePackets() {
packetSource := gopacket.NewPacketSource(p.handle, p.handle.LinkType())
for {
select {
case packet := <-packetSource.Packets():
p.processPacket(packet)
case <-p.stopChan:
return
}
}
}
// 处理捕获到的数据包
func (p *PacketCapture) processPacket(packet gopacket.Packet) {
// 解析网络层和传输层
ipLayer := packet.Layer(layers.LayerTypeIPv4)
transportLayer := packet.Layer(layers.LayerTypeTCP)
if transportLayer == nil {
transportLayer = packet.Layer(layers.LayerTypeUDP)
}
icmpLayer := packet.Layer(layers.LayerTypeICMPv4)
// 只处理包含IP层和传输层/ICMP层的数据包
if ipLayer == nil || (transportLayer == nil && icmpLayer == nil) {
return
}
// 提取IP信息
ip, _ := ipLayer.(*layers.IPv4)
srcIP := ip.SrcIP
dstIP := ip.DstIP
// 确定协议类型
var protocol Protocol
var srcPort, dstPort int
if tcp, ok := transportLayer.(*layers.TCP); ok {
protocol = ProtocolTCP
srcPort = int(tcp.SrcPort)
dstPort = int(tcp.DstPort)
} else if udp, ok := transportLayer.(*layers.UDP); ok {
protocol = ProtocolUDP
srcPort = int(udp.SrcPort)
dstPort = int(udp.DstPort)
} else if icmpLayer != nil {
protocol = ProtocolICMP
srcPort = 0
dstPort = 0
} else {
return
}
// 查找匹配的规则
rule := p.ruleManager.MatchRule(srcIP, dstIP, srcPort, dstPort, protocol)
// 应用规则
if rule != nil {
p.logger.LogPacket(rule, srcIP.String(), dstIP.String(), srcPort, dstPort, protocol, rule.Action)
if rule.Action == ActionDeny {
// 阻止数据包(不做任何处理)
return
}
}
// 如果允许通过且需要转发,则进行转发处理
if rule == nil || rule.Action == ActionAllow {
if p.forwarder.enabled {
// 克隆数据包以便修改
buf := gopacket.NewSerializeBuffer()
err := gopacket.SerializePacket(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, packet)
if err != nil {
p.logger.Error("无法序列化数据包: ", err)
return
}
// 解析克隆的数据包
newPacket := gopacket.NewPacket(buf.Bytes(), packet.LinkLayer().LayerType(), gopacket.Default)
newIpLayer := newPacket.Layer(layers.LayerTypeIPv4)
newTransportLayer := newPacket.Layer(layers.LayerTypeTCP)
if newTransportLayer == nil {
newTransportLayer = newPacket.Layer(layers.LayerTypeUDP)
}
if newIpLayer != nil && newTransportLayer != nil {
ip, _ := newIpLayer.(*layers.IPv4)
transport, ok := newTransportLayer.(gopacket.TransportLayer)
if !ok {
p.logger.Error("Invalid transport layer type")
return
}
p.forwarder.ForwardPacket(ip, transport, buf.Bytes())
// 发送修改后的数据包
if err := p.handle.WritePacketData(buf.Bytes()); err != nil {
p.logger.Error("转发数据包失败: ", err)
}
}
}
}
}

76
config.go Normal file
View File

@ -0,0 +1,76 @@
package main
import (
"encoding/json"
"log"
"os"
)
// Config 防火墙配置结构
type Config struct {
LogLevel LogLevel `json:"log_level"`
CaptureInterface string `json:"capture_interface"`
ForwardEnabled bool `json:"forward_enabled"`
MaxPacketSize int `json:"max_packet_size"`
ConfigFile string `json:"config_file"`
}
// NewConfig 创建新的配置实例
func NewConfig() *Config {
return &Config{
LogLevel: LogLevelInfo,
CaptureInterface: "",
ForwardEnabled: false,
MaxPacketSize: 65536,
ConfigFile: "firewall.json",
}
}
// Load 从配置文件加载配置
func (c *Config) Load() error {
// 检查文件是否存在
if _, err := os.Stat(c.ConfigFile); os.IsNotExist(err) {
// 文件不存在,使用默认配置并保存
log.Println("Config file not found, creating default config")
return c.Save()
}
// 读取文件内容
data, err := os.ReadFile(c.ConfigFile)
if err != nil {
return err
}
// 解析JSON
if err := json.Unmarshal(data, c); err != nil {
return err
}
log.Println("Config loaded successfully")
return nil
}
// Save 将配置保存到文件
func (c *Config) Save() error {
// 转换为JSON
data, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
}
// 写入文件
if err := os.WriteFile(c.ConfigFile, data, 0644); err != nil {
return err
}
log.Println("Config saved successfully")
return nil
}
// Update 更新配置参数
func (c *Config) Update(newConfig *Config) {
c.LogLevel = newConfig.LogLevel
c.CaptureInterface = newConfig.CaptureInterface
c.ForwardEnabled = newConfig.ForwardEnabled
c.MaxPacketSize = newConfig.MaxPacketSize
}

104
config_test.go Normal file
View File

@ -0,0 +1,104 @@
package main
import (
"os"
"testing"
)
func TestConfig_LoadSave(t *testing.T) {
// 创建临时配置文件
configFile := "test_config.json"
defer os.Remove(configFile)
// 创建新配置
config := NewConfig()
config.ConfigFile = configFile
config.LogLevel = LogLevelDebug
config.CaptureInterface = "eth0"
config.ForwardEnabled = true
config.MaxPacketSize = 4096
// 测试保存配置
if err := config.Save(); err != nil {
t.Fatalf("Failed to save config: %v", err)
}
// 创建新配置实例加载保存的配置
loadedConfig := NewConfig()
loadedConfig.ConfigFile = configFile
if err := loadedConfig.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// 验证加载的配置是否与保存的一致
if loadedConfig.LogLevel != config.LogLevel {
t.Errorf("Expected LogLevel %v, got %v", config.LogLevel, loadedConfig.LogLevel)
}
if loadedConfig.CaptureInterface != config.CaptureInterface {
t.Errorf("Expected CaptureInterface %s, got %s", config.CaptureInterface, loadedConfig.CaptureInterface)
}
if loadedConfig.ForwardEnabled != config.ForwardEnabled {
t.Errorf("Expected ForwardEnabled %v, got %v", config.ForwardEnabled, loadedConfig.ForwardEnabled)
}
if loadedConfig.MaxPacketSize != config.MaxPacketSize {
t.Errorf("Expected MaxPacketSize %d, got %d", config.MaxPacketSize, loadedConfig.MaxPacketSize)
}
}
func TestConfig_LoadDefault(t *testing.T) {
// 创建临时配置文件
configFile := "test_default_config.json"
defer os.Remove(configFile)
// 创建新配置并加载不存在的文件(应该创建默认配置)
config := NewConfig()
config.ConfigFile = configFile
if err := config.Load(); err != nil {
t.Fatalf("Failed to load default config: %v", err)
}
// 验证默认值
if config.LogLevel != LogLevelInfo {
t.Errorf("Expected default LogLevel %v, got %v", LogLevelInfo, config.LogLevel)
}
if config.CaptureInterface != "" {
t.Errorf("Expected empty default CaptureInterface, got %s", config.CaptureInterface)
}
if config.ForwardEnabled != false {
t.Errorf("Expected default ForwardEnabled false, got %v", config.ForwardEnabled)
}
if config.MaxPacketSize != 65536 {
t.Errorf("Expected default MaxPacketSize 65536, got %d", config.MaxPacketSize)
}
// 验证是否创建了配置文件
if _, err := os.Stat(configFile); os.IsNotExist(err) {
t.Error("Config file was not created")
}
}
func TestConfig_Update(t *testing.T) {
config := NewConfig()
newConfig := &Config{
LogLevel: LogLevelWarn,
CaptureInterface: "wlan0",
ForwardEnabled: true,
MaxPacketSize: 8192,
}
config.Update(newConfig)
if config.LogLevel != newConfig.LogLevel {
t.Errorf("Expected LogLevel %v, got %v", newConfig.LogLevel, config.LogLevel)
}
if config.CaptureInterface != newConfig.CaptureInterface {
t.Errorf("Expected CaptureInterface %s, got %s", newConfig.CaptureInterface, config.CaptureInterface)
}
if config.ForwardEnabled != newConfig.ForwardEnabled {
t.Errorf("Expected ForwardEnabled %v, got %v", newConfig.ForwardEnabled, config.ForwardEnabled)
}
if config.MaxPacketSize != newConfig.MaxPacketSize {
t.Errorf("Expected MaxPacketSize %d, got %d", newConfig.MaxPacketSize, config.MaxPacketSize)
}
}

125
forwarder.go Normal file
View File

@ -0,0 +1,125 @@
package main
import (
"fmt"
"log"
"net"
"strconv"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// Forwarder 流量转发器
type Forwarder struct {
enabled bool
natTable map[string]string // 简单的NAT映射表key: 原始地址:端口, value: 转发后地址:端口
}
// NewForwarder 创建新的流量转发器
func NewForwarder() *Forwarder {
return &Forwarder{
enabled: false,
natTable: make(map[string]string),
}
}
// Start 启动转发服务
func (f *Forwarder) Start() error {
f.enabled = true
log.Println("Forwarding service started")
return nil
}
// Stop 停止转发服务
func (f *Forwarder) Stop() {
f.enabled = false
log.Println("Forwarding service stopped")
}
// ForwardRule 定义转发规则结构
type ForwardRule struct {
SrcIP string // 源IP
SrcPort int // 源端口
DstIP string // 目标IP
DstPort int // 目标端口
}
// AddForwardRule 添加转发规则
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
f.natTable[key] = value
}
// RemoveForwardRule 移除转发规则
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
delete(f.natTable, key)
}
// ForwardPacket 转发数据包
func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error {
if !f.enabled {
return nil
}
// 获取源IP和端口
srcIP := ipLayer.SrcIP.String()
var srcPort int
// 根据传输层协议获取端口
switch t := transportLayer.(type) {
case *layers.TCP:
srcPort = int(t.SrcPort)
// dstPort = int(t.DstPort)
case *layers.UDP:
srcPort = int(t.SrcPort)
// dstPort = int(t.DstPort)
default:
// 不支持的传输层协议
return nil
}
// 查找转发规则
key := fmt.Sprintf("%s:%d", srcIP, srcPort)
if forwardAddr, exists := f.natTable[key]; exists {
// 解析转发目标地址
addr, port, err := net.SplitHostPort(forwardAddr)
if err != nil {
return err
}
// 更新IP层目标地址
newDstIP := net.ParseIP(addr)
if newDstIP == nil {
return fmt.Errorf("invalid forward IP address: %s", addr)
}
ipLayer.DstIP = newDstIP
// 更新传输层目标端口
newDstPort, err := strconv.Atoi(port)
if err != nil {
return err
}
switch t := transportLayer.(type) {
case *layers.TCP:
t.DstPort = layers.TCPPort(newDstPort)
case *layers.UDP:
t.DstPort = layers.UDPPort(newDstPort)
}
// 重新计算校验和
switch t := transportLayer.(type) {
case *layers.TCP:
t.SetNetworkLayerForChecksum(ipLayer)
case *layers.UDP:
t.SetNetworkLayerForChecksum(ipLayer)
}
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
}
return nil
}

8
go.mod Normal file
View File

@ -0,0 +1,8 @@
module git.kingecg.top/kingecg/gofirewall
go 1.24.4
require (
github.com/google/gopacket v1.1.19 // indirect
golang.org/x/sys v0.0.0-20190412213103-97732733099d // indirect
)

15
go.sum Normal file
View File

@ -0,0 +1,15 @@
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

109
logger.go Normal file
View File

@ -0,0 +1,109 @@
package main
import (
"fmt"
"log"
"os"
)
// LogLevel 定义日志级别
type LogLevel int
// 日志级别常量
const (
LogLevelInfo LogLevel = iota
LogLevelWarn
LogLevelError
LogLevelDebug
)
// Logger 日志管理器
type Logger struct {
file *os.File
infoLog *log.Logger
warnLog *log.Logger
errorLog *log.Logger
debugLog *log.Logger
level LogLevel
}
// NewLogger 创建新的日志管理器
func NewLogger() *Logger {
// 打开或创建日志文件,追加模式
file, err := os.OpenFile("firewall.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Printf("Failed to open log file, using stdout: %v", err)
file = os.Stdout
}
// 创建不同级别的日志记录器
infoLog := log.New(file, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile)
warnLog := log.New(file, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile)
errorLog := log.New(file, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
debugLog := log.New(file, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile)
return &Logger{
file: file,
infoLog: infoLog,
warnLog: warnLog,
errorLog: errorLog,
debugLog: debugLog,
level: LogLevelInfo, // 默认日志级别为INFO
}
}
// SetLevel 设置日志级别
func (l *Logger) SetLevel(level LogLevel) {
l.level = level
}
// Info 记录INFO级别日志
func (l *Logger) Info(v ...interface{}) {
if l.level <= LogLevelInfo {
l.infoLog.Println(v...)
}
}
// Warn 记录WARN级别日志
func (l *Logger) Warn(v ...interface{}) {
if l.level <= LogLevelWarn {
l.warnLog.Println(v...)
}
}
// Error 记录ERROR级别日志
func (l *Logger) Error(v ...interface{}) {
if l.level <= LogLevelError {
l.errorLog.Println(v...)
}
}
// Debug 记录DEBUG级别日志
func (l *Logger) Debug(v ...interface{}) {
if l.level <= LogLevelDebug {
l.debugLog.Println(v...)
}
}
// Close 关闭日志文件
func (l *Logger) Close() {
if l.file != os.Stdout {
l.file.Close()
}
}
// LogPacket 记录数据包信息
func (l *Logger) LogPacket(rule *Rule, srcIP, dstIP string, srcPort, dstPort int, protocol Protocol, action RuleAction) {
logMsg := fmt.Sprintf(
"Packet matched rule %s: %s %s:%d -> %s:%d, action: %s",
rule.ID,
protocol,
srcIP,
srcPort,
dstIP,
dstPort,
action,
)
l.Info(logMsg)
}

121
main.go Normal file
View File

@ -0,0 +1,121 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
)
// Firewall 主防火墙结构体
type Firewall struct {
ruleManager *RuleManager
logger *Logger
config *Config
forwarder *Forwarder
capture *PacketCapture
}
// NewFirewall 创建新的防火墙实例
func NewFirewall() *Firewall {
logger := NewLogger()
return &Firewall{
ruleManager: NewRuleManager(),
logger: logger,
config: NewConfig(),
forwarder: NewForwarder(),
}
}
// Start 启动防火墙服务
func (f *Firewall) Start() error {
log.Println("Starting firewall service...")
// 加载配置
if err := f.config.Load(); err != nil {
return err
}
// 加载规则
if err := f.loadRules(); err != nil {
return err
}
// 启动流量捕获和过滤
if err := f.startPacketCapture(); err != nil {
return err
}
// 启动转发服务
if err := f.forwarder.Start(); err != nil {
return err
}
return nil
}
// Stop 停止防火墙服务
func (f *Firewall) Stop() {
log.Println("Stopping firewall service...")
if f.capture != nil {
f.capture.Stop()
}
f.forwarder.Stop()
f.logger.Close()
}
// 加载防火墙规则
func (f *Firewall) loadRules() error {
// 示例规则:允许本地回环地址的所有流量
loopbackRule := &Rule{
ID: "rule-1",
Name: "Allow Loopback",
Protocol: ProtocolAll,
SrcIP: "127.0.0.1",
DstIP: "127.0.0.1",
Action: ActionAllow,
Description: "Allow all loopback traffic",
Enabled: true,
}
f.ruleManager.AddRule(loopbackRule)
// 可以从配置文件或数据库加载更多规则
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
return nil
}
// 启动数据包捕获和过滤
func (f *Firewall) startPacketCapture() error {
if f.config.CaptureInterface == "" {
return fmt.Errorf("capture interface not configured")
}
f.capture = NewPacketCapture(
f.config.CaptureInterface,
f.ruleManager,
f.logger,
f.forwarder,
)
if err := f.capture.Start(); err != nil {
return fmt.Errorf("failed to start packet capture: %v", err)
}
return nil
}
func main() {
firewall := NewFirewall()
if err := firewall.Start(); err != nil {
log.Fatalf("Failed to start firewall: %v", err)
}
defer firewall.Stop()
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
log.Println("Firewall stopped successfully")
}

171
rule.go Normal file
View File

@ -0,0 +1,171 @@
package main
import (
"fmt"
"net"
"strings"
)
// RuleAction 定义规则动作类型
type RuleAction string
// 规则动作常量
const (
ActionAllow RuleAction = "allow"
ActionDeny RuleAction = "deny"
)
// Protocol 定义支持的网络协议
type Protocol string
// 协议常量
const (
ProtocolTCP Protocol = "tcp"
ProtocolUDP Protocol = "udp"
ProtocolICMP Protocol = "icmp"
ProtocolAll Protocol = "all"
)
// Rule 定义防火墙规则结构
type Rule struct {
ID string `json:"id"`
Name string `json:"name"`
Protocol Protocol `json:"protocol"`
SrcIP string `json:"src_ip"`
SrcPort string `json:"src_port"`
DstIP string `json:"dst_ip"`
DstPort string `json:"dst_port"`
Action RuleAction `json:"action"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
}
// RuleManager 规则管理器
type RuleManager struct {
rules []*Rule
}
// NewRuleManager 创建新的规则管理器
func NewRuleManager() *RuleManager {
return &RuleManager{
rules: make([]*Rule, 0),
}
}
// AddRule 添加规则
func (rm *RuleManager) AddRule(rule *Rule) {
rm.rules = append(rm.rules, rule)
}
// RemoveRule 移除规则
func (rm *RuleManager) RemoveRule(ruleID string) bool {
for i, rule := range rm.rules {
if rule.ID == ruleID {
// 从切片中删除元素
rm.rules = append(rm.rules[:i], rm.rules[i+1:]...)
return true
}
}
return false
}
// GetRule 获取规则
func (rm *RuleManager) GetRule(ruleID string) *Rule {
for _, rule := range rm.rules {
if rule.ID == ruleID {
return rule
}
}
return nil
}
// ListRules 列出所有规则
func (rm *RuleManager) ListRules() []*Rule {
return rm.rules
}
// MatchRule 匹配规则
func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule {
for _, rule := range rm.rules {
if !rule.Enabled {
continue
}
// 检查协议匹配
if rule.Protocol != ProtocolAll && rule.Protocol != protocol {
continue
}
// 检查源IP匹配
if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) {
continue
}
// 检查目的IP匹配
if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) {
continue
}
// 检查源端口匹配
if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) {
continue
}
// 检查目的端口匹配
if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) {
continue
}
// 找到匹配的规则
return rule
}
return nil
}
// 匹配IP地址支持通配符*
func matchIP(ip, pattern string) bool {
if pattern == "*" {
return true
}
return ip == pattern
}
// 匹配端口(支持范围和通配符*
func matchPort(port int, pattern string) bool {
if pattern == "*" {
return true
}
// 处理端口范围如8080-8090
if strings.Contains(pattern, "-") {
parts := strings.Split(pattern, "-")
if len(parts) != 2 {
return false
}
startPort, err1 := atoi(parts[0])
endPort, err2 := atoi(parts[1])
if err1 != nil || err2 != nil || startPort > endPort {
return false
}
return port >= startPort && port <= endPort
}
// 处理单个端口
p, err := atoi(pattern)
if err != nil {
return false
}
return port == p
}
// 字符串转整数辅助函数
func atoi(s string) (int, error) {
var res int
for _, c := range s {
if c < '0' || c > '9' {
return 0, fmt.Errorf("invalid digit: %c", c)
}
res = res*10 + int(c-'0')
}
return res, nil
}

287
rule_test.go Normal file
View File

@ -0,0 +1,287 @@
package main
import (
"net"
"testing"
)
func TestRuleManager_AddRemoveRule(t *testing.T) {
rm := NewRuleManager()
rule := &Rule{
ID: "test-1",
Action: ActionAllow,
}
// 测试添加规则
rm.AddRule(rule)
if len(rm.ListRules()) != 1 {
t.Error("Expected 1 rule after add, got", len(rm.ListRules()))
}
// 测试获取规则
found := rm.GetRule("test-1")
if found == nil {
t.Error("Expected to find rule test-1")
}
// 测试删除规则
removed := rm.RemoveRule("test-1")
if !removed {
t.Error("Expected to remove rule test-1")
}
if len(rm.ListRules()) != 0 {
t.Error("Expected 0 rules after remove, got", len(rm.ListRules()))
}
}
func TestRuleManager_MatchRule(t *testing.T) {
rm := NewRuleManager()
// 添加测试规则
rm.AddRule(&Rule{
ID: "tcp-8080",
Protocol: ProtocolTCP,
DstPort: "8080",
Action: ActionAllow,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "udp-53",
Protocol: ProtocolUDP,
DstPort: "53",
Action: ActionDeny,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "icmp-all",
Protocol: ProtocolICMP,
Action: ActionAllow,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "range-ports",
Protocol: ProtocolTCP,
DstPort: "8000-9000",
Action: ActionAllow,
Enabled: true,
})
// 禁用规则
rm.AddRule(&Rule{
ID: "disabled-rule",
Protocol: ProtocolTCP,
DstPort: "80",
Action: ActionAllow,
Enabled: false,
})
// 测试用例
tests := []struct {
name string
srcIP string
dstIP string
srcPort int
dstPort int
protocol Protocol
expected string
}{{
name: "match tcp 8080",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 8080,
protocol: ProtocolTCP,
expected: "tcp-8080",
}, {
name: "match udp 53",
srcIP: "10.0.0.1",
dstIP: "10.0.0.2",
srcPort: 54321,
dstPort: 53,
protocol: ProtocolUDP,
expected: "udp-53",
}, {
name: "match icmp",
srcIP: "172.16.0.1",
dstIP: "172.16.0.2",
srcPort: 0,
dstPort: 0,
protocol: ProtocolICMP,
expected: "icmp-all",
}, {
name: "match port range",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 8500,
protocol: ProtocolTCP,
expected: "range-ports",
}, {
name: "no match",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 80,
protocol: ProtocolTCP,
expected: "",
}, {
name: "disabled rule",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 80,
protocol: ProtocolTCP,
expected: "",
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
src := net.ParseIP(tt.srcIP)
dst := net.ParseIP(tt.dstIP)
if src == nil || dst == nil {
t.Fatal("Invalid IP address in test case")
}
rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol)
var ruleID string
if rule != nil {
ruleID = rule.ID
}
if ruleID != tt.expected {
t.Errorf("Expected rule %s, got %s", tt.expected, ruleID)
}
})
}
}
func TestAtoi(t *testing.T) {
tests := []struct {
name string
input string
expected int
err bool
}{{
name: "valid number",
input: "12345",
expected: 12345,
err: false,
}, {
name: "single digit",
input: "5",
expected: 5,
err: false,
}, {
name: "invalid character",
input: "12a3",
expected: 0,
err: true,
}, {
name: "empty string",
input: "",
expected: 0,
err: true,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := atoi(tt.input)
if (err != nil) != tt.err {
t.Errorf("Expected error %v, got %v", tt.err, err)
return
}
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestMatchPort(t *testing.T) {
tests := []struct {
name string
port int
pattern string
expected bool
}{{
name: "wildcard match",
port: 1234,
pattern: "*",
expected: true,
}, {
name: "exact match",
port: 8080,
pattern: "8080",
expected: true,
}, {
name: "range match lower",
port: 8000,
pattern: "8000-9000",
expected: true,
}, {
name: "range match upper",
port: 9000,
pattern: "8000-9000",
expected: true,
}, {
name: "range match middle",
port: 8500,
pattern: "8000-9000",
expected: true,
}, {
name: "range no match",
port: 7999,
pattern: "8000-9000",
expected: false,
}, {
name: "invalid pattern",
port: 80,
pattern: "abc",
expected: false,
}, {
name: "invalid range",
port: 80,
pattern: "9000-8000",
expected: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchPort(tt.port, tt.pattern)
if result != tt.expected {
t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result)
}
})
}
}
func TestMatchIP(t *testing.T) {
tests := []struct {
name string
ip string
pattern string
expected bool
}{{
name: "wildcard match",
ip: "192.168.1.1",
pattern: "*",
expected: true,
}, {
name: "exact match",
ip: "192.168.1.1",
pattern: "192.168.1.1",
expected: true,
}, {
name: "no match",
ip: "192.168.1.1",
pattern: "10.0.0.1",
expected: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchIP(tt.ip, tt.pattern)
if result != tt.expected {
t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result)
}
})
}
}

15
task.md Normal file
View File

@ -0,0 +1,15 @@
实现一个防火墙程序,功能:
- 可以添加自定义防火墙规则
- 可以根据规则过滤网络流量
- 可以记录防火墙日志
- 可以配置防火墙参数
- 可以配置网络流量转发
编码规范:
- 采用分层架构,实现防火墙规则、流量过滤、日志记录、配置管理等功能模块
- 采用面向对象设计,每个模块封装成一个类
- 采用模块化设计,每个模块负责一个具体的功能
- 采用异常处理机制,保证程序稳定性
- 采用注释说明,提高代码可读性
- 采用单元测试,保证每个模块功能的正确性
- 单个代码文件和函数不能过长