goaiagent/main.go

171 lines
4.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"fmt"
"strings"
"sync/atomic"
"git.kingecg.top/kingecg/goaiagent/model"
"git.kingecg.top/kingecg/goaiagent/provider"
openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4]
)
// ==============================
// 多厂商API支持核心设计
// ==============================
// ==============================
// AI Agent 核心实现
// ==============================
type AIAgent struct {
provider provider.Provider
tokenCount TokenCounter
maxTokens int
}
type TokenCounter struct {
InputTokens int64
OutputTokens int64
TotalTokens int64
}
func (tc *TokenCounter) Add(input, output int) {
atomic.AddInt64(&tc.InputTokens, int64(input))
atomic.AddInt64(&tc.OutputTokens, int64(output))
atomic.AddInt64(&tc.TotalTokens, int64(input+output))
}
func (tc *TokenCounter) Stats() string {
return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d",
tc.InputTokens, tc.OutputTokens, tc.TotalTokens)
}
func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent {
return &AIAgent{
provider: provider,
maxTokens: maxTokens,
}
}
func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) {
messages := []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"},
{Role: openai.ChatMessageRoleUser, Content: prompt},
}
// 检查并截断过长对话
messages = a.truncateConversation(messages)
// 创建统一请求
req := model.ChatRequest{
Model: a.provider.GetModelName(),
Messages: messages,
MaxTokens: a.maxTokens,
}
// 调用厂商API
resp, err := a.provider.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
// 更新Token统计
a.tokenCount.Add(resp.InputTokens, resp.OutputTokens)
return resp.Content, nil
}
// 截断策略:保留系统消息+最新对话
func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
const maxContextTokens = 3000
const keepLastMessages = 4 // 保留最后4轮对话
currentTokens := 0
for _, msg := range messages {
currentTokens += a.provider.CountTokens(msg.Content)
}
// 无需截断
if currentTokens <= maxContextTokens {
return messages
}
// 优先保留系统消息
var newMessages []openai.ChatCompletionMessage
if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem {
newMessages = append(newMessages, messages[0])
}
// 添加最近的对话
startIndex := len(messages) - keepLastMessages
if startIndex < 0 {
startIndex = 0
}
newMessages = append(newMessages, messages[startIndex:]...)
// 二次检查并逐条截断
total := 0
for _, msg := range newMessages {
total += a.provider.CountTokens(msg.Content)
}
overflow := total - maxContextTokens
if overflow > 0 {
for i := len(newMessages) - 1; i >= 0; i-- {
if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem {
continue // 不截断系统消息
}
contentTokens := a.provider.CountTokens(newMessages[i].Content)
if contentTokens > overflow {
// 截断内容
tokens := strings.Fields(newMessages[i].Content)
keepTokens := tokens[:len(tokens)-overflow]
newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]"
break
} else {
newMessages[i].Content = "[内容过长已移除]"
overflow -= contentTokens
}
}
}
return newMessages
}
func (a *AIAgent) GetTokenStats() string {
return a.tokenCount.Stats()
}
// ==============================
// 使用示例
// ==============================
func main() {
// 使用DeepSeek的Reasoner模型128K上下文:cite[5]
deepseekAgent := NewAIAgent(
provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7]
2048,
)
// 使用OpenAI兼容模型
// openaiAgent := NewAIAgent(
// NewOpenAiProvider("your_openai_api_key", "gpt-4"),
// 2048,
// )
response, err := deepseekAgent.Query(
context.Background(),
"请用中文解释量子纠缠现象,并说明其在量子计算中的作用",
)
if err != nil {
fmt.Println("错误:", err)
return
}
fmt.Println("回复:", response)
fmt.Println(deepseekAgent.GetTokenStats())
}