196 lines
3.9 KiB
Go
196 lines
3.9 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net"
|
||
"os"
|
||
"strings"
|
||
)
|
||
|
||
// RuleAction 定义规则动作类型
|
||
type RuleAction string
|
||
|
||
// 规则动作常量
|
||
const (
|
||
ActionAllow RuleAction = "allow"
|
||
ActionDeny RuleAction = "deny"
|
||
)
|
||
|
||
// Protocol 定义支持的网络协议
|
||
type Protocol string
|
||
|
||
// 协议常量
|
||
const (
|
||
ProtocolTCP Protocol = "tcp"
|
||
ProtocolUDP Protocol = "udp"
|
||
ProtocolICMP Protocol = "icmp"
|
||
ProtocolAll Protocol = "all"
|
||
)
|
||
|
||
// Rule 定义防火墙规则结构
|
||
type Rule struct {
|
||
ID string `json:"id"`
|
||
Name string `json:"name"`
|
||
Protocol Protocol `json:"protocol"`
|
||
SrcIP string `json:"src_ip"`
|
||
SrcPort string `json:"src_port"`
|
||
DstIP string `json:"dst_ip"`
|
||
DstPort string `json:"dst_port"`
|
||
Action RuleAction `json:"action"`
|
||
Description string `json:"description"`
|
||
Enabled bool `json:"enabled"`
|
||
}
|
||
|
||
// RuleManager 规则管理器
|
||
type RuleManager struct {
|
||
rules []*Rule
|
||
}
|
||
|
||
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
|
||
// load rules from file
|
||
content, err := os.ReadFile(bRuleFile)
|
||
if err != nil {
|
||
|
||
return err
|
||
}
|
||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||
for decoder.More() {
|
||
var rule Rule
|
||
err := decoder.Decode(&rule)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
rm.AddRule(&rule)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// NewRuleManager 创建新的规则管理器
|
||
func NewRuleManager() *RuleManager {
|
||
return &RuleManager{
|
||
rules: make([]*Rule, 0),
|
||
}
|
||
}
|
||
|
||
// AddRule 添加规则
|
||
func (rm *RuleManager) AddRule(rule *Rule) {
|
||
rm.rules = append(rm.rules, rule)
|
||
}
|
||
|
||
// RemoveRule 移除规则
|
||
func (rm *RuleManager) RemoveRule(ruleID string) bool {
|
||
for i, rule := range rm.rules {
|
||
if rule.ID == ruleID {
|
||
// 从切片中删除元素
|
||
rm.rules = append(rm.rules[:i], rm.rules[i+1:]...)
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// GetRule 获取规则
|
||
func (rm *RuleManager) GetRule(ruleID string) *Rule {
|
||
for _, rule := range rm.rules {
|
||
if rule.ID == ruleID {
|
||
return rule
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ListRules 列出所有规则
|
||
func (rm *RuleManager) ListRules() []*Rule {
|
||
return rm.rules
|
||
}
|
||
|
||
// MatchRule 匹配规则
|
||
func (rm *RuleManager) MatchRule(srcIP, dstIP net.IP, srcPort, dstPort int, protocol Protocol) *Rule {
|
||
for _, rule := range rm.rules {
|
||
if !rule.Enabled {
|
||
continue
|
||
}
|
||
|
||
// 检查协议匹配
|
||
if rule.Protocol != ProtocolAll && rule.Protocol != protocol {
|
||
continue
|
||
}
|
||
|
||
// 检查源IP匹配
|
||
if rule.SrcIP != "" && !matchIP(srcIP.String(), rule.SrcIP) {
|
||
continue
|
||
}
|
||
|
||
// 检查目的IP匹配
|
||
if rule.DstIP != "" && !matchIP(dstIP.String(), rule.DstIP) {
|
||
continue
|
||
}
|
||
|
||
// 检查源端口匹配
|
||
if rule.SrcPort != "" && !matchPort(srcPort, rule.SrcPort) {
|
||
continue
|
||
}
|
||
|
||
// 检查目的端口匹配
|
||
if rule.DstPort != "" && !matchPort(dstPort, rule.DstPort) {
|
||
continue
|
||
}
|
||
|
||
// 找到匹配的规则
|
||
return rule
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 匹配IP地址(支持通配符*)
|
||
func matchIP(ip, pattern string) bool {
|
||
if pattern == "*" {
|
||
return true
|
||
}
|
||
return ip == pattern
|
||
}
|
||
|
||
// 匹配端口(支持范围和通配符*)
|
||
func matchPort(port int, pattern string) bool {
|
||
if pattern == "*" {
|
||
return true
|
||
}
|
||
|
||
// 处理端口范围(如8080-8090)
|
||
if strings.Contains(pattern, "-") {
|
||
parts := strings.Split(pattern, "-")
|
||
if len(parts) != 2 {
|
||
return false
|
||
}
|
||
startPort, err1 := atoi(parts[0])
|
||
endPort, err2 := atoi(parts[1])
|
||
if err1 != nil || err2 != nil || startPort > endPort {
|
||
return false
|
||
}
|
||
return port >= startPort && port <= endPort
|
||
}
|
||
|
||
// 处理单个端口
|
||
p, err := atoi(pattern)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return port == p
|
||
}
|
||
|
||
// 字符串转整数辅助函数
|
||
func atoi(s string) (int, error) {
|
||
if s == "" {
|
||
return 0, fmt.Errorf("empty string")
|
||
}
|
||
var res int
|
||
for _, c := range s {
|
||
if c < '0' || c > '9' {
|
||
return 0, fmt.Errorf("invalid digit: %c", c)
|
||
}
|
||
res = res*10 + int(c-'0')
|
||
}
|
||
return res, nil
|
||
}
|