goaiagent/provider/deepseek.go

104 lines
2.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package provider
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.kingecg.top/kingecg/goaiagent/model"
"github.com/sashabaranov/go-openai"
)
// ==============================
// DeepSeek API 实现:cite[1]:cite[7]
// ==============================
type DeepSeekProvider struct {
apiKey string
model string // deepseek-chat 或 deepseek-reasoner:cite[8]
endpoint string
}
func NewDeepSeekProvider(apiKey, model string) *DeepSeekProvider {
return &DeepSeekProvider{
apiKey: apiKey,
model: model,
endpoint: "https://api.deepseek.com/v1/chat/completions", // DeepSeek官方API地址:cite[7]
}
}
func (d *DeepSeekProvider) GetProviderName() string {
return "DeepSeek"
}
func (d *DeepSeekProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
// 构造DeepSeek专属请求体:cite[1]
deepseekReq := struct {
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
}{
Model: d.model,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
}
reqBody, _ := json.Marshal(deepseekReq)
httpReq, _ := http.NewRequest("POST", d.endpoint, bytes.NewReader(reqBody))
httpReq.Header.Set("Authorization", "Bearer "+d.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return model.ChatResponse{}, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
// 解析DeepSeek特有响应格式
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &result); err != nil {
return model.ChatResponse{}, err
}
if len(result.Choices) == 0 {
return model.ChatResponse{}, fmt.Errorf("empty response from DeepSeek")
}
return model.ChatResponse{
Content: result.Choices[0].Message.Content,
InputTokens: result.Usage.PromptTokens,
OutputTokens: result.Usage.CompletionTokens,
}, nil
}
func (d *DeepSeekProvider) CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) {
streamClient := NewDeepSeekStreamClient(d.apiKey)
return streamClient.CreateStream(ctx, req.Messages, d.model, req.MaxTokens)
// return nil, fmt.Errorf("not implemented")
}
// 简化的Token计数器实际生产应使用tiktoken
func (d *DeepSeekProvider) CountTokens(text string) int {
return len(strings.Fields(text)) // 按空格分词计数
}
func (d *DeepSeekProvider) GetModelName() string {
return d.model
}