diff --git a/core/cli/chat/chat_test.go b/core/cli/chat/chat_test.go index a5c9a1f3c..7399c3802 100644 --- a/core/cli/chat/chat_test.go +++ b/core/cli/chat/chat_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -19,7 +20,7 @@ var _ = Describe("Run chat", func() { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/models" { w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`) + writeResponse(w, `{"object":"list","data":[{"id":"test-model","object":"model"}]}`) return } @@ -40,9 +41,9 @@ var _ = Describe("Run chat", func() { Expect(body.Messages[0].Content).To(Equal("hello")) w.Header().Set("Content-Type", "text/event-stream") - fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n") - fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n") - fmt.Fprint(w, "data: [DONE]\n\n") + writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n") + writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n") + writeResponse(w, "data: [DONE]\n\n") })) defer server.Close() @@ -135,14 +136,14 @@ func chatTestServer(models []string, onChat func(model string)) *httptest.Server switch r.URL.Path { case "/v1/models": w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"object":"list","data":[`) + writeResponse(w, `{"object":"list","data":[`) for i, model := range models { if i > 0 { - fmt.Fprint(w, ",") + writeResponse(w, ",") } - fmt.Fprintf(w, `{"id":%q,"object":"model"}`, model) + writeResponsef(w, `{"id":%q,"object":"model"}`, model) } - fmt.Fprint(w, `]}`) + writeResponse(w, `]}`) case "/v1/chat/completions": var body struct { Model string `json:"model"` @@ -152,10 +153,20 @@ func chatTestServer(models []string, onChat func(model string)) *httptest.Server onChat(body.Model) } w.Header().Set("Content-Type", "text/event-stream") - fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n") - fmt.Fprint(w, "data: [DONE]\n\n") + writeResponse(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n\n") + writeResponse(w, "data: [DONE]\n\n") default: w.WriteHeader(http.StatusNotFound) } })) } + +func writeResponse(w io.Writer, text string) { + _, err := fmt.Fprint(w, text) + Expect(err).ToNot(HaveOccurred()) +} + +func writeResponsef(w io.Writer, format string, args ...any) { + _, err := fmt.Fprintf(w, format, args...) + Expect(err).ToNot(HaveOccurred()) +} diff --git a/core/cli/chat/client.go b/core/cli/chat/client.go index 93910de7e..407845d0b 100644 --- a/core/cli/chat/client.go +++ b/core/cli/chat/client.go @@ -50,7 +50,9 @@ func (c *localAIChatClient) StreamChat(ctx context.Context, model string, messag if err != nil { return "", friendlyChatError(err, model) } - defer stream.Close() + defer func() { + _ = stream.Close() + }() var answer strings.Builder for { diff --git a/core/cli/chat/terminal.go b/core/cli/chat/terminal.go index c9aebdbe3..8d76e1e6f 100644 --- a/core/cli/chat/terminal.go +++ b/core/cli/chat/terminal.go @@ -12,11 +12,17 @@ func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, ou scanner := bufio.NewScanner(in) scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024) - fmt.Fprintf(out, "LocalAI chat (%s)\n", session.CurrentModel()) - fmt.Fprintln(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.") + if err := writeChat(out, "LocalAI chat (%s)\n", session.CurrentModel()); err != nil { + return err + } + if err := writeChat(out, "Type /exit to quit, /clear to reset the conversation, /models to list models.\n"); err != nil { + return err + } for { - fmt.Fprint(out, "\n> ") + if err := writeChat(out, "\n> "); err != nil { + return err + } if !scanner.Scan() { break } @@ -26,45 +32,62 @@ func runTerminalChat(ctx context.Context, session *chatSession, in io.Reader, ou case "": continue case "/bye", "/exit", "/quit": - fmt.Fprintln(out, "bye") - return nil + return writeChat(out, "bye\n") case "/clear": session.Clear() - fmt.Fprintln(out, "conversation cleared") + if err := writeChat(out, "conversation cleared\n"); err != nil { + return err + } continue case "/models": - printChatModels(out, session.Models(), session.CurrentModel()) + if err := printChatModels(out, session.Models(), session.CurrentModel()); err != nil { + return err + } continue } if nextModel, ok := strings.CutPrefix(prompt, "/model "); ok { nextModel = strings.TrimSpace(nextModel) if nextModel == "" { - fmt.Fprintln(out, "usage: /model ") + if err := writeChat(out, "usage: /model \n"); err != nil { + return err + } continue } if err := session.SwitchModel(nextModel); err != nil { - fmt.Fprintln(out, err) + if writeErr := writeChat(out, "%s\n", err); writeErr != nil { + return writeErr + } continue } - fmt.Fprintf(out, "switched to %s; conversation cleared\n", session.CurrentModel()) + if err := writeChat(out, "switched to %s; conversation cleared\n", session.CurrentModel()); err != nil { + return err + } continue } - fmt.Fprint(out, "assistant: ") + if err := writeChat(out, "assistant: "); err != nil { + return err + } if err := session.Send(ctx, prompt, out); err != nil { return err } - fmt.Fprintln(out) + if err := writeChat(out, "\n"); err != nil { + return err + } } return scanner.Err() } -func printChatModels(out io.Writer, models []string, current string) { +func printChatModels(out io.Writer, models []string, current string) error { if len(models) == 0 { - fmt.Fprintln(out, "no models installed") - return + return writeChat(out, "no models installed\n") } - fmt.Fprint(out, formatChatModelList(models, current)) + return writeChat(out, "%s", formatChatModelList(models, current)) +} + +func writeChat(out io.Writer, format string, args ...any) error { + _, err := fmt.Fprintf(out, format, args...) + return err }