gofirewall/main.go

145 lines
3.2 KiB
Go

package main
import (
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
)
// Firewall 主防火墙结构体
type Firewall struct {
ruleManager *RuleManager
logger *Logger
config *Config
forwarder *Forwarder
capture *PacketCapture
}
// NewFirewall 创建新的防火墙实例
func NewFirewall() *Firewall {
logger := NewLogger()
return &Firewall{
ruleManager: NewRuleManager(),
logger: logger,
config: NewConfig(),
forwarder: NewForwarder(),
}
}
// Start 启动防火墙服务
func (f *Firewall) Start() error {
log.Println("Starting firewall service...")
// 加载配置
if err := f.config.Load(); err != nil {
return err
}
// 加载规则
if err := f.loadRules(); err != nil {
return err
}
// 启动流量捕获和过滤
if err := f.startPacketCapture(); err != nil {
return err
}
// 启动转发服务
if err := f.forwarder.Start(); err != nil {
return err
}
return nil
}
// Stop 停止防火墙服务
func (f *Firewall) Stop() {
log.Println("Stopping firewall service...")
if f.capture != nil {
f.capture.Stop()
}
f.forwarder.Stop()
f.logger.Close()
}
// 加载防火墙规则
func (f *Firewall) loadRules() error {
// 示例规则:允许本地回环地址的所有流量
loopbackRule := &Rule{
ID: "rule-1",
Name: "Allow Loopback",
Protocol: ProtocolAll,
SrcIP: "127.0.0.1",
DstIP: "127.0.0.1",
Action: ActionAllow,
Description: "Allow all loopback traffic",
Enabled: true,
}
f.ruleManager.AddRule(loopbackRule)
// load rules from rules dir
rulesDir := f.config.RuleDir
if !filepath.IsAbs(rulesDir) {
rulesDir = filepath.Join(filepath.Dir(f.config.ConfigFile), rulesDir)
}
bRuleFile := filepath.Join(rulesDir, "rules.json")
if _, err := os.Stat(bRuleFile); err == nil {
if err := f.ruleManager.LoadFromFile(bRuleFile); err != nil {
return fmt.Errorf("failed to load rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load rules from file: %v", err)
}
fRuleFile := filepath.Join(rulesDir, "forward.json")
if _, err := os.Stat(fRuleFile); err == nil {
if err := f.forwarder.LoadRulesFromFile(fRuleFile); err != nil {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to load forward rules from file: %v", err)
}
// 可以从配置文件或数据库加载更多规则
f.logger.Info("Loaded ", len(f.ruleManager.ListRules()), " firewall rules")
return nil
}
// 启动数据包捕获和过滤
func (f *Firewall) startPacketCapture() error {
if f.config.CaptureInterface == "" {
return fmt.Errorf("capture interface not configured")
}
f.capture = NewPacketCapture(
f.config.CaptureInterface,
f.ruleManager,
f.logger,
f.forwarder,
)
if err := f.capture.Start(); err != nil {
return fmt.Errorf("failed to start packet capture: %v", err)
}
return nil
}
func main() {
firewall := NewFirewall()
if err := firewall.Start(); err != nil {
log.Fatalf("Failed to start firewall: %v", err)
}
defer firewall.Stop()
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
log.Println("Firewall stopped successfully")
}