"feat(config): 添加规则目录配置并实现从文件加载规则功能"
This commit is contained in:
parent
fcbf669779
commit
f27002ddff
|
|
@ -13,6 +13,7 @@ type Config struct {
|
|||
ForwardEnabled bool `json:"forward_enabled"`
|
||||
MaxPacketSize int `json:"max_packet_size"`
|
||||
ConfigFile string `json:"config_file"`
|
||||
RuleDir string `json:"rule_dir"`
|
||||
}
|
||||
|
||||
// NewConfig 创建新的配置实例
|
||||
|
|
|
|||
21
forwarder.go
21
forwarder.go
|
|
@ -1,10 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
|
@ -37,6 +40,24 @@ func (f *Forwarder) Stop() {
|
|||
log.Println("Forwarding service stopped")
|
||||
}
|
||||
|
||||
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||||
content, err := os.ReadFile(ruleFile)
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
}
|
||||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||
for decoder.More() {
|
||||
var rule ForwardRule
|
||||
err := decoder.Decode(&rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.AddForwardRule(rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardRule 定义转发规则结构
|
||||
type ForwardRule struct {
|
||||
SrcIP string // 源IP
|
||||
|
|
|
|||
23
main.go
23
main.go
|
|
@ -5,6 +5,7 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
|
|
@ -78,6 +79,28 @@ func (f *Firewall) loadRules() error {
|
|||
Enabled: true,
|
||||
}
|
||||
f.ruleManager.AddRule(loopbackRule)
|
||||
// load rules from rules dir
|
||||
rulesDir := f.config.RuleDir
|
||||
if !filepath.IsAbs(rulesDir) {
|
||||
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
|
||||
}
|
||||
bRuleFile := filepath.Join(rulesDir, "rules.json")
|
||||
if _, err := os.Stat(bRuleFile); err == nil {
|
||||
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
|
||||
return fmt.Errorf("failed to load rules from file: %v", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to load rules from file: %v", err)
|
||||
}
|
||||
fRuleFile := filepath.Join(rulesDir, "forward.json")
|
||||
|
||||
if _, err := os.Stat(fRuleFile); err == nil {
|
||||
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
|
||||
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
||||
}
|
||||
|
||||
// 可以从配置文件或数据库加载更多规则
|
||||
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
|
||||
|
|
|
|||
21
rule.go
21
rule.go
|
|
@ -1,8 +1,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
|
@ -45,6 +47,25 @@ type RuleManager struct {
|
|||
rules []*Rule
|
||||
}
|
||||
|
||||
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
|
||||
// load rules from file
|
||||
content, err := os.ReadFile(bRuleFile)
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
}
|
||||
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||
for decoder.More() {
|
||||
var rule Rule
|
||||
err := decoder.Decode(&rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rm.AddRule(&rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewRuleManager 创建新的规则管理器
|
||||
func NewRuleManager() *RuleManager {
|
||||
return &RuleManager{
|
||||
|
|
|
|||
Loading…
Reference in New Issue