commit 61d8c20d0927541c82e66f4e48acf351175aa5ea Author: kingecg Date: Thu Jul 3 11:23:49 2025 +0800 init code diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..69a5881 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +vendor/ +target/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f130661 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# GoFirewall + +基于Go语言开发的高性能防火墙系统,支持自定义规则、流量过滤、日志记录和网络转发功能。 + +## 功能特性 + +- 支持TCP/UDP/ICMP协议过滤 +- 自定义防火墙规则管理 +- 实时网络流量监控 +- 数据包转发(NAT)功能 +- 多级别日志记录 +- 配置文件管理 + +## 安装指南 + +1. 确保已安装Go 1.16+环境 +2. 克隆项目仓库: + ``` + git clone https://github.com/yourusername/gofirewall.git + ``` +3. 安装依赖: + ``` + go mod download + ``` +4. 编译项目: + ``` + go build + ``` + +## 快速开始 + +1. 配置防火墙规则(编辑`firewall.json`) +2. 启动防火墙: + ``` + ./gofirewall + ``` +3. 查看日志: + ``` + tail -f firewall.log + ``` + +## 配置文件 + +配置文件示例(`firewall.json`): +```json +{ + "log_level": "info", + "capture_interface": "eth0", + "forward_enabled": false, + "max_packet_size": 65536 +} +``` + +## 开发指南 + +运行测试: +``` +go test ./... +``` + +## 许可证 + +MIT License \ No newline at end of file diff --git a/USAGE.md b/USAGE.md new file mode 100644 index 0000000..97d9934 --- /dev/null +++ b/USAGE.md @@ -0,0 +1,103 @@ +# GoFirewall 使用文档 + +## 目录 +1. [基本概念](#基本概念) +2. [配置文件](#配置文件) +3. [规则管理](#规则管理) +4. [日志系统](#日志系统) +5. [流量转发](#流量转发) +6. [API参考](#api参考) + +## 基本概念 + +GoFirewall 是一个基于规则的网络防火墙,主要功能包括: + +- **流量过滤**:根据规则允许或阻止网络流量 +- **日志记录**:记录匹配的流量和系统事件 +- **流量转发**:支持NAT规则转发数据包 + +## 配置文件 + +### 配置参数 + +| 参数 | 类型 | 默认值 | 描述 | +|------|------|--------|------| +| log_level | string | "info" | 日志级别 (debug/info/warn/error) | +| capture_interface | string | "" | 监听的网络接口 | +| forward_enabled | bool | false | 是否启用流量转发 | +| max_packet_size | int | 65536 | 最大数据包大小 | + +## 规则管理 + +### 规则格式 +```json +{ + "id": "rule-1", + "name": "Allow SSH", + "protocol": "tcp", + "src_ip": "192.168.1.1", + "src_port": "", + "dst_ip": "", + "dst_port": "22", + "action": "allow", + "description": "Allow SSH access", + "enabled": true +} +``` + +### 规则字段 + +| 字段 | 必填 | 描述 | +|------|------|------| +| protocol | 是 | 协议类型 (tcp/udp/icmp/all) | +| src_ip | 否 | 源IP地址,支持通配符* | +| dst_ip | 否 | 目标IP地址,支持通配符* | +| src_port | 否 | 源端口,支持范围(如8000-9000) | +| dst_port | 否 | 目标端口,支持范围 | +| action | 是 | 动作 (allow/deny) | +| enabled | 是 | 是否启用规则 | + +## 日志系统 + +日志文件默认输出到`firewall.log`,包含以下信息: + +- 匹配的规则ID +- 数据包源/目标信息 +- 执行动作 +- 时间戳 + +## 流量转发 + +### 转发规则示例 +```go +forwarder.AddForwardRule(ForwardRule{ + SrcIP: "192.168.1.100", + SrcPort: 8080, + DstIP: "10.0.0.2", + DstPort: 80, +}) +``` + +## API参考 + +### Firewall 接口 + +```go +type Firewall interface { + Start() error + Stop() + AddRule(rule *Rule) + RemoveRule(ruleID string) bool +} +``` + +### Logger 接口 + +```go +type Logger interface { + Info(v ...interface{}) + Warn(v ...interface{}) + Error(v ...interface{}) + Debug(v ...interface{}) +} +``` \ No newline at end of file diff --git a/capture.go b/capture.go new file mode 100644 index 0000000..7aa343b --- /dev/null +++ b/capture.go @@ -0,0 +1,165 @@ +package main + +import ( + "fmt" + "log" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" +) + +// PacketCapture 数据包捕获器 +type PacketCapture struct { + device string + handle *pcap.Handle + ruleManager *RuleManager + logger *Logger + forwarder *Forwarder + stopChan chan struct{} +} + +// NewPacketCapture 创建新的数据包捕获器 +func NewPacketCapture(device string, rm *RuleManager, logger *Logger, forwarder *Forwarder) *PacketCapture { + return &PacketCapture{ + device: device, + ruleManager: rm, + logger: logger, + forwarder: forwarder, + stopChan: make(chan struct{}), + } +} + +// Start 启动数据包捕获 +func (p *PacketCapture) Start() error { + // 打开网络设备 + handle, err := pcap.OpenLive(p.device, 65536, true, pcap.BlockForever) + if err != nil { + return fmt.Errorf("无法打开网络设备: %v", err) + } + p.handle = handle + + // 设置BPF过滤器,捕获TCP、UDP和ICMP流量 + filter := "tcp or udp or icmp" + if err := handle.SetBPFFilter(filter); err != nil { + return fmt.Errorf("无法设置过滤器: %v", err) + } + + // 开始捕获数据包 + go p.capturePackets() + log.Printf("开始在设备 %s 上捕获数据包", p.device) + return nil +} + +// Stop 停止数据包捕获 +func (p *PacketCapture) Stop() { + close(p.stopChan) + if p.handle != nil { + p.handle.Close() + } + log.Println("停止数据包捕获") +} + +// 捕获并处理数据包 +func (p *PacketCapture) capturePackets() { + packetSource := gopacket.NewPacketSource(p.handle, p.handle.LinkType()) + + for { + select { + case packet := <-packetSource.Packets(): + p.processPacket(packet) + case <-p.stopChan: + return + } + } +} + +// 处理捕获到的数据包 +func (p *PacketCapture) processPacket(packet gopacket.Packet) { + // 解析网络层和传输层 + ipLayer := packet.Layer(layers.LayerTypeIPv4) + transportLayer := packet.Layer(layers.LayerTypeTCP) + if transportLayer == nil { + transportLayer = packet.Layer(layers.LayerTypeUDP) + } + icmpLayer := packet.Layer(layers.LayerTypeICMPv4) + + // 只处理包含IP层和传输层/ICMP层的数据包 + if ipLayer == nil || (transportLayer == nil && icmpLayer == nil) { + return + } + + // 提取IP信息 + ip, _ := ipLayer.(*layers.IPv4) + srcIP := ip.SrcIP + dstIP := ip.DstIP + + // 确定协议类型 + var protocol Protocol + var srcPort, dstPort int + + if tcp, ok := transportLayer.(*layers.TCP); ok { + protocol = ProtocolTCP + srcPort = int(tcp.SrcPort) + dstPort = int(tcp.DstPort) + } else if udp, ok := transportLayer.(*layers.UDP); ok { + protocol = ProtocolUDP + srcPort = int(udp.SrcPort) + dstPort = int(udp.DstPort) + } else if icmpLayer != nil { + protocol = ProtocolICMP + srcPort = 0 + dstPort = 0 + } else { + return + } + + // 查找匹配的规则 + rule := p.ruleManager.MatchRule(srcIP, dstIP, srcPort, dstPort, protocol) + + // 应用规则 + if rule != nil { + p.logger.LogPacket(rule, srcIP.String(), dstIP.String(), srcPort, dstPort, protocol, rule.Action) + + if rule.Action == ActionDeny { + // 阻止数据包(不做任何处理) + return + } + } + + // 如果允许通过且需要转发,则进行转发处理 + if rule == nil || rule.Action == ActionAllow { + if p.forwarder.enabled { + // 克隆数据包以便修改 + buf := gopacket.NewSerializeBuffer() + err := gopacket.SerializePacket(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, packet) + if err != nil { + p.logger.Error("无法序列化数据包: ", err) + return + } + + // 解析克隆的数据包 + newPacket := gopacket.NewPacket(buf.Bytes(), packet.LinkLayer().LayerType(), gopacket.Default) + newIpLayer := newPacket.Layer(layers.LayerTypeIPv4) + newTransportLayer := newPacket.Layer(layers.LayerTypeTCP) + if newTransportLayer == nil { + newTransportLayer = newPacket.Layer(layers.LayerTypeUDP) + } + + if newIpLayer != nil && newTransportLayer != nil { + ip, _ := newIpLayer.(*layers.IPv4) + transport, ok := newTransportLayer.(gopacket.TransportLayer) + if !ok { + p.logger.Error("Invalid transport layer type") + return + } + p.forwarder.ForwardPacket(ip, transport, buf.Bytes()) + + // 发送修改后的数据包 + if err := p.handle.WritePacketData(buf.Bytes()); err != nil { + p.logger.Error("转发数据包失败: ", err) + } + } + } + } +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..33e1631 --- /dev/null +++ b/config.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/json" + "log" + "os" +) + +// Config 防火墙配置结构 +type Config struct { + LogLevel LogLevel `json:"log_level"` + CaptureInterface string `json:"capture_interface"` + ForwardEnabled bool `json:"forward_enabled"` + MaxPacketSize int `json:"max_packet_size"` + ConfigFile string `json:"config_file"` +} + +// NewConfig 创建新的配置实例 +func NewConfig() *Config { + return &Config{ + LogLevel: LogLevelInfo, + CaptureInterface: "", + ForwardEnabled: false, + MaxPacketSize: 65536, + ConfigFile: "firewall.json", + } +} + +// Load 从配置文件加载配置 +func (c *Config) Load() error { + // 检查文件是否存在 + if _, err := os.Stat(c.ConfigFile); os.IsNotExist(err) { + // 文件不存在,使用默认配置并保存 + log.Println("Config file not found, creating default config") + return c.Save() + } + + // 读取文件内容 + data, err := os.ReadFile(c.ConfigFile) + if err != nil { + return err + } + + // 解析JSON + if err := json.Unmarshal(data, c); err != nil { + return err + } + + log.Println("Config loaded successfully") + return nil +} + +// Save 将配置保存到文件 +func (c *Config) Save() error { + // 转换为JSON + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return err + } + + // 写入文件 + if err := os.WriteFile(c.ConfigFile, data, 0644); err != nil { + return err + } + + log.Println("Config saved successfully") + return nil +} + +// Update 更新配置参数 +func (c *Config) Update(newConfig *Config) { + c.LogLevel = newConfig.LogLevel + c.CaptureInterface = newConfig.CaptureInterface + c.ForwardEnabled = newConfig.ForwardEnabled + c.MaxPacketSize = newConfig.MaxPacketSize +} \ No newline at end of file diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..76b5fa3 --- /dev/null +++ b/config_test.go @@ -0,0 +1,104 @@ +package main + +import ( + "os" + "testing" +) + +func TestConfig_LoadSave(t *testing.T) { + // 创建临时配置文件 + configFile := "test_config.json" + defer os.Remove(configFile) + + // 创建新配置 + config := NewConfig() + config.ConfigFile = configFile + config.LogLevel = LogLevelDebug + config.CaptureInterface = "eth0" + config.ForwardEnabled = true + config.MaxPacketSize = 4096 + + // 测试保存配置 + if err := config.Save(); err != nil { + t.Fatalf("Failed to save config: %v", err) + } + + // 创建新配置实例加载保存的配置 + loadedConfig := NewConfig() + loadedConfig.ConfigFile = configFile + if err := loadedConfig.Load(); err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // 验证加载的配置是否与保存的一致 + if loadedConfig.LogLevel != config.LogLevel { + t.Errorf("Expected LogLevel %v, got %v", config.LogLevel, loadedConfig.LogLevel) + } + if loadedConfig.CaptureInterface != config.CaptureInterface { + t.Errorf("Expected CaptureInterface %s, got %s", config.CaptureInterface, loadedConfig.CaptureInterface) + } + if loadedConfig.ForwardEnabled != config.ForwardEnabled { + t.Errorf("Expected ForwardEnabled %v, got %v", config.ForwardEnabled, loadedConfig.ForwardEnabled) + } + if loadedConfig.MaxPacketSize != config.MaxPacketSize { + t.Errorf("Expected MaxPacketSize %d, got %d", config.MaxPacketSize, loadedConfig.MaxPacketSize) + } +} + +func TestConfig_LoadDefault(t *testing.T) { + // 创建临时配置文件 + configFile := "test_default_config.json" + defer os.Remove(configFile) + + // 创建新配置并加载不存在的文件(应该创建默认配置) + config := NewConfig() + config.ConfigFile = configFile + + if err := config.Load(); err != nil { + t.Fatalf("Failed to load default config: %v", err) + } + + // 验证默认值 + if config.LogLevel != LogLevelInfo { + t.Errorf("Expected default LogLevel %v, got %v", LogLevelInfo, config.LogLevel) + } + if config.CaptureInterface != "" { + t.Errorf("Expected empty default CaptureInterface, got %s", config.CaptureInterface) + } + if config.ForwardEnabled != false { + t.Errorf("Expected default ForwardEnabled false, got %v", config.ForwardEnabled) + } + if config.MaxPacketSize != 65536 { + t.Errorf("Expected default MaxPacketSize 65536, got %d", config.MaxPacketSize) + } + + // 验证是否创建了配置文件 + if _, err := os.Stat(configFile); os.IsNotExist(err) { + t.Error("Config file was not created") + } +} + +func TestConfig_Update(t *testing.T) { + config := NewConfig() + newConfig := &Config{ + LogLevel: LogLevelWarn, + CaptureInterface: "wlan0", + ForwardEnabled: true, + MaxPacketSize: 8192, + } + + config.Update(newConfig) + + if config.LogLevel != newConfig.LogLevel { + t.Errorf("Expected LogLevel %v, got %v", newConfig.LogLevel, config.LogLevel) + } + if config.CaptureInterface != newConfig.CaptureInterface { + t.Errorf("Expected CaptureInterface %s, got %s", newConfig.CaptureInterface, config.CaptureInterface) + } + if config.ForwardEnabled != newConfig.ForwardEnabled { + t.Errorf("Expected ForwardEnabled %v, got %v", newConfig.ForwardEnabled, config.ForwardEnabled) + } + if config.MaxPacketSize != newConfig.MaxPacketSize { + t.Errorf("Expected MaxPacketSize %d, got %d", newConfig.MaxPacketSize, config.MaxPacketSize) + } +} \ No newline at end of file diff --git a/forwarder.go b/forwarder.go new file mode 100644 index 0000000..4cc1022 --- /dev/null +++ b/forwarder.go @@ -0,0 +1,125 @@ +package main + +import ( + "fmt" + "log" + "net" + "strconv" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// Forwarder 流量转发器 +type Forwarder struct { + enabled bool + natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口 +} + +// NewForwarder 创建新的流量转发器 +func NewForwarder() *Forwarder { + return &Forwarder{ + enabled: false, + natTable: make(map[string]string), + } +} + +// Start 启动转发服务 +func (f *Forwarder) Start() error { + f.enabled = true + log.Println("Forwarding service started") + return nil +} + +// Stop 停止转发服务 +func (f *Forwarder) Stop() { + f.enabled = false + log.Println("Forwarding service stopped") +} + +// ForwardRule 定义转发规则结构 +type ForwardRule struct { + SrcIP string // 源IP + SrcPort int // 源端口 + DstIP string // 目标IP + DstPort int // 目标端口 +} + +// AddForwardRule 添加转发规则 +func (f *Forwarder) AddForwardRule(rule ForwardRule) { + key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) + value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort) + f.natTable[key] = value +} + +// RemoveForwardRule 移除转发规则 +func (f *Forwarder) RemoveForwardRule(rule ForwardRule) { + key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort) + delete(f.natTable, key) +} + +// ForwardPacket 转发数据包 +func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error { + if !f.enabled { + return nil + } + + // 获取源IP和端口 + srcIP := ipLayer.SrcIP.String() + var srcPort int + + // 根据传输层协议获取端口 + switch t := transportLayer.(type) { + case *layers.TCP: + srcPort = int(t.SrcPort) + // dstPort = int(t.DstPort) + case *layers.UDP: + srcPort = int(t.SrcPort) + // dstPort = int(t.DstPort) + default: + // 不支持的传输层协议 + return nil + } + + // 查找转发规则 + key := fmt.Sprintf("%s:%d", srcIP, srcPort) + if forwardAddr, exists := f.natTable[key]; exists { + // 解析转发目标地址 + addr, port, err := net.SplitHostPort(forwardAddr) + if err != nil { + return err + } + + // 更新IP层目标地址 + newDstIP := net.ParseIP(addr) + if newDstIP == nil { + return fmt.Errorf("invalid forward IP address: %s", addr) + } + ipLayer.DstIP = newDstIP + + // 更新传输层目标端口 + newDstPort, err := strconv.Atoi(port) + if err != nil { + return err + } + + switch t := transportLayer.(type) { + case *layers.TCP: + t.DstPort = layers.TCPPort(newDstPort) + case *layers.UDP: + t.DstPort = layers.UDPPort(newDstPort) + } + + // 重新计算校验和 + switch t := transportLayer.(type) { + case *layers.TCP: + t.SetNetworkLayerForChecksum(ipLayer) + case *layers.UDP: + t.SetNetworkLayerForChecksum(ipLayer) + } + + log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort) + } + + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b2cbcc6 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module git.kingecg.top/kingecg/gofirewall + +go 1.24.4 + +require ( + github.com/google/gopacket v1.1.19 // indirect + golang.org/x/sys v0.0.0-20190412213103-97732733099d // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..aea2a4a --- /dev/null +++ b/go.sum @@ -0,0 +1,15 @@ +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..7a2daec --- /dev/null +++ b/logger.go @@ -0,0 +1,109 @@ +package main + +import ( + "fmt" + "log" + "os" +) + +// LogLevel 定义日志级别 +type LogLevel int + +// 日志级别常量 +const ( + LogLevelInfo LogLevel = iota + LogLevelWarn + LogLevelError + LogLevelDebug +) + +// Logger 日志管理器 +type Logger struct { + file *os.File + infoLog *log.Logger + warnLog *log.Logger + errorLog *log.Logger + debugLog *log.Logger + level LogLevel +} + +// NewLogger 创建新的日志管理器 +func NewLogger() *Logger { + // 打开或创建日志文件,追加模式 + file, err := os.OpenFile("firewall.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + log.Printf("Failed to open log file, using stdout: %v", err) + file = os.Stdout + } + + // 创建不同级别的日志记录器 + infoLog := log.New(file, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) + warnLog := log.New(file, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile) + errorLog := log.New(file, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) + debugLog := log.New(file, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile) + + return &Logger{ + file: file, + infoLog: infoLog, + warnLog: warnLog, + errorLog: errorLog, + debugLog: debugLog, + level: LogLevelInfo, // 默认日志级别为INFO + } +} + +// SetLevel 设置日志级别 +func (l *Logger) SetLevel(level LogLevel) { + l.level = level +} + +// Info 记录INFO级别日志 +func (l *Logger) Info(v ...interface{}) { + if l.level <= LogLevelInfo { + l.infoLog.Println(v...) + } +} + +// Warn 记录WARN级别日志 +func (l *Logger) Warn(v ...interface{}) { + if l.level <= LogLevelWarn { + l.warnLog.Println(v...) + } +} + +// Error 记录ERROR级别日志 +func (l *Logger) Error(v ...interface{}) { + if l.level <= LogLevelError { + l.errorLog.Println(v...) + } +} + +// Debug 记录DEBUG级别日志 +func (l *Logger) Debug(v ...interface{}) { + if l.level <= LogLevelDebug { + l.debugLog.Println(v...) + } +} + +// Close 关闭日志文件 +func (l *Logger) Close() { + if l.file != os.Stdout { + l.file.Close() + } +} + +// LogPacket 记录数据包信息 +func (l *Logger) LogPacket(rule *Rule, srcIP, dstIP string, srcPort, dstPort int, protocol Protocol, action RuleAction) { + logMsg := fmt.Sprintf( + "Packet matched rule %s: %s %s:%d -> %s:%d, action: %s", + rule.ID, + protocol, + srcIP, + srcPort, + dstIP, + dstPort, + action, + ) + + l.Info(logMsg) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..f394bcc --- /dev/null +++ b/main.go @@ -0,0 +1,121 @@ +package main + +import ( + "fmt" + "log" + "os" + "os/signal" + "syscall" +) + +// Firewall 主防火墙结构体 +type Firewall struct { + ruleManager *RuleManager + logger *Logger + config *Config + forwarder *Forwarder + capture *PacketCapture +} + +// NewFirewall 创建新的防火墙实例 +func NewFirewall() *Firewall { + logger := NewLogger() + return &Firewall{ + ruleManager: NewRuleManager(), + logger: logger, + config: NewConfig(), + forwarder: NewForwarder(), + } +} + +// Start 启动防火墙服务 +func (f *Firewall) Start() error { + log.Println("Starting firewall service...") + // 加载配置 + if err := f.config.Load(); err != nil { + return err + } + + // 加载规则 + if err := f.loadRules(); err != nil { + return err + } + + // 启动流量捕获和过滤 + if err := f.startPacketCapture(); err != nil { + return err + } + + // 启动转发服务 + if err := f.forwarder.Start(); err != nil { + return err + } + + return nil +} + +// Stop 停止防火墙服务 +func (f *Firewall) Stop() { + log.Println("Stopping firewall service...") + if f.capture != nil { + f.capture.Stop() + } + f.forwarder.Stop() + f.logger.Close() +} + +// 加载防火墙规则 +func (f *Firewall) loadRules() error { + // 示例规则:允许本地回环地址的所有流量 + loopbackRule := &Rule{ + ID: "rule-1", + Name: "Allow Loopback", + Protocol: ProtocolAll, + SrcIP: "127.0.0.1", + DstIP: "127.0.0.1", + Action: ActionAllow, + Description: "Allow all loopback traffic", + Enabled: true, + } + f.ruleManager.AddRule(loopbackRule) + + // 可以从配置文件或数据库加载更多规则 + f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules") + return nil +} + +// 启动数据包捕获和过滤 +func (f *Firewall) startPacketCapture() error { + if f.config.CaptureInterface == "" { + return fmt.Errorf("capture interface not configured") + } + + f.capture = NewPacketCapture( + f.config.CaptureInterface, + f.ruleManager, + f.logger, + f.forwarder, + ) + + if err := f.capture.Start(); err != nil { + return fmt.Errorf("failed to start packet capture: %v", err) + } + + return nil +} + +func main() { + firewall := NewFirewall() + + if err := firewall.Start(); err != nil { + log.Fatalf("Failed to start firewall: %v", err) + } + defer firewall.Stop() + + // 等待中断信号 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("Firewall stopped successfully") +} \ No newline at end of file diff --git a/rule.go b/rule.go new file mode 100644 index 0000000..2fc8871 --- /dev/null +++ b/rule.go @@ -0,0 +1,171 @@ +package main + +import ( + "fmt" + "net" + "strings" +) + +// RuleAction 定义规则动作类型 +type RuleAction string + +// 规则动作常量 +const ( + ActionAllow RuleAction = "allow" + ActionDeny RuleAction = "deny" +) + +// Protocol 定义支持的网络协议 +type Protocol string + +// 协议常量 +const ( + ProtocolTCP Protocol = "tcp" + ProtocolUDP Protocol = "udp" + ProtocolICMP Protocol = "icmp" + ProtocolAll Protocol = "all" +) + +// Rule 定义防火墙规则结构 +type Rule struct { + ID string `json:"id"` + Name string `json:"name"` + Protocol Protocol `json:"protocol"` + SrcIP string `json:"src_ip"` + SrcPort string `json:"src_port"` + DstIP string `json:"dst_ip"` + DstPort string `json:"dst_port"` + Action RuleAction `json:"action"` + Description string `json:"description"` + Enabled bool `json:"enabled"` +} + +// RuleManager 规则管理器 +type RuleManager struct { + rules []*Rule +} + +// NewRuleManager 创建新的规则管理器 +func NewRuleManager() *RuleManager { + return &RuleManager{ + rules: make([]*Rule, 0), + } +} + +// AddRule 添加规则 +func (rm *RuleManager) AddRule(rule *Rule) { + rm.rules = append(rm.rules, rule) +} + +// RemoveRule 移除规则 +func (rm *RuleManager) RemoveRule(ruleID string) bool { + for i, rule := range rm.rules { + if rule.ID == ruleID { + // 从切片中删除元素 + rm.rules = append(rm.rules[:i], rm.rules[i+1:]...) + return true + } + } + return false +} + +// GetRule 获取规则 +func (rm *RuleManager) GetRule(ruleID string) *Rule { + for _, rule := range rm.rules { + if rule.ID == ruleID { + return rule + } + } + return nil +} + +// ListRules 列出所有规则 +func (rm *RuleManager) ListRules() []*Rule { + return rm.rules +} + +// MatchRule 匹配规则 +func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule { + for _, rule := range rm.rules { + if !rule.Enabled { + continue + } + + // 检查协议匹配 + if rule.Protocol != ProtocolAll && rule.Protocol != protocol { + continue + } + + // 检查源IP匹配 + if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) { + continue + } + + // 检查目的IP匹配 + if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) { + continue + } + + // 检查源端口匹配 + if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) { + continue + } + + // 检查目的端口匹配 + if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) { + continue + } + + // 找到匹配的规则 + return rule + } + return nil +} + +// 匹配IP地址(支持通配符*) +func matchIP(ip, pattern string) bool { + if pattern == "*" { + return true + } + return ip == pattern +} + +// 匹配端口(支持范围和通配符*) +func matchPort(port int, pattern string) bool { + if pattern == "*" { + return true + } + + // 处理端口范围(如8080-8090) + if strings.Contains(pattern, "-") { + parts := strings.Split(pattern, "-") + if len(parts) != 2 { + return false + } + startPort, err1 := atoi(parts[0]) + endPort, err2 := atoi(parts[1]) + if err1 != nil || err2 != nil || startPort > endPort { + return false + } + return port >= startPort && port <= endPort + } + + // 处理单个端口 + p, err := atoi(pattern) + if err != nil { + return false + } + return port == p +} + +// 字符串转整数辅助函数 +func atoi(s string) (int, error) { + var res int + for _, c := range s { + if c < '0' || c > '9' { + return 0, fmt.Errorf("invalid digit: %c", c) + } + res = res*10 + int(c-'0') + } + return res, nil +} diff --git a/rule_test.go b/rule_test.go new file mode 100644 index 0000000..20e79ff --- /dev/null +++ b/rule_test.go @@ -0,0 +1,287 @@ +package main + +import ( + "net" + "testing" +) + +func TestRuleManager_AddRemoveRule(t *testing.T) { + rm := NewRuleManager() + rule := &Rule{ + ID: "test-1", + Action: ActionAllow, + } + + // 测试添加规则 + rm.AddRule(rule) + if len(rm.ListRules()) != 1 { + t.Error("Expected 1 rule after add, got", len(rm.ListRules())) + } + + // 测试获取规则 + found := rm.GetRule("test-1") + if found == nil { + t.Error("Expected to find rule test-1") + } + + // 测试删除规则 + removed := rm.RemoveRule("test-1") + if !removed { + t.Error("Expected to remove rule test-1") + } + if len(rm.ListRules()) != 0 { + t.Error("Expected 0 rules after remove, got", len(rm.ListRules())) + } +} + +func TestRuleManager_MatchRule(t *testing.T) { + rm := NewRuleManager() + // 添加测试规则 + rm.AddRule(&Rule{ + ID: "tcp-8080", + Protocol: ProtocolTCP, + DstPort: "8080", + Action: ActionAllow, + Enabled: true, + }) + rm.AddRule(&Rule{ + ID: "udp-53", + Protocol: ProtocolUDP, + DstPort: "53", + Action: ActionDeny, + Enabled: true, + }) + rm.AddRule(&Rule{ + ID: "icmp-all", + Protocol: ProtocolICMP, + Action: ActionAllow, + Enabled: true, + }) + rm.AddRule(&Rule{ + ID: "range-ports", + Protocol: ProtocolTCP, + DstPort: "8000-9000", + Action: ActionAllow, + Enabled: true, + }) + // 禁用规则 + rm.AddRule(&Rule{ + ID: "disabled-rule", + Protocol: ProtocolTCP, + DstPort: "80", + Action: ActionAllow, + Enabled: false, + }) + + // 测试用例 + tests := []struct { + name string + srcIP string + dstIP string + srcPort int + dstPort int + protocol Protocol + expected string + }{{ + name: "match tcp 8080", + srcIP: "192.168.1.1", + dstIP: "192.168.1.2", + srcPort: 12345, + dstPort: 8080, + protocol: ProtocolTCP, + expected: "tcp-8080", + }, { + name: "match udp 53", + srcIP: "10.0.0.1", + dstIP: "10.0.0.2", + srcPort: 54321, + dstPort: 53, + protocol: ProtocolUDP, + expected: "udp-53", + }, { + name: "match icmp", + srcIP: "172.16.0.1", + dstIP: "172.16.0.2", + srcPort: 0, + dstPort: 0, + protocol: ProtocolICMP, + expected: "icmp-all", + }, { + name: "match port range", + srcIP: "192.168.1.1", + dstIP: "192.168.1.2", + srcPort: 12345, + dstPort: 8500, + protocol: ProtocolTCP, + expected: "range-ports", + }, { + name: "no match", + srcIP: "192.168.1.1", + dstIP: "192.168.1.2", + srcPort: 12345, + dstPort: 80, + protocol: ProtocolTCP, + expected: "", + }, { + name: "disabled rule", + srcIP: "192.168.1.1", + dstIP: "192.168.1.2", + srcPort: 12345, + dstPort: 80, + protocol: ProtocolTCP, + expected: "", + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := net.ParseIP(tt.srcIP) + dst := net.ParseIP(tt.dstIP) + if src == nil || dst == nil { + t.Fatal("Invalid IP address in test case") + } + + rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol) + var ruleID string + if rule != nil { + ruleID = rule.ID + } + + if ruleID != tt.expected { + t.Errorf("Expected rule %s, got %s", tt.expected, ruleID) + } + }) + } +} + +func TestAtoi(t *testing.T) { + tests := []struct { + name string + input string + expected int + err bool + }{{ + name: "valid number", + input: "12345", + expected: 12345, + err: false, + }, { + name: "single digit", + input: "5", + expected: 5, + err: false, + }, { + name: "invalid character", + input: "12a3", + expected: 0, + err: true, + }, { + name: "empty string", + input: "", + expected: 0, + err: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := atoi(tt.input) + if (err != nil) != tt.err { + t.Errorf("Expected error %v, got %v", tt.err, err) + return + } + if result != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestMatchPort(t *testing.T) { + tests := []struct { + name string + port int + pattern string + expected bool + }{{ + name: "wildcard match", + port: 1234, + pattern: "*", + expected: true, + }, { + name: "exact match", + port: 8080, + pattern: "8080", + expected: true, + }, { + name: "range match lower", + port: 8000, + pattern: "8000-9000", + expected: true, + }, { + name: "range match upper", + port: 9000, + pattern: "8000-9000", + expected: true, + }, { + name: "range match middle", + port: 8500, + pattern: "8000-9000", + expected: true, + }, { + name: "range no match", + port: 7999, + pattern: "8000-9000", + expected: false, + }, { + name: "invalid pattern", + port: 80, + pattern: "abc", + expected: false, + }, { + name: "invalid range", + port: 80, + pattern: "9000-8000", + expected: false, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchPort(tt.port, tt.pattern) + if result != tt.expected { + t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result) + } + }) + } +} + +func TestMatchIP(t *testing.T) { + tests := []struct { + name string + ip string + pattern string + expected bool + }{{ + name: "wildcard match", + ip: "192.168.1.1", + pattern: "*", + expected: true, + }, { + name: "exact match", + ip: "192.168.1.1", + pattern: "192.168.1.1", + expected: true, + }, { + name: "no match", + ip: "192.168.1.1", + pattern: "10.0.0.1", + expected: false, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchIP(tt.ip, tt.pattern) + if result != tt.expected { + t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result) + } + }) + } +} \ No newline at end of file diff --git a/task.md b/task.md new file mode 100644 index 0000000..69c561f --- /dev/null +++ b/task.md @@ -0,0 +1,15 @@ +实现一个防火墙程序,功能: +- 可以添加自定义防火墙规则 +- 可以根据规则过滤网络流量 +- 可以记录防火墙日志 +- 可以配置防火墙参数 +- 可以配置网络流量转发 + +编码规范: +- 采用分层架构,实现防火墙规则、流量过滤、日志记录、配置管理等功能模块 +- 采用面向对象设计,每个模块封装成一个类 +- 采用模块化设计,每个模块负责一个具体的功能 +- 采用异常处理机制,保证程序稳定性 +- 采用注释说明,提高代码可读性 +- 采用单元测试,保证每个模块功能的正确性 +- 单个代码文件和函数不能过长 \ No newline at end of file