feat(goaiagent): 实现多厂商大模型API兼容和流式读取功能

- 新增 DeepSeek 和 OpenAI 兼容实现
- 添加流式读取器和相关处理逻辑
- 实现统一的请求和响应格式
- 优化对话截断策略和 Token 计数功能
This commit is contained in:
kingecg 2025-06-28 15:07:05 +08:00
commit 8a732bf623
9 changed files with 679 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
vendor/

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.kingecg.top/kingecg/goaiagent
go 1.23.1
require github.com/sashabaranov/go-openai v1.40.3

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/sashabaranov/go-openai v1.40.3 h1:PkOw0SK34wrvYVOuXF1HZzuTBRh992qRZHil4kG3eYE=
github.com/sashabaranov/go-openai v1.40.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=

170
main.go Normal file
View File

@ -0,0 +1,170 @@
package main
import (
"context"
"fmt"
"strings"
"sync/atomic"
"git.kingecg.top/kingecg/goaiagent/model"
"git.kingecg.top/kingecg/goaiagent/provider"
openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4]
)
// ==============================
// 多厂商API支持核心设计
// ==============================
// ==============================
// AI Agent 核心实现
// ==============================
type AIAgent struct {
provider provider.Provider
tokenCount TokenCounter
maxTokens int
}
type TokenCounter struct {
InputTokens int64
OutputTokens int64
TotalTokens int64
}
func (tc *TokenCounter) Add(input, output int) {
atomic.AddInt64(&tc.InputTokens, int64(input))
atomic.AddInt64(&tc.OutputTokens, int64(output))
atomic.AddInt64(&tc.TotalTokens, int64(input+output))
}
func (tc *TokenCounter) Stats() string {
return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d",
tc.InputTokens, tc.OutputTokens, tc.TotalTokens)
}
func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent {
return &AIAgent{
provider: provider,
maxTokens: maxTokens,
}
}
func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) {
messages := []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"},
{Role: openai.ChatMessageRoleUser, Content: prompt},
}
// 检查并截断过长对话
messages = a.truncateConversation(messages)
// 创建统一请求
req := model.ChatRequest{
Model: a.provider.GetModelName(),
Messages: messages,
MaxTokens: a.maxTokens,
}
// 调用厂商API
resp, err := a.provider.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
// 更新Token统计
a.tokenCount.Add(resp.InputTokens, resp.OutputTokens)
return resp.Content, nil
}
// 截断策略:保留系统消息+最新对话
func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
const maxContextTokens = 3000
const keepLastMessages = 4 // 保留最后4轮对话
currentTokens := 0
for _, msg := range messages {
currentTokens += a.provider.CountTokens(msg.Content)
}
// 无需截断
if currentTokens <= maxContextTokens {
return messages
}
// 优先保留系统消息
var newMessages []openai.ChatCompletionMessage
if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem {
newMessages = append(newMessages, messages[0])
}
// 添加最近的对话
startIndex := len(messages) - keepLastMessages
if startIndex < 0 {
startIndex = 0
}
newMessages = append(newMessages, messages[startIndex:]...)
// 二次检查并逐条截断
total := 0
for _, msg := range newMessages {
total += a.provider.CountTokens(msg.Content)
}
overflow := total - maxContextTokens
if overflow > 0 {
for i := len(newMessages) - 1; i >= 0; i-- {
if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem {
continue // 不截断系统消息
}
contentTokens := a.provider.CountTokens(newMessages[i].Content)
if contentTokens > overflow {
// 截断内容
tokens := strings.Fields(newMessages[i].Content)
keepTokens := tokens[:len(tokens)-overflow]
newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]"
break
} else {
newMessages[i].Content = "[内容过长已移除]"
overflow -= contentTokens
}
}
}
return newMessages
}
func (a *AIAgent) GetTokenStats() string {
return a.tokenCount.Stats()
}
// ==============================
// 使用示例
// ==============================
func main() {
// 使用DeepSeek的Reasoner模型128K上下文:cite[5]
deepseekAgent := NewAIAgent(
provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7]
2048,
)
// 使用OpenAI兼容模型
// openaiAgent := NewAIAgent(
// NewOpenAiProvider("your_openai_api_key", "gpt-4"),
// 2048,
// )
response, err := deepseekAgent.Query(
context.Background(),
"请用中文解释量子纠缠现象,并说明其在量子计算中的作用",
)
if err != nil {
fmt.Println("错误:", err)
return
}
fmt.Println("回复:", response)
fmt.Println(deepseekAgent.GetTokenStats())
}

17
model/model.go Normal file
View File

@ -0,0 +1,17 @@
package model
import "github.com/sashabaranov/go-openai"
// ChatRequest 统一请求格式
type ChatRequest struct {
Model string
Messages []openai.ChatCompletionMessage
MaxTokens int
}
// ChatResponse 统一响应格式
type ChatResponse struct {
Content string
InputTokens int
OutputTokens int
}

