"feat(config): 添加规则目录配置并实现从文件加载规则功能"
This commit is contained in:
parent
fcbf669779
commit
f27002ddff
19
config.go
19
config.go
|
|
@ -8,21 +8,22 @@ import (
|
||||||
|
|
||||||
// Config 防火墙配置结构
|
// Config 防火墙配置结构
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel LogLevel `json:"log_level"`
|
LogLevel LogLevel `json:"log_level"`
|
||||||
CaptureInterface string `json:"capture_interface"`
|
CaptureInterface string `json:"capture_interface"`
|
||||||
ForwardEnabled bool `json:"forward_enabled"`
|
ForwardEnabled bool `json:"forward_enabled"`
|
||||||
MaxPacketSize int `json:"max_packet_size"`
|
MaxPacketSize int `json:"max_packet_size"`
|
||||||
ConfigFile string `json:"config_file"`
|
ConfigFile string `json:"config_file"`
|
||||||
|
RuleDir string `json:"rule_dir"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig 创建新的配置实例
|
// NewConfig 创建新的配置实例
|
||||||
func NewConfig() *Config {
|
func NewConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
LogLevel: LogLevelInfo,
|
LogLevel: LogLevelInfo,
|
||||||
CaptureInterface: "",
|
CaptureInterface: "",
|
||||||
ForwardEnabled: false,
|
ForwardEnabled: false,
|
||||||
MaxPacketSize: 65536,
|
MaxPacketSize: 65536,
|
||||||
ConfigFile: "firewall.json",
|
ConfigFile: "firewall.json",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
21
forwarder.go
21
forwarder.go
|
|
@ -1,10 +1,13 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
|
@ -37,6 +40,24 @@ func (f *Forwarder) Stop() {
|
||||||
log.Println("Forwarding service stopped")
|
log.Println("Forwarding service stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) LoadRulesFromFile(ruleFile string) error {
|
||||||
|
content, err := os.ReadFile(ruleFile)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||||
|
for decoder.More() {
|
||||||
|
var rule ForwardRule
|
||||||
|
err := decoder.Decode(&rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.AddForwardRule(rule)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ForwardRule 定义转发规则结构
|
// ForwardRule 定义转发规则结构
|
||||||
type ForwardRule struct {
|
type ForwardRule struct {
|
||||||
SrcIP string // 源IP
|
SrcIP string // 源IP
|
||||||
|
|
|
||||||
23
main.go
23
main.go
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -78,6 +79,28 @@ func (f *Firewall) loadRules() error {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
f.ruleManager.AddRule(loopbackRule)
|
f.ruleManager.AddRule(loopbackRule)
|
||||||
|
// load rules from rules dir
|
||||||
|
rulesDir := f.config.RuleDir
|
||||||
|
if !filepath.IsAbs(rulesDir) {
|
||||||
|
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
|
||||||
|
}
|
||||||
|
bRuleFile := filepath.Join(rulesDir, "rules.json")
|
||||||
|
if _, err := os.Stat(bRuleFile); err == nil {
|
||||||
|
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to load rules from file: %v", err)
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to load rules from file: %v", err)
|
||||||
|
}
|
||||||
|
fRuleFile := filepath.Join(rulesDir, "forward.json")
|
||||||
|
|
||||||
|
if _, err := os.Stat(fRuleFile); err == nil {
|
||||||
|
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 可以从配置文件或数据库加载更多规则
|
// 可以从配置文件或数据库加载更多规则
|
||||||
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
|
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
|
||||||
|
|
|
||||||
21
rule.go
21
rule.go
|
|
@ -1,8 +1,10 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -45,6 +47,25 @@ type RuleManager struct {
|
||||||
rules []*Rule
|
rules []*Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rm *RuleManager) LoadFromFile(bRuleFile string) error {
|
||||||
|
// load rules from file
|
||||||
|
content, err := os.ReadFile(bRuleFile)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decoder := json.NewDecoder(strings.NewReader(string(content)))
|
||||||
|
for decoder.More() {
|
||||||
|
var rule Rule
|
||||||
|
err := decoder.Decode(&rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rm.AddRule(&rule)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewRuleManager 创建新的规则管理器
|
// NewRuleManager 创建新的规则管理器
|
||||||
func NewRuleManager() *RuleManager {
|
func NewRuleManager() *RuleManager {
|
||||||
return &RuleManager{
|
return &RuleManager{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue