From 8a732bf623b03ce3ac4d50d9f67bdce509260ae8 Mon Sep 17 00:00:00 2001 From: kingecg Date: Sat, 28 Jun 2025 15:07:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(goaiagent):=20=E5=AE=9E=E7=8E=B0=E5=A4=9A?= =?UTF-8?q?=E5=8E=82=E5=95=86=E5=A4=A7=E6=A8=A1=E5=9E=8BAPI=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E5=92=8C=E6=B5=81=E5=BC=8F=E8=AF=BB=E5=8F=96=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 DeepSeek 和 OpenAI 兼容实现 - 添加流式读取器和相关处理逻辑 - 实现统一的请求和响应格式 - 优化对话截断策略和 Token 计数功能 --- .gitignore | 1 + go.mod | 5 + go.sum | 2 + main.go | 170 +++++++++++++++++++++++ model/model.go | 17 +++ provider/deepseek.go | 103 ++++++++++++++ provider/interface.go | 15 +++ provider/openai.go | 59 ++++++++ provider/stream.go | 307 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 679 insertions(+) create mode 100644 .gitignore create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 model/model.go create mode 100644 provider/deepseek.go create mode 100644 provider/interface.go create mode 100644 provider/openai.go create mode 100644 provider/stream.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a725465 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +vendor/ \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..90ddd2c --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.kingecg.top/kingecg/goaiagent + +go 1.23.1 + +require github.com/sashabaranov/go-openai v1.40.3 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c34d4f1 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/sashabaranov/go-openai v1.40.3 h1:PkOw0SK34wrvYVOuXF1HZzuTBRh992qRZHil4kG3eYE= +github.com/sashabaranov/go-openai v1.40.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= diff --git a/main.go b/main.go new file mode 100644 index 0000000..9cc1118 --- /dev/null +++ b/main.go @@ -0,0 +1,170 @@ +package main + +import ( + "context" + + "fmt" + + "strings" + "sync/atomic" + + "git.kingecg.top/kingecg/goaiagent/model" + "git.kingecg.top/kingecg/goaiagent/provider" + openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4] +) + +// ============================== +// 多厂商API支持核心设计 +// ============================== + +// ============================== +// AI Agent 核心实现 +// ============================== +type AIAgent struct { + provider provider.Provider + tokenCount TokenCounter + maxTokens int +} + +type TokenCounter struct { + InputTokens int64 + OutputTokens int64 + TotalTokens int64 +} + +func (tc *TokenCounter) Add(input, output int) { + atomic.AddInt64(&tc.InputTokens, int64(input)) + atomic.AddInt64(&tc.OutputTokens, int64(output)) + atomic.AddInt64(&tc.TotalTokens, int64(input+output)) +} + +func (tc *TokenCounter) Stats() string { + return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d", + tc.InputTokens, tc.OutputTokens, tc.TotalTokens) +} + +func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent { + return &AIAgent{ + provider: provider, + maxTokens: maxTokens, + } +} + +func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) { + messages := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"}, + {Role: openai.ChatMessageRoleUser, Content: prompt}, + } + + // 检查并截断过长对话 + messages = a.truncateConversation(messages) + + // 创建统一请求 + req := model.ChatRequest{ + Model: a.provider.GetModelName(), + Messages: messages, + MaxTokens: a.maxTokens, + } + + // 调用厂商API + resp, err := a.provider.CreateChatCompletion(ctx, req) + if err != nil { + return "", err + } + + // 更新Token统计 + a.tokenCount.Add(resp.InputTokens, resp.OutputTokens) + + return resp.Content, nil +} + +// 截断策略:保留系统消息+最新对话 +func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage { + const maxContextTokens = 3000 + const keepLastMessages = 4 // 保留最后4轮对话 + + currentTokens := 0 + for _, msg := range messages { + currentTokens += a.provider.CountTokens(msg.Content) + } + + // 无需截断 + if currentTokens <= maxContextTokens { + return messages + } + + // 优先保留系统消息 + var newMessages []openai.ChatCompletionMessage + if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem { + newMessages = append(newMessages, messages[0]) + } + + // 添加最近的对话 + startIndex := len(messages) - keepLastMessages + if startIndex < 0 { + startIndex = 0 + } + newMessages = append(newMessages, messages[startIndex:]...) + + // 二次检查并逐条截断 + total := 0 + for _, msg := range newMessages { + total += a.provider.CountTokens(msg.Content) + } + + overflow := total - maxContextTokens + if overflow > 0 { + for i := len(newMessages) - 1; i >= 0; i-- { + if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem { + continue // 不截断系统消息 + } + + contentTokens := a.provider.CountTokens(newMessages[i].Content) + if contentTokens > overflow { + // 截断内容 + tokens := strings.Fields(newMessages[i].Content) + keepTokens := tokens[:len(tokens)-overflow] + newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]" + break + } else { + newMessages[i].Content = "[内容过长已移除]" + overflow -= contentTokens + } + } + } + + return newMessages +} + +func (a *AIAgent) GetTokenStats() string { + return a.tokenCount.Stats() +} + +// ============================== +// 使用示例 +// ============================== +func main() { + // 使用DeepSeek的Reasoner模型(128K上下文):cite[5] + deepseekAgent := NewAIAgent( + provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7] + 2048, + ) + + // 使用OpenAI兼容模型 + // openaiAgent := NewAIAgent( + // NewOpenAiProvider("your_openai_api_key", "gpt-4"), + // 2048, + // ) + + response, err := deepseekAgent.Query( + context.Background(), + "请用中文解释量子纠缠现象,并说明其在量子计算中的作用", + ) + if err != nil { + fmt.Println("错误:", err) + return + } + + fmt.Println("回复:", response) + fmt.Println(deepseekAgent.GetTokenStats()) +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..b473be5 --- /dev/null +++ b/model/model.go @@ -0,0 +1,17 @@ +package model + +import "github.com/sashabaranov/go-openai" + +// ChatRequest 统一请求格式 +type ChatRequest struct { + Model string + Messages []openai.ChatCompletionMessage + MaxTokens int +} + +// ChatResponse 统一响应格式 +type ChatResponse struct { + Content string + InputTokens int + OutputTokens int +} diff --git a/provider/deepseek.go b/provider/deepseek.go new file mode 100644 index 0000000..be75d91 --- /dev/null +++ b/provider/deepseek.go @@ -0,0 +1,103 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "git.kingecg.top/kingecg/goaiagent/model" + "github.com/sashabaranov/go-openai" +) + +// ============================== +// DeepSeek API 实现:cite[1]:cite[7] +// ============================== +type DeepSeekProvider struct { + apiKey string + model string // deepseek-chat 或 deepseek-reasoner:cite[8] + endpoint string +} + +func NewDeepSeekProvider(apiKey, model string) *DeepSeekProvider { + return &DeepSeekProvider{ + apiKey: apiKey, + model: model, + endpoint: "https://api.deepseek.com/v1/chat/completions", // DeepSeek官方API地址:cite[7] + } +} +func (d *DeepSeekProvider) GetProviderName() string { + return "DeepSeek" +} + +func (d *DeepSeekProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) { + // 构造DeepSeek专属请求体:cite[1] + deepseekReq := struct { + Model string `json:"model"` + Messages []openai.ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + }{ + Model: d.model, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + } + + reqBody, _ := json.Marshal(deepseekReq) + httpReq, _ := http.NewRequest("POST", d.endpoint, bytes.NewReader(reqBody)) + httpReq.Header.Set("Authorization", "Bearer "+d.apiKey) + httpReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return model.ChatResponse{}, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + // 解析DeepSeek特有响应格式 + var result struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return model.ChatResponse{}, err + } + + if len(result.Choices) == 0 { + return model.ChatResponse{}, fmt.Errorf("empty response from DeepSeek") + } + + return model.ChatResponse{ + Content: result.Choices[0].Message.Content, + InputTokens: result.Usage.PromptTokens, + OutputTokens: result.Usage.CompletionTokens, + }, nil +} + +func (d *DeepSeekProvider) CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) { + streamClient := NewDeepSeekStreamClient(d.apiKey) + return streamClient.CreateStream(ctx, req.Messages, d.model, req.MaxTokens) + // return nil, fmt.Errorf("not implemented") +} + +// 简化的Token计数器(实际生产应使用tiktoken) +func (d *DeepSeekProvider) CountTokens(text string) int { + return len(strings.Fields(text)) // 按空格分词计数 +} + +func (d *DeepSeekProvider) GetModelName() string { + return d.model +} diff --git a/provider/interface.go b/provider/interface.go new file mode 100644 index 0000000..c481047 --- /dev/null +++ b/provider/interface.go @@ -0,0 +1,15 @@ +package provider + +import ( + "context" + + "git.kingecg.top/kingecg/goaiagent/model" +) + +// Provider 定义统一的大模型接口 +type Provider interface { + CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) + CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) + CountTokens(text string) int + GetModelName() string +} diff --git a/provider/openai.go b/provider/openai.go new file mode 100644 index 0000000..efcacdf --- /dev/null +++ b/provider/openai.go @@ -0,0 +1,59 @@ +package provider + +import ( + "context" + "strings" + + "git.kingecg.top/kingecg/goaiagent/model" + "github.com/sashabaranov/go-openai" +) + +// ============================== +// OpenAI 兼容实现 +// ============================== +type OpenAiProvider struct { + client *openai.Client + model string +} + +func NewOpenAiProvider(apiKey, model string) *OpenAiProvider { + config := openai.DefaultConfig(apiKey) + return &OpenAiProvider{ + client: openai.NewClientWithConfig(config), + model: model, + } +} + +func (o *OpenAiProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) { + // 转换统一请求为OpenAI SDK格式 + messages := make([]openai.ChatCompletionMessage, len(req.Messages)) + for i, m := range req.Messages { + messages[i] = openai.ChatCompletionMessage{ + Role: m.Role, + Content: m.Content, + } + } + + resp, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: o.model, + Messages: messages, + MaxTokens: req.MaxTokens, + }) + if err != nil { + return model.ChatResponse{}, err + } + + return model.ChatResponse{ + Content: resp.Choices[0].Message.Content, + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + }, nil +} + +func (o *OpenAiProvider) CountTokens(text string) int { + return len(strings.Fields(text)) +} + +func (o *OpenAiProvider) GetModelName() string { + return o.model +} diff --git a/provider/stream.go b/provider/stream.go new file mode 100644 index 0000000..4e6c6c8 --- /dev/null +++ b/provider/stream.go @@ -0,0 +1,307 @@ +package provider + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/sashabaranov/go-openai" +) + +// ============================== +// 流式读取器核心实现 +// ============================== + +// StreamEvent 表示单个流式事件 +type StreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +// StreamReader 处理流式API响应 +type StreamReader struct { + scanner *bufio.Scanner + response *http.Response + currentEvent *StreamEvent + events chan StreamEvent + errors chan error + done chan struct{} + closeOnce sync.Once + mu sync.Mutex + totalTokens int + startTime time.Time + lastEventTime time.Time +} + +// NewStreamReader 创建新的流式读取器 +func NewStreamReader(response *http.Response) *StreamReader { + sr := &StreamReader{ + response: response, + scanner: bufio.NewScanner(response.Body), + events: make(chan StreamEvent, 100), + errors: make(chan error, 1), + done: make(chan struct{}), + startTime: time.Now(), + lastEventTime: time.Now(), + } + + sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数 + + go sr.processStream() + return sr +} + +// splitSSE 自定义扫描函数用于处理SSE格式 +func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + // 查找事件分隔符: \n\n 或 \r\n\r\n + if i := bytes.Index(data, []byte("\n\n")); i >= 0 { + return i + 2, data[0:i], nil + } + if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 { + return i + 4, data[0:i], nil + } + + // 如果到达文件末尾,返回剩余数据 + if atEOF { + return len(data), data, nil + } + + // 请求更多数据 + return 0, nil, nil +} + +// processStream 处理流数据 +func (sr *StreamReader) processStream() { + defer close(sr.events) + defer close(sr.errors) + defer sr.response.Body.Close() + + for sr.scanner.Scan() { + eventData := sr.scanner.Bytes() + sr.lastEventTime = time.Now() + + // 跳过空行和注释 + if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) { + continue + } + + // 检查是否为 [DONE] 事件 + if bytes.Contains(eventData, []byte("[DONE]")) { + return + } + + // 提取事件数据部分 + prefix := []byte("data: ") + if bytes.HasPrefix(eventData, prefix) { + eventData = bytes.TrimPrefix(eventData, prefix) + } + + var event StreamEvent + if err := json.Unmarshal(eventData, &event); err != nil { + sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData)) + return + } + + // 更新Token计数 + if content := event.Choices[0].Delta.Content; content != "" { + sr.mu.Lock() + sr.totalTokens += len(strings.Fields(content)) + sr.mu.Unlock() + } + + // 发送事件 + sr.events <- event + } + + if err := sr.scanner.Err(); err != nil { + sr.errors <- fmt.Errorf("扫描错误: %w", err) + } +} + +// Recv 接收下一个事件 +func (sr *StreamReader) Recv() (StreamEvent, error) { + select { + case event, ok := <-sr.events: + if !ok { + return StreamEvent{}, io.EOF + } + return event, nil + case err := <-sr.errors: + return StreamEvent{}, err + case <-sr.done: + return StreamEvent{}, io.EOF + } +} + +// Close 关闭流 +func (sr *StreamReader) Close() error { + sr.closeOnce.Do(func() { + close(sr.done) + sr.response.Body.Close() + }) + return nil +} + +// Stats 获取流统计信息 +func (sr *StreamReader) Stats() string { + sr.mu.Lock() + defer sr.mu.Unlock() + + duration := time.Since(sr.startTime) + return fmt.Sprintf( + "Tokens: %d | 持续时间: %s | 最后事件: %s前", + sr.totalTokens, + duration.Round(time.Second), + time.Since(sr.lastEventTime).Round(time.Second), + ) +} + +// ============================== +// DeepSeek流式API客户端 +// ============================== + +type DeepSeekStreamClient struct { + apiKey string + endpoint string +} + +func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient { + return &DeepSeekStreamClient{ + apiKey: apiKey, + endpoint: "https://api.deepseek.com/v1/chat/completions", + } +} + +func (c *DeepSeekStreamClient) CreateStream( + ctx context.Context, + messages []openai.ChatCompletionMessage, + model string, + maxTokens int, +) (*StreamReader, error) { + // 构造请求体 + requestBody := struct { + Model string `json:"model"` + Messages []openai.ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + }{ + Model: model, + Messages: messages, + MaxTokens: maxTokens, + Stream: true, + } + + body, _ := json.Marshal(requestBody) + req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + // 检查响应状态 + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body)) + } + + return NewStreamReader(resp), nil +} + +// ============================== +// 使用示例 +// ============================== + +// func main() { +// apiKey := "your_deepseek_api_key" // 替换为你的API密钥 +// client := NewDeepSeekStreamClient(apiKey) + +// // 创建对话消息 +// messages := []openai.ChatCompletionMessage{ +// { +// Role: openai.ChatMessageRoleSystem, +// Content: "你是一个乐于助人的AI助手", +// }, +// { +// Role: openai.ChatMessageRoleUser, +// Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?", +// }, +// } + +// // 创建流式请求 +// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) +// defer cancel() + +// stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000) +// if err != nil { +// fmt.Println("创建流失败:", err) +// return +// } +// defer stream.Close() + +// fmt.Println("流式响应开始:") +// fmt.Println("====================") + +// var fullResponse strings.Builder +// lastPrint := time.Now() +// const updateInterval = 100 * time.Millisecond + +// for { +// event, err := stream.Recv() +// if errors.Is(err, io.EOF) { +// fmt.Println("\n\n====================") +// fmt.Println("流式响应结束") +// break +// } + +// if err != nil { +// fmt.Println("\n接收错误:", err) +// break +// } + +// if len(event.Choices) > 0 { +// content := event.Choices[0].Delta.Content +// if content != "" { +// fullResponse.WriteString(content) + +// // 流式输出控制:定期打印新内容 +// if time.Since(lastPrint) > updateInterval { +// fmt.Print(content) +// lastPrint = time.Now() +// } +// } +// } +// } + +// fmt.Println("\n完整响应:") +// fmt.Println(fullResponse.String()) +// fmt.Println("\n统计信息:", stream.Stats()) +// }