init code
This commit is contained in:
commit
61d8c20d09
|
|
@ -0,0 +1,2 @@
|
|||
vendor/
|
||||
target/
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# GoFirewall
|
||||
|
||||
基于Go语言开发的高性能防火墙系统,支持自定义规则、流量过滤、日志记录和网络转发功能。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持TCP/UDP/ICMP协议过滤
|
||||
- 自定义防火墙规则管理
|
||||
- 实时网络流量监控
|
||||
- 数据包转发(NAT)功能
|
||||
- 多级别日志记录
|
||||
- 配置文件管理
|
||||
|
||||
## 安装指南
|
||||
|
||||
1. 确保已安装Go 1.16+环境
|
||||
2. 克隆项目仓库:
|
||||
```
|
||||
git clone https://github.com/yourusername/gofirewall.git
|
||||
```
|
||||
3. 安装依赖:
|
||||
```
|
||||
go mod download
|
||||
```
|
||||
4. 编译项目:
|
||||
```
|
||||
go build
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 配置防火墙规则(编辑`firewall.json`)
|
||||
2. 启动防火墙:
|
||||
```
|
||||
./gofirewall
|
||||
```
|
||||
3. 查看日志:
|
||||
```
|
||||
tail -f firewall.log
|
||||
```
|
||||
|
||||
## 配置文件
|
||||
|
||||
配置文件示例(`firewall.json`):
|
||||
```json
|
||||
{
|
||||
"log_level": "info",
|
||||
"capture_interface": "eth0",
|
||||
"forward_enabled": false,
|
||||
"max_packet_size": 65536
|
||||
}
|
||||
```
|
||||
|
||||
## 开发指南
|
||||
|
||||
运行测试:
|
||||
```
|
||||
go test ./...
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
# GoFirewall 使用文档
|
||||
|
||||
## 目录
|
||||
1. [基本概念](#基本概念)
|
||||
2. [配置文件](#配置文件)
|
||||
3. [规则管理](#规则管理)
|
||||
4. [日志系统](#日志系统)
|
||||
5. [流量转发](#流量转发)
|
||||
6. [API参考](#api参考)
|
||||
|
||||
## 基本概念
|
||||
|
||||
GoFirewall 是一个基于规则的网络防火墙,主要功能包括:
|
||||
|
||||
- **流量过滤**:根据规则允许或阻止网络流量
|
||||
- **日志记录**:记录匹配的流量和系统事件
|
||||
- **流量转发**:支持NAT规则转发数据包
|
||||
|
||||
## 配置文件
|
||||
|
||||
### 配置参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 描述 |
|
||||
|------|------|--------|------|
|
||||
| log_level | string | "info" | 日志级别 (debug/info/warn/error) |
|
||||
| capture_interface | string | "" | 监听的网络接口 |
|
||||
| forward_enabled | bool | false | 是否启用流量转发 |
|
||||
| max_packet_size | int | 65536 | 最大数据包大小 |
|
||||
|
||||
## 规则管理
|
||||
|
||||
### 规则格式
|
||||
```json
|
||||
{
|
||||
"id": "rule-1",
|
||||
"name": "Allow SSH",
|
||||
"protocol": "tcp",
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": "",
|
||||
"dst_ip": "",
|
||||
"dst_port": "22",
|
||||
"action": "allow",
|
||||
"description": "Allow SSH access",
|
||||
"enabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### 规则字段
|
||||
|
||||
| 字段 | 必填 | 描述 |
|
||||
|------|------|------|
|
||||
| protocol | 是 | 协议类型 (tcp/udp/icmp/all) |
|
||||
| src_ip | 否 | 源IP地址,支持通配符* |
|
||||
| dst_ip | 否 | 目标IP地址,支持通配符* |
|
||||
| src_port | 否 | 源端口,支持范围(如8000-9000) |
|
||||
| dst_port | 否 | 目标端口,支持范围 |
|
||||
| action | 是 | 动作 (allow/deny) |
|
||||
| enabled | 是 | 是否启用规则 |
|
||||
|
||||
## 日志系统
|
||||
|
||||
日志文件默认输出到`firewall.log`,包含以下信息:
|
||||
|
||||
- 匹配的规则ID
|
||||
- 数据包源/目标信息
|
||||
- 执行动作
|
||||
- 时间戳
|
||||
|
||||
## 流量转发
|
||||
|
||||
### 转发规则示例
|
||||
```go
|
||||
forwarder.AddForwardRule(ForwardRule{
|
||||
SrcIP: "192.168.1.100",
|
||||
SrcPort: 8080,
|
||||
DstIP: "10.0.0.2",
|
||||
DstPort: 80,
|
||||
})
|
||||
```
|
||||
|
||||
## API参考
|
||||
|
||||
### Firewall 接口
|
||||
|
||||
```go
|
||||
type Firewall interface {
|
||||
Start() error
|
||||
Stop()
|
||||
AddRule(rule *Rule)
|
||||
RemoveRule(ruleID string) bool
|
||||
}
|
||||
```
|
||||
|
||||
### Logger 接口
|
||||
|
||||
```go
|
||||
type Logger interface {
|
||||
Info(v ...interface{})
|
||||
Warn(v ...interface{})
|
||||
Error(v ...interface{})
|
||||
Debug(v ...interface{})
|
||||
}
|
||||
```
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/pcap"
|
||||
)
|
||||
|
||||
// PacketCapture 数据包捕获器
|
||||
type PacketCapture struct {
|
||||
device string
|
||||
handle *pcap.Handle
|
||||
ruleManager *RuleManager
|
||||
logger *Logger
|
||||
forwarder *Forwarder
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewPacketCapture 创建新的数据包捕获器
|
||||
func NewPacketCapture(device string, rm *RuleManager, logger *Logger, forwarder *Forwarder) *PacketCapture {
|
||||
return &PacketCapture{
|
||||
device: device,
|
||||
ruleManager: rm,
|
||||
logger: logger,
|
||||
forwarder: forwarder,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动数据包捕获
|
||||
func (p *PacketCapture) Start() error {
|
||||
// 打开网络设备
|
||||
handle, err := pcap.OpenLive(p.device, 65536, true, pcap.BlockForever)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法打开网络设备: %v", err)
|
||||
}
|
||||
p.handle = handle
|
||||
|
||||
// 设置BPF过滤器,捕获TCP、UDP和ICMP流量
|
||||
filter := "tcp or udp or icmp"
|
||||
if err := handle.SetBPFFilter(filter); err != nil {
|
||||
return fmt.Errorf("无法设置过滤器: %v", err)
|
||||
}
|
||||
|
||||
// 开始捕获数据包
|
||||
go p.capturePackets()
|
||||
log.Printf("开始在设备 %s 上捕获数据包", p.device)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止数据包捕获
|
||||
func (p *PacketCapture) Stop() {
|
||||
close(p.stopChan)
|
||||
if p.handle != nil {
|
||||
p.handle.Close()
|
||||
}
|
||||
log.Println("停止数据包捕获")
|
||||
}
|
||||
|
||||
// 捕获并处理数据包
|
||||
func (p *PacketCapture) capturePackets() {
|
||||
packetSource := gopacket.NewPacketSource(p.handle, p.handle.LinkType())
|
||||
|
||||
for {
|
||||
select {
|
||||
case packet := <-packetSource.Packets():
|
||||
p.processPacket(packet)
|
||||
case <-p.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理捕获到的数据包
|
||||
func (p *PacketCapture) processPacket(packet gopacket.Packet) {
|
||||
// 解析网络层和传输层
|
||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||
transportLayer := packet.Layer(layers.LayerTypeTCP)
|
||||
if transportLayer == nil {
|
||||
transportLayer = packet.Layer(layers.LayerTypeUDP)
|
||||
}
|
||||
icmpLayer := packet.Layer(layers.LayerTypeICMPv4)
|
||||
|
||||
// 只处理包含IP层和传输层/ICMP层的数据包
|
||||
if ipLayer == nil || (transportLayer == nil && icmpLayer == nil) {
|
||||
return
|
||||
}
|
||||
|
||||
// 提取IP信息
|
||||
ip, _ := ipLayer.(*layers.IPv4)
|
||||
srcIP := ip.SrcIP
|
||||
dstIP := ip.DstIP
|
||||
|
||||
// 确定协议类型
|
||||
var protocol Protocol
|
||||
var srcPort, dstPort int
|
||||
|
||||
if tcp, ok := transportLayer.(*layers.TCP); ok {
|
||||
protocol = ProtocolTCP
|
||||
srcPort = int(tcp.SrcPort)
|
||||
dstPort = int(tcp.DstPort)
|
||||
} else if udp, ok := transportLayer.(*layers.UDP); ok {
|
||||
protocol = ProtocolUDP
|
||||
srcPort = int(udp.SrcPort)
|
||||
dstPort = int(udp.DstPort)
|
||||
} else if icmpLayer != nil {
|
||||
protocol = ProtocolICMP
|
||||
srcPort = 0
|
||||
dstPort = 0
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
// 查找匹配的规则
|
||||
rule := p.ruleManager.MatchRule(srcIP, dstIP, srcPort, dstPort, protocol)
|
||||
|
||||
// 应用规则
|
||||
if rule != nil {
|
||||
p.logger.LogPacket(rule, srcIP.String(), dstIP.String(), srcPort, dstPort, protocol, rule.Action)
|
||||
|
||||
if rule.Action == ActionDeny {
|
||||
// 阻止数据包(不做任何处理)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果允许通过且需要转发,则进行转发处理
|
||||
if rule == nil || rule.Action == ActionAllow {
|
||||
if p.forwarder.enabled {
|
||||
// 克隆数据包以便修改
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
err := gopacket.SerializePacket(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, packet)
|
||||
if err != nil {
|
||||
p.logger.Error("无法序列化数据包: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析克隆的数据包
|
||||
newPacket := gopacket.NewPacket(buf.Bytes(), packet.LinkLayer().LayerType(), gopacket.Default)
|
||||
newIpLayer := newPacket.Layer(layers.LayerTypeIPv4)
|
||||
newTransportLayer := newPacket.Layer(layers.LayerTypeTCP)
|
||||
if newTransportLayer == nil {
|
||||
newTransportLayer = newPacket.Layer(layers.LayerTypeUDP)
|
||||
}
|
||||
|
||||
if newIpLayer != nil && newTransportLayer != nil {
|
||||
ip, _ := newIpLayer.(*layers.IPv4)
|
||||
transport, ok := newTransportLayer.(gopacket.TransportLayer)
|
||||
if !ok {
|
||||
p.logger.Error("Invalid transport layer type")
|
||||
return
|
||||
}
|
||||
p.forwarder.ForwardPacket(ip, transport, buf.Bytes())
|
||||
|
||||
// 发送修改后的数据包
|
||||
if err := p.handle.WritePacketData(buf.Bytes()); err != nil {
|
||||
p.logger.Error("转发数据包失败: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config 防火墙配置结构
|
||||
type Config struct {
|
||||
LogLevel LogLevel `json:"log_level"`
|
||||
CaptureInterface string `json:"capture_interface"`
|
||||
ForwardEnabled bool `json:"forward_enabled"`
|
||||
MaxPacketSize int `json:"max_packet_size"`
|
||||
ConfigFile string `json:"config_file"`
|
||||
}
|
||||
|
||||
// NewConfig 创建新的配置实例
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
LogLevel: LogLevelInfo,
|
||||
CaptureInterface: "",
|
||||
ForwardEnabled: false,
|
||||
MaxPacketSize: 65536,
|
||||
ConfigFile: "firewall.json",
|
||||
}
|
||||
}
|
||||
|
||||
// Load 从配置文件加载配置
|
||||
func (c *Config) Load() error {
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(c.ConfigFile); os.IsNotExist(err) {
|
||||
// 文件不存在,使用默认配置并保存
|
||||
log.Println("Config file not found, creating default config")
|
||||
return c.Save()
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
data, err := os.ReadFile(c.ConfigFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
if err := json.Unmarshal(data, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Config loaded successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 将配置保存到文件
|
||||
func (c *Config) Save() error {
|
||||
// 转换为JSON
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(c.ConfigFile, data, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Config saved successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update 更新配置参数
|
||||
func (c *Config) Update(newConfig *Config) {
|
||||
c.LogLevel = newConfig.LogLevel
|
||||
c.CaptureInterface = newConfig.CaptureInterface
|
||||
c.ForwardEnabled = newConfig.ForwardEnabled
|
||||
c.MaxPacketSize = newConfig.MaxPacketSize
|
||||
}
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfig_LoadSave(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
configFile := "test_config.json"
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// 创建新配置
|
||||
config := NewConfig()
|
||||
config.ConfigFile = configFile
|
||||
config.LogLevel = LogLevelDebug
|
||||
config.CaptureInterface = "eth0"
|
||||
config.ForwardEnabled = true
|
||||
config.MaxPacketSize = 4096
|
||||
|
||||
// 测试保存配置
|
||||
if err := config.Save(); err != nil {
|
||||
t.Fatalf("Failed to save config: %v", err)
|
||||
}
|
||||
|
||||
// 创建新配置实例加载保存的配置
|
||||
loadedConfig := NewConfig()
|
||||
loadedConfig.ConfigFile = configFile
|
||||
if err := loadedConfig.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// 验证加载的配置是否与保存的一致
|
||||
if loadedConfig.LogLevel != config.LogLevel {
|
||||
t.Errorf("Expected LogLevel %v, got %v", config.LogLevel, loadedConfig.LogLevel)
|
||||
}
|
||||
if loadedConfig.CaptureInterface != config.CaptureInterface {
|
||||
t.Errorf("Expected CaptureInterface %s, got %s", config.CaptureInterface, loadedConfig.CaptureInterface)
|
||||
}
|
||||
if loadedConfig.ForwardEnabled != config.ForwardEnabled {
|
||||
t.Errorf("Expected ForwardEnabled %v, got %v", config.ForwardEnabled, loadedConfig.ForwardEnabled)
|
||||
}
|
||||
if loadedConfig.MaxPacketSize != config.MaxPacketSize {
|
||||
t.Errorf("Expected MaxPacketSize %d, got %d", config.MaxPacketSize, loadedConfig.MaxPacketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_LoadDefault(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
configFile := "test_default_config.json"
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// 创建新配置并加载不存在的文件(应该创建默认配置)
|
||||
config := NewConfig()
|
||||
config.ConfigFile = configFile
|
||||
|
||||
if err := config.Load(); err != nil {
|
||||
t.Fatalf("Failed to load default config: %v", err)
|
||||
}
|
||||
|
||||
// 验证默认值
|
||||
if config.LogLevel != LogLevelInfo {
|
||||
t.Errorf("Expected default LogLevel %v, got %v", LogLevelInfo, config.LogLevel)
|
||||
}
|
||||
if config.CaptureInterface != "" {
|
||||
t.Errorf("Expected empty default CaptureInterface, got %s", config.CaptureInterface)
|
||||
}
|
||||
if config.ForwardEnabled != false {
|
||||
t.Errorf("Expected default ForwardEnabled false, got %v", config.ForwardEnabled)
|
||||
}
|
||||
if config.MaxPacketSize != 65536 {
|
||||
t.Errorf("Expected default MaxPacketSize 65536, got %d", config.MaxPacketSize)
|
||||
}
|
||||
|
||||
// 验证是否创建了配置文件
|
||||
if _, err := os.Stat(configFile); os.IsNotExist(err) {
|
||||
t.Error("Config file was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Update(t *testing.T) {
|
||||
config := NewConfig()
|
||||
newConfig := &Config{
|
||||
LogLevel: LogLevelWarn,
|
||||
CaptureInterface: "wlan0",
|
||||
ForwardEnabled: true,
|
||||
MaxPacketSize: 8192,
|
||||
}
|
||||
|
||||
config.Update(newConfig)
|
||||
|
||||
if config.LogLevel != newConfig.LogLevel {
|
||||
t.Errorf("Expected LogLevel %v, got %v", newConfig.LogLevel, config.LogLevel)
|
||||
}
|
||||
if config.CaptureInterface != newConfig.CaptureInterface {
|
||||
t.Errorf("Expected CaptureInterface %s, got %s", newConfig.CaptureInterface, config.CaptureInterface)
|
||||
}
|
||||
if config.ForwardEnabled != newConfig.ForwardEnabled {
|
||||
t.Errorf("Expected ForwardEnabled %v, got %v", newConfig.ForwardEnabled, config.ForwardEnabled)
|
||||
}
|
||||
if config.MaxPacketSize != newConfig.MaxPacketSize {
|
||||
t.Errorf("Expected MaxPacketSize %d, got %d", newConfig.MaxPacketSize, config.MaxPacketSize)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// Forwarder 流量转发器
|
||||
type Forwarder struct {
|
||||
enabled bool
|
||||
natTable map[string]string // 简单的NAT映射表,key: 原始地址:端口, value: 转发后地址:端口
|
||||
}
|
||||
|
||||
// NewForwarder 创建新的流量转发器
|
||||
func NewForwarder() *Forwarder {
|
||||
return &Forwarder{
|
||||
enabled: false,
|
||||
natTable: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动转发服务
|
||||
func (f *Forwarder) Start() error {
|
||||
f.enabled = true
|
||||
log.Println("Forwarding service started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止转发服务
|
||||
func (f *Forwarder) Stop() {
|
||||
f.enabled = false
|
||||
log.Println("Forwarding service stopped")
|
||||
}
|
||||
|
||||
// ForwardRule 定义转发规则结构
|
||||
type ForwardRule struct {
|
||||
SrcIP string // 源IP
|
||||
SrcPort int // 源端口
|
||||
DstIP string // 目标IP
|
||||
DstPort int // 目标端口
|
||||
}
|
||||
|
||||
// AddForwardRule 添加转发规则
|
||||
func (f *Forwarder) AddForwardRule(rule ForwardRule) {
|
||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||
value := fmt.Sprintf("%s:%d", rule.DstIP, rule.DstPort)
|
||||
f.natTable[key] = value
|
||||
}
|
||||
|
||||
// RemoveForwardRule 移除转发规则
|
||||
func (f *Forwarder) RemoveForwardRule(rule ForwardRule) {
|
||||
key := fmt.Sprintf("%s:%d", rule.SrcIP, rule.SrcPort)
|
||||
delete(f.natTable, key)
|
||||
}
|
||||
|
||||
// ForwardPacket 转发数据包
|
||||
func (f *Forwarder) ForwardPacket(ipLayer *layers.IPv4, transportLayer gopacket.TransportLayer, packetData []byte) error {
|
||||
if !f.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取源IP和端口
|
||||
srcIP := ipLayer.SrcIP.String()
|
||||
var srcPort int
|
||||
|
||||
// 根据传输层协议获取端口
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
srcPort = int(t.SrcPort)
|
||||
// dstPort = int(t.DstPort)
|
||||
case *layers.UDP:
|
||||
srcPort = int(t.SrcPort)
|
||||
// dstPort = int(t.DstPort)
|
||||
default:
|
||||
// 不支持的传输层协议
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查找转发规则
|
||||
key := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||
if forwardAddr, exists := f.natTable[key]; exists {
|
||||
// 解析转发目标地址
|
||||
addr, port, err := net.SplitHostPort(forwardAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新IP层目标地址
|
||||
newDstIP := net.ParseIP(addr)
|
||||
if newDstIP == nil {
|
||||
return fmt.Errorf("invalid forward IP address: %s", addr)
|
||||
}
|
||||
ipLayer.DstIP = newDstIP
|
||||
|
||||
// 更新传输层目标端口
|
||||
newDstPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.DstPort = layers.TCPPort(newDstPort)
|
||||
case *layers.UDP:
|
||||
t.DstPort = layers.UDPPort(newDstPort)
|
||||
}
|
||||
|
||||
// 重新计算校验和
|
||||
switch t := transportLayer.(type) {
|
||||
case *layers.TCP:
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
case *layers.UDP:
|
||||
t.SetNetworkLayerForChecksum(ipLayer)
|
||||
}
|
||||
|
||||
log.Printf("Forwarding packet: %s:%d -> %s:%d", srcIP, srcPort, addr, newDstPort)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
module git.kingecg.top/kingecg/gofirewall
|
||||
|
||||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d // indirect
|
||||
)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// LogLevel 定义日志级别
|
||||
type LogLevel int
|
||||
|
||||
// 日志级别常量
|
||||
const (
|
||||
LogLevelInfo LogLevel = iota
|
||||
LogLevelWarn
|
||||
LogLevelError
|
||||
LogLevelDebug
|
||||
)
|
||||
|
||||
// Logger 日志管理器
|
||||
type Logger struct {
|
||||
file *os.File
|
||||
infoLog *log.Logger
|
||||
warnLog *log.Logger
|
||||
errorLog *log.Logger
|
||||
debugLog *log.Logger
|
||||
level LogLevel
|
||||
}
|
||||
|
||||
// NewLogger 创建新的日志管理器
|
||||
func NewLogger() *Logger {
|
||||
// 打开或创建日志文件,追加模式
|
||||
file, err := os.OpenFile("firewall.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open log file, using stdout: %v", err)
|
||||
file = os.Stdout
|
||||
}
|
||||
|
||||
// 创建不同级别的日志记录器
|
||||
infoLog := log.New(file, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
warnLog := log.New(file, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
errorLog := log.New(file, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
debugLog := log.New(file, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
return &Logger{
|
||||
file: file,
|
||||
infoLog: infoLog,
|
||||
warnLog: warnLog,
|
||||
errorLog: errorLog,
|
||||
debugLog: debugLog,
|
||||
level: LogLevelInfo, // 默认日志级别为INFO
|
||||
}
|
||||
}
|
||||
|
||||
// SetLevel 设置日志级别
|
||||
func (l *Logger) SetLevel(level LogLevel) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
// Info 记录INFO级别日志
|
||||
func (l *Logger) Info(v ...interface{}) {
|
||||
if l.level <= LogLevelInfo {
|
||||
l.infoLog.Println(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn 记录WARN级别日志
|
||||
func (l *Logger) Warn(v ...interface{}) {
|
||||
if l.level <= LogLevelWarn {
|
||||
l.warnLog.Println(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error 记录ERROR级别日志
|
||||
func (l *Logger) Error(v ...interface{}) {
|
||||
if l.level <= LogLevelError {
|
||||
l.errorLog.Println(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug 记录DEBUG级别日志
|
||||
func (l *Logger) Debug(v ...interface{}) {
|
||||
if l.level <= LogLevelDebug {
|
||||
l.debugLog.Println(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭日志文件
|
||||
func (l *Logger) Close() {
|
||||
if l.file != os.Stdout {
|
||||
l.file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// LogPacket 记录数据包信息
|
||||
func (l *Logger) LogPacket(rule *Rule, srcIP, dstIP string, srcPort, dstPort int, protocol Protocol, action RuleAction) {
|
||||
logMsg := fmt.Sprintf(
|
||||
"Packet matched rule %s: %s %s:%d -> %s:%d, action: %s",
|
||||
rule.ID,
|
||||
protocol,
|
||||
srcIP,
|
||||
srcPort,
|
||||
dstIP,
|
||||
dstPort,
|
||||
action,
|
||||
)
|
||||
|
||||
l.Info(logMsg)
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Firewall 主防火墙结构体
|
||||
type Firewall struct {
|
||||
ruleManager *RuleManager
|
||||
logger *Logger
|
||||
config *Config
|
||||
forwarder *Forwarder
|
||||
capture *PacketCapture
|
||||
}
|
||||
|
||||
// NewFirewall 创建新的防火墙实例
|
||||
func NewFirewall() *Firewall {
|
||||
logger := NewLogger()
|
||||
return &Firewall{
|
||||
ruleManager: NewRuleManager(),
|
||||
logger: logger,
|
||||
config: NewConfig(),
|
||||
forwarder: NewForwarder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动防火墙服务
|
||||
func (f *Firewall) Start() error {
|
||||
log.Println("Starting firewall service...")
|
||||
// 加载配置
|
||||
if err := f.config.Load(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 加载规则
|
||||
if err := f.loadRules(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动流量捕获和过滤
|
||||
if err := f.startPacketCapture(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动转发服务
|
||||
if err := f.forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止防火墙服务
|
||||
func (f *Firewall) Stop() {
|
||||
log.Println("Stopping firewall service...")
|
||||
if f.capture != nil {
|
||||
f.capture.Stop()
|
||||
}
|
||||
f.forwarder.Stop()
|
||||
f.logger.Close()
|
||||
}
|
||||
|
||||
// 加载防火墙规则
|
||||
func (f *Firewall) loadRules() error {
|
||||
// 示例规则:允许本地回环地址的所有流量
|
||||
loopbackRule := &Rule{
|
||||
ID: "rule-1",
|
||||
Name: "Allow Loopback",
|
||||
Protocol: ProtocolAll,
|
||||
SrcIP: "127.0.0.1",
|
||||
DstIP: "127.0.0.1",
|
||||
Action: ActionAllow,
|
||||
Description: "Allow all loopback traffic",
|
||||
Enabled: true,
|
||||
}
|
||||
f.ruleManager.AddRule(loopbackRule)
|
||||
|
||||
// 可以从配置文件或数据库加载更多规则
|
||||
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 启动数据包捕获和过滤
|
||||
func (f *Firewall) startPacketCapture() error {
|
||||
if f.config.CaptureInterface == "" {
|
||||
return fmt.Errorf("capture interface not configured")
|
||||
}
|
||||
|
||||
f.capture = NewPacketCapture(
|
||||
f.config.CaptureInterface,
|
||||
f.ruleManager,
|
||||
f.logger,
|
||||
f.forwarder,
|
||||
)
|
||||
|
||||
if err := f.capture.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start packet capture: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
firewall := NewFirewall()
|
||||
|
||||
if err := firewall.Start(); err != nil {
|
||||
log.Fatalf("Failed to start firewall: %v", err)
|
||||
}
|
||||
defer firewall.Stop()
|
||||
|
||||
// 等待中断信号
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigChan
|
||||
|
||||
log.Println("Firewall stopped successfully")
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RuleAction 定义规则动作类型
|
||||
type RuleAction string
|
||||
|
||||
// 规则动作常量
|
||||
const (
|
||||
ActionAllow RuleAction = "allow"
|
||||
ActionDeny RuleAction = "deny"
|
||||
)
|
||||
|
||||
// Protocol 定义支持的网络协议
|
||||
type Protocol string
|
||||
|
||||
// 协议常量
|
||||
const (
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
ProtocolUDP Protocol = "udp"
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
ProtocolAll Protocol = "all"
|
||||
)
|
||||
|
||||
// Rule 定义防火墙规则结构
|
||||
type Rule struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
SrcIP string `json:"src_ip"`
|
||||
SrcPort string `json:"src_port"`
|
||||
DstIP string `json:"dst_ip"`
|
||||
DstPort string `json:"dst_port"`
|
||||
Action RuleAction `json:"action"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// RuleManager 规则管理器
|
||||
type RuleManager struct {
|
||||
rules []*Rule
|
||||
}
|
||||
|
||||
// NewRuleManager 创建新的规则管理器
|
||||
func NewRuleManager() *RuleManager {
|
||||
return &RuleManager{
|
||||
rules: make([]*Rule, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule 添加规则
|
||||
func (rm *RuleManager) AddRule(rule *Rule) {
|
||||
rm.rules = append(rm.rules, rule)
|
||||
}
|
||||
|
||||
// RemoveRule 移除规则
|
||||
func (rm *RuleManager) RemoveRule(ruleID string) bool {
|
||||
for i, rule := range rm.rules {
|
||||
if rule.ID == ruleID {
|
||||
// 从切片中删除元素
|
||||
rm.rules = append(rm.rules[:i], rm.rules[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRule 获取规则
|
||||
func (rm *RuleManager) GetRule(ruleID string) *Rule {
|
||||
for _, rule := range rm.rules {
|
||||
if rule.ID == ruleID {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRules 列出所有规则
|
||||
func (rm *RuleManager) ListRules() []*Rule {
|
||||
return rm.rules
|
||||
}
|
||||
|
||||
// MatchRule 匹配规则
|
||||
func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule {
|
||||
for _, rule := range rm.rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查协议匹配
|
||||
if rule.Protocol != ProtocolAll && rule.Protocol != protocol {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查源IP匹配
|
||||
if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目的IP匹配
|
||||
if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查源端口匹配
|
||||
if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目的端口匹配
|
||||
if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 找到匹配的规则
|
||||
return rule
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 匹配IP地址(支持通配符*)
|
||||
func matchIP(ip, pattern string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
return ip == pattern
|
||||
}
|
||||
|
||||
// 匹配端口(支持范围和通配符*)
|
||||
func matchPort(port int, pattern string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 处理端口范围(如8080-8090)
|
||||
if strings.Contains(pattern, "-") {
|
||||
parts := strings.Split(pattern, "-")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
startPort, err1 := atoi(parts[0])
|
||||
endPort, err2 := atoi(parts[1])
|
||||
if err1 != nil || err2 != nil || startPort > endPort {
|
||||
return false
|
||||
}
|
||||
return port >= startPort && port <= endPort
|
||||
}
|
||||
|
||||
// 处理单个端口
|
||||
p, err := atoi(pattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return port == p
|
||||
}
|
||||
|
||||
// 字符串转整数辅助函数
|
||||
func atoi(s string) (int, error) {
|
||||
var res int
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, fmt.Errorf("invalid digit: %c", c)
|
||||
}
|
||||
res = res*10 + int(c-'0')
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,287 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRuleManager_AddRemoveRule(t *testing.T) {
|
||||
rm := NewRuleManager()
|
||||
rule := &Rule{
|
||||
ID: "test-1",
|
||||
Action: ActionAllow,
|
||||
}
|
||||
|
||||
// 测试添加规则
|
||||
rm.AddRule(rule)
|
||||
if len(rm.ListRules()) != 1 {
|
||||
t.Error("Expected 1 rule after add, got", len(rm.ListRules()))
|
||||
}
|
||||
|
||||
// 测试获取规则
|
||||
found := rm.GetRule("test-1")
|
||||
if found == nil {
|
||||
t.Error("Expected to find rule test-1")
|
||||
}
|
||||
|
||||
// 测试删除规则
|
||||
removed := rm.RemoveRule("test-1")
|
||||
if !removed {
|
||||
t.Error("Expected to remove rule test-1")
|
||||
}
|
||||
if len(rm.ListRules()) != 0 {
|
||||
t.Error("Expected 0 rules after remove, got", len(rm.ListRules()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleManager_MatchRule(t *testing.T) {
|
||||
rm := NewRuleManager()
|
||||
// 添加测试规则
|
||||
rm.AddRule(&Rule{
|
||||
ID: "tcp-8080",
|
||||
Protocol: ProtocolTCP,
|
||||
DstPort: "8080",
|
||||
Action: ActionAllow,
|
||||
Enabled: true,
|
||||
})
|
||||
rm.AddRule(&Rule{
|
||||
ID: "udp-53",
|
||||
Protocol: ProtocolUDP,
|
||||
DstPort: "53",
|
||||
Action: ActionDeny,
|
||||
Enabled: true,
|
||||
})
|
||||
rm.AddRule(&Rule{
|
||||
ID: "icmp-all",
|
||||
Protocol: ProtocolICMP,
|
||||
Action: ActionAllow,
|
||||
Enabled: true,
|
||||
})
|
||||
rm.AddRule(&Rule{
|
||||
ID: "range-ports",
|
||||
Protocol: ProtocolTCP,
|
||||
DstPort: "8000-9000",
|
||||
Action: ActionAllow,
|
||||
Enabled: true,
|
||||
})
|
||||
// 禁用规则
|
||||
rm.AddRule(&Rule{
|
||||
ID: "disabled-rule",
|
||||
Protocol: ProtocolTCP,
|
||||
DstPort: "80",
|
||||
Action: ActionAllow,
|
||||
Enabled: false,
|
||||
})
|
||||
|
||||
// 测试用例
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP string
|
||||
dstIP string
|
||||
srcPort int
|
||||
dstPort int
|
||||
protocol Protocol
|
||||
expected string
|
||||
}{{
|
||||
name: "match tcp 8080",
|
||||
srcIP: "192.168.1.1",
|
||||
dstIP: "192.168.1.2",
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
protocol: ProtocolTCP,
|
||||
expected: "tcp-8080",
|
||||
}, {
|
||||
name: "match udp 53",
|
||||
srcIP: "10.0.0.1",
|
||||
dstIP: "10.0.0.2",
|
||||
srcPort: 54321,
|
||||
dstPort: 53,
|
||||
protocol: ProtocolUDP,
|
||||
expected: "udp-53",
|
||||
}, {
|
||||
name: "match icmp",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "172.16.0.2",
|
||||
srcPort: 0,
|
||||
dstPort: 0,
|
||||
protocol: ProtocolICMP,
|
||||
expected: "icmp-all",
|
||||
}, {
|
||||
name: "match port range",
|
||||
srcIP: "192.168.1.1",
|
||||
dstIP: "192.168.1.2",
|
||||
srcPort: 12345,
|
||||
dstPort: 8500,
|
||||
protocol: ProtocolTCP,
|
||||
expected: "range-ports",
|
||||
}, {
|
||||
name: "no match",
|
||||
srcIP: "192.168.1.1",
|
||||
dstIP: "192.168.1.2",
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
protocol: ProtocolTCP,
|
||||
expected: "",
|
||||
}, {
|
||||
name: "disabled rule",
|
||||
srcIP: "192.168.1.1",
|
||||
dstIP: "192.168.1.2",
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
protocol: ProtocolTCP,
|
||||
expected: "",
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
src := net.ParseIP(tt.srcIP)
|
||||
dst := net.ParseIP(tt.dstIP)
|
||||
if src == nil || dst == nil {
|
||||
t.Fatal("Invalid IP address in test case")
|
||||
}
|
||||
|
||||
rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol)
|
||||
var ruleID string
|
||||
if rule != nil {
|
||||
ruleID = rule.ID
|
||||
}
|
||||
|
||||
if ruleID != tt.expected {
|
||||
t.Errorf("Expected rule %s, got %s", tt.expected, ruleID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtoi(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
err bool
|
||||
}{{
|
||||
name: "valid number",
|
||||
input: "12345",
|
||||
expected: 12345,
|
||||
err: false,
|
||||
}, {
|
||||
name: "single digit",
|
||||
input: "5",
|
||||
expected: 5,
|
||||
err: false,
|
||||
}, {
|
||||
name: "invalid character",
|
||||
input: "12a3",
|
||||
expected: 0,
|
||||
err: true,
|
||||
}, {
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: 0,
|
||||
err: true,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := atoi(tt.input)
|
||||
if (err != nil) != tt.err {
|
||||
t.Errorf("Expected error %v, got %v", tt.err, err)
|
||||
return
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
pattern string
|
||||
expected bool
|
||||
}{{
|
||||
name: "wildcard match",
|
||||
port: 1234,
|
||||
pattern: "*",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "exact match",
|
||||
port: 8080,
|
||||
pattern: "8080",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "range match lower",
|
||||
port: 8000,
|
||||
pattern: "8000-9000",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "range match upper",
|
||||
port: 9000,
|
||||
pattern: "8000-9000",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "range match middle",
|
||||
port: 8500,
|
||||
pattern: "8000-9000",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "range no match",
|
||||
port: 7999,
|
||||
pattern: "8000-9000",
|
||||
expected: false,
|
||||
}, {
|
||||
name: "invalid pattern",
|
||||
port: 80,
|
||||
pattern: "abc",
|
||||
expected: false,
|
||||
}, {
|
||||
name: "invalid range",
|
||||
port: 80,
|
||||
pattern: "9000-8000",
|
||||
expected: false,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchPort(tt.port, tt.pattern)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
pattern string
|
||||
expected bool
|
||||
}{{
|
||||
name: "wildcard match",
|
||||
ip: "192.168.1.1",
|
||||
pattern: "*",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "exact match",
|
||||
ip: "192.168.1.1",
|
||||
pattern: "192.168.1.1",
|
||||
expected: true,
|
||||
}, {
|
||||
name: "no match",
|
||||
ip: "192.168.1.1",
|
||||
pattern: "10.0.0.1",
|
||||
expected: false,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchIP(tt.ip, tt.pattern)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue