"feat(config): 添加规则目录配置并实现从文件加载规则功能"

This commit is contained in:
程广 2025-07-03 18:22:14 +08:00
parent fcbf669779
commit f27002ddff
4 changed files with 77 additions and 11 deletions

View File

@ -8,21 +8,22 @@ import (
// Config 防火墙配置结构
type Config struct {
LogLevel LogLevel `json:"log_level"`
CaptureInterface string `json:"capture_interface"`
ForwardEnabled bool `json:"forward_enabled"`
MaxPacketSize int `json:"max_packet_size"`
ConfigFile string `json:"config_file"`
LogLevel LogLevel `json:"log_level"`
CaptureInterface string `json:"capture_interface"`
ForwardEnabled bool `json:"forward_enabled"`
MaxPacketSize int `json:"max_packet_size"`
ConfigFile string `json:"config_file"`
RuleDir string `json:"rule_dir"`
}
// NewConfig 创建新的配置实例
func NewConfig() *Config {
return &Config{
LogLevel: LogLevelInfo,
LogLevel: LogLevelInfo,
CaptureInterface: "",
ForwardEnabled: false,
MaxPacketSize: 65536,
ConfigFile: "firewall.json",
ForwardEnabled: false,
MaxPacketSize: 65536,
ConfigFile: "firewall.json",
}
}
@ -73,4 +74,4 @@ func (c *Config) Update(newConfig *Config) {
c.CaptureInterface = newConfig.CaptureInterface
c.ForwardEnabled = newConfig.ForwardEnabled
c.MaxPacketSize = newConfig.MaxPacketSize
}
}

View File

@ -1,10 +1,13 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@ -37,6 +40,24 @@ func (f *Forwarder) Stop() {
log.Println("Forwarding service stopped")
}
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
content, err := os.ReadFile(ruleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule ForwardRule
err := decoder.Decode(&rule)
if err != nil {
return err
}
f.AddForwardRule(rule)
}
return nil
}
// ForwardRule 定义转发规则结构
type ForwardRule struct {
SrcIP string // 源IP

25
main.go
View File

@ -5,6 +5,7 @@ import (
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
)
@ -78,6 +79,28 @@ func (f *Firewall) loadRules() error {
Enabled: true,
}
f.ruleManager.AddRule(loopbackRule)
// load rules from rules dir
rulesDir := f.config.RuleDir
if !filepath.IsAbs(rulesDir) {
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
}
bRuleFile := filepath.Join(rulesDir, "rules.json")
if _, err := os.Stat(bRuleFile); err == nil {
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
return fmt.Errorf("failed to load rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load rules from file: %v", err)
}
fRuleFile := filepath.Join(rulesDir, "forward.json")
if _, err := os.Stat(fRuleFile); err == nil {
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
// 可以从配置文件或数据库加载更多规则
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
@ -118,4 +141,4 @@ func main() {
<-sigChan
log.Println("Firewall stopped successfully")
}
}

21
rule.go
View File

@ -1,8 +1,10 @@
package main
import (
"encoding/json"
"fmt"
"net"
"os"
"strings"
)
@ -45,6 +47,25 @@ type RuleManager struct {
rules []*Rule
}
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
// load rules from file
content, err := os.ReadFile(bRuleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule Rule
err := decoder.Decode(&rule)
if err != nil {
return err
}
rm.AddRule(&rule)
}
return nil
}
// NewRuleManager 创建新的规则管理器
func NewRuleManager() *RuleManager {
return &RuleManager{