mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-11 02:07:27 -04:00
Propagate terminal write errors from the chat prompt and explicitly ignore stream close errors during cleanup. Update chat tests to assert response writer errors so errcheck passes without hiding failed writes. Tests: - go test -count=1 ./core/cli/chat - go test -count=1 ./core/cli Assisted-by: Codex:GPT-5 Signed-off-by: Ching Kao <0980124jim@gmail.com>
115 lines
2.9 KiB
Go
115 lines
2.9 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"strings"
|
|
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type chatClient interface {
|
|
ListModels(ctx context.Context) ([]string, error)
|
|
StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error)
|
|
}
|
|
|
|
type localAIChatClient struct {
|
|
client *openai.Client
|
|
}
|
|
|
|
func newLocalAIChatClient(baseURL string, apiKey string) *localAIChatClient {
|
|
cfg := openai.DefaultConfig(apiKey)
|
|
cfg.BaseURL = baseURL
|
|
return &localAIChatClient{client: openai.NewClientWithConfig(cfg)}
|
|
}
|
|
|
|
func (c *localAIChatClient) ListModels(ctx context.Context) ([]string, error) {
|
|
resp, err := c.client.ListModels(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
models := make([]string, 0, len(resp.Models))
|
|
for _, model := range resp.Models {
|
|
if model.ID != "" {
|
|
models = append(models, model.ID)
|
|
}
|
|
}
|
|
sort.Strings(models)
|
|
return models, nil
|
|
}
|
|
|
|
func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messages []chatMessage, out io.Writer) (string, error) {
|
|
stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
|
Model: model,
|
|
Messages: openAIChatMessages(messages),
|
|
})
|
|
if err != nil {
|
|
return "", friendlyChatError(err, model)
|
|
}
|
|
defer func() {
|
|
_ = stream.Close()
|
|
}()
|
|
|
|
var answer strings.Builder
|
|
for {
|
|
resp, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return answer.String(), friendlyChatError(err, model)
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
token := resp.Choices[0].Delta.Content
|
|
if token == "" {
|
|
continue
|
|
}
|
|
answer.WriteString(token)
|
|
if _, err := fmt.Fprint(out, token); err != nil {
|
|
return answer.String(), err
|
|
}
|
|
}
|
|
|
|
return answer.String(), nil
|
|
}
|
|
|
|
func openAIChatMessages(messages []chatMessage) []openai.ChatCompletionMessage {
|
|
converted := make([]openai.ChatCompletionMessage, len(messages))
|
|
for i, message := range messages {
|
|
converted[i] = openai.ChatCompletionMessage{
|
|
Role: message.Role,
|
|
Content: message.Content,
|
|
}
|
|
}
|
|
return converted
|
|
}
|
|
|
|
func friendlyChatError(err error, model string) error {
|
|
var apiErr *openai.APIError
|
|
if errors.As(err, &apiErr) {
|
|
switch apiErr.HTTPStatusCode {
|
|
case 404:
|
|
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
|
case 403:
|
|
return fmt.Errorf("model %q is disabled. Enable it from LocalAI settings or choose another model with `/model <name>`", model)
|
|
}
|
|
if apiErr.Message != "" {
|
|
return errors.New(apiErr.Message)
|
|
}
|
|
}
|
|
|
|
msg := err.Error()
|
|
if strings.Contains(msg, "model") && strings.Contains(msg, "not found") {
|
|
return fmt.Errorf("model %q is not available. Run `local-ai models list`, install a model with `local-ai models install <model>`, or switch with `/model <name>`", model)
|
|
}
|
|
|
|
return err
|
|
}
|