"feat(config): 添加规则目录配置并实现从文件加载规则功能"

This commit is contained in:
程广 2025-07-03 18:22:14 +08:00
parent fcbf669779
commit f27002ddff
4 changed files with 77 additions and 11 deletions

View File

@ -13,6 +13,7 @@ type Config struct {
ForwardEnabled bool `json:"forward_enabled"` ForwardEnabled bool `json:"forward_enabled"`
MaxPacketSize int `json:"max_packet_size"` MaxPacketSize int `json:"max_packet_size"`
ConfigFile string `json:"config_file"` ConfigFile string `json:"config_file"`
RuleDir string `json:"rule_dir"`
} }
// NewConfig 创建新的配置实例 // NewConfig 创建新的配置实例

View File

@ -1,10 +1,13 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"log" "log"
"net" "net"
"os"
"strconv" "strconv"
"strings"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@ -37,6 +40,24 @@ func (f *Forwarder) Stop() {
log.Println("Forwarding service stopped") log.Println("Forwarding service stopped")
} }
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
content, err := os.ReadFile(ruleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule ForwardRule
err := decoder.Decode(&rule)
if err != nil {
return err
}
f.AddForwardRule(rule)
}
return nil
}
// ForwardRule 定义转发规则结构 // ForwardRule 定义转发规则结构
type ForwardRule struct { type ForwardRule struct {
SrcIP string // 源IP SrcIP string // 源IP

23
main.go
View File

@ -5,6 +5,7 @@ import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"syscall" "syscall"
) )
@ -78,6 +79,28 @@ func (f *Firewall) loadRules() error {
Enabled: true, Enabled: true,
} }
f.ruleManager.AddRule(loopbackRule) f.ruleManager.AddRule(loopbackRule)
// load rules from rules dir
rulesDir := f.config.RuleDir
if !filepath.IsAbs(rulesDir) {
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
}
bRuleFile := filepath.Join(rulesDir, "rules.json")
if _, err := os.Stat(bRuleFile); err == nil {
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
return fmt.Errorf("failed to load rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load rules from file: %v", err)
}
fRuleFile := filepath.Join(rulesDir, "forward.json")
if _, err := os.Stat(fRuleFile); err == nil {
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
// 可以从配置文件或数据库加载更多规则 // 可以从配置文件或数据库加载更多规则
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules") f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")

21
rule.go
View File

@ -1,8 +1,10 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"net" "net"
"os"
"strings" "strings"
) )
@ -45,6 +47,25 @@ type RuleManager struct {
rules []*Rule rules []*Rule
} }
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
// load rules from file
content, err := os.ReadFile(bRuleFile)
if err != nil {
return err
}
decoder := json.NewDecoder(strings.NewReader(string(content)))
for decoder.More() {
var rule Rule
err := decoder.Decode(&rule)
if err != nil {
return err
}
rm.AddRule(&rule)
}
return nil
}
// NewRuleManager 创建新的规则管理器 // NewRuleManager 创建新的规则管理器
func NewRuleManager() *RuleManager { func NewRuleManager() *RuleManager {
return &RuleManager{ return &RuleManager{