mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-10 17:56:49 -04:00
fix(cli): handle chat output errors (#10229)
Propagate terminal write errors from the chat prompt and explicitly ignore stream close errors during cleanup. Update chat tests to assert response writer errors so errcheck passes without hiding failed writes. Tests: - go test -count=1 ./core/cli/chat - go test -count=1 ./core/cli Assisted-by: Codex:GPT-5 Signed-off-by: Ching Kao <0980124jim@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -19,7 +20,7 @@ var _ = Describe("Run chat", func() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v1/models" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
|
||||
writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,9 +41,9 @@ var _ = Describe("Run chat", func() {
|
||||
Expect(body.Messages[0].Content).To(Equal("hello"))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -135,14 +136,14 @@ func chatTestServer(models []string, onChat func(model string)) *httptest.Server
|
||||
switch r.URL.Path {
|
||||
case "/v1/models":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"object":"list","data":[`)
|
||||
writeResponse(w, `{"object":"list","data":[`)
|
||||
for i, model := range models {
|
||||
if i > 0 {
|
||||
fmt.Fprint(w, ",")
|
||||
writeResponse(w, ",")
|
||||
}
|
||||
fmt.Fprintf(w, `{"id":%q,"object":"model"}`, model)
|
||||
writeResponsef(w, `{"id":%q,"object":"model"}`, model)
|
||||
}
|
||||
fmt.Fprint(w, `]}`)
|
||||
writeResponse(w, `]}`)
|
||||
case "/v1/chat/completions":
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
@@ -152,10 +153,20 @@ func chatTestServer(models []string, onChat func(model string)) *httptest.Server
|
||||
onChat(body.Model)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n")
|
||||
writeResponse(w, "data: [DONE]\n\n")
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func writeResponse(w io.Writer, text string) {
|
||||
_, err := fmt.Fprint(w, text)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
func writeResponsef(w io.Writer, format string, args ...any) {
|
||||
_, err := fmt.Fprintf(w, format, args...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
@@ -50,7 +50,9 @@ func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messag
|
||||
if err != nil {
|
||||
return "", friendlyChatError(err, model)
|
||||
}
|
||||
defer stream.Close()
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
var answer strings.Builder
|
||||
for {
|
||||
|
||||
@@ -12,11 +12,17 @@ func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, ou
|
||||
scanner := bufio.NewScanner(in)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
|
||||
fmt.Fprintf(out, "LocalAI chat (%s)\n", session.CurrentModel())
|
||||
fmt.Fprintln(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.")
|
||||
if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
fmt.Fprint(out, "\n> ")
|
||||
if err := writeChat(out, "\n> "); err != nil {
|
||||
return err
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
@@ -26,45 +32,62 @@ func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, ou
|
||||
case "":
|
||||
continue
|
||||
case "/bye", "/exit", "/quit":
|
||||
fmt.Fprintln(out, "bye")
|
||||
return nil
|
||||
return writeChat(out, "bye\n")
|
||||
case "/clear":
|
||||
session.Clear()
|
||||
fmt.Fprintln(out, "conversation cleared")
|
||||
if err := writeChat(out, "conversation cleared\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case "/models":
|
||||
printChatModels(out, session.Models(), session.CurrentModel())
|
||||
if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok {
|
||||
nextModel = strings.TrimSpace(nextModel)
|
||||
if nextModel == "" {
|
||||
fmt.Fprintln(out, "usage: /model <name>")
|
||||
if err := writeChat(out, "usage: /model <name>\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := session.SwitchModel(nextModel); err != nil {
|
||||
fmt.Fprintln(out, err)
|
||||
if writeErr := writeChat(out, "%s\n", err); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(out, "switched to %s; conversation cleared\n", session.CurrentModel())
|
||||
if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprint(out, "assistant: ")
|
||||
if err := writeChat(out, "assistant: "); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Send(ctx, prompt, out); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(out)
|
||||
if err := writeChat(out, "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func printChatModels(out io.Writer, models []string, current string) {
|
||||
func printChatModels(out io.Writer, models []string, current string) error {
|
||||
if len(models) == 0 {
|
||||
fmt.Fprintln(out, "no models installed")
|
||||
return
|
||||
return writeChat(out, "no models installed\n")
|
||||
}
|
||||
fmt.Fprint(out, formatChatModelList(models, current))
|
||||
return writeChat(out, "%s", formatChatModelList(models, current))
|
||||
}
|
||||
|
||||
func writeChat(out io.Writer, format string, args ...any) error {
|
||||
_, err := fmt.Fprintf(out, format, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user