feat(goaiagent): 实现多厂商大模型API兼容和流式读取功能
- 新增 DeepSeek 和 OpenAI 兼容实现 - 添加流式读取器和相关处理逻辑 - 实现统一的请求和响应格式 - 优化对话截断策略和 Token 计数功能
This commit is contained in:
commit
8a732bf623
|
@ -0,0 +1 @@
|
||||||
|
vendor/
|
|
@ -0,0 +1,5 @@
|
||||||
|
module git.kingecg.top/kingecg/goaiagent
|
||||||
|
|
||||||
|
go 1.23.1
|
||||||
|
|
||||||
|
require github.com/sashabaranov/go-openai v1.40.3
|
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/sashabaranov/go-openai v1.40.3 h1:PkOw0SK34wrvYVOuXF1HZzuTBRh992qRZHil4kG3eYE=
|
||||||
|
github.com/sashabaranov/go-openai v1.40.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
|
@ -0,0 +1,170 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/goaiagent/model"
|
||||||
|
"git.kingecg.top/kingecg/goaiagent/provider"
|
||||||
|
openai "github.com/sashabaranov/go-openai" // 兼容OpenAI的SDK:cite[4]
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// 多厂商API支持核心设计
|
||||||
|
// ==============================
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// AI Agent 核心实现
|
||||||
|
// ==============================
|
||||||
|
type AIAgent struct {
|
||||||
|
provider provider.Provider
|
||||||
|
tokenCount TokenCounter
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenCounter struct {
|
||||||
|
InputTokens int64
|
||||||
|
OutputTokens int64
|
||||||
|
TotalTokens int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TokenCounter) Add(input, output int) {
|
||||||
|
atomic.AddInt64(&tc.InputTokens, int64(input))
|
||||||
|
atomic.AddInt64(&tc.OutputTokens, int64(output))
|
||||||
|
atomic.AddInt64(&tc.TotalTokens, int64(input+output))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TokenCounter) Stats() string {
|
||||||
|
return fmt.Sprintf("Tokens: 输入=%d, 输出=%d, 总计=%d",
|
||||||
|
tc.InputTokens, tc.OutputTokens, tc.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAIAgent(provider provider.Provider, maxTokens int) *AIAgent {
|
||||||
|
return &AIAgent{
|
||||||
|
provider: provider,
|
||||||
|
maxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AIAgent) Query(ctx context.Context, prompt string) (string, error) {
|
||||||
|
messages := []openai.ChatCompletionMessage{
|
||||||
|
{Role: openai.ChatMessageRoleSystem, Content: "你是一个AI助手"},
|
||||||
|
{Role: openai.ChatMessageRoleUser, Content: prompt},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查并截断过长对话
|
||||||
|
messages = a.truncateConversation(messages)
|
||||||
|
|
||||||
|
// 创建统一请求
|
||||||
|
req := model.ChatRequest{
|
||||||
|
Model: a.provider.GetModelName(),
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: a.maxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用厂商API
|
||||||
|
resp, err := a.provider.CreateChatCompletion(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新Token统计
|
||||||
|
a.tokenCount.Add(resp.InputTokens, resp.OutputTokens)
|
||||||
|
|
||||||
|
return resp.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 截断策略:保留系统消息+最新对话
|
||||||
|
func (a *AIAgent) truncateConversation(messages []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
|
||||||
|
const maxContextTokens = 3000
|
||||||
|
const keepLastMessages = 4 // 保留最后4轮对话
|
||||||
|
|
||||||
|
currentTokens := 0
|
||||||
|
for _, msg := range messages {
|
||||||
|
currentTokens += a.provider.CountTokens(msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无需截断
|
||||||
|
if currentTokens <= maxContextTokens {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先保留系统消息
|
||||||
|
var newMessages []openai.ChatCompletionMessage
|
||||||
|
if len(messages) > 0 && messages[0].Role == openai.ChatMessageRoleSystem {
|
||||||
|
newMessages = append(newMessages, messages[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加最近的对话
|
||||||
|
startIndex := len(messages) - keepLastMessages
|
||||||
|
if startIndex < 0 {
|
||||||
|
startIndex = 0
|
||||||
|
}
|
||||||
|
newMessages = append(newMessages, messages[startIndex:]...)
|
||||||
|
|
||||||
|
// 二次检查并逐条截断
|
||||||
|
total := 0
|
||||||
|
for _, msg := range newMessages {
|
||||||
|
total += a.provider.CountTokens(msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
overflow := total - maxContextTokens
|
||||||
|
if overflow > 0 {
|
||||||
|
for i := len(newMessages) - 1; i >= 0; i-- {
|
||||||
|
if i == 0 && newMessages[i].Role == openai.ChatMessageRoleSystem {
|
||||||
|
continue // 不截断系统消息
|
||||||
|
}
|
||||||
|
|
||||||
|
contentTokens := a.provider.CountTokens(newMessages[i].Content)
|
||||||
|
if contentTokens > overflow {
|
||||||
|
// 截断内容
|
||||||
|
tokens := strings.Fields(newMessages[i].Content)
|
||||||
|
keepTokens := tokens[:len(tokens)-overflow]
|
||||||
|
newMessages[i].Content = strings.Join(keepTokens, " ") + "...[截断]"
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
newMessages[i].Content = "[内容过长已移除]"
|
||||||
|
overflow -= contentTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AIAgent) GetTokenStats() string {
|
||||||
|
return a.tokenCount.Stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// 使用示例
|
||||||
|
// ==============================
|
||||||
|
func main() {
|
||||||
|
// 使用DeepSeek的Reasoner模型(128K上下文):cite[5]
|
||||||
|
deepseekAgent := NewAIAgent(
|
||||||
|
provider.NewDeepSeekProvider("sk-51d647a7e6324f8eb98880c768427223", "deepseek-reasoner"), // 替换为你的API Key:cite[7]
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用OpenAI兼容模型
|
||||||
|
// openaiAgent := NewAIAgent(
|
||||||
|
// NewOpenAiProvider("your_openai_api_key", "gpt-4"),
|
||||||
|
// 2048,
|
||||||
|
// )
|
||||||
|
|
||||||
|
response, err := deepseekAgent.Query(
|
||||||
|
context.Background(),
|
||||||
|
"请用中文解释量子纠缠现象,并说明其在量子计算中的作用",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("错误:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("回复:", response)
|
||||||
|
fmt.Println(deepseekAgent.GetTokenStats())
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
|
// ChatRequest 统一请求格式
|
||||||
|
type ChatRequest struct {
|
||||||
|
Model string
|
||||||
|
Messages []openai.ChatCompletionMessage
|
||||||
|
MaxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse 统一响应格式
|
||||||
|
type ChatResponse struct {
|
||||||
|
Content string
|
||||||
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
}
|
|
@ -0,0 +1,103 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/goaiagent/model"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// DeepSeek API 实现:cite[1]:cite[7]
|
||||||
|
// ==============================
|
||||||
|
type DeepSeekProvider struct {
|
||||||
|
apiKey string
|
||||||
|
model string // deepseek-chat 或 deepseek-reasoner:cite[8]
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekProvider(apiKey, model string) *DeepSeekProvider {
|
||||||
|
return &DeepSeekProvider{
|
||||||
|
apiKey: apiKey,
|
||||||
|
model: model,
|
||||||
|
endpoint: "https://api.deepseek.com/v1/chat/completions", // DeepSeek官方API地址:cite[7]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (d *DeepSeekProvider) GetProviderName() string {
|
||||||
|
return "DeepSeek"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DeepSeekProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
|
||||||
|
// 构造DeepSeek专属请求体:cite[1]
|
||||||
|
deepseekReq := struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []openai.ChatCompletionMessage `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
}{
|
||||||
|
Model: d.model,
|
||||||
|
Messages: req.Messages,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, _ := json.Marshal(deepseekReq)
|
||||||
|
httpReq, _ := http.NewRequest("POST", d.endpoint, bytes.NewReader(reqBody))
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+d.apiKey)
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return model.ChatResponse{}, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
// 解析DeepSeek特有响应格式
|
||||||
|
var result struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return model.ChatResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Choices) == 0 {
|
||||||
|
return model.ChatResponse{}, fmt.Errorf("empty response from DeepSeek")
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.ChatResponse{
|
||||||
|
Content: result.Choices[0].Message.Content,
|
||||||
|
InputTokens: result.Usage.PromptTokens,
|
||||||
|
OutputTokens: result.Usage.CompletionTokens,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DeepSeekProvider) CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) {
|
||||||
|
streamClient := NewDeepSeekStreamClient(d.apiKey)
|
||||||
|
return streamClient.CreateStream(ctx, req.Messages, d.model, req.MaxTokens)
|
||||||
|
// return nil, fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简化的Token计数器(实际生产应使用tiktoken)
|
||||||
|
func (d *DeepSeekProvider) CountTokens(text string) int {
|
||||||
|
return len(strings.Fields(text)) // 按空格分词计数
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DeepSeekProvider) GetModelName() string {
|
||||||
|
return d.model
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/goaiagent/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider 定义统一的大模型接口
|
||||||
|
type Provider interface {
|
||||||
|
CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error)
|
||||||
|
CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error)
|
||||||
|
CountTokens(text string) int
|
||||||
|
GetModelName() string
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/goaiagent/model"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// OpenAI 兼容实现
|
||||||
|
// ==============================
|
||||||
|
type OpenAiProvider struct {
|
||||||
|
client *openai.Client
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAiProvider(apiKey, model string) *OpenAiProvider {
|
||||||
|
config := openai.DefaultConfig(apiKey)
|
||||||
|
return &OpenAiProvider{
|
||||||
|
client: openai.NewClientWithConfig(config),
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAiProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
|
||||||
|
// 转换统一请求为OpenAI SDK格式
|
||||||
|
messages := make([]openai.ChatCompletionMessage, len(req.Messages))
|
||||||
|
for i, m := range req.Messages {
|
||||||
|
messages[i] = openai.ChatCompletionMessage{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||||
|
Model: o.model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return model.ChatResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.ChatResponse{
|
||||||
|
Content: resp.Choices[0].Message.Content,
|
||||||
|
InputTokens: resp.Usage.PromptTokens,
|
||||||
|
OutputTokens: resp.Usage.CompletionTokens,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAiProvider) CountTokens(text string) int {
|
||||||
|
return len(strings.Fields(text))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAiProvider) GetModelName() string {
|
||||||
|
return o.model
|
||||||
|
}
|
|
@ -0,0 +1,307 @@
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// 流式读取器核心实现
|
||||||
|
// ==============================
|
||||||
|
|
||||||
|
// StreamEvent 表示单个流式事件
|
||||||
|
type StreamEvent struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamReader 处理流式API响应
|
||||||
|
type StreamReader struct {
|
||||||
|
scanner *bufio.Scanner
|
||||||
|
response *http.Response
|
||||||
|
currentEvent *StreamEvent
|
||||||
|
events chan StreamEvent
|
||||||
|
errors chan error
|
||||||
|
done chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
mu sync.Mutex
|
||||||
|
totalTokens int
|
||||||
|
startTime time.Time
|
||||||
|
lastEventTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStreamReader 创建新的流式读取器
|
||||||
|
func NewStreamReader(response *http.Response) *StreamReader {
|
||||||
|
sr := &StreamReader{
|
||||||
|
response: response,
|
||||||
|
scanner: bufio.NewScanner(response.Body),
|
||||||
|
events: make(chan StreamEvent, 100),
|
||||||
|
errors: make(chan error, 1),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
startTime: time.Now(),
|
||||||
|
lastEventTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数
|
||||||
|
|
||||||
|
go sr.processStream()
|
||||||
|
return sr
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitSSE 自定义扫描函数用于处理SSE格式
|
||||||
|
func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找事件分隔符: \n\n 或 \r\n\r\n
|
||||||
|
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||||
|
return i + 2, data[0:i], nil
|
||||||
|
}
|
||||||
|
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
|
||||||
|
return i + 4, data[0:i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果到达文件末尾,返回剩余数据
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求更多数据
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processStream 处理流数据
|
||||||
|
func (sr *StreamReader) processStream() {
|
||||||
|
defer close(sr.events)
|
||||||
|
defer close(sr.errors)
|
||||||
|
defer sr.response.Body.Close()
|
||||||
|
|
||||||
|
for sr.scanner.Scan() {
|
||||||
|
eventData := sr.scanner.Bytes()
|
||||||
|
sr.lastEventTime = time.Now()
|
||||||
|
|
||||||
|
// 跳过空行和注释
|
||||||
|
if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为 [DONE] 事件
|
||||||
|
if bytes.Contains(eventData, []byte("[DONE]")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取事件数据部分
|
||||||
|
prefix := []byte("data: ")
|
||||||
|
if bytes.HasPrefix(eventData, prefix) {
|
||||||
|
eventData = bytes.TrimPrefix(eventData, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
var event StreamEvent
|
||||||
|
if err := json.Unmarshal(eventData, &event); err != nil {
|
||||||
|
sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新Token计数
|
||||||
|
if content := event.Choices[0].Delta.Content; content != "" {
|
||||||
|
sr.mu.Lock()
|
||||||
|
sr.totalTokens += len(strings.Fields(content))
|
||||||
|
sr.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送事件
|
||||||
|
sr.events <- event
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sr.scanner.Err(); err != nil {
|
||||||
|
sr.errors <- fmt.Errorf("扫描错误: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv 接收下一个事件
|
||||||
|
func (sr *StreamReader) Recv() (StreamEvent, error) {
|
||||||
|
select {
|
||||||
|
case event, ok := <-sr.events:
|
||||||
|
if !ok {
|
||||||
|
return StreamEvent{}, io.EOF
|
||||||
|
}
|
||||||
|
return event, nil
|
||||||
|
case err := <-sr.errors:
|
||||||
|
return StreamEvent{}, err
|
||||||
|
case <-sr.done:
|
||||||
|
return StreamEvent{}, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭流
|
||||||
|
func (sr *StreamReader) Close() error {
|
||||||
|
sr.closeOnce.Do(func() {
|
||||||
|
close(sr.done)
|
||||||
|
sr.response.Body.Close()
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats 获取流统计信息
|
||||||
|
func (sr *StreamReader) Stats() string {
|
||||||
|
sr.mu.Lock()
|
||||||
|
defer sr.mu.Unlock()
|
||||||
|
|
||||||
|
duration := time.Since(sr.startTime)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"Tokens: %d | 持续时间: %s | 最后事件: %s前",
|
||||||
|
sr.totalTokens,
|
||||||
|
duration.Round(time.Second),
|
||||||
|
time.Since(sr.lastEventTime).Round(time.Second),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// DeepSeek流式API客户端
|
||||||
|
// ==============================
|
||||||
|
|
||||||
|
type DeepSeekStreamClient struct {
|
||||||
|
apiKey string
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient {
|
||||||
|
return &DeepSeekStreamClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
endpoint: "https://api.deepseek.com/v1/chat/completions",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DeepSeekStreamClient) CreateStream(
|
||||||
|
ctx context.Context,
|
||||||
|
messages []openai.ChatCompletionMessage,
|
||||||
|
model string,
|
||||||
|
maxTokens int,
|
||||||
|
) (*StreamReader, error) {
|
||||||
|
// 构造请求体
|
||||||
|
requestBody := struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []openai.ChatCompletionMessage `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}{
|
||||||
|
Model: model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(requestBody)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewStreamReader(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==============================
|
||||||
|
// 使用示例
|
||||||
|
// ==============================
|
||||||
|
|
||||||
|
// func main() {
|
||||||
|
// apiKey := "your_deepseek_api_key" // 替换为你的API密钥
|
||||||
|
// client := NewDeepSeekStreamClient(apiKey)
|
||||||
|
|
||||||
|
// // 创建对话消息
|
||||||
|
// messages := []openai.ChatCompletionMessage{
|
||||||
|
// {
|
||||||
|
// Role: openai.ChatMessageRoleSystem,
|
||||||
|
// Content: "你是一个乐于助人的AI助手",
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// Role: openai.ChatMessageRoleUser,
|
||||||
|
// Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?",
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // 创建流式请求
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
// defer cancel()
|
||||||
|
|
||||||
|
// stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000)
|
||||||
|
// if err != nil {
|
||||||
|
// fmt.Println("创建流失败:", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// defer stream.Close()
|
||||||
|
|
||||||
|
// fmt.Println("流式响应开始:")
|
||||||
|
// fmt.Println("====================")
|
||||||
|
|
||||||
|
// var fullResponse strings.Builder
|
||||||
|
// lastPrint := time.Now()
|
||||||
|
// const updateInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
|
// for {
|
||||||
|
// event, err := stream.Recv()
|
||||||
|
// if errors.Is(err, io.EOF) {
|
||||||
|
// fmt.Println("\n\n====================")
|
||||||
|
// fmt.Println("流式响应结束")
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if err != nil {
|
||||||
|
// fmt.Println("\n接收错误:", err)
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if len(event.Choices) > 0 {
|
||||||
|
// content := event.Choices[0].Delta.Content
|
||||||
|
// if content != "" {
|
||||||
|
// fullResponse.WriteString(content)
|
||||||
|
|
||||||
|
// // 流式输出控制:定期打印新内容
|
||||||
|
// if time.Since(lastPrint) > updateInterval {
|
||||||
|
// fmt.Print(content)
|
||||||
|
// lastPrint = time.Now()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fmt.Println("\n完整响应:")
|
||||||
|
// fmt.Println(fullResponse.String())
|
||||||
|
// fmt.Println("\n统计信息:", stream.Stats())
|
||||||
|
// }
|
Loading…
Reference in New Issue