208 lines
4.8 KiB
Go
208 lines
4.8 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.kingecg.top/kingecg/gohttpd/model"
|
|
)
|
|
|
|
func Test_Route_Match(t *testing.T) {
|
|
// 创建测试路由
|
|
route := &Route{
|
|
Method: "GET",
|
|
Path: "/test",
|
|
}
|
|
|
|
// 测试方法匹配
|
|
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
|
if !route.Match(req) {
|
|
t.Error("Expected GET method to match")
|
|
}
|
|
|
|
// 测试路径匹配
|
|
req, _ = http.NewRequest("GET", "http://example.com/test/123", nil)
|
|
if !route.Match(req) {
|
|
t.Error("Expected path to match with prefix")
|
|
}
|
|
|
|
// 测试方法不匹配
|
|
req, _ = http.NewRequest("POST", "http://example.com/test", nil)
|
|
if route.Match(req) {
|
|
t.Error("Expected POST method to not match")
|
|
}
|
|
|
|
// 测试路径不匹配
|
|
req, _ = http.NewRequest("GET", "http://example.com/other", nil)
|
|
if route.Match(req) {
|
|
t.Error("Expected different path to not match")
|
|
}
|
|
}
|
|
|
|
func Test_RestMux_HandleFunc(t *testing.T) {
|
|
// 创建RestMux实例
|
|
mux := NewRestMux("/api")
|
|
|
|
// 测试GET路由注册
|
|
called := false
|
|
mux.HandleFunc("GET", "/users", func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
})
|
|
|
|
// 测试路由是否正确添加
|
|
if len(mux.routes) != 1 {
|
|
t.Errorf("Expected 1 route, got %d", len(mux.routes))
|
|
}
|
|
|
|
// 测试路由匹配
|
|
req, _ := http.NewRequest("GET", "http://example.com/api/users", nil)
|
|
found := false
|
|
for _, route := range mux.routes {
|
|
// 需要更新上下文路径来匹配路由
|
|
ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{})
|
|
req = req.WithContext(ctx)
|
|
|
|
if route.Match(req) {
|
|
found = true
|
|
// 执行处理函数
|
|
route.ServeHTTP(nil, req)
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Error("Expected route to be found")
|
|
}
|
|
|
|
if !called {
|
|
t.Error("Expected handler function to be called")
|
|
}
|
|
}
|
|
|
|
func Test_ServerMux_Handle(t *testing.T) {
|
|
// 创建ServerMux实例
|
|
s := &ServerMux{
|
|
handlers: make(map[string]http.Handler),
|
|
paths: []string{},
|
|
}
|
|
|
|
// 创建测试处理程序
|
|
called := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
})
|
|
|
|
// 测试基本路由注册
|
|
directives := []string{"test_directive"}
|
|
s.Handle("/test", handler, directives)
|
|
|
|
// 测试路由是否正确添加
|
|
if len(s.handlers) != 1 {
|
|
t.Errorf("Expected 1 handler, got %d", len(s.handlers))
|
|
}
|
|
|
|
// 测试路径排序
|
|
if len(s.paths) != 1 || s.paths[0] != "/test" {
|
|
t.Errorf("Expected paths to contain '/test'")
|
|
}
|
|
|
|
// 测试路由匹配
|
|
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
|
ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{})
|
|
req = req.WithContext(ctx)
|
|
|
|
// 调用内部serveHTTP方法进行测试
|
|
// 创建ResponseRecorder来捕获响应
|
|
w := &testResponseWriter{}
|
|
s.serveHTTP(w, req)
|
|
|
|
if !called {
|
|
t.Error("Expected handler function to be called")
|
|
}
|
|
}
|
|
|
|
// 自定义ResponseWriter实现用于测试
|
|
type testResponseWriter struct {
|
|
statusCode int
|
|
body []byte
|
|
}
|
|
|
|
func (w *testResponseWriter) Header() http.Header {
|
|
return http.Header{}
|
|
}
|
|
|
|
func (w *testResponseWriter) Write(b []byte) (int, error) {
|
|
w.body = append(w.body, b...)
|
|
return len(b), nil
|
|
}
|
|
|
|
func (w *testResponseWriter) WriteHeader(statusCode int) {
|
|
w.statusCode = statusCode
|
|
}
|
|
|
|
func Test_NewServeMux(t *testing.T) {
|
|
// 创建测试配置
|
|
config := &model.HttpServerConfig{
|
|
Name: "test_server",
|
|
Port: 8080,
|
|
Host: "localhost",
|
|
Directives: []string{
|
|
"Record-Access test_server",
|
|
},
|
|
AllowIPs: []string{"192.168.1.0/24"},
|
|
DenyIPs: []string{"10.0.0.0/8"},
|
|
Paths: []model.HttpPath{
|
|
{
|
|
Path: "/",
|
|
Root: "/var/www/html",
|
|
Default: "index.html",
|
|
},
|
|
{
|
|
Path: "/test",
|
|
Root: "/home/kingecg/code/gohttp",
|
|
Default: "index.html",
|
|
},
|
|
},
|
|
}
|
|
|
|
// 调用NewServeMux
|
|
s := NewServeMux(config)
|
|
|
|
// 验证基本配置
|
|
if s == nil {
|
|
t.Error("Expected ServerMux instance, got nil")
|
|
}
|
|
|
|
// // 验证指令处理程序
|
|
// if s.directiveHandlers.Len() == 0 {
|
|
// t.Error("Expected directive handlers to be registered")
|
|
// }
|
|
|
|
// 验证路径处理程序
|
|
if len(s.handlers) != 2 || !strings.Contains(s.paths[0], "/test") {
|
|
t.Error("Expected handlers to be registered for /test path")
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "http://localhost/test", nil)
|
|
ctx := context.WithValue(req.Context(), RequestCtxKey("data"), map[string]interface{}{})
|
|
req = req.WithContext(ctx)
|
|
|
|
// 调用内部serveHTTP方法进行测试
|
|
// 创建ResponseRecorder来捕获响应
|
|
w := &testResponseWriter{}
|
|
s.serveHTTP(w, req)
|
|
// 验证IP访问控制中间件是否添加
|
|
// ipAccessFound := false
|
|
// for _, m := range s.directiveHandlers. {
|
|
// if fmt.Sprintf("%v", m) == "IPAccessControl" {
|
|
// ipAccessFound = true
|
|
// break
|
|
// }
|
|
// }
|
|
// if !ipAccessFound {
|
|
// t.Error("Expected IPAccessControl middleware to be added")
|
|
// }
|
|
}
|