package main import ( "encoding/json" "fmt" "net" "os" "strings" ) // RuleAction 定义规则动作类型 type RuleAction string // 规则动作常量 const ( ActionAllow RuleAction = "allow" ActionDeny RuleAction = "deny" ) // Protocol 定义支持的网络协议 type Protocol string // 协议常量 const ( ProtocolTCP Protocol = "tcp" ProtocolUDP Protocol = "udp" ProtocolICMP Protocol = "icmp" ProtocolAll Protocol = "all" ) // Rule 定义防火墙规则结构 type Rule struct { ID string `json:"id"` Name string `json:"name"` Protocol Protocol `json:"protocol"` SrcIP string `json:"src_ip"` SrcPort string `json:"src_port"` DstIP string `json:"dst_ip"` DstPort string `json:"dst_port"` Action RuleAction `json:"action"` Description string `json:"description"` Enabled bool `json:"enabled"` } // RuleManager 规则管理器 type RuleManager struct { rules []*Rule } func (rm *RuleManager) LoadFromFile(bRuleFile string) error { // load rules from file content, err := os.ReadFile(bRuleFile) if err != nil { return err } decoder := json.NewDecoder(strings.NewReader(string(content))) for decoder.More() { var rule Rule err := decoder.Decode(&rule) if err != nil { return err } rm.AddRule(&rule) } return nil } // NewRuleManager 创建新的规则管理器 func NewRuleManager() *RuleManager { return &RuleManager{ rules: make([]*Rule, 0), } } // AddRule 添加规则 func (rm *RuleManager) AddRule(rule *Rule) { rm.rules = append(rm.rules, rule) } // RemoveRule 移除规则 func (rm *RuleManager) RemoveRule(ruleID string) bool { for i, rule := range rm.rules { if rule.ID == ruleID { // 从切片中删除元素 rm.rules = append(rm.rules[:i], rm.rules[i+1:]...) return true } } return false } // GetRule 获取规则 func (rm *RuleManager) GetRule(ruleID string) *Rule { for _, rule := range rm.rules { if rule.ID == ruleID { return rule } } return nil } // ListRules 列出所有规则 func (rm *RuleManager) ListRules() []*Rule { return rm.rules } // MatchRule 匹配规则 func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule { for _, rule := range rm.rules { if !rule.Enabled { continue } // 检查协议匹配 if rule.Protocol != ProtocolAll && rule.Protocol != protocol { continue } // 检查源IP匹配 if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) { continue } // 检查目的IP匹配 if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) { continue } // 检查源端口匹配 if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) { continue } // 检查目的端口匹配 if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) { continue } // 找到匹配的规则 return rule } return nil } // 匹配IP地址(支持通配符*) func matchIP(ip, pattern string) bool { if pattern == "*" { return true } return ip == pattern } // 匹配端口(支持范围和通配符*) func matchPort(port int, pattern string) bool { if pattern == "*" { return true } // 处理端口范围(如8080-8090) if strings.Contains(pattern, "-") { parts := strings.Split(pattern, "-") if len(parts) != 2 { return false } startPort, err1 := atoi(parts[0]) endPort, err2 := atoi(parts[1]) if err1 != nil || err2 != nil || startPort > endPort { return false } return port >= startPort && port <= endPort } // 处理单个端口 p, err := atoi(pattern) if err != nil { return false } return port == p } // 字符串转整数辅助函数 func atoi(s string) (int, error) { var res int for _, c := range s { if c < '0' || c > '9' { return 0, fmt.Errorf("invalid digit: %c", c) } res = res*10 + int(c-'0') } return res, nil }