171 lines
4.2 KiB
Go
171 lines
4.2 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
|
||
"fmt"
|
||
|
||
"strings"
|
||
"sync/atomic"
|
||
|
||
"git.kingecg.top/kingecg/goaiagent/model"
|
||
"git.kingecg.top/kingecg/goaiagent/provider"
|
||
openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4]
|
||
)
|
||
|
||
// ==============================
|
||
// 多厂商API支持核心设计
|
||
// ==============================
|
||
|
||
// ==============================
|
||
// AI Agent 核心实现
|
||
// ==============================
|
||
type AIAgent struct {
|
||
provider provider.Provider
|
||
tokenCount TokenCounter
|
||
maxTokens int
|
||
}
|
||
|
||
type TokenCounter struct {
|
||
InputTokens int64
|
||
OutputTokens int64
|
||
TotalTokens int64
|
||
}
|
||
|
||
func (tc *TokenCounter) Add(input, output int) {
|
||
atomic.AddInt64(&tc.InputTokens, int64(input))
|
||
atomic.AddInt64(&tc.OutputTokens, int64(output))
|
||
atomic.AddInt64(&tc.TotalTokens, int64(input+output))
|
||
}
|
||
|
||
func (tc *TokenCounter) Stats() string {
|
||
return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d",
|
||
tc.InputTokens, tc.OutputTokens, tc.TotalTokens)
|
||
}
|
||
|
||
func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent {
|
||
return &AIAgent{
|
||
provider: provider,
|
||
maxTokens: maxTokens,
|
||
}
|
||
}
|
||
|
||
func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) {
|
||
messages := []openai.ChatCompletionMessage{
|
||
{Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"},
|
||
{Role: openai.ChatMessageRoleUser, Content: prompt},
|
||
}
|
||
|
||
// 检查并截断过长对话
|
||
messages = a.truncateConversation(messages)
|
||
|
||
// 创建统一请求
|
||
req := model.ChatRequest{
|
||
Model: a.provider.GetModelName(),
|
||
Messages: messages,
|
||
MaxTokens: a.maxTokens,
|
||
}
|
||
|
||
// 调用厂商API
|
||
resp, err := a.provider.CreateChatCompletion(ctx, req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 更新Token统计
|
||
a.tokenCount.Add(resp.InputTokens, resp.OutputTokens)
|
||
|
||
return resp.Content, nil
|
||
}
|
||
|
||
// 截断策略:保留系统消息+最新对话
|
||
func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
|
||
const maxContextTokens = 3000
|
||
const keepLastMessages = 4 // 保留最后4轮对话
|
||
|
||
currentTokens := 0
|
||
for _, msg := range messages {
|
||
currentTokens += a.provider.CountTokens(msg.Content)
|
||
}
|
||
|
||
// 无需截断
|
||
if currentTokens <= maxContextTokens {
|
||
return messages
|
||
}
|
||
|
||
// 优先保留系统消息
|
||
var newMessages []openai.ChatCompletionMessage
|
||
if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem {
|
||
newMessages = append(newMessages, messages[0])
|
||
}
|
||
|
||
// 添加最近的对话
|
||
startIndex := len(messages) - keepLastMessages
|
||
if startIndex < 0 {
|
||
startIndex = 0
|
||
}
|
||
newMessages = append(newMessages, messages[startIndex:]...)
|
||
|
||
// 二次检查并逐条截断
|
||
total := 0
|
||
for _, msg := range newMessages {
|
||
total += a.provider.CountTokens(msg.Content)
|
||
}
|
||
|
||
overflow := total - maxContextTokens
|
||
if overflow > 0 {
|
||
for i := len(newMessages) - 1; i >= 0; i-- {
|
||
if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem {
|
||
continue // 不截断系统消息
|
||
}
|
||
|
||
contentTokens := a.provider.CountTokens(newMessages[i].Content)
|
||
if contentTokens > overflow {
|
||
// 截断内容
|
||
tokens := strings.Fields(newMessages[i].Content)
|
||
keepTokens := tokens[:len(tokens)-overflow]
|
||
newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]"
|
||
break
|
||
} else {
|
||
newMessages[i].Content = "[内容过长已移除]"
|
||
overflow -= contentTokens
|
||
}
|
||
}
|
||
}
|
||
|
||
return newMessages
|
||
}
|
||
|
||
func (a *AIAgent) GetTokenStats() string {
|
||
return a.tokenCount.Stats()
|
||
}
|
||
|
||
// ==============================
|
||
// 使用示例
|
||
// ==============================
|
||
func main() {
|
||
// 使用DeepSeek的Reasoner模型(128K上下文):cite[5]
|
||
deepseekAgent := NewAIAgent(
|
||
provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7]
|
||
2048,
|
||
)
|
||
|
||
// 使用OpenAI兼容模型
|
||
// openaiAgent := NewAIAgent(
|
||
// NewOpenAiProvider("your_openai_api_key", "gpt-4"),
|
||
// 2048,
|
||
// )
|
||
|
||
response, err := deepseekAgent.Query(
|
||
context.Background(),
|
||
"请用中文解释量子纠缠现象,并说明其在量子计算中的作用",
|
||
)
|
||
if err != nil {
|
||
fmt.Println("错误:", err)
|
||
return
|
||
}
|
||
|
||
fmt.Println("回复:", response)
|
||
fmt.Println(deepseekAgent.GetTokenStats())
|
||
}
|