103
provider/deepseek.go Normal file
View File

@ -0,0 +1,103 @@
package provider
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.kingecg.top/kingecg/goaiagent/model"
"github.com/sashabaranov/go-openai"
)
// ==============================
// DeepSeek API 实现:cite[1]:cite[7]
// ==============================
type DeepSeekProvider struct {
apiKey string
model string // deepseek-chat 或 deepseek-reasoner:cite[8]
endpoint string
}
func NewDeepSeekProvider(apiKey, model string) *DeepSeekProvider {
return &DeepSeekProvider{
apiKey: apiKey,
model: model,
endpoint: "https://api.deepseek.com/v1/chat/completions", // DeepSeek官方API地址:cite[7]
}
}
func (d *DeepSeekProvider) GetProviderName() string {
return "DeepSeek"
}
func (d *DeepSeekProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
// 构造DeepSeek专属请求体:cite[1]
deepseekReq := struct {
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
}{
Model: d.model,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
}
reqBody, _ := json.Marshal(deepseekReq)
httpReq, _ := http.NewRequest("POST", d.endpoint, bytes.NewReader(reqBody))
httpReq.Header.Set("Authorization", "Bearer "+d.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return model.ChatResponse{}, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
// 解析DeepSeek特有响应格式
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &result); err != nil {
return model.ChatResponse{}, err
}
if len(result.Choices) == 0 {
return model.ChatResponse{}, fmt.Errorf("empty response from DeepSeek")
}
return model.ChatResponse{
Content: result.Choices[0].Message.Content,
InputTokens: result.Usage.PromptTokens,
OutputTokens: result.Usage.CompletionTokens,
}, nil
}
func (d *DeepSeekProvider) CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) {
streamClient := NewDeepSeekStreamClient(d.apiKey)
return streamClient.CreateStream(ctx, req.Messages, d.model, req.MaxTokens)
// return nil, fmt.Errorf("not implemented")
}
// 简化的Token计数器实际生产应使用tiktoken
func (d *DeepSeekProvider) CountTokens(text string) int {
return len(strings.Fields(text)) // 按空格分词计数
}
func (d *DeepSeekProvider) GetModelName() string {
return d.model
}

15
provider/interface.go Normal file
View File

@ -0,0 +1,15 @@
package provider
import (
"context"
"git.kingecg.top/kingecg/goaiagent/model"
)
// Provider 定义统一的大模型接口
type Provider interface {
CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error)
CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error)
CountTokens(text string) int
GetModelName() string
}

59
provider/openai.go Normal file
View File

@ -0,0 +1,59 @@
package provider
import (
"context"
"strings"
"git.kingecg.top/kingecg/goaiagent/model"
"github.com/sashabaranov/go-openai"
)
// ==============================
// OpenAI 兼容实现
// ==============================
type OpenAiProvider struct {
client *openai.Client
model string
}
func NewOpenAiProvider(apiKey, model string) *OpenAiProvider {
config := openai.DefaultConfig(apiKey)
return &OpenAiProvider{
client: openai.NewClientWithConfig(config),
model: model,
}
}
func (o *OpenAiProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
// 转换统一请求为OpenAI SDK格式
messages := make([]openai.ChatCompletionMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Content,
}
}
resp, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: o.model,
Messages: messages,
MaxTokens: req.MaxTokens,
})
if err != nil {
return model.ChatResponse{}, err
}
return model.ChatResponse{
Content: resp.Choices[0].Message.Content,
InputTokens: resp.Usage.PromptTokens,
OutputTokens: resp.Usage.CompletionTokens,
}, nil
}
func (o *OpenAiProvider) CountTokens(text string) int {
return len(strings.Fields(text))
}
func (o *OpenAiProvider) GetModelName() string {
return o.model
}

307
provider/stream.go Normal file
View File

