diff --git a/go.mod b/go.mod index 12c81a5..9bb706c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module git.kingecg.top/kingecg/gohttpd -go 1.23.0 +go 1.23 -toolchain go1.23.1 +// toolchain go1.23.1 require ( git.kingecg.top/kingecg/cmux v1.0.1 diff --git a/handler/file.go b/handler/file.go index ebadc14..893c844 100644 --- a/handler/file.go +++ b/handler/file.go @@ -33,12 +33,12 @@ func (f FileHandler) Open(name string) (http.File, error) { rPath := filepath.Join(f.Root, strings.TrimPrefix(relatedPath, "/")) l.Debug("access:", rPath) - // if rPath == f.Root { - // if f.Default == "" { - // return nil, errors.New("not permit list dir") - // } - // rPath = filepath.Join(rPath, f.Default) - // } + if rPath == f.Root { + if f.Default == "" { + return nil, errors.New("not permit list dir") + } + rPath = filepath.Join(rPath, f.Default) + } fInfo, _, err := FileExists(rPath) if err != nil { diff --git a/server/server.go b/server/server.go index 61899ef..9f87d84 100644 --- a/server/server.go +++ b/server/server.go @@ -360,6 +360,7 @@ func NewServeMux(c *model.HttpServerConfig) *ServerMux { } if fhandler != nil { directives := httpPath.Directives + l.Info(fmt.Sprintf("Register path %s", httpPath.Path)) s.Handle(httpPath.Path, fhandler, directives) } // s.Handle(httpPath.Path, fhandler) diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..2b7e37f --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,207 @@ +package server + +import ( + "context" + "net/http" + "strings" + "testing" + + "git.kingecg.top/kingecg/gohttpd/model" +) + +func Test_Route_Match(t *testing.T) { + // 创建测试路由 + route := &Route{ + Method: "GET", + Path: "/test", + } + + // 测试方法匹配 + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + if !route.Match(req) { + t.Error("Expected GET method to match") + } + + // 测试路径匹配 + req, _ = http.NewRequest("GET", "http://example.com/test/123", nil) + if !route.Match(req) { + t.Error("Expected path to match with prefix") + } + + // 测试方法不匹配 + req, _ = http.NewRequest("POST", "http://example.com/test", nil) + if route.Match(req) { + t.Error("Expected POST method to not match") + } + + // 测试路径不匹配 + req, _ = http.NewRequest("GET", "http://example.com/other", nil) + if route.Match(req) { + t.Error("Expected different path to not match") + } +} + +func Test_RestMux_HandleFunc(t *testing.T) { + // 创建RestMux实例 + mux := NewRestMux("/api") + + // 测试GET路由注册 + called := false + mux.HandleFunc("GET", "/users", func(w http.ResponseWriter, r *http.Request) { + called = true + }) + + // 测试路由是否正确添加 + if len(mux.routes) != 1 { + t.Errorf("Expected 1 route, got %d", len(mux.routes)) + } + + // 测试路由匹配 + req, _ := http.NewRequest("GET", "http://example.com/api/users", nil) + found := false + for _, route := range mux.routes { + // 需要更新上下文路径来匹配路由 + ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{}) + req = req.WithContext(ctx) + + if route.Match(req) { + found = true + // 执行处理函数 + route.ServeHTTP(nil, req) + break + } + } + + if !found { + t.Error("Expected route to be found") + } + + if !called { + t.Error("Expected handler function to be called") + } +} + +func Test_ServerMux_Handle(t *testing.T) { + // 创建ServerMux实例 + s := &ServerMux{ + handlers: make(map[string]http.Handler), + paths: []string{}, + } + + // 创建测试处理程序 + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + }) + + // 测试基本路由注册 + directives := []string{"test_directive"} + s.Handle("/test", handler, directives) + + // 测试路由是否正确添加 + if len(s.handlers) != 1 { + t.Errorf("Expected 1 handler, got %d", len(s.handlers)) + } + + // 测试路径排序 + if len(s.paths) != 1 || s.paths[0] != "/test" { + t.Errorf("Expected paths to contain '/test'") + } + + // 测试路由匹配 + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{}) + req = req.WithContext(ctx) + + // 调用内部serveHTTP方法进行测试 + // 创建ResponseRecorder来捕获响应 + w := &testResponseWriter{} + s.serveHTTP(w, req) + + if !called { + t.Error("Expected handler function to be called") + } +} + +// 自定义ResponseWriter实现用于测试 +type testResponseWriter struct { + statusCode int + body []byte +} + +func (w *testResponseWriter) Header() http.Header { + return http.Header{} +} + +func (w *testResponseWriter) Write(b []byte) (int, error) { + w.body = append(w.body, b...) + return len(b), nil +} + +func (w *testResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +func Test_NewServeMux(t *testing.T) { + // 创建测试配置 + config := &model.HttpServerConfig{ + Name: "test_server", + Port: 8080, + Host: "localhost", + Directives: []string{ + "Record-Access test_server", + }, + AllowIPs: []string{"192.168.1.0/24"}, + DenyIPs: []string{"10.0.0.0/8"}, + Paths: []model.HttpPath{ + { + Path: "/", + Root: "/var/www/html", + Default: "index.html", + }, + { + Path: "/test", + Root: "/home/kingecg/code/gohttp", + Default: "index.html", + }, + }, + } + + // 调用NewServeMux + s := NewServeMux(config) + + // 验证基本配置 + if s == nil { + t.Error("Expected ServerMux instance, got nil") + } + + // // 验证指令处理程序 + // if s.directiveHandlers.Len() == 0 { + // t.Error("Expected directive handlers to be registered") + // } + + // 验证路径处理程序 + if len(s.handlers) != 2 || !strings.Contains(s.paths[0], "/test") { + t.Error("Expected handlers to be registered for /test path") + } + + req, _ := http.NewRequest("GET", "http://localhost/test", nil) + ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{}) + req = req.WithContext(ctx) + + // 调用内部serveHTTP方法进行测试 + // 创建ResponseRecorder来捕获响应 + w := &testResponseWriter{} + s.serveHTTP(w, req) + // 验证IP访问控制中间件是否添加 + // ipAccessFound := false + // for _, m := range s.directiveHandlers. { + // if fmt.Sprintf("%v", m) == "IPAccessControl" { + // ipAccessFound = true + // break + // } + // } + // if !ipAccessFound { + // t.Error("Expected IPAccessControl middleware to be added") + // } +}