package main import ( "net" "testing" ) func TestRuleManager_AddRemoveRule(t *testing.T) { rm := NewRuleManager() rule := &Rule{ ID: "test-1", Action: ActionAllow, } // 测试添加规则 rm.AddRule(rule) if len(rm.ListRules()) != 1 { t.Error("Expected 1 rule after add, got", len(rm.ListRules())) } // 测试获取规则 found := rm.GetRule("test-1") if found == nil { t.Error("Expected to find rule test-1") } // 测试删除规则 removed := rm.RemoveRule("test-1") if !removed { t.Error("Expected to remove rule test-1") } if len(rm.ListRules()) != 0 { t.Error("Expected 0 rules after remove, got", len(rm.ListRules())) } } func TestRuleManager_MatchRule(t *testing.T) { rm := NewRuleManager() // 添加测试规则 rm.AddRule(&Rule{ ID: "tcp-8080", Protocol: ProtocolTCP, DstPort: "8080", Action: ActionAllow, Enabled: true, }) rm.AddRule(&Rule{ ID: "udp-53", Protocol: ProtocolUDP, DstPort: "53", Action: ActionDeny, Enabled: true, }) rm.AddRule(&Rule{ ID: "icmp-all", Protocol: ProtocolICMP, Action: ActionAllow, Enabled: true, }) rm.AddRule(&Rule{ ID: "range-ports", Protocol: ProtocolTCP, DstPort: "8000-9000", Action: ActionAllow, Enabled: true, }) // 禁用规则 rm.AddRule(&Rule{ ID: "disabled-rule", Protocol: ProtocolTCP, DstPort: "80", Action: ActionAllow, Enabled: false, }) // 测试用例 tests := []struct { name string srcIP string dstIP string srcPort int dstPort int protocol Protocol expected string }{{ name: "match tcp 8080", srcIP: "192.168.1.1", dstIP: "192.168.1.2", srcPort: 12345, dstPort: 8080, protocol: ProtocolTCP, expected: "tcp-8080", }, { name: "match udp 53", srcIP: "10.0.0.1", dstIP: "10.0.0.2", srcPort: 54321, dstPort: 53, protocol: ProtocolUDP, expected: "udp-53", }, { name: "match icmp", srcIP: "172.16.0.1", dstIP: "172.16.0.2", srcPort: 0, dstPort: 0, protocol: ProtocolICMP, expected: "icmp-all", }, { name: "match port range", srcIP: "192.168.1.1", dstIP: "192.168.1.2", srcPort: 12345, dstPort: 8500, protocol: ProtocolTCP, expected: "range-ports", }, { name: "no match", srcIP: "192.168.1.1", dstIP: "192.168.1.2", srcPort: 12345, dstPort: 80, protocol: ProtocolTCP, expected: "", }, { name: "disabled rule", srcIP: "192.168.1.1", dstIP: "192.168.1.2", srcPort: 12345, dstPort: 80, protocol: ProtocolTCP, expected: "", }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { src := net.ParseIP(tt.srcIP) dst := net.ParseIP(tt.dstIP) if src == nil || dst == nil { t.Fatal("Invalid IP address in test case") } rule := rm.MatchRule(src, dst, tt.srcPort, tt.dstPort, tt.protocol) var ruleID string if rule != nil { ruleID = rule.ID } if ruleID != tt.expected { t.Errorf("Expected rule %s, got %s", tt.expected, ruleID) } }) } } func TestAtoi(t *testing.T) { tests := []struct { name string input string expected int err bool }{{ name: "valid number", input: "12345", expected: 12345, err: false, }, { name: "single digit", input: "5", expected: 5, err: false, }, { name: "invalid character", input: "12a3", expected: 0, err: true, }, { name: "empty string", input: "", expected: 0, err: true, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := atoi(tt.input) if (err != nil) != tt.err { t.Errorf("Expected error %v, got %v", tt.err, err) return } if result != tt.expected { t.Errorf("Expected %d, got %d", tt.expected, result) } }) } } func TestMatchPort(t *testing.T) { tests := []struct { name string port int pattern string expected bool }{{ name: "wildcard match", port: 1234, pattern: "*", expected: true, }, { name: "exact match", port: 8080, pattern: "8080", expected: true, }, { name: "range match lower", port: 8000, pattern: "8000-9000", expected: true, }, { name: "range match upper", port: 9000, pattern: "8000-9000", expected: true, }, { name: "range match middle", port: 8500, pattern: "8000-9000", expected: true, }, { name: "range no match", port: 7999, pattern: "8000-9000", expected: false, }, { name: "invalid pattern", port: 80, pattern: "abc", expected: false, }, { name: "invalid range", port: 80, pattern: "9000-8000", expected: false, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchPort(tt.port, tt.pattern) if result != tt.expected { t.Errorf("For port %d and pattern '%s', expected %v, got %v", tt.port, tt.pattern, tt.expected, result) } }) } } func TestMatchIP(t *testing.T) { tests := []struct { name string ip string pattern string expected bool }{{ name: "wildcard match", ip: "192.168.1.1", pattern: "*", expected: true, }, { name: "exact match", ip: "192.168.1.1", pattern: "192.168.1.1", expected: true, }, { name: "no match", ip: "192.168.1.1", pattern: "10.0.0.1", expected: false, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchIP(tt.ip, tt.pattern) if result != tt.expected { t.Errorf("For IP %s and pattern '%s', expected %v, got %v", tt.ip, tt.pattern, tt.expected, result) } }) } }