diff --git a/config.go b/config.go index 33e1631..2a4faf3 100644 --- a/config.go +++ b/config.go @@ -8,21 +8,22 @@ import ( // Config 防火墙配置结构 type Config struct { - LogLevel LogLevel `json:"log_level"` - CaptureInterface string `json:"capture_interface"` - ForwardEnabled bool `json:"forward_enabled"` - MaxPacketSize int `json:"max_packet_size"` - ConfigFile string `json:"config_file"` + LogLevel LogLevel `json:"log_level"` + CaptureInterface string `json:"capture_interface"` + ForwardEnabled bool `json:"forward_enabled"` + MaxPacketSize int `json:"max_packet_size"` + ConfigFile string `json:"config_file"` + RuleDir string `json:"rule_dir"` } // NewConfig 创建新的配置实例 func NewConfig() *Config { return &Config{ - LogLevel: LogLevelInfo, + LogLevel: LogLevelInfo, CaptureInterface: "", - ForwardEnabled: false, - MaxPacketSize: 65536, - ConfigFile: "firewall.json", + ForwardEnabled: false, + MaxPacketSize: 65536, + ConfigFile: "firewall.json", } } @@ -73,4 +74,4 @@ func (c *Config) Update(newConfig *Config) { c.CaptureInterface = newConfig.CaptureInterface c.ForwardEnabled = newConfig.ForwardEnabled c.MaxPacketSize = newConfig.MaxPacketSize -} \ No newline at end of file +} diff --git a/forwarder.go b/forwarder.go index 4cc1022..a9f6705 100644 --- a/forwarder.go +++ b/forwarder.go @@ -1,10 +1,13 @@ package main import ( + "encoding/json" "fmt" "log" "net" + "os" "strconv" + "strings" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -37,6 +40,24 @@ func (f *Forwarder) Stop() { log.Println("Forwarding service stopped") } +func (f *Forwarder) LoadRulesFromFile(ruleFile string) error { + content, err := os.ReadFile(ruleFile) + if err != nil { + + return err + } + decoder := json.NewDecoder(strings.NewReader(string(content))) + for decoder.More() { + var rule ForwardRule + err := decoder.Decode(&rule) + if err != nil { + return err + } + f.AddForwardRule(rule) + } + return nil +} + // ForwardRule 定义转发规则结构 type ForwardRule struct { SrcIP string // 源IP diff --git a/main.go b/main.go index f394bcc..21ad29c 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "path/filepath" "syscall" ) @@ -78,6 +79,28 @@ func (f *Firewall) loadRules() error { Enabled: true, } f.ruleManager.AddRule(loopbackRule) + // load rules from rules dir + rulesDir := f.config.RuleDir + if !filepath.IsAbs(rulesDir) { + rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir) + } + bRuleFile := filepath.Join(rulesDir, "rules.json") + if _, err := os.Stat(bRuleFile); err == nil { + if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil { + return fmt.Errorf("failed to load rules from file: %v", err) + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to load rules from file: %v", err) + } + fRuleFile := filepath.Join(rulesDir, "forward.json") + + if _, err := os.Stat(fRuleFile); err == nil { + if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil { + return fmt.Errorf("failed to load forward rules from file: %v", err) + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to load forward rules from file: %v", err) + } // 可以从配置文件或数据库加载更多规则 f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules") @@ -118,4 +141,4 @@ func main() { <-sigChan log.Println("Firewall stopped successfully") -} \ No newline at end of file +} diff --git a/rule.go b/rule.go index 2fc8871..cdaf06c 100644 --- a/rule.go +++ b/rule.go @@ -1,8 +1,10 @@ package main import ( + "encoding/json" "fmt" "net" + "os" "strings" ) @@ -45,6 +47,25 @@ type RuleManager struct { rules []*Rule } +func (rm *RuleManager) LoadFromFile(bRuleFile string) error { + // load rules from file + content, err := os.ReadFile(bRuleFile) + if err != nil { + + return err + } + decoder := json.NewDecoder(strings.NewReader(string(content))) + for decoder.More() { + var rule Rule + err := decoder.Decode(&rule) + if err != nil { + return err + } + rm.AddRule(&rule) + } + return nil +} + // NewRuleManager 创建新的规则管理器 func NewRuleManager() *RuleManager { return &RuleManager{