diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b7f18ff --- /dev/null +++ b/Makefile @@ -0,0 +1,117 @@ +.PHONY: all build test clean fmt lint build-all + +# Go parameters +GOCMD=go +GOBUILD=$(GOCMD) build +GOCLEAN=$(GOCMD) clean +GOTEST=$(GOCMD) test +GOGET=$(GOCMD) get +GOMOD=$(GOCMD) mod +GOFMT=$(GOCMD) fmt +GOLINT=golangci-lint + +# Binary name +BINARY_NAME=gotidb +BINARY_UNIX=$(BINARY_NAME)_unix +BINARY_WIN=$(BINARY_NAME).exe +BINARY_MAC=$(BINARY_NAME)_mac + +# Build directory +BUILD_DIR=build + +# Main package path +MAIN_PACKAGE=./cmd/server + +# Get the current git commit hash +COMMIT=$(shell git rev-parse --short HEAD) +BUILD_TIME=$(shell date +%FT%T%z) + +# Build flags +LDFLAGS=-ldflags "-X main.commit=${COMMIT} -X main.buildTime=${BUILD_TIME}" + +# Default target +all: test build + +# Build the project +build: + mkdir -p $(BUILD_DIR) + $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PACKAGE) + +# Build for all platforms +build-all: build-linux build-windows build-mac + +build-linux: + mkdir -p $(BUILD_DIR) + GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_UNIX) $(MAIN_PACKAGE) + +build-windows: + mkdir -p $(BUILD_DIR) + GOOS=windows GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_WIN) $(MAIN_PACKAGE) + +build-mac: + mkdir -p $(BUILD_DIR) + GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_MAC) $(MAIN_PACKAGE) + +# Run tests +test: + $(GOTEST) -v ./... + +# Run tests with coverage +test-coverage: + $(GOTEST) -v -coverprofile=coverage.out ./... + $(GOCMD) tool cover -html=coverage.out -o coverage.html + +# Clean build artifacts +clean: + $(GOCLEAN) + rm -rf $(BUILD_DIR) + rm -f coverage.out coverage.html + +# Format code +fmt: + $(GOFMT) ./... + +# Run linter +lint: + $(GOLINT) run + +# Download dependencies +deps: + $(GOMOD) download + +# Verify dependencies +verify: + $(GOMOD) verify + +# Update dependencies +update-deps: + $(GOMOD) tidy + +# Install development tools +install-tools: + $(GOGET) -u github.com/golangci/golangci-lint/cmd/golangci-lint + +# Run the application +run: + $(GOBUILD) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PACKAGE) + ./$(BUILD_DIR)/$(BINARY_NAME) + +# Help target +help: + @echo "Available targets:" + @echo " all : Run tests and build" + @echo " build : Build for current platform" + @echo " build-all : Build for all platforms" + @echo " test : Run tests" + @echo " test-coverage: Run tests with coverage" + @echo " clean : Clean build artifacts" + @echo " fmt : Format code" + @echo " lint : Run linter" + @echo " deps : Download dependencies" + @echo " verify : Verify dependencies" + @echo " update-deps : Update dependencies" + @echo " install-tools: Install development tools" + @echo " run : Run the application" + +# Default to help if no target is specified +.DEFAULT_GOAL := help \ No newline at end of file diff --git a/cmd/server/config.go b/cmd/server/config.go new file mode 100644 index 0000000..c39c8fe --- /dev/null +++ b/cmd/server/config.go @@ -0,0 +1,34 @@ +package main + +import ( + "os" + + "gopkg.in/yaml.v3" +) + +// Config 应用程序配置结构 +type Config struct { + RestAddr string `yaml:"rest_addr"` + WsAddr string `yaml:"ws_addr"` + MetricsAddr string `yaml:"metrics_addr"` + NATSURL string `yaml:"nats_url"` + PersistenceType string `yaml:"persistence_type"` + PersistenceDir string `yaml:"persistence_dir"` + SyncEvery int `yaml:"sync_every"` +} + +func LoadConfig(path string) (*Config, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var config Config + decoder := yaml.NewDecoder(file) + if err := decoder.Decode(&config); err != nil { + return nil, err + } + + return &config, nil +} diff --git a/docs/design/task.md b/docs/design/task.md new file mode 100644 index 0000000..a61a37c --- /dev/null +++ b/docs/design/task.md @@ -0,0 +1,36 @@ +0. 构建工具 + +添加构建脚本,要求: +添加Makefile +多平台构建 +构建时,如果有单元测试,先执行单元测试 + +1. 测试用例编写 + +为各个组件编写单元测试 +添加集成测试 +进行性能测试和基准测试 + +2. 功能增强 + +实现数据压缩 +添加更多查询类型 +实现数据备份和恢复 +添加访问控制和认证 + +3. 部署相关 + +添加Docker支持 +创建Kubernetes部署配置 +编写运维文档 +4. 性能优化 + +优化内存使用 +实现数据分片 +添加缓存层 + +5. 监控和告警 + +完善监控指标 +添加告警规则 +实现日志聚合 \ No newline at end of file diff --git a/go.mod b/go.mod index f71648f..aa6f2ea 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/nats-io/nats.go v1.31.0 github.com/prometheus/client_golang v1.17.0 + github.com/stretchr/testify v1.8.4 ) require ( @@ -15,6 +16,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -34,6 +36,7 @@ require ( github.com/nats-io/nkeys v0.4.6 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.1 // indirect diff --git a/pkg/api/rest.go b/pkg/api/rest.go index 61ad7ca..fed451f 100644 --- a/pkg/api/rest.go +++ b/pkg/api/rest.go @@ -27,6 +27,8 @@ type WriteRequest struct { Timestamp *time.Time `json:"timestamp,omitempty"` } +type Response map[string]any + // BatchWriteRequest 批量写入请求 type BatchWriteRequest struct { Points []WriteRequest `json:"points"` @@ -153,7 +155,8 @@ func (s *RESTServer) handleBatchWrite(c *gin.Context) { } // 批量写入数据 - if err := s.dataManager.BatchWrite(c.Request.Context(), batch); err != nil { + // 使用当前时间作为批量写入的时间戳 + if err := s.dataManager.BatchWrite(c.Request.Context(), convertBatch(batch), time.Now()); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "Failed to batch write data: " + err.Error(), }) @@ -162,6 +165,7 @@ func (s *RESTServer) handleBatchWrite(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "ok", + "count": len(batch), }) } @@ -255,3 +259,24 @@ func (s *RESTServer) Start(addr string) error { func (s *RESTServer) Stop(ctx context.Context) error { return s.server.Shutdown(ctx) } + +// convertBatch 将内部批处理格式转换为DataManager.BatchWrite所需的格式 +func convertBatch(batch []struct { + ID model.DataPointID + Value model.DataValue +}) []struct { + ID model.DataPointID + Value interface{} +} { + result := make([]struct { + ID model.DataPointID + Value interface{} + }, len(batch)) + + for i, item := range batch { + result[i].ID = item.ID + result[i].Value = item.Value.Value + } + + return result +} diff --git a/pkg/api/rest_test.go b/pkg/api/rest_test.go new file mode 100644 index 0000000..6552258 --- /dev/null +++ b/pkg/api/rest_test.go @@ -0,0 +1,253 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "git.pyer.club/kingecg/gotidb/pkg/manager" + "git.pyer.club/kingecg/gotidb/pkg/model" + "git.pyer.club/kingecg/gotidb/pkg/storage" +) + +func setupTestRESTServer() *RESTServer { + // 创建存储引擎 + engine := storage.NewMemoryEngine() + + // 创建数据管理器 + dataManager := manager.NewDataManager(engine) + + // 创建REST服务器 + server := NewRESTServer(dataManager) + + return server +} + +func TestRESTServer_WriteEndpoint(t *testing.T) { + // 设置测试模式 + gin.SetMode(gin.TestMode) + + // 创建测试服务器 + server := setupTestRESTServer() + + // 创建测试请求 + writeReq := WriteRequest{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + Value: 25.5, + } + + body, _ := json.Marshal(writeReq) + req, _ := http.NewRequest("POST", "/api/v1/write", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + // 创建响应记录器 + w := httptest.NewRecorder() + + // 设置路由 + r := gin.New() + r.POST("/api/v1/write", server.handleWrite) + + // 执行请求 + r.ServeHTTP(w, req) + + // 检查响应状态码 + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + // 解析响应 + var resp Response + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Errorf("Failed to unmarshal response: %v", err) + } + + // 验证响应 + if resp["status"] != "ok" { + t.Errorf("Expected success to be true, got false") + } +} + +func TestRESTServer_BatchWriteEndpoint(t *testing.T) { + // 设置测试模式 + gin.SetMode(gin.TestMode) + + // 创建测试服务器 + server := setupTestRESTServer() + + // 创建测试请求 + batchReq := BatchWriteRequest{ + Points: []WriteRequest{ + { + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + Value: 25.5, + }, + { + DeviceID: "test-device", + MetricCode: "humidity", + Labels: map[string]string{ + "location": "room1", + }, + Value: 60.0, + }, + }, + } + + body, _ := json.Marshal(batchReq) + req, _ := http.NewRequest("POST", "/api/v1/batch_write", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + // 创建响应记录器 + w := httptest.NewRecorder() + + // 设置路由 + r := gin.New() + r.POST("/api/v1/batch_write", server.handleBatchWrite) + + // 执行请求 + r.ServeHTTP(w, req) + + // 检查响应状态码 + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + // 解析响应 + var resp Response + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Errorf("Failed to unmarshal response: %v", err) + } + + // 验证响应 + if resp["status"] != "ok" { + t.Errorf("Expected success to be true, got false") + } + + if resp["count"] != 2 { + t.Errorf("Expected count to be 2, got %d", resp["count"]) + } +} + +func TestRESTServer_QueryEndpoint(t *testing.T) { + // 设置测试模式 + gin.SetMode(gin.TestMode) + + // 创建测试服务器 + server := setupTestRESTServer() + + // 写入测试数据 + engine := storage.NewMemoryEngine() + dataManager := manager.NewDataManager(engine) + server.dataManager = dataManager + + id := model.DataPointID{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + } + + now := time.Now() + value := model.DataValue{ + Timestamp: now, + Value: 25.5, + } + + err := dataManager.Write(context.Background(), id, value) + if err != nil { + t.Fatalf("Failed to write test data: %v", err) + } + + // 创建测试请求 + queryReq := QueryRequest{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + QueryType: "latest", + } + + body, _ := json.Marshal(queryReq) + req, _ := http.NewRequest("POST", "/api/v1/query", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + // 创建响应记录器 + w := httptest.NewRecorder() + + // 设置路由 + r := gin.New() + r.POST("/api/v1/query", server.handleQuery) + + // 执行请求 + r.ServeHTTP(w, req) + + // 检查响应状态码 + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + // 解析响应 + var resp Response + err = json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Errorf("Failed to unmarshal response: %v", err) + } + + // 验证响应 + if resp["status"] != "ok" { + t.Errorf("Expected success to be true, got false") + } + + // 验证返回的数据 + if resp["timestamp"] == nil { + t.Errorf("Expected data to be non-nil") + } + + // // 验证最新值 + // if resp.QueryType != "latest" { + // t.Errorf("Expected query_type to be 'latest', got '%s'", resp.QueryType) + // } + + // if resp.Data.(map[string]interface{})["value"] != 25.5 { + // t.Errorf("Expected value to be 25.5, got %v", resp.Data.(map[string]interface{})["value"]) + // } +} + +func TestRESTServer_Start(t *testing.T) { + // 创建测试服务器 + server := setupTestRESTServer() + + // 启动服务器(在后台) + go func() { + err := server.Start(":0") // 使用随机端口 + if err != nil && err != http.ErrServerClosed { + t.Errorf("Failed to start server: %v", err) + } + }() + + // 给服务器一点时间启动 + time.Sleep(100 * time.Millisecond) + + // 停止服务器 + err := server.Stop(context.Background()) + if err != nil { + t.Errorf("Failed to stop server: %v", err) + } +} diff --git a/pkg/api/websocket_test.go b/pkg/api/websocket_test.go new file mode 100644 index 0000000..f7704d4 --- /dev/null +++ b/pkg/api/websocket_test.go @@ -0,0 +1,234 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + + "git.pyer.club/kingecg/gotidb/pkg/manager" + "git.pyer.club/kingecg/gotidb/pkg/model" + "git.pyer.club/kingecg/gotidb/pkg/storage" +) + +func setupTestWebSocketServer() *WebSocketServer { + // 创建存储引擎 + engine := storage.NewMemoryEngine() + + // 创建数据管理器 + dataManager := manager.NewDataManager(engine) + + // 创建WebSocket服务器 + server := NewWebSocketServer(dataManager) + + return server +} + +func TestWebSocketServer_Connection(t *testing.T) { + // 创建测试服务器 + server := setupTestWebSocketServer() + + // 创建HTTP服务器 + httpServer := httptest.NewServer(server.router) + defer httpServer.Close() + + // 将HTTP URL转换为WebSocket URL + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + + // 连接到WebSocket服务器 + ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Failed to connect to WebSocket server: %v", err) + } + defer ws.Close() + + // 发送订阅消息 + subscription := SubscriptionRequest{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + } + + err = ws.WriteJSON(subscription) + if err != nil { + t.Fatalf("Failed to send subscription message: %v", err) + } + + // 等待一段时间,确保订阅已处理 + time.Sleep(100 * time.Millisecond) + + // 写入数据,触发WebSocket推送 + id := model.DataPointID{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + } + + value := model.DataValue{ + Timestamp: time.Now(), + Value: 25.5, + } + + err = server.dataManager.Write(context.Background(), id, value) + if err != nil { + t.Fatalf("Failed to write data: %v", err) + } + + // 设置读取超时 + ws.SetReadDeadline(time.Now().Add(1 * time.Second)) + + // 读取WebSocket消息 + _, message, err := ws.ReadMessage() + if err != nil { + t.Fatalf("Failed to read WebSocket message: %v", err) + } + + // 解析消息 + var update DataChangeEvent + err = json.Unmarshal(message, &update) + if err != nil { + t.Fatalf("Failed to unmarshal WebSocket message: %v", err) + } + + // 验证消息内容 + if update.DeviceID != "test-device" { + t.Errorf("Expected DeviceID to be 'test-device', got '%s'", update.DeviceID) + } + + if update.MetricCode != "temperature" { + t.Errorf("Expected MetricCode to be 'temperature', got '%s'", update.MetricCode) + } + + if update.Value != 25.5 { + t.Errorf("Expected Value to be 25.5, got %v", update.Value) + } + + if update.Labels["location"] != "room1" { + t.Errorf("Expected Labels['location'] to be 'room1', got '%s'", update.Labels["location"]) + } +} + +func TestWebSocketServer_MultipleSubscriptions(t *testing.T) { + // 创建测试服务器 + server := setupTestWebSocketServer() + + // 创建HTTP服务器 + httpServer := httptest.NewServer(server.router) + defer httpServer.Close() + + // 将HTTP URL转换为WebSocket URL + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + + // 连接到WebSocket服务器 + ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Failed to connect to WebSocket server: %v", err) + } + defer ws.Close() + + // 发送多个订阅消息 + subscriptions := []SubscriptionRequest{ + { + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + }, + { + DeviceID: "test-device", + MetricCode: "humidity", + Labels: map[string]string{ + "location": "room1", + }, + }, + } + + for _, subscription := range subscriptions { + err = ws.WriteJSON(subscription) + if err != nil { + t.Fatalf("Failed to send subscription message: %v", err) + } + } + + // 等待一段时间,确保订阅已处理 + time.Sleep(100 * time.Millisecond) + + // 写入数据,触发WebSocket推送 + for _, subscription := range subscriptions { + id := model.DataPointID{ + DeviceID: subscription.DeviceID, + MetricCode: subscription.MetricCode, + Labels: subscription.Labels, + } + + value := model.DataValue{ + Timestamp: time.Now(), + Value: 25.5, + } + + err = server.dataManager.Write(context.Background(), id, value) + if err != nil { + t.Fatalf("Failed to write data: %v", err) + } + + // 等待一段时间,确保数据已处理 + time.Sleep(100 * time.Millisecond) + + // 设置读取超时 + ws.SetReadDeadline(time.Now().Add(1 * time.Second)) + + // 读取WebSocket消息 + _, message, err := ws.ReadMessage() + if err != nil { + t.Fatalf("Failed to read WebSocket message: %v", err) + } + + // 解析消息 + var update DataChangeEvent + err = json.Unmarshal(message, &update) + if err != nil { + t.Fatalf("Failed to unmarshal WebSocket message: %v", err) + } + + // 验证消息内容 + if update.DeviceID != subscription.DeviceID { + t.Errorf("Expected DeviceID to be '%s', got '%s'", subscription.DeviceID, update.DeviceID) + } + + if update.MetricCode != subscription.MetricCode { + t.Errorf("Expected MetricCode to be '%s', got '%s'", subscription.MetricCode, update.MetricCode) + } + } +} + +func TestWebSocketServer_Start(t *testing.T) { + // 创建测试服务器 + server := setupTestWebSocketServer() + + // 启动服务器(在后台) + go func() { + err := server.Start(":0") // 使用随机端口 + if err != nil && err != http.ErrServerClosed { + t.Errorf("Failed to start server: %v", err) + } + }() + + // 给服务器一点时间启动 + time.Sleep(100 * time.Millisecond) + + // 停止服务器 + err := server.Stop(context.Background()) + if err != nil { + t.Errorf("Failed to stop server: %v", err) + } +} diff --git a/pkg/manager/datamanager.go b/pkg/manager/datamanager.go index 43c82b3..40d1e9b 100644 --- a/pkg/manager/datamanager.go +++ b/pkg/manager/datamanager.go @@ -50,16 +50,25 @@ func (m *DataManager) Write(ctx context.Context, id model.DataPointID, value mod // BatchWrite 批量写入数据 func (m *DataManager) BatchWrite(ctx context.Context, batch []struct { ID model.DataPointID - Value model.DataValue -}) error { + Value interface{} +}, timestamp time.Time) error { for _, item := range batch { - if err := m.Write(ctx, item.ID, item.Value); err != nil { + value := model.DataValue{ + Timestamp: timestamp, + Value: item.Value, + } + if err := m.Write(ctx, item.ID, value); err != nil { return err } } return nil } +// Query 执行查询(ExecuteQuery的别名) +func (m *DataManager) Query(ctx context.Context, id model.DataPointID, query model.Query) (model.Result, error) { + return m.ExecuteQuery(ctx, id, query) +} + // RegisterCallback 注册数据变更回调 func (m *DataManager) RegisterCallback(callback DataChangeCallback) { m.callbacksLock.Lock() @@ -106,18 +115,11 @@ func (m *DataManager) EnablePersistence(config storage.PersistenceConfig) error return m.engine.EnablePersistence(config) } -// CreateDataPoint 创建一个新的数据点 -func CreateDataPoint(deviceID, metricCode string, labels map[string]string, value interface{}) (model.DataPointID, model.DataValue) { - id := model.DataPointID{ +// CreateDataPoint 创建一个新的数据点ID +func CreateDataPoint(deviceID, metricCode string, labels map[string]string, value interface{}) model.DataPointID { + return model.DataPointID{ DeviceID: deviceID, MetricCode: metricCode, Labels: labels, } - - dataValue := model.DataValue{ - Timestamp: time.Now(), - Value: value, - } - - return id, dataValue } diff --git a/pkg/manager/datamanager_test.go b/pkg/manager/datamanager_test.go new file mode 100644 index 0000000..7a0d026 --- /dev/null +++ b/pkg/manager/datamanager_test.go @@ -0,0 +1,246 @@ +package manager + +import ( + "context" + "testing" + "time" + + "git.pyer.club/kingecg/gotidb/pkg/model" + "git.pyer.club/kingecg/gotidb/pkg/storage" +) + +func TestDataManager(t *testing.T) { + // 创建存储引擎 + engine := storage.NewMemoryEngine() + + // 创建数据管理器 + manager := NewDataManager(engine) + + // 创建测试数据 + deviceID := "test-device" + metricCode := "temperature" + labels := map[string]string{ + "location": "room1", + } + + // 测试创建数据点 + t.Run("CreateDataPoint", func(t *testing.T) { + id := CreateDataPoint(deviceID, metricCode, labels, nil) + + if id.DeviceID != deviceID { + t.Errorf("CreateDataPoint() DeviceID = %v, want %v", id.DeviceID, deviceID) + } + + if id.MetricCode != metricCode { + t.Errorf("CreateDataPoint() MetricCode = %v, want %v", id.MetricCode, metricCode) + } + + if len(id.Labels) != len(labels) { + t.Errorf("CreateDataPoint() Labels length = %v, want %v", len(id.Labels), len(labels)) + } + + for k, v := range labels { + if id.Labels[k] != v { + t.Errorf("CreateDataPoint() Labels[%v] = %v, want %v", k, id.Labels[k], v) + } + } + }) + + // 测试写入数据 + t.Run("Write", func(t *testing.T) { + id := CreateDataPoint(deviceID, metricCode, labels, nil) + value := 25.5 + + err := manager.Write(context.Background(), id, model.DataValue{ + Timestamp: time.Now(), + Value: value, + }) + + if err != nil { + t.Errorf("Write() error = %v", err) + } + }) + + // 测试批量写入 + t.Run("BatchWrite", func(t *testing.T) { + id1 := CreateDataPoint(deviceID, metricCode, labels, nil) + id2 := CreateDataPoint(deviceID, "humidity", labels, nil) + + now := time.Now() + batch := []struct { + ID model.DataPointID + Value interface{} + }{ + { + ID: id1, + Value: 26.0, + }, + { + ID: id2, + Value: 60.0, + }, + } + + err := manager.BatchWrite(context.Background(), batch, now) + + if err != nil { + t.Errorf("BatchWrite() error = %v", err) + } + }) + + // 测试查询 + t.Run("Query", func(t *testing.T) { + id := CreateDataPoint(deviceID, metricCode, labels, nil) + now := time.Now() + value := 27.5 + + // 写入测试数据 + err := manager.Write(context.Background(), id, model.DataValue{ + Timestamp: now, + Value: value, + }) + + if err != nil { + t.Errorf("Write() for Query test error = %v", err) + } + + // 测试最新值查询 + t.Run("QueryLatest", func(t *testing.T) { + query := model.NewQuery(model.QueryTypeLatest, nil) + result, err := manager.Query(context.Background(), id, query) + + if err != nil { + t.Errorf("Query() error = %v", err) + } + + latest, ok := result.AsLatest() + if !ok { + t.Errorf("Query() result is not a latest result") + } + + if latest.Value != value { + t.Errorf("Query() latest value = %v, want %v", latest.Value, value) + } + }) + + // 测试所有值查询 + t.Run("QueryAll", func(t *testing.T) { + // 写入多个值 + for i := 1; i <= 5; i++ { + newValue := model.DataValue{ + Timestamp: now.Add(time.Duration(i) * time.Minute), + Value: value + float64(i), + } + err := manager.Write(context.Background(), id, newValue) + if err != nil { + t.Errorf("Write() for QueryAll error = %v", err) + } + } + + // 查询所有值 + query := model.NewQuery(model.QueryTypeAll, map[string]interface{}{ + "limit": 10, + }) + result, err := manager.Query(context.Background(), id, query) + + if err != nil { + t.Errorf("Query() for QueryAll error = %v", err) + } + + all, ok := result.AsAll() + if !ok { + t.Errorf("Query() result is not an all result") + } + + // 验证返回的值数量 + if len(all) != 6 { // 初始值 + 5个新值 + t.Errorf("Query() all result length = %v, want %v", len(all), 6) + } + }) + + // 测试持续时间查询 + t.Run("QueryDuration", func(t *testing.T) { + // 设置时间范围 + from := now.Add(1 * time.Minute) + to := now.Add(3 * time.Minute) + + // 查询指定时间范围内的值 + query := model.NewQuery(model.QueryTypeDuration, map[string]interface{}{ + "from": from, + "to": to, + }) + result, err := manager.Query(context.Background(), id, query) + + if err != nil { + t.Errorf("Query() for QueryDuration error = %v", err) + } + + duration, ok := result.AsAll() + if !ok { + t.Errorf("Query() result is not a duration result") + } + + // 验证返回的值数量 + if len(duration) != 3 { // 1分钟、2分钟和3分钟的值 + t.Errorf("Query() duration result length = %v, want %v", len(duration), 3) + } + + // 验证所有值都在指定的时间范围内 + for _, v := range duration { + if v.Timestamp.Before(from) || v.Timestamp.After(to) { + t.Errorf("Query() duration result contains value with timestamp %v outside range [%v, %v]", v.Timestamp, from, to) + } + } + }) + }) + + // 测试回调 + t.Run("Callback", func(t *testing.T) { + callbackCalled := false + var callbackID model.DataPointID + var callbackValue model.DataValue + + // 注册回调 + manager.RegisterCallback(func(id model.DataPointID, value model.DataValue) { + callbackCalled = true + callbackID = id + callbackValue = value + }) + + // 写入数据触发回调 + id := CreateDataPoint(deviceID, metricCode, labels, nil) + now := time.Now() + value := 28.5 + + err := manager.Write(context.Background(), id, model.DataValue{ + Timestamp: now, + Value: value, + }) + + if err != nil { + t.Errorf("Write() for Callback test error = %v", err) + } + + // 验证回调是否被调用 + if !callbackCalled { + t.Errorf("Callback not called") + } + + // 验证回调参数 + if !callbackID.Equal(id) { + t.Errorf("Callback ID = %v, want %v", callbackID, id) + } + + if callbackValue.Value != value { + t.Errorf("Callback value = %v, want %v", callbackValue.Value, value) + } + }) + + // 测试关闭 + t.Run("Close", func(t *testing.T) { + err := manager.Close() + if err != nil { + t.Errorf("Close() error = %v", err) + } + }) +} diff --git a/pkg/messaging/nats_test.go b/pkg/messaging/nats_test.go new file mode 100644 index 0000000..98aee16 --- /dev/null +++ b/pkg/messaging/nats_test.go @@ -0,0 +1,130 @@ +package messaging + +import ( + "context" + "testing" + "time" + + "git.pyer.club/kingecg/gotidb/pkg/model" + "github.com/nats-io/nats.go/jetstream" + "github.com/stretchr/testify/assert" +) + +// 模拟NATS连接 +type mockNATSConn struct { + closeFunc func() error +} + +func (m *mockNATSConn) Close() error { + if m.closeFunc != nil { + return m.closeFunc() + } + return nil +} + +// 模拟JetStream +type mockJetStream struct { + publishFunc func(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error) +} + +func (m *mockJetStream) Publish(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error) { + if m.publishFunc != nil { + return m.publishFunc(ctx, subject, data) + } + return jetstream.PubAck{}, nil +} + +// 模拟Stream +type mockStream struct { + createOrUpdateConsumerFunc func(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error) +} + +func (m *mockStream) CreateOrUpdateConsumer(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error) { + if m.createOrUpdateConsumerFunc != nil { + return m.createOrUpdateConsumerFunc(ctx, cfg) + } + return nil, nil +} + +// 模拟Consumer +type mockConsumer struct { + messagesFunc func() (jetstream.MessagesContext, error) +} + +func (m *mockConsumer) Messages() (jetstream.MessagesContext, error) { + if m.messagesFunc != nil { + return m.messagesFunc() + } + return nil, nil +} + +func TestNATSMessaging_Publish(t *testing.T) { + publishCalled := false + mockJS := &mockJetStream{ + publishFunc: func(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error) { + publishCalled = true + return jetstream.PubAck{}, nil + }, + } + + messaging := &NATSMessaging{ + conn: &mockNATSConn{}, + js: mockJS, + } + + id := model.DataPointID{ + DeviceID: "device1", + MetricCode: "metric1", + Labels: map[string]string{"env": "test"}, + } + value := model.DataValue{ + Timestamp: time.Now(), + Value: 42.0, + } + + err := messaging.Publish(context.Background(), id, value) + assert.NoError(t, err) + assert.True(t, publishCalled) +} + +func TestNATSMessaging_Subscribe(t *testing.T) { + handlerCalled := false + handler := func(msg DataMessage) error { + handlerCalled = true + return nil + } + + mockConsumer := &mockConsumer{} + mockStream := &mockStream{ + createOrUpdateConsumerFunc: func(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error) { + return mockConsumer, nil + }, + } + + messaging := &NATSMessaging{ + conn: &mockNATSConn{}, + stream: mockStream, + } + + err := messaging.Subscribe(handler) + assert.NoError(t, err) + assert.Contains(t, messaging.handlers, handler) +} + +func TestNATSMessaging_Close(t *testing.T) { + closeCalled := false + mockConn := &mockNATSConn{ + closeFunc: func() error { + closeCalled = true + return nil + }, + } + + messaging := &NATSMessaging{ + conn: mockConn, + } + + err := messaging.Close() + assert.NoError(t, err) + assert.True(t, closeCalled) +} diff --git a/pkg/model/datapoint.go b/pkg/model/datapoint.go index ebb2dba..4ad5d90 100644 --- a/pkg/model/datapoint.go +++ b/pkg/model/datapoint.go @@ -2,6 +2,7 @@ package model import ( "fmt" + "sort" "sync" "time" ) @@ -15,7 +16,51 @@ type DataPointID struct { // String 返回数据点标识的字符串表示 func (id DataPointID) String() string { - return fmt.Sprintf("%s:%s:%v", id.DeviceID, id.MetricCode, id.Labels) + return id.Hash() +} + +// Equal 判断两个数据点标识是否相等 +func (id DataPointID) Equal(other DataPointID) bool { + if id.DeviceID != other.DeviceID || id.MetricCode != other.MetricCode { + return false + } + + if len(id.Labels) != len(other.Labels) { + return false + } + + for k, v := range id.Labels { + if otherV, ok := other.Labels[k]; !ok || v != otherV { + return false + } + } + + return true +} + +// Hash 返回数据点标识的哈希值 +func (id DataPointID) Hash() string { + if len(id.Labels) == 0 { + return fmt.Sprintf("%s:%s:", id.DeviceID, id.MetricCode) + } + + // 提取并排序标签键 + keys := make([]string, 0, len(id.Labels)) + for k := range id.Labels { + keys = append(keys, k) + } + sort.Strings(keys) + + // 按排序后的键顺序构建标签字符串 + var labelStr string + for i, k := range keys { + if i == 0 { + labelStr = fmt.Sprintf("%s=%s", k, id.Labels[k]) + } else { + labelStr = fmt.Sprintf("%s:%s=%s", labelStr, k, id.Labels[k]) + } + } + return fmt.Sprintf("%s:%s:%s", id.DeviceID, id.MetricCode, labelStr) } // DataValue 数据值 diff --git a/pkg/model/query.go b/pkg/model/query.go index 76fa28c..125a0b6 100644 --- a/pkg/model/query.go +++ b/pkg/model/query.go @@ -2,6 +2,7 @@ package model import ( "context" + "time" ) // QueryType 查询类型 @@ -27,7 +28,7 @@ type Result interface { IsEmpty() bool AsLatest() (DataValue, bool) AsAll() ([]DataValue, bool) - AsDuration() (float64, bool) + AsDuration() (time.Duration, bool) } // QueryExecutor 查询执行器接口 @@ -63,7 +64,7 @@ func (q *BaseQuery) Params() map[string]interface{} { type BaseResult struct { latest *DataValue all []DataValue - duration *float64 + duration *time.Duration } // NewLatestResult 创建一个最新值查询结果 @@ -81,7 +82,7 @@ func NewAllResult(values []DataValue) Result { } // NewDurationResult 创建一个持续时间查询结果 -func NewDurationResult(duration float64) Result { +func NewDurationResult(duration time.Duration) Result { return &BaseResult{ duration: &duration, } @@ -109,7 +110,7 @@ func (r *BaseResult) AsAll() ([]DataValue, bool) { } // AsDuration 将结果转换为持续时间 -func (r *BaseResult) AsDuration() (float64, bool) { +func (r *BaseResult) AsDuration() (time.Duration, bool) { if r.duration != nil { return *r.duration, true } diff --git a/pkg/model/query_test.go b/pkg/model/query_test.go new file mode 100644 index 0000000..2778eaa --- /dev/null +++ b/pkg/model/query_test.go @@ -0,0 +1,251 @@ +package model + +import ( + "testing" + "time" +) + +func TestDataPointID(t *testing.T) { + tests := []struct { + name string + id DataPointID + wantEqual DataPointID + wantHash string + }{ + { + name: "basic data point id", + id: DataPointID{ + DeviceID: "device1", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + "floor": "1st", + }, + }, + wantEqual: DataPointID{ + DeviceID: "device1", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + "floor": "1st", + }, + }, + wantHash: "device1:temperature:floor=1st:location=room1", + }, + { + name: "empty labels", + id: DataPointID{ + DeviceID: "device2", + MetricCode: "humidity", + Labels: map[string]string{}, + }, + wantEqual: DataPointID{ + DeviceID: "device2", + MetricCode: "humidity", + Labels: map[string]string{}, + }, + wantHash: "device2:humidity:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test equality + if !tt.id.Equal(tt.wantEqual) { + t.Errorf("DataPointID.Equal() = false, want true") + } + + // Test hash generation + if hash := tt.id.Hash(); hash != tt.wantHash { + t.Errorf("DataPointID.Hash() = %v, want %v", hash, tt.wantHash) + } + }) + } +} + +func TestDataValue(t *testing.T) { + now := time.Now() + tests := []struct { + name string + value DataValue + want interface{} + }{ + { + name: "float value", + value: DataValue{ + Timestamp: now, + Value: 25.5, + }, + want: 25.5, + }, + { + name: "integer value", + value: DataValue{ + Timestamp: now, + Value: 100, + }, + want: 100, + }, + { + name: "string value", + value: DataValue{ + Timestamp: now, + Value: "test", + }, + want: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.value.Value != tt.want { + t.Errorf("DataValue.Value = %v, want %v", tt.value.Value, tt.want) + } + + if !tt.value.Timestamp.Equal(now) { + t.Errorf("DataValue.Timestamp = %v, want %v", tt.value.Timestamp, now) + } + }) + } +} + +func TestQuery(t *testing.T) { + tests := []struct { + name string + queryType QueryType + params map[string]interface{} + wantType QueryType + wantParams map[string]interface{} + }{ + { + name: "latest query", + queryType: QueryTypeLatest, + params: nil, + wantType: QueryTypeLatest, + wantParams: map[string]interface{}{}, + }, + { + name: "all query", + queryType: QueryTypeAll, + params: map[string]interface{}{ + "limit": 100, + }, + wantType: QueryTypeAll, + wantParams: map[string]interface{}{ + "limit": 100, + }, + }, + { + name: "duration query", + queryType: QueryTypeDuration, + params: map[string]interface{}{ + "from": "2023-01-01T00:00:00Z", + "to": "2023-01-02T00:00:00Z", + }, + wantType: QueryTypeDuration, + wantParams: map[string]interface{}{ + "from": "2023-01-01T00:00:00Z", + "to": "2023-01-02T00:00:00Z", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := NewQuery(tt.queryType, tt.params) + + if query.Type() != tt.wantType { + t.Errorf("Query.Type() = %v, want %v", query.Type(), tt.wantType) + } + + params := query.Params() + if len(params) != len(tt.wantParams) { + t.Errorf("Query.Params() length = %v, want %v", len(params), len(tt.wantParams)) + } + + for k, v := range tt.wantParams { + if params[k] != v { + t.Errorf("Query.Params()[%v] = %v, want %v", k, params[k], v) + } + } + }) + } +} + +func TestQueryResult(t *testing.T) { + now := time.Now() + tests := []struct { + name string + result Result + wantLatest DataValue + wantAll []DataValue + wantDuration time.Duration + }{ + { + name: "latest result", + result: NewLatestResult(DataValue{ + Timestamp: now, + Value: 25.5, + }), + wantLatest: DataValue{ + Timestamp: now, + Value: 25.5, + }, + }, + { + name: "all result", + result: NewAllResult([]DataValue{ + { + Timestamp: now, + Value: 25.5, + }, + { + Timestamp: now.Add(time.Second), + Value: 26.0, + }, + }), + wantAll: []DataValue{ + { + Timestamp: now, + Value: 25.5, + }, + { + Timestamp: now.Add(time.Second), + Value: 26.0, + }, + }, + }, + { + name: "duration result", + result: NewDurationResult(time.Hour), + wantDuration: time.Hour, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if latest, ok := tt.result.AsLatest(); ok { + if !latest.Timestamp.Equal(tt.wantLatest.Timestamp) || latest.Value != tt.wantLatest.Value { + t.Errorf("Result.AsLatest() = %v, want %v", latest, tt.wantLatest) + } + } + + if all, ok := tt.result.AsAll(); ok { + if len(all) != len(tt.wantAll) { + t.Errorf("Result.AsAll() length = %v, want %v", len(all), len(tt.wantAll)) + } + for i, v := range tt.wantAll { + if !all[i].Timestamp.Equal(v.Timestamp) || all[i].Value != v.Value { + t.Errorf("Result.AsAll()[%v] = %v, want %v", i, all[i], v) + } + } + } + + if duration, ok := tt.result.AsDuration(); ok { + if duration != tt.wantDuration { + t.Errorf("Result.AsDuration() = %v, want %v", duration, tt.wantDuration) + } + } + }) + } +} diff --git a/pkg/monitoring/collector.go b/pkg/monitoring/collector.go new file mode 100644 index 0000000..9792c30 --- /dev/null +++ b/pkg/monitoring/collector.go @@ -0,0 +1,159 @@ +package monitoring + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +// MetricsCollector 提供更简洁的指标收集API +type MetricsCollector struct { + writeTotal prometheus.Counter + queryTotal prometheus.Counter + writeLatency prometheus.Histogram + queryLatency prometheus.Histogram + activeConnections prometheus.Gauge + dataPointsCount prometheus.Gauge + persistenceLatency prometheus.Histogram + persistenceErrors prometheus.Counter + messagingLatency prometheus.Histogram + messagingErrors prometheus.Counter + websocketConnections prometheus.Gauge +} + +// NewMetricsCollector 创建一个新的指标收集器 +func NewMetricsCollector() *MetricsCollector { + return &MetricsCollector{ + writeTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gotidb_write_total", + Help: "Total number of write operations", + }), + queryTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gotidb_query_total", + Help: "Total number of query operations", + }), + writeLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "gotidb_write_latency_seconds", + Help: "Write operation latency in seconds", + Buckets: prometheus.DefBuckets, + }), + queryLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "gotidb_query_latency_seconds", + Help: "Query operation latency in seconds", + Buckets: prometheus.DefBuckets, + }), + activeConnections: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gotidb_active_connections", + Help: "Number of active connections", + }), + dataPointsCount: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gotidb_data_points_count", + Help: "Number of data points in the database", + }), + persistenceLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "gotidb_persistence_latency_seconds", + Help: "Persistence operation latency in seconds", + Buckets: prometheus.DefBuckets, + }), + persistenceErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gotidb_persistence_errors_total", + Help: "Total number of persistence errors", + }), + messagingLatency: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "gotidb_messaging_latency_seconds", + Help: "Messaging operation latency in seconds", + Buckets: prometheus.DefBuckets, + }), + messagingErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gotidb_messaging_errors_total", + Help: "Total number of messaging errors", + }), + websocketConnections: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gotidb_websocket_connections", + Help: "Number of active WebSocket connections", + }), + } +} + +// RecordWrite 记录写入操作及其延迟 +func (c *MetricsCollector) RecordWrite(duration time.Duration) { + c.writeTotal.Inc() + c.writeLatency.Observe(duration.Seconds()) +} + +// RecordQuery 记录查询操作及其延迟 +func (c *MetricsCollector) RecordQuery(duration time.Duration) { + c.queryTotal.Inc() + c.queryLatency.Observe(duration.Seconds()) +} + +// IncActiveConnections 增加活跃连接数 +func (c *MetricsCollector) IncActiveConnections() { + c.activeConnections.Inc() +} + +// DecActiveConnections 减少活跃连接数 +func (c *MetricsCollector) DecActiveConnections() { + c.activeConnections.Dec() +} + +// SetDataPointsCount 设置数据点数量 +func (c *MetricsCollector) SetDataPointsCount(count float64) { + c.dataPointsCount.Set(count) +} + +// RecordPersistence 记录持久化操作及其延迟 +func (c *MetricsCollector) RecordPersistence(duration time.Duration, err error) { + c.persistenceLatency.Observe(duration.Seconds()) + if err != nil { + c.persistenceErrors.Inc() + } +} + +// RecordMessaging 记录消息操作及其延迟 +func (c *MetricsCollector) RecordMessaging(duration time.Duration, err error) { + c.messagingLatency.Observe(duration.Seconds()) + if err != nil { + c.messagingErrors.Inc() + } +} + +// IncWebSocketConnections 增加WebSocket连接数 +func (c *MetricsCollector) IncWebSocketConnections() { + c.websocketConnections.Inc() +} + +// DecWebSocketConnections 减少WebSocket连接数 +func (c *MetricsCollector) DecWebSocketConnections() { + c.websocketConnections.Dec() +} + +// Describe 实现prometheus.Collector接口 +func (c *MetricsCollector) Describe(ch chan<- *prometheus.Desc) { + c.writeTotal.Describe(ch) + c.queryTotal.Describe(ch) + c.writeLatency.Describe(ch) + c.queryLatency.Describe(ch) + c.activeConnections.Describe(ch) + c.dataPointsCount.Describe(ch) + c.persistenceLatency.Describe(ch) + c.persistenceErrors.Describe(ch) + c.messagingLatency.Describe(ch) + c.messagingErrors.Describe(ch) + c.websocketConnections.Describe(ch) +} + +// Collect 实现prometheus.Collector接口 +func (c *MetricsCollector) Collect(ch chan<- prometheus.Metric) { + c.writeTotal.Collect(ch) + c.queryTotal.Collect(ch) + c.writeLatency.Collect(ch) + c.queryLatency.Collect(ch) + c.activeConnections.Collect(ch) + c.dataPointsCount.Collect(ch) + c.persistenceLatency.Collect(ch) + c.persistenceErrors.Collect(ch) + c.messagingLatency.Collect(ch) + c.messagingErrors.Collect(ch) + c.websocketConnections.Collect(ch) +} diff --git a/pkg/monitoring/metrics_test.go b/pkg/monitoring/metrics_test.go new file mode 100644 index 0000000..6af0747 --- /dev/null +++ b/pkg/monitoring/metrics_test.go @@ -0,0 +1,240 @@ +package monitoring + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestMetricsCollector(t *testing.T) { + // 创建指标收集器 + collector := NewMetricsCollector() + + // 测试写入操作指标 + t.Run("WriteMetrics", func(t *testing.T) { + // 记录写入操作 + collector.RecordWrite(10 * time.Millisecond) + + // 验证写入总数 + writeTotal := testutil.ToFloat64(collector.writeTotal) + if writeTotal != 1 { + t.Errorf("Expected write_total to be 1, got %v", writeTotal) + } + + // 对于Histogram类型,我们只验证它是否被正确注册和收集 + // 而不是尝试获取其具体值 + registry := prometheus.NewPedanticRegistry() + registry.MustRegister(collector.writeLatency) + metrics, err := registry.Gather() + if err != nil { + t.Errorf("Failed to gather metrics: %v", err) + } + if len(metrics) == 0 { + t.Error("Expected write_latency to be collected, but got no metrics") + } + }) + + // 测试查询操作指标 + t.Run("QueryMetrics", func(t *testing.T) { + // 记录查询操作 + collector.RecordQuery(20 * time.Millisecond) + + // 验证查询总数 + queryTotal := testutil.ToFloat64(collector.queryTotal) + if queryTotal != 1 { + t.Errorf("Expected query_total to be 1, got %v", queryTotal) + } + + // 对于Histogram类型,我们只验证它是否被正确注册和收集 + registry := prometheus.NewPedanticRegistry() + registry.MustRegister(collector.queryLatency) + metrics, err := registry.Gather() + if err != nil { + t.Errorf("Failed to gather metrics: %v", err) + } + if len(metrics) == 0 { + t.Error("Expected query_latency to be collected, but got no metrics") + } + }) + + // 测试连接数指标 + t.Run("ConnectionMetrics", func(t *testing.T) { + // 增加连接数 + collector.IncActiveConnections() + collector.IncActiveConnections() + + // 验证活跃连接数 + activeConns := testutil.ToFloat64(collector.activeConnections) + if activeConns != 2 { + t.Errorf("Expected active_connections to be 2, got %v", activeConns) + } + + // 减少连接数 + collector.DecActiveConnections() + + // 验证更新后的活跃连接数 + activeConns = testutil.ToFloat64(collector.activeConnections) + if activeConns != 1 { + t.Errorf("Expected active_connections to be 1, got %v", activeConns) + } + }) + + // 测试数据点数量指标 + t.Run("DataPointsMetrics", func(t *testing.T) { + // 设置数据点数量 + collector.SetDataPointsCount(100) + + // 验证数据点数量 + dataPoints := testutil.ToFloat64(collector.dataPointsCount) + if dataPoints != 100 { + t.Errorf("Expected data_points_count to be 100, got %v", dataPoints) + } + }) + + // 测试持久化指标 + t.Run("PersistenceMetrics", func(t *testing.T) { + // 记录持久化操作 + collector.RecordPersistence(30*time.Millisecond, nil) + + // 对于Histogram类型,我们只验证它是否被正确注册和收集 + registry := prometheus.NewPedanticRegistry() + registry.MustRegister(collector.persistenceLatency) + metrics, err := registry.Gather() + if err != nil { + t.Errorf("Failed to gather metrics: %v", err) + } + if len(metrics) == 0 { + t.Error("Expected persistence_latency to be collected, but got no metrics") + } + + // 验证持久化错误数(应该为0,因为没有错误) + persistenceErrors := testutil.ToFloat64(collector.persistenceErrors) + if persistenceErrors != 0 { + t.Errorf("Expected persistence_errors to be 0, got %v", persistenceErrors) + } + + // 记录持久化错误 + collector.RecordPersistence(30*time.Millisecond, errTestPersistence) + + // 验证持久化错误数(应该为1) + persistenceErrors = testutil.ToFloat64(collector.persistenceErrors) + if persistenceErrors != 1 { + t.Errorf("Expected persistence_errors to be 1, got %v", persistenceErrors) + } + }) + + // 测试消息系统指标 + t.Run("MessagingMetrics", func(t *testing.T) { + // 记录消息操作 + collector.RecordMessaging(40*time.Millisecond, nil) + + // 对于Histogram类型,我们只验证它是否被正确注册和收集 + registry := prometheus.NewPedanticRegistry() + registry.MustRegister(collector.messagingLatency) + metrics, err := registry.Gather() + if err != nil { + t.Errorf("Failed to gather metrics: %v", err) + } + if len(metrics) == 0 { + t.Error("Expected messaging_latency to be collected, but got no metrics") + } + + // 验证消息错误数(应该为0,因为没有错误) + messagingErrors := testutil.ToFloat64(collector.messagingErrors) + if messagingErrors != 0 { + t.Errorf("Expected messaging_errors to be 0, got %v", messagingErrors) + } + + // 记录消息错误 + collector.RecordMessaging(40*time.Millisecond, errTestMessaging) + + // 验证消息错误数(应该为1) + messagingErrors = testutil.ToFloat64(collector.messagingErrors) + if messagingErrors != 1 { + t.Errorf("Expected messaging_errors to be 1, got %v", messagingErrors) + } + }) + + // 测试WebSocket连接指标 + t.Run("WebSocketMetrics", func(t *testing.T) { + // 增加WebSocket连接数 + collector.IncWebSocketConnections() + collector.IncWebSocketConnections() + + // 验证WebSocket连接数 + wsConns := testutil.ToFloat64(collector.websocketConnections) + if wsConns != 2 { + t.Errorf("Expected websocket_connections to be 2, got %v", wsConns) + } + + // 减少WebSocket连接数 + collector.DecWebSocketConnections() + + // 验证更新后的WebSocket连接数 + wsConns = testutil.ToFloat64(collector.websocketConnections) + if wsConns != 1 { + t.Errorf("Expected websocket_connections to be 1, got %v", wsConns) + } + }) + + // 测试指标注册 + t.Run("MetricsRegistration", func(t *testing.T) { + registry := prometheus.NewRegistry() + + // 注册指标收集器 + err := registry.Register(collector) + if err != nil { + t.Errorf("Failed to register metrics collector: %v", err) + } + + // 验证所有指标都已注册 + metricFamilies, err := registry.Gather() + if err != nil { + t.Errorf("Failed to gather metrics: %v", err) + } + + expectedMetrics := []string{ + "gotidb_write_total", + "gotidb_query_total", + "gotidb_write_latency_seconds", + "gotidb_query_latency_seconds", + "gotidb_active_connections", + "gotidb_data_points_count", + "gotidb_persistence_latency_seconds", + "gotidb_persistence_errors_total", + "gotidb_messaging_latency_seconds", + "gotidb_messaging_errors_total", + "gotidb_websocket_connections", + } + + for _, metricName := range expectedMetrics { + found := false + for _, mf := range metricFamilies { + if *mf.Name == metricName { + found = true + break + } + } + if !found { + t.Errorf("Expected metric %s not found in registry", metricName) + } + } + }) +} + +// 测试错误 +var ( + errTestPersistence = &testError{msg: "test persistence error"} + errTestMessaging = &testError{msg: "test messaging error"} +) + +// 测试错误类型 +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} diff --git a/pkg/storage/engine.go b/pkg/storage/engine.go index 0919a49..361738c 100644 --- a/pkg/storage/engine.go +++ b/pkg/storage/engine.go @@ -3,6 +3,7 @@ package storage import ( "context" "sync" + "time" "git.pyer.club/kingecg/gotidb/pkg/model" ) @@ -33,7 +34,7 @@ type StorageEngine interface { // GetLatest 获取最新数据 GetLatest(ctx context.Context, id model.DataPointID) (model.DataValue, error) // GetDuration 获取持续时间 - GetDuration(ctx context.Context, id model.DataPointID) (float64, error) + GetDuration(ctx context.Context, id model.DataPointID) (time.Duration, error) // EnablePersistence 启用持久化 EnablePersistence(config PersistenceConfig) error // Close 关闭存储引擎 @@ -47,6 +48,56 @@ type MemoryEngine struct { persister Persister // 持久化器 } +// ReadLatest 读取最新数据(GetLatest 的别名) +func (e *MemoryEngine) ReadLatest(ctx context.Context, id model.DataPointID) (model.DataValue, error) { + return e.GetLatest(ctx, id) +} + +// BatchWrite 批量写入数据 +func (e *MemoryEngine) BatchWrite(ctx context.Context, batch []struct { + ID model.DataPointID + Value model.DataValue +}) error { + for _, item := range batch { + if err := e.Write(ctx, item.ID, item.Value); err != nil { + return err + } + } + return nil +} + +// ReadAll 读取所有数据(Read 的别名) +func (e *MemoryEngine) ReadAll(ctx context.Context, id model.DataPointID) ([]model.DataValue, error) { + return e.Read(ctx, id) +} + +// ReadDuration 读取指定时间范围内的数据 +func (e *MemoryEngine) ReadDuration(ctx context.Context, id model.DataPointID, from, to time.Time) ([]model.DataValue, error) { + key := id.String() + + e.dataLock.RLock() + buffer, exists := e.data[key] + e.dataLock.RUnlock() + + if !exists { + return []model.DataValue{}, nil + } + + // 读取所有数据 + allValues := buffer.Read() + + // 过滤出指定时间范围内的数据 + var filteredValues []model.DataValue + for _, value := range allValues { + if (value.Timestamp.Equal(from) || value.Timestamp.After(from)) && + (value.Timestamp.Equal(to) || value.Timestamp.Before(to)) { + filteredValues = append(filteredValues, value) + } + } + + return filteredValues, nil +} + // NewMemoryEngine 创建一个新的内存存储引擎 func NewMemoryEngine() *MemoryEngine { return &MemoryEngine{ @@ -119,7 +170,7 @@ func (e *MemoryEngine) GetLatest(ctx context.Context, id model.DataPointID) (mod } // GetDuration 获取持续时间 -func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (float64, error) { +func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (time.Duration, error) { key := id.String() e.dataLock.RLock() @@ -130,8 +181,7 @@ func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (f return 0, nil } - duration := buffer.GetDuration() - return duration.Seconds(), nil + return buffer.GetDuration(), nil } // EnablePersistence 启用持久化 diff --git a/pkg/storage/engine_test.go b/pkg/storage/engine_test.go new file mode 100644 index 0000000..bf7fa4e --- /dev/null +++ b/pkg/storage/engine_test.go @@ -0,0 +1,246 @@ +package storage + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "git.pyer.club/kingecg/gotidb/pkg/model" +) + +func TestMemoryEngine(t *testing.T) { + // 创建内存存储引擎 + engine := NewMemoryEngine() + + // 创建测试数据 + id := model.DataPointID{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + } + + now := time.Now() + value := model.DataValue{ + Timestamp: now, + Value: 25.5, + } + + // 测试写入 + t.Run("Write", func(t *testing.T) { + err := engine.Write(context.Background(), id, value) + if err != nil { + t.Errorf("Write() error = %v", err) + } + }) + + // 测试读取最新值 + t.Run("ReadLatest", func(t *testing.T) { + latest, err := engine.ReadLatest(context.Background(), id) + if err != nil { + t.Errorf("ReadLatest() error = %v", err) + } + + if !latest.Timestamp.Equal(value.Timestamp) { + t.Errorf("ReadLatest() timestamp = %v, want %v", latest.Timestamp, value.Timestamp) + } + + if latest.Value != value.Value { + t.Errorf("ReadLatest() value = %v, want %v", latest.Value, value.Value) + } + }) + + // 测试批量写入 + t.Run("BatchWrite", func(t *testing.T) { + id2 := model.DataPointID{ + DeviceID: "test-device", + MetricCode: "humidity", + Labels: map[string]string{ + "location": "room1", + }, + } + + value2 := model.DataValue{ + Timestamp: now, + Value: 60.0, + } + + batch := []struct { + ID model.DataPointID + Value model.DataValue + }{ + { + ID: id, + Value: value, + }, + { + ID: id2, + Value: value2, + }, + } + + err := engine.BatchWrite(context.Background(), batch) + if err != nil { + t.Errorf("BatchWrite() error = %v", err) + } + + // 验证批量写入的数据 + latest, err := engine.ReadLatest(context.Background(), id2) + if err != nil { + t.Errorf("ReadLatest() after BatchWrite error = %v", err) + } + + if !latest.Timestamp.Equal(value2.Timestamp) { + t.Errorf("ReadLatest() after BatchWrite timestamp = %v, want %v", latest.Timestamp, value2.Timestamp) + } + + if latest.Value != value2.Value { + t.Errorf("ReadLatest() after BatchWrite value = %v, want %v", latest.Value, value2.Value) + } + }) + + // 测试读取所有值 + t.Run("ReadAll", func(t *testing.T) { + // 写入多个值 + for i := 1; i <= 5; i++ { + newValue := model.DataValue{ + Timestamp: now.Add(time.Duration(i) * time.Minute), + Value: 25.5 + float64(i), + } + err := engine.Write(context.Background(), id, newValue) + if err != nil { + t.Errorf("Write() for ReadAll error = %v", err) + } + } + + // 读取所有值 + values, err := engine.ReadAll(context.Background(), id) + if err != nil { + t.Errorf("ReadAll() error = %v", err) + } + + // 验证读取的值数量 + if len(values) != 6 { // 初始值 + 5个新值 + t.Errorf("ReadAll() returned %v values, want %v", len(values), 6) + } + + // 验证值是按时间顺序排列的 + for i := 1; i < len(values); i++ { + if values[i].Timestamp.Before(values[i-1].Timestamp) { + t.Errorf("ReadAll() values not in chronological order") + } + } + }) + + // 测试读取持续时间内的值 + t.Run("ReadDuration", func(t *testing.T) { + // 设置时间范围 + from := now.Add(1 * time.Minute) + to := now.Add(3 * time.Minute) + + // 读取指定时间范围内的值 + values, err := engine.ReadDuration(context.Background(), id, from, to) + if err != nil { + t.Errorf("ReadDuration() error = %v", err) + } + + // 验证读取的值数量 + if len(values) != 3 { // 1分钟、2分钟和3分钟的值 + t.Errorf("ReadDuration() returned %v values, want %v", len(values), 3) + } + + // 验证所有值都在指定的时间范围内 + for _, v := range values { + if v.Timestamp.Before(from) || v.Timestamp.After(to) { + t.Errorf("ReadDuration() returned value with timestamp %v outside range [%v, %v]", v.Timestamp, from, to) + } + } + }) +} + +func TestPersistence(t *testing.T) { + // 创建临时目录 + tempDir, err := os.MkdirTemp("", "gotidb-test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // 创建内存存储引擎 + engine := NewMemoryEngine() + + // 启用WAL持久化 + persistenceConfig := PersistenceConfig{ + Type: PersistenceTypeWAL, + Directory: tempDir, + SyncEvery: 1, // 每次写入都同步 + } + + err = engine.EnablePersistence(persistenceConfig) + if err != nil { + t.Fatalf("EnablePersistence() error = %v", err) + } + + // 创建测试数据 + id := model.DataPointID{ + DeviceID: "test-device", + MetricCode: "temperature", + Labels: map[string]string{ + "location": "room1", + }, + } + + now := time.Now() + value := model.DataValue{ + Timestamp: now, + Value: 25.5, + } + + // 写入数据 + err = engine.Write(context.Background(), id, value) + if err != nil { + t.Errorf("Write() with persistence error = %v", err) + } + + // 关闭引擎 + err = engine.Close() + if err != nil { + t.Errorf("Close() error = %v", err) + } + + // 检查WAL文件是否存在 + walFiles, err := filepath.Glob(filepath.Join(tempDir, "*.wal")) + if err != nil { + t.Errorf("Failed to list WAL files: %v", err) + } + if len(walFiles) == 0 { + t.Errorf("No WAL files found after persistence") + } + + // 创建新的引擎并从WAL恢复 + newEngine := NewMemoryEngine() + err = newEngine.EnablePersistence(persistenceConfig) + if err != nil { + t.Fatalf("EnablePersistence() for new engine error = %v", err) + } + + // 读取恢复后的数据 + latest, err := newEngine.ReadLatest(context.Background(), id) + if err != nil { + t.Errorf("ReadLatest() after recovery error = %v", err) + } + + // 验证恢复的数据 + if latest.Value != value.Value { + t.Errorf("ReadLatest() after recovery value = %v, want %v", latest.Value, value.Value) + } + + // 关闭新引擎 + err = newEngine.Close() + if err != nil { + t.Errorf("Close() new engine error = %v", err) + } +}