145 lines
3.2 KiB
Go
145 lines
3.2 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"syscall"
|
|
)
|
|
|
|
// Firewall 主防火墙结构体
|
|
type Firewall struct {
|
|
ruleManager *RuleManager
|
|
logger *Logger
|
|
config *Config
|
|
forwarder *Forwarder
|
|
capture *PacketCapture
|
|
}
|
|
|
|
// NewFirewall 创建新的防火墙实例
|
|
func NewFirewall() *Firewall {
|
|
logger := NewLogger()
|
|
return &Firewall{
|
|
ruleManager: NewRuleManager(),
|
|
logger: logger,
|
|
config: NewConfig(),
|
|
forwarder: NewForwarder(),
|
|
}
|
|
}
|
|
|
|
// Start 启动防火墙服务
|
|
func (f *Firewall) Start() error {
|
|
log.Println("Starting firewall service...")
|
|
// 加载配置
|
|
if err := f.config.Load(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 加载规则
|
|
if err := f.loadRules(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 启动流量捕获和过滤
|
|
if err := f.startPacketCapture(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 启动转发服务
|
|
if err := f.forwarder.Start(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop 停止防火墙服务
|
|
func (f *Firewall) Stop() {
|
|
log.Println("Stopping firewall service...")
|
|
if f.capture != nil {
|
|
f.capture.Stop()
|
|
}
|
|
f.forwarder.Stop()
|
|
f.logger.Close()
|
|
}
|
|
|
|
// 加载防火墙规则
|
|
func (f *Firewall) loadRules() error {
|
|
// 示例规则:允许本地回环地址的所有流量
|
|
loopbackRule := &Rule{
|
|
ID: "rule-1",
|
|
Name: "Allow Loopback",
|
|
Protocol: ProtocolAll,
|
|
SrcIP: "127.0.0.1",
|
|
DstIP: "127.0.0.1",
|
|
Action: ActionAllow,
|
|
Description: "Allow all loopback traffic",
|
|
Enabled: true,
|
|
}
|
|
f.ruleManager.AddRule(loopbackRule)
|
|
// load rules from rules dir
|
|
rulesDir := f.config.RuleDir
|
|
if !filepath.IsAbs(rulesDir) {
|
|
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
|
|
}
|
|
bRuleFile := filepath.Join(rulesDir, "rules.json")
|
|
if _, err := os.Stat(bRuleFile); err == nil {
|
|
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
|
|
return fmt.Errorf("failed to load rules from file: %v", err)
|
|
}
|
|
} else if !os.IsNotExist(err) {
|
|
return fmt.Errorf("failed to load rules from file: %v", err)
|
|
}
|
|
fRuleFile := filepath.Join(rulesDir, "forward.json")
|
|
|
|
if _, err := os.Stat(fRuleFile); err == nil {
|
|
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
|
|
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
|
}
|
|
} else if !os.IsNotExist(err) {
|
|
return fmt.Errorf("failed to load forward rules from file: %v", err)
|
|
}
|
|
|
|
// 可以从配置文件或数据库加载更多规则
|
|
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
|
|
return nil
|
|
}
|
|
|
|
// 启动数据包捕获和过滤
|
|
func (f *Firewall) startPacketCapture() error {
|
|
if f.config.CaptureInterface == "" {
|
|
return fmt.Errorf("capture interface not configured")
|
|
}
|
|
|
|
f.capture = NewPacketCapture(
|
|
f.config.CaptureInterface,
|
|
f.ruleManager,
|
|
f.logger,
|
|
f.forwarder,
|
|
)
|
|
|
|
if err := f.capture.Start(); err != nil {
|
|
return fmt.Errorf("failed to start packet capture: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
firewall := NewFirewall()
|
|
|
|
if err := firewall.Start(); err != nil {
|
|
log.Fatalf("Failed to start firewall: %v", err)
|
|
}
|
|
defer firewall.Stop()
|
|
|
|
// 等待中断信号
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sigChan
|
|
|
|
log.Println("Firewall stopped successfully")
|
|
}
|