60 lines
1.4 KiB
Go
60 lines
1.4 KiB
Go
package provider
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
|
|
"git.kingecg.top/kingecg/goaiagent/model"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// ==============================
|
|
// OpenAI 兼容实现
|
|
// ==============================
|
|
type OpenAiProvider struct {
|
|
client *openai.Client
|
|
model string
|
|
}
|
|
|
|
func NewOpenAiProvider(apiKey, model string) *OpenAiProvider {
|
|
config := openai.DefaultConfig(apiKey)
|
|
return &OpenAiProvider{
|
|
client: openai.NewClientWithConfig(config),
|
|
model: model,
|
|
}
|
|
}
|
|
|
|
func (o *OpenAiProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
|
|
// 转换统一请求为OpenAI SDK格式
|
|
messages := make([]openai.ChatCompletionMessage, len(req.Messages))
|
|
for i, m := range req.Messages {
|
|
messages[i] = openai.ChatCompletionMessage{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
}
|
|
}
|
|
|
|
resp, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
|
Model: o.model,
|
|
Messages: messages,
|
|
MaxTokens: req.MaxTokens,
|
|
})
|
|
if err != nil {
|
|
return model.ChatResponse{}, err
|
|
}
|
|
|
|
return model.ChatResponse{
|
|
Content: resp.Choices[0].Message.Content,
|
|
InputTokens: resp.Usage.PromptTokens,
|
|
OutputTokens: resp.Usage.CompletionTokens,
|
|
}, nil
|
|
}
|
|
|
|
func (o *OpenAiProvider) CountTokens(text string) int {
|
|
return len(strings.Fields(text))
|
|
}
|
|
|
|
func (o *OpenAiProvider) GetModelName() string {
|
|
return o.model
|
|
}
|