gofirewall/rule.go

196 lines
3.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"fmt"
"net"
"os"
"strings"
)
// RuleAction 定义规则动作类型
type RuleAction string
// 规则动作常量
const (
ActionAllow RuleAction = "allow"
ActionDeny RuleAction = "deny"
)
// Protocol 定义支持的网络协议
type Protocol string
// 协议常量
const (
ProtocolTCP Protocol = "tcp"
ProtocolUDP Protocol = "udp"
ProtocolICMP Protocol = "icmp"
ProtocolAll Protocol = "all"
)
// Rule 定义防火墙规则结构
type Rule struct {
ID string `json:"id"`
Name string `json:"name"`
Protocol Protocol `json:"protocol"`
SrcIP string `json:"src_ip"`
SrcPort string `json:"src_port"`
DstIP string `json:"dst_ip"`
DstPort string `json:"dst_port"`
Action RuleAction `json:"action"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
}
// RuleManager 规则管理器
type RuleManager struct {
rules []*Rule
}
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
// load rules from file
content, err := os.ReadFile(bRuleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule Rule
err := decoder.Decode(&rule)
if err != nil {
return err
}
rm.AddRule(&rule)
}
return nil
}
// NewRuleManager 创建新的规则管理器
func NewRuleManager() *RuleManager {
return &RuleManager{
rules: make([]*Rule, 0),
}
}
// AddRule 添加规则
func (rm *RuleManager) AddRule(rule *Rule) {
rm.rules = append(rm.rules, rule)
}
// RemoveRule 移除规则
func (rm *RuleManager) RemoveRule(ruleID string) bool {
for i, rule := range rm.rules {
if rule.ID == ruleID {
// 从切片中删除元素
rm.rules = append(rm.rules[:i], rm.rules[i+1:]...)
return true
}
}
return false
}
// GetRule 获取规则
func (rm *RuleManager) GetRule(ruleID string) *Rule {
for _, rule := range rm.rules {
if rule.ID == ruleID {
return rule
}
}
return nil
}
// ListRules 列出所有规则
func (rm *RuleManager) ListRules() []*Rule {
return rm.rules
}
// MatchRule 匹配规则
func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule {
for _, rule := range rm.rules {
if !rule.Enabled {
continue
}
// 检查协议匹配
if rule.Protocol != ProtocolAll && rule.Protocol != protocol {
continue
}
// 检查源IP匹配
if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) {
continue
}
// 检查目的IP匹配
if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) {
continue
}
// 检查源端口匹配
if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) {
continue
}
// 检查目的端口匹配
if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) {
continue
}
// 找到匹配的规则
return rule
}
return nil
}
// 匹配IP地址支持通配符*
func matchIP(ip, pattern string) bool {
if pattern == "*" {
return true
}
return ip == pattern
}
// 匹配端口(支持范围和通配符*
func matchPort(port int, pattern string) bool {
if pattern == "*" {
return true
}
// 处理端口范围如8080-8090
if strings.Contains(pattern, "-") {
parts := strings.Split(pattern, "-")
if len(parts) != 2 {
return false
}
startPort, err1 := atoi(parts[0])
endPort, err2 := atoi(parts[1])
if err1 != nil || err2 != nil || startPort > endPort {
return false
}
return port >= startPort && port <= endPort
}
// 处理单个端口
p, err := atoi(pattern)
if err != nil {
return false
}
return port == p
}
// 字符串转整数辅助函数
func atoi(s string) (int, error) {
if s == "" {
return 0, fmt.Errorf("empty string")
}
var res int
for _, c := range s {
if c < '0' || c > '9' {
return 0, fmt.Errorf("invalid digit: %c", c)
}
res = res*10 + int(c-'0')
}
return res, nil
}