goaiagent/provider/stream.go

308 lines
7.3 KiB
Go

package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/sashabaranov/go-openai"
)
// ==============================
// 流式读取器核心实现
// ==============================
// StreamEvent 表示单个流式事件
type StreamEvent struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
// StreamReader 处理流式API响应
type StreamReader struct {
scanner *bufio.Scanner
response *http.Response
currentEvent *StreamEvent
events chan StreamEvent
errors chan error
done chan struct{}
closeOnce sync.Once
mu sync.Mutex
totalTokens int
startTime time.Time
lastEventTime time.Time
}
// NewStreamReader 创建新的流式读取器
func NewStreamReader(response *http.Response) *StreamReader {
sr := &StreamReader{
response: response,
scanner: bufio.NewScanner(response.Body),
events: make(chan StreamEvent, 100),
errors: make(chan error, 1),
done: make(chan struct{}),
startTime: time.Now(),
lastEventTime: time.Now(),
}
sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数
go sr.processStream()
return sr
}
// splitSSE 自定义扫描函数用于处理SSE格式
func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
// 查找事件分隔符: \n\n 或 \r\n\r\n
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
return i + 2, data[0:i], nil
}
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
return i + 4, data[0:i], nil
}
// 如果到达文件末尾,返回剩余数据
if atEOF {
return len(data), data, nil
}
// 请求更多数据
return 0, nil, nil
}
// processStream 处理流数据
func (sr *StreamReader) processStream() {
defer close(sr.events)
defer close(sr.errors)
defer sr.response.Body.Close()
for sr.scanner.Scan() {
eventData := sr.scanner.Bytes()
sr.lastEventTime = time.Now()
// 跳过空行和注释
if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) {
continue
}
// 检查是否为 [DONE] 事件
if bytes.Contains(eventData, []byte("[DONE]")) {
return
}
// 提取事件数据部分
prefix := []byte("data: ")
if bytes.HasPrefix(eventData, prefix) {
eventData = bytes.TrimPrefix(eventData, prefix)
}
var event StreamEvent
if err := json.Unmarshal(eventData, &event); err != nil {
sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData))
return
}
// 更新Token计数
if content := event.Choices[0].Delta.Content; content != "" {
sr.mu.Lock()
sr.totalTokens += len(strings.Fields(content))
sr.mu.Unlock()
}
// 发送事件
sr.events <- event
}
if err := sr.scanner.Err(); err != nil {
sr.errors <- fmt.Errorf("扫描错误: %w", err)
}
}
// Recv 接收下一个事件
func (sr *StreamReader) Recv() (StreamEvent, error) {
select {
case event, ok := <-sr.events:
if !ok {
return StreamEvent{}, io.EOF
}
return event, nil
case err := <-sr.errors:
return StreamEvent{}, err
case <-sr.done:
return StreamEvent{}, io.EOF
}
}
// Close 关闭流
func (sr *StreamReader) Close() error {
sr.closeOnce.Do(func() {
close(sr.done)
sr.response.Body.Close()
})
return nil
}
// Stats 获取流统计信息
func (sr *StreamReader) Stats() string {
sr.mu.Lock()
defer sr.mu.Unlock()
duration := time.Since(sr.startTime)
return fmt.Sprintf(
"Tokens: %d | 持续时间: %s | 最后事件: %s前",
sr.totalTokens,
duration.Round(time.Second),
time.Since(sr.lastEventTime).Round(time.Second),
)
}
// ==============================
// DeepSeek流式API客户端
// ==============================
type DeepSeekStreamClient struct {
apiKey string
endpoint string
}
func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient {
return &DeepSeekStreamClient{
apiKey: apiKey,
endpoint: "https://api.deepseek.com/v1/chat/completions",
}
}
func (c *DeepSeekStreamClient) CreateStream(
ctx context.Context,
messages []openai.ChatCompletionMessage,
model string,
maxTokens int,
) (*StreamReader, error) {
// 构造请求体
requestBody := struct {
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}{
Model: model,
Messages: messages,
MaxTokens: maxTokens,
Stream: true,
}
body, _ := json.Marshal(requestBody)
req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body))
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
// 发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
// 检查响应状态
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
}
return NewStreamReader(resp), nil
}
// ==============================
// 使用示例
// ==============================
// func main() {
// apiKey := "your_deepseek_api_key" // 替换为你的API密钥
// client := NewDeepSeekStreamClient(apiKey)
// // 创建对话消息
// messages := []openai.ChatCompletionMessage{
// {
// Role: openai.ChatMessageRoleSystem,
// Content: "你是一个乐于助人的AI助手",
// },
// {
// Role: openai.ChatMessageRoleUser,
// Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?",
// },
// }
// // 创建流式请求
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
// defer cancel()
// stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000)
// if err != nil {
// fmt.Println("创建流失败:", err)
// return
// }
// defer stream.Close()
// fmt.Println("流式响应开始:")
// fmt.Println("====================")
// var fullResponse strings.Builder
// lastPrint := time.Now()
// const updateInterval = 100 * time.Millisecond
// for {
// event, err := stream.Recv()
// if errors.Is(err, io.EOF) {
// fmt.Println("\n\n====================")
// fmt.Println("流式响应结束")
// break
// }
// if err != nil {
// fmt.Println("\n接收错误:", err)
// break
// }
// if len(event.Choices) > 0 {
// content := event.Choices[0].Delta.Content
// if content != "" {
// fullResponse.WriteString(content)
// // 流式输出控制:定期打印新内容
// if time.Since(lastPrint) > updateInterval {
// fmt.Print(content)
// lastPrint = time.Now()
// }
// }
// }
// }
// fmt.Println("\n完整响应:")
// fmt.Println(fullResponse.String())
// fmt.Println("\n统计信息:", stream.Stats())
// }