287 lines
5.5 KiB
Go
287 lines
5.5 KiB
Go
package main
|
|
|
|
import (
|
|
"net"
|
|
"testing"
|
|
)
|
|
|
|
func TestRuleManager_AddRemoveRule(t *testing.T) {
|
|
rm := NewRuleManager()
|
|
rule := &Rule{
|
|
ID: "test-1",
|
|
Action: ActionAllow,
|
|
}
|
|
|
|
// 测试添加规则
|
|
rm.AddRule(rule)
|
|
if len(rm.ListRules()) != 1 {
|
|
t.Error("Expected 1 rule after add, got", len(rm.ListRules()))
|
|
}
|
|
|
|
// 测试获取规则
|
|
found := rm.GetRule("test-1")
|
|
if found == nil {
|
|
t.Error("Expected to find rule test-1")
|
|
}
|
|
|
|
// 测试删除规则
|
|
removed := rm.RemoveRule("test-1")
|
|
if !removed {
|
|
t.Error("Expected to remove rule test-1")
|
|
}
|
|
if len(rm.ListRules()) != 0 {
|
|
t.Error("Expected 0 rules after remove, got", len(rm.ListRules()))
|
|
}
|
|
}
|
|
|
|
func TestRuleManager_MatchRule(t *testing.T) {
|
|
rm := NewRuleManager()
|
|
// 添加测试规则
|
|
rm.AddRule(&Rule{
|
|
ID: "tcp-8080",
|
|
Protocol: ProtocolTCP,
|
|
DstPort: "8080",
|
|
Action: ActionAllow,
|
|
Enabled: true,
|
|
})
|
|
rm.AddRule(&Rule{
|
|
ID: "udp-53",
|
|
Protocol: ProtocolUDP,
|
|
DstPort: "53",
|
|
Action: ActionDeny,
|
|
Enabled: true,
|
|
})
|
|
rm.AddRule(&Rule{
|
|
ID: "icmp-all",
|
|
Protocol: ProtocolICMP,
|
|
Action: ActionAllow,
|
|
Enabled: true,
|
|
})
|
|
rm.AddRule(&Rule{
|
|
ID: "range-ports",
|
|
Protocol: ProtocolTCP,
|
|
DstPort: "8000-9000",
|
|
Action: ActionAllow,
|
|
Enabled: true,
|
|
})
|
|
// 禁用规则
|
|
rm.AddRule(&Rule{
|
|
ID: "disabled-rule",
|
|
Protocol: ProtocolTCP,
|
|
DstPort: "80",
|
|
Action: ActionAllow,
|
|
Enabled: false,
|
|
})
|
|
|
|
// 测试用例
|
|
tests := []struct {
|
|
name string
|
|
srcIP string
|
|
dstIP string
|
|
srcPort int
|
|
dstPort int
|
|
protocol Protocol
|
|
expected string
|
|
}{{
|
|
name: "match tcp 8080",
|
|
srcIP: "192.168.1.1",
|
|
dstIP: "192.168.1.2",
|
|
srcPort: 12345,
|
|
dstPort: 8080,
|
|
protocol: ProtocolTCP,
|
|
expected: "tcp-8080",
|
|
}, {
|
|
name: "match udp 53",
|
|
srcIP: "10.0.0.1",
|
|
dstIP: "10.0.0.2",
|
|
srcPort: 54321,
|
|
dstPort: 53,
|
|
protocol: ProtocolUDP,
|
|
expected: "udp-53",
|
|
}, {
|
|
name: "match icmp",
|
|
srcIP: "172.16.0.1",
|
|
dstIP: "172.16.0.2",
|
|
srcPort: 0,
|
|
dstPort: 0,
|
|
protocol: ProtocolICMP,
|
|
expected: "icmp-all",
|
|
}, {
|
|
name: "match port range",
|
|
srcIP: "192.168.1.1",
|
|
dstIP: "192.168.1.2",
|
|
srcPort: 12345,
|
|
dstPort: 8500,
|
|
protocol: ProtocolTCP,
|
|
expected: "range-ports",
|
|
}, {
|
|
name: "no match",
|
|
srcIP: "192.168.1.1",
|
|
dstIP: "192.168.1.2",
|
|
srcPort: 12345,
|
|
dstPort: 80,
|
|
protocol: ProtocolTCP,
|
|
expected: "",
|
|
}, {
|
|
name: "disabled rule",
|
|
srcIP: "192.168.1.1",
|
|
dstIP: "192.168.1.2",
|
|
srcPort: 12345,
|
|
dstPort: 80,
|
|
protocol: ProtocolTCP,
|
|
expected: "",
|
|
}}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
src := net.ParseIP(tt.srcIP)
|
|
dst := net.ParseIP(tt.dstIP)
|
|
if src == nil || dst == nil {
|
|
t.Fatal("Invalid IP address in test case")
|
|
}
|
|
|
|
rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol)
|
|
var ruleID string
|
|
if rule != nil {
|
|
ruleID = rule.ID
|
|
}
|
|
|
|
if ruleID != tt.expected {
|
|
t.Errorf("Expected rule %s, got %s", tt.expected, ruleID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAtoi(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected int
|
|
err bool
|
|
}{{
|
|
name: "valid number",
|
|
input: "12345",
|
|
expected: 12345,
|
|
err: false,
|
|
}, {
|
|
name: "single digit",
|
|
input: "5",
|
|
expected: 5,
|
|
err: false,
|
|
}, {
|
|
name: "invalid character",
|
|
input: "12a3",
|
|
expected: 0,
|
|
err: true,
|
|
}, {
|
|
name: "empty string",
|
|
input: "",
|
|
expected: 0,
|
|
err: true,
|
|
}}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := atoi(tt.input)
|
|
if (err != nil) != tt.err {
|
|
t.Errorf("Expected error %v, got %v", tt.err, err)
|
|
return
|
|
}
|
|
if result != tt.expected {
|
|
t.Errorf("Expected %d, got %d", tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchPort(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
port int
|
|
pattern string
|
|
expected bool
|
|
}{{
|
|
name: "wildcard match",
|
|
port: 1234,
|
|
pattern: "*",
|
|
expected: true,
|
|
}, {
|
|
name: "exact match",
|
|
port: 8080,
|
|
pattern: "8080",
|
|
expected: true,
|
|
}, {
|
|
name: "range match lower",
|
|
port: 8000,
|
|
pattern: "8000-9000",
|
|
expected: true,
|
|
}, {
|
|
name: "range match upper",
|
|
port: 9000,
|
|
pattern: "8000-9000",
|
|
expected: true,
|
|
}, {
|
|
name: "range match middle",
|
|
port: 8500,
|
|
pattern: "8000-9000",
|
|
expected: true,
|
|
}, {
|
|
name: "range no match",
|
|
port: 7999,
|
|
pattern: "8000-9000",
|
|
expected: false,
|
|
}, {
|
|
name: "invalid pattern",
|
|
port: 80,
|
|
pattern: "abc",
|
|
expected: false,
|
|
}, {
|
|
name: "invalid range",
|
|
port: 80,
|
|
pattern: "9000-8000",
|
|
expected: false,
|
|
}}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := matchPort(tt.port, tt.pattern)
|
|
if result != tt.expected {
|
|
t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchIP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
pattern string
|
|
expected bool
|
|
}{{
|
|
name: "wildcard match",
|
|
ip: "192.168.1.1",
|
|
pattern: "*",
|
|
expected: true,
|
|
}, {
|
|
name: "exact match",
|
|
ip: "192.168.1.1",
|
|
pattern: "192.168.1.1",
|
|
expected: true,
|
|
}, {
|
|
name: "no match",
|
|
ip: "192.168.1.1",
|
|
pattern: "10.0.0.1",
|
|
expected: false,
|
|
}}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := matchIP(tt.ip, tt.pattern)
|
|
if result != tt.expected {
|
|
t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
} |