gofirewall/rule_test.go

287 lines
5.5 KiB
Go

package main
import (
"net"
"testing"
)
func TestRuleManager_AddRemoveRule(t *testing.T) {
rm := NewRuleManager()
rule := &Rule{
ID: "test-1",
Action: ActionAllow,
}
// 测试添加规则
rm.AddRule(rule)
if len(rm.ListRules()) != 1 {
t.Error("Expected 1 rule after add, got", len(rm.ListRules()))
}
// 测试获取规则
found := rm.GetRule("test-1")
if found == nil {
t.Error("Expected to find rule test-1")
}
// 测试删除规则
removed := rm.RemoveRule("test-1")
if !removed {
t.Error("Expected to remove rule test-1")
}
if len(rm.ListRules()) != 0 {
t.Error("Expected 0 rules after remove, got", len(rm.ListRules()))
}
}
func TestRuleManager_MatchRule(t *testing.T) {
rm := NewRuleManager()
// 添加测试规则
rm.AddRule(&Rule{
ID: "tcp-8080",
Protocol: ProtocolTCP,
DstPort: "8080",
Action: ActionAllow,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "udp-53",
Protocol: ProtocolUDP,
DstPort: "53",
Action: ActionDeny,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "icmp-all",
Protocol: ProtocolICMP,
Action: ActionAllow,
Enabled: true,
})
rm.AddRule(&Rule{
ID: "range-ports",
Protocol: ProtocolTCP,
DstPort: "8000-9000",
Action: ActionAllow,
Enabled: true,
})
// 禁用规则
rm.AddRule(&Rule{
ID: "disabled-rule",
Protocol: ProtocolTCP,
DstPort: "80",
Action: ActionAllow,
Enabled: false,
})
// 测试用例
tests := []struct {
name string
srcIP string
dstIP string
srcPort int
dstPort int
protocol Protocol
expected string
}{{
name: "match tcp 8080",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 8080,
protocol: ProtocolTCP,
expected: "tcp-8080",
}, {
name: "match udp 53",
srcIP: "10.0.0.1",
dstIP: "10.0.0.2",
srcPort: 54321,
dstPort: 53,
protocol: ProtocolUDP,
expected: "udp-53",
}, {
name: "match icmp",
srcIP: "172.16.0.1",
dstIP: "172.16.0.2",
srcPort: 0,
dstPort: 0,
protocol: ProtocolICMP,
expected: "icmp-all",
}, {
name: "match port range",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 8500,
protocol: ProtocolTCP,
expected: "range-ports",
}, {
name: "no match",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 80,
protocol: ProtocolTCP,
expected: "",
}, {
name: "disabled rule",
srcIP: "192.168.1.1",
dstIP: "192.168.1.2",
srcPort: 12345,
dstPort: 80,
protocol: ProtocolTCP,
expected: "",
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
src := net.ParseIP(tt.srcIP)
dst := net.ParseIP(tt.dstIP)
if src == nil || dst == nil {
t.Fatal("Invalid IP address in test case")
}
rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol)
var ruleID string
if rule != nil {
ruleID = rule.ID
}
if ruleID != tt.expected {
t.Errorf("Expected rule %s, got %s", tt.expected, ruleID)
}
})
}
}
func TestAtoi(t *testing.T) {
tests := []struct {
name string
input string
expected int
err bool
}{{
name: "valid number",
input: "12345",
expected: 12345,
err: false,
}, {
name: "single digit",
input: "5",
expected: 5,
err: false,
}, {
name: "invalid character",
input: "12a3",
expected: 0,
err: true,
}, {
name: "empty string",
input: "",
expected: 0,
err: true,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := atoi(tt.input)
if (err != nil) != tt.err {
t.Errorf("Expected error %v, got %v", tt.err, err)
return
}
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestMatchPort(t *testing.T) {
tests := []struct {
name string
port int
pattern string
expected bool
}{{
name: "wildcard match",
port: 1234,
pattern: "*",
expected: true,
}, {
name: "exact match",
port: 8080,
pattern: "8080",
expected: true,
}, {
name: "range match lower",
port: 8000,
pattern: "8000-9000",
expected: true,
}, {
name: "range match upper",
port: 9000,
pattern: "8000-9000",
expected: true,
}, {
name: "range match middle",
port: 8500,
pattern: "8000-9000",
expected: true,
}, {
name: "range no match",
port: 7999,
pattern: "8000-9000",
expected: false,
}, {
name: "invalid pattern",
port: 80,
pattern: "abc",
expected: false,
}, {
name: "invalid range",
port: 80,
pattern: "9000-8000",
expected: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchPort(tt.port, tt.pattern)
if result != tt.expected {
t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result)
}
})
}
}
func TestMatchIP(t *testing.T) {
tests := []struct {
name string
ip string
pattern string
expected bool
}{{
name: "wildcard match",
ip: "192.168.1.1",
pattern: "*",
expected: true,
}, {
name: "exact match",
ip: "192.168.1.1",
pattern: "192.168.1.1",
expected: true,
}, {
name: "no match",
ip: "192.168.1.1",
pattern: "10.0.0.1",
expected: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchIP(tt.ip, tt.pattern)
if result != tt.expected {
t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result)
}
})
}
}