package main import ( "context" "fmt" "strings" "sync/atomic" "git.kingecg.top/kingecg/goaiagent/model" "git.kingecg.top/kingecg/goaiagent/provider" openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4] ) // ============================== // 多厂商API支持核心设计 // ============================== // ============================== // AI Agent 核心实现 // ============================== type AIAgent struct { provider provider.Provider tokenCount TokenCounter maxTokens int } type TokenCounter struct { InputTokens int64 OutputTokens int64 TotalTokens int64 } func (tc *TokenCounter) Add(input, output int) { atomic.AddInt64(&tc.InputTokens, int64(input)) atomic.AddInt64(&tc.OutputTokens, int64(output)) atomic.AddInt64(&tc.TotalTokens, int64(input+output)) } func (tc *TokenCounter) Stats() string { return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d", tc.InputTokens, tc.OutputTokens, tc.TotalTokens) } func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent { return &AIAgent{ provider: provider, maxTokens: maxTokens, } } func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) { messages := []openai.ChatCompletionMessage{ {Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"}, {Role: openai.ChatMessageRoleUser, Content: prompt}, } // 检查并截断过长对话 messages = a.truncateConversation(messages) // 创建统一请求 req := model.ChatRequest{ Model: a.provider.GetModelName(), Messages: messages, MaxTokens: a.maxTokens, } // 调用厂商API resp, err := a.provider.CreateChatCompletion(ctx, req) if err != nil { return "", err } // 更新Token统计 a.tokenCount.Add(resp.InputTokens, resp.OutputTokens) return resp.Content, nil } // 截断策略:保留系统消息+最新对话 func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage { const maxContextTokens = 3000 const keepLastMessages = 4 // 保留最后4轮对话 currentTokens := 0 for _, msg := range messages { currentTokens += a.provider.CountTokens(msg.Content) } // 无需截断 if currentTokens <= maxContextTokens { return messages } // 优先保留系统消息 var newMessages []openai.ChatCompletionMessage if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem { newMessages = append(newMessages, messages[0]) } // 添加最近的对话 startIndex := len(messages) - keepLastMessages if startIndex < 0 { startIndex = 0 } newMessages = append(newMessages, messages[startIndex:]...) // 二次检查并逐条截断 total := 0 for _, msg := range newMessages { total += a.provider.CountTokens(msg.Content) } overflow := total - maxContextTokens if overflow > 0 { for i := len(newMessages) - 1; i >= 0; i-- { if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem { continue // 不截断系统消息 } contentTokens := a.provider.CountTokens(newMessages[i].Content) if contentTokens > overflow { // 截断内容 tokens := strings.Fields(newMessages[i].Content) keepTokens := tokens[:len(tokens)-overflow] newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]" break } else { newMessages[i].Content = "[内容过长已移除]" overflow -= contentTokens } } } return newMessages } func (a *AIAgent) GetTokenStats() string { return a.tokenCount.Stats() } // ============================== // 使用示例 // ============================== func main() { // 使用DeepSeek的Reasoner模型(128K上下文):cite[5] deepseekAgent := NewAIAgent( provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7] 2048, ) // 使用OpenAI兼容模型 // openaiAgent := NewAIAgent( // NewOpenAiProvider("your_openai_api_key", "gpt-4"), // 2048, // ) response, err := deepseekAgent.Query( context.Background(), "请用中文解释量子纠缠现象,并说明其在量子计算中的作用", ) if err != nil { fmt.Println("错误:", err) return } fmt.Println("回复:", response) fmt.Println(deepseekAgent.GetTokenStats()) }