308 lines
7.3 KiB
Go
308 lines
7.3 KiB
Go
package provider
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// ==============================
|
|
// 流式读取器核心实现
|
|
// ==============================
|
|
|
|
// StreamEvent 表示单个流式事件
|
|
type StreamEvent struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
Index int `json:"index"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
// StreamReader 处理流式API响应
|
|
type StreamReader struct {
|
|
scanner *bufio.Scanner
|
|
response *http.Response
|
|
currentEvent *StreamEvent
|
|
events chan StreamEvent
|
|
errors chan error
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
mu sync.Mutex
|
|
totalTokens int
|
|
startTime time.Time
|
|
lastEventTime time.Time
|
|
}
|
|
|
|
// NewStreamReader 创建新的流式读取器
|
|
func NewStreamReader(response *http.Response) *StreamReader {
|
|
sr := &StreamReader{
|
|
response: response,
|
|
scanner: bufio.NewScanner(response.Body),
|
|
events: make(chan StreamEvent, 100),
|
|
errors: make(chan error, 1),
|
|
done: make(chan struct{}),
|
|
startTime: time.Now(),
|
|
lastEventTime: time.Now(),
|
|
}
|
|
|
|
sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数
|
|
|
|
go sr.processStream()
|
|
return sr
|
|
}
|
|
|
|
// splitSSE 自定义扫描函数用于处理SSE格式
|
|
func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
|
|
// 查找事件分隔符: \n\n 或 \r\n\r\n
|
|
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
|
return i + 2, data[0:i], nil
|
|
}
|
|
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
|
|
return i + 4, data[0:i], nil
|
|
}
|
|
|
|
// 如果到达文件末尾,返回剩余数据
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
|
|
// 请求更多数据
|
|
return 0, nil, nil
|
|
}
|
|
|
|
// processStream 处理流数据
|
|
func (sr *StreamReader) processStream() {
|
|
defer close(sr.events)
|
|
defer close(sr.errors)
|
|
defer sr.response.Body.Close()
|
|
|
|
for sr.scanner.Scan() {
|
|
eventData := sr.scanner.Bytes()
|
|
sr.lastEventTime = time.Now()
|
|
|
|
// 跳过空行和注释
|
|
if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) {
|
|
continue
|
|
}
|
|
|
|
// 检查是否为 [DONE] 事件
|
|
if bytes.Contains(eventData, []byte("[DONE]")) {
|
|
return
|
|
}
|
|
|
|
// 提取事件数据部分
|
|
prefix := []byte("data: ")
|
|
if bytes.HasPrefix(eventData, prefix) {
|
|
eventData = bytes.TrimPrefix(eventData, prefix)
|
|
}
|
|
|
|
var event StreamEvent
|
|
if err := json.Unmarshal(eventData, &event); err != nil {
|
|
sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData))
|
|
return
|
|
}
|
|
|
|
// 更新Token计数
|
|
if content := event.Choices[0].Delta.Content; content != "" {
|
|
sr.mu.Lock()
|
|
sr.totalTokens += len(strings.Fields(content))
|
|
sr.mu.Unlock()
|
|
}
|
|
|
|
// 发送事件
|
|
sr.events <- event
|
|
}
|
|
|
|
if err := sr.scanner.Err(); err != nil {
|
|
sr.errors <- fmt.Errorf("扫描错误: %w", err)
|
|
}
|
|
}
|
|
|
|
// Recv 接收下一个事件
|
|
func (sr *StreamReader) Recv() (StreamEvent, error) {
|
|
select {
|
|
case event, ok := <-sr.events:
|
|
if !ok {
|
|
return StreamEvent{}, io.EOF
|
|
}
|
|
return event, nil
|
|
case err := <-sr.errors:
|
|
return StreamEvent{}, err
|
|
case <-sr.done:
|
|
return StreamEvent{}, io.EOF
|
|
}
|
|
}
|
|
|
|
// Close 关闭流
|
|
func (sr *StreamReader) Close() error {
|
|
sr.closeOnce.Do(func() {
|
|
close(sr.done)
|
|
sr.response.Body.Close()
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Stats 获取流统计信息
|
|
func (sr *StreamReader) Stats() string {
|
|
sr.mu.Lock()
|
|
defer sr.mu.Unlock()
|
|
|
|
duration := time.Since(sr.startTime)
|
|
return fmt.Sprintf(
|
|
"Tokens: %d | 持续时间: %s | 最后事件: %s前",
|
|
sr.totalTokens,
|
|
duration.Round(time.Second),
|
|
time.Since(sr.lastEventTime).Round(time.Second),
|
|
)
|
|
}
|
|
|
|
// ==============================
|
|
// DeepSeek流式API客户端
|
|
// ==============================
|
|
|
|
type DeepSeekStreamClient struct {
|
|
apiKey string
|
|
endpoint string
|
|
}
|
|
|
|
func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient {
|
|
return &DeepSeekStreamClient{
|
|
apiKey: apiKey,
|
|
endpoint: "https://api.deepseek.com/v1/chat/completions",
|
|
}
|
|
}
|
|
|
|
func (c *DeepSeekStreamClient) CreateStream(
|
|
ctx context.Context,
|
|
messages []openai.ChatCompletionMessage,
|
|
model string,
|
|
maxTokens int,
|
|
) (*StreamReader, error) {
|
|
// 构造请求体
|
|
requestBody := struct {
|
|
Model string `json:"model"`
|
|
Messages []openai.ChatCompletionMessage `json:"messages"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Stream bool `json:"stream"`
|
|
}{
|
|
Model: model,
|
|
Messages: messages,
|
|
MaxTokens: maxTokens,
|
|
Stream: true,
|
|
}
|
|
|
|
body, _ := json.Marshal(requestBody)
|
|
req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body))
|
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
req.Header.Set("Cache-Control", "no-cache")
|
|
req.Header.Set("Connection", "keep-alive")
|
|
|
|
// 发送请求
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 检查响应状态
|
|
if resp.StatusCode != http.StatusOK {
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body))
|
|
}
|
|
|
|
return NewStreamReader(resp), nil
|
|
}
|
|
|
|
// ==============================
|
|
// 使用示例
|
|
// ==============================
|
|
|
|
// func main() {
|
|
// apiKey := "your_deepseek_api_key" // 替换为你的API密钥
|
|
// client := NewDeepSeekStreamClient(apiKey)
|
|
|
|
// // 创建对话消息
|
|
// messages := []openai.ChatCompletionMessage{
|
|
// {
|
|
// Role: openai.ChatMessageRoleSystem,
|
|
// Content: "你是一个乐于助人的AI助手",
|
|
// },
|
|
// {
|
|
// Role: openai.ChatMessageRoleUser,
|
|
// Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?",
|
|
// },
|
|
// }
|
|
|
|
// // 创建流式请求
|
|
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
// defer cancel()
|
|
|
|
// stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000)
|
|
// if err != nil {
|
|
// fmt.Println("创建流失败:", err)
|
|
// return
|
|
// }
|
|
// defer stream.Close()
|
|
|
|
// fmt.Println("流式响应开始:")
|
|
// fmt.Println("====================")
|
|
|
|
// var fullResponse strings.Builder
|
|
// lastPrint := time.Now()
|
|
// const updateInterval = 100 * time.Millisecond
|
|
|
|
// for {
|
|
// event, err := stream.Recv()
|
|
// if errors.Is(err, io.EOF) {
|
|
// fmt.Println("\n\n====================")
|
|
// fmt.Println("流式响应结束")
|
|
// break
|
|
// }
|
|
|
|
// if err != nil {
|
|
// fmt.Println("\n接收错误:", err)
|
|
// break
|
|
// }
|
|
|
|
// if len(event.Choices) > 0 {
|
|
// content := event.Choices[0].Delta.Content
|
|
// if content != "" {
|
|
// fullResponse.WriteString(content)
|
|
|
|
// // 流式输出控制:定期打印新内容
|
|
// if time.Since(lastPrint) > updateInterval {
|
|
// fmt.Print(content)
|
|
// lastPrint = time.Now()
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// fmt.Println("\n完整响应:")
|
|
// fmt.Println(fullResponse.String())
|
|
// fmt.Println("\n统计信息:", stream.Stats())
|
|
// }
|