@ -0,0 +1,307 @@
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/sashabaranov/go-openai"
)
// ==============================
// 流式读取器核心实现
// ==============================
// StreamEvent 表示单个流式事件
type StreamEvent struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
// StreamReader 处理流式API响应
type StreamReader struct {
scanner *bufio.Scanner
response *http.Response
currentEvent *StreamEvent
events chan StreamEvent
errors chan error
done chan struct{}
closeOnce sync.Once
mu sync.Mutex
totalTokens int
startTime time.Time
lastEventTime time.Time
}
// NewStreamReader 创建新的流式读取器
func NewStreamReader(response *http.Response) *StreamReader {
sr := &StreamReader{
response: response,
scanner: bufio.NewScanner(response.Body),
events: make(chan StreamEvent, 100),
errors: make(chan error, 1),
done: make(chan struct{}),
startTime: time.Now(),
lastEventTime: time.Now(),
}
sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数
go sr.processStream()
return sr
}
// splitSSE 自定义扫描函数用于处理SSE格式
func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
// 查找事件分隔符: \n\n 或 \r\n\r\n
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
return i + 2, data[0:i], nil
}
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
return i + 4, data[0:i], nil
}
// 如果到达文件末尾,返回剩余数据
if atEOF {
return len(data), data, nil
}
// 请求更多数据
return 0, nil, nil
}
// processStream 处理流数据
func (sr *StreamReader) processStream() {
defer close(sr.events)
defer close(sr.errors)
defer sr.response.Body.Close()
for sr.scanner.Scan() {
eventData := sr.scanner.Bytes()
sr.lastEventTime = time.Now()
// 跳过空行和注释
if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) {
continue
}
// 检查是否为 [DONE] 事件
if bytes.Contains(eventData, []byte("[DONE]")) {
return
}
// 提取事件数据部分
prefix := []byte("data: ")
if bytes.HasPrefix(eventData, prefix) {
eventData = bytes.TrimPrefix(eventData, prefix)
}
var event StreamEvent
if err := json.Unmarshal(eventData, &event); err != nil {
sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData))
return
}
// 更新Token计数
if content := event.Choices[0].Delta.Content; content != "" {
sr.mu.Lock()
sr.totalTokens += len(strings.Fields(content))
sr.mu.Unlock()
}
// 发送事件
sr.events <- event
}
if err := sr.scanner.Err(); err != nil {
sr.errors <- fmt.Errorf("扫描错误: %w", err)
}
}
// Recv 接收下一个事件
func (sr *StreamReader) Recv() (StreamEvent, error) {
select {
case event, ok := <-sr.events:
if !ok {
return StreamEvent{}, io.EOF
}
return event, nil
case err := <-sr.errors:
return StreamEvent{}, err
case <-sr.done:
return StreamEvent{}, io.EOF
}
}
// Close 关闭流
func (sr *StreamReader) Close() error {
sr.closeOnce.Do(func() {
close(sr.done)
sr.response.Body.Close()
})
return nil
}
// Stats 获取流统计信息
func (sr *StreamReader) Stats() string {
sr.mu.Lock()
defer sr.mu.Unlock()
duration := time.Since(sr.startTime)
return fmt.Sprintf(
"Tokens: %d | 持续时间: %s | 最后事件: %s前",
sr.totalTokens,
duration.Round(time.Second),
time.Since(sr.lastEventTime).Round(time.Second),
)
}
// ==============================
// DeepSeek流式API客户端
// ==============================
type DeepSeekStreamClient struct {
apiKey string
endpoint string
}
func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient {
return &DeepSeekStreamClient{
apiKey: apiKey,
endpoint: "https://api.deepseek.com/v1/chat/completions",
}
}
func (c *DeepSeekStreamClient) CreateStream(
ctx context.Context,
messages []openai.ChatCompletionMessage,
model string,
maxTokens int,
) (*StreamReader, error) {
// 构造请求体
requestBody := struct {
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}{
Model: model,
Messages: messages,
MaxTokens: maxTokens,
Stream: true,
}
body, _ := json.Marshal(requestBody)
req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body))
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
// 发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
// 检查响应状态
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
}
return NewStreamReader(resp), nil
}
// ==============================
// 使用示例
// ==============================
// func main() {
// apiKey := "your_deepseek_api_key" // 替换为你的API密钥
// client := NewDeepSeekStreamClient(apiKey)
// // 创建对话消息
// messages := []openai.ChatCompletionMessage{
// {
// Role: openai.ChatMessageRoleSystem,
// Content: "你是一个乐于助人的AI助手",
// },
// {
// Role: openai.ChatMessageRoleUser,
// Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?",
// },
// }
// // 创建流式请求
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
// defer cancel()
// stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000)
// if err != nil {
// fmt.Println("创建流失败:", err)
// return
// }
// defer stream.Close()
// fmt.Println("流式响应开始:")
// fmt.Println("====================")
// var fullResponse strings.Builder
// lastPrint := time.Now()
// const updateInterval = 100 * time.Millisecond
// for {
// event, err := stream.Recv()
// if errors.Is(err, io.EOF) {
// fmt.Println("\n\n====================")
// fmt.Println("流式响应结束")
// break
// }
// if err != nil {
// fmt.Println("\n接收错误:", err)
// break
// }
// if len(event.Choices) > 0 {
// content := event.Choices[0].Delta.Content
// if content != "" {
// fullResponse.WriteString(content)
// // 流式输出控制:定期打印新内容
// if time.Since(lastPrint) > updateInterval {
// fmt.Print(content)
// lastPrint = time.Now()
// }
// }
// }
// }
// fmt.Println("\n完整响应:")
// fmt.Println(fullResponse.String())
// fmt.Println("\n统计信息:", stream.Stats())
// }