diff --git a/.github/workflows/tests-e2e.yml b/.github/workflows/tests-e2e.yml new file mode 100644 index 000000000..b7ba47feb --- /dev/null +++ b/.github/workflows/tests-e2e.yml @@ -0,0 +1,56 @@ +--- +name: 'E2E Backend Tests' + +on: + pull_request: + push: + branches: + - master + tags: + - '*' + +concurrency: + group: ci-tests-e2e-backend-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + +jobs: + tests-e2e-backend: + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.25.x'] + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go ${{ matrix.go-version }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: false + - name: Display Go version + run: go version + - name: Proto Dependencies + run: | + # Install protoc + curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ + unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ + rm protoc.zip + go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af + PATH="$PATH:$HOME/go/bin" make protogen-go + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential + - name: Test Backend E2E + run: | + PATH="$PATH:$HOME/go/bin" make build-mock-backend test-e2e + - name: Setup tmate session if tests fail + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3.23 + with: + detached: true + connect-timeout-seconds: 180 + limit-access-to-actor: true diff --git a/Makefile b/Makefile index 7f4941532..505b4eba7 100644 --- a/Makefile +++ b/Makefile @@ -191,9 +191,6 @@ run-e2e-aio: protogen-go ######################################################## prepare-e2e: - mkdir -p $(TEST_DIR) - cp -rfv $(abspath ./tests/e2e-fixtures)/gpu.yaml $(TEST_DIR)/gpu.yaml - test -e $(TEST_DIR)/ggllm-test-model.bin || wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O $(TEST_DIR)/ggllm-test-model.bin docker build \ --build-arg IMAGE_TYPE=core \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ @@ -207,14 +204,16 @@ prepare-e2e: -t localai-tests . run-e2e-image: - ls -liah $(abspath ./tests/e2e-fixtures) - docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --gpus all --name e2e-tests-$(RANDOM) localai-tests + docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests -test-e2e: +test-e2e: build-mock-backend prepare-e2e run-e2e-image @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ - LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \ + LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + $(MAKE) clean-mock-backend + $(MAKE) teardown-e2e + docker rmi localai-tests teardown-e2e: rm -rf $(TEST_DIR) || true @@ -522,6 +521,16 @@ docker-save-%: backend-images docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-qwen-asr docker-build-voxcpm +######################################################## +### Mock Backend for E2E Tests +######################################################## + +build-mock-backend: protogen-go + $(GOCMD) build -o tests/e2e/mock-backend/mock-backend ./tests/e2e/mock-backend + +clean-mock-backend: + rm -f tests/e2e/mock-backend/mock-backend + ######################################################## ### END Backends ######################################################## diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go index 389d60466..12d500125 100644 --- a/core/http/endpoints/anthropic/messages.go +++ b/core/http/endpoints/anthropic/messages.go @@ -88,21 +88,38 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) if input.Stream { - return handleAnthropicStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn) + return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn) } - return handleAnthropicNonStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn) + return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn) } } -func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error { +func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error { images := []string{} for _, m := range openAIReq.Messages { images = append(images, m.StringImages...) } + toolsJSON := "" + if len(funcs) > 0 { + openAITools := make([]functions.Tool, len(funcs)) + for i, f := range funcs { + openAITools[i] = functions.Tool{Type: "function", Function: f} + } + if toolsBytes, err := json.Marshal(openAITools); err == nil { + toolsJSON = string(toolsBytes) + } + } + toolChoiceJSON := "" + if input.ToolChoice != nil { + if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + predFunc, err := backend.ModelInference( - input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, nil, nil, nil, "", "", nil, nil, nil) + input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, nil, nil, nil) if err != nil { xlog.Error("Anthropic model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) @@ -175,7 +192,7 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic return c.JSON(200, resp) } -func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error { +func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") @@ -292,8 +309,25 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq return true } + toolsJSON := "" + if len(funcs) > 0 { + openAITools := make([]functions.Tool, len(funcs)) + for i, f := range funcs { + openAITools[i] = functions.Tool{Type: "function", Function: f} + } + if toolsBytes, err := json.Marshal(openAITools); err == nil { + toolsJSON = string(toolsBytes) + } + } + toolChoiceJSON := "" + if input.ToolChoice != nil { + if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + predFunc, err := backend.ModelInference( - input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, nil, nil, tokenCallback, "", "", nil, nil, nil) + input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, nil, nil, nil) if err != nil { xlog.Error("Anthropic stream model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) @@ -367,10 +401,11 @@ func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.M // Add system message if present if input.System != "" { + sysStr := string(input.System) messages = append(messages, schema.Message{ Role: "system", - StringContent: input.System, - Content: input.System, + StringContent: sysStr, + Content: sysStr, }) } diff --git a/core/schema/anthropic.go b/core/schema/anthropic.go index d6c17ba79..a44ea64bd 100644 --- a/core/schema/anthropic.go +++ b/core/schema/anthropic.go @@ -5,16 +5,44 @@ import ( "encoding/json" ) +// AnthropicSystemParam accepts system as string or array of content blocks (SDK sends array). +type AnthropicSystemParam string + +// UnmarshalJSON accepts string or array of blocks with "text" field. +func (s *AnthropicSystemParam) UnmarshalJSON(data []byte) error { + var raw interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + switch v := raw.(type) { + case string: + *s = AnthropicSystemParam(v) + return nil + case []interface{}: + var out string + for _, block := range v { + if m, ok := block.(map[string]interface{}); ok && m["type"] == "text" { + if t, ok := m["text"].(string); ok { + out += t + } + } + } + *s = AnthropicSystemParam(out) + return nil + } + return nil +} + // AnthropicRequest represents a request to the Anthropic Messages API // https://docs.anthropic.com/claude/reference/messages_post type AnthropicRequest struct { - Model string `json:"model"` - Messages []AnthropicMessage `json:"messages"` - MaxTokens int `json:"max_tokens"` - Metadata map[string]string `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` + Model string `json:"model"` + Messages []AnthropicMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Metadata map[string]string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + System AnthropicSystemParam `json:"system,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopK *int `json:"top_k,omitempty"` TopP *float64 `json:"top_p,omitempty"` diff --git a/core/schema/anthropic_test.go b/core/schema/anthropic_test.go index 56f7bc5dd..440399d34 100644 --- a/core/schema/anthropic_test.go +++ b/core/schema/anthropic_test.go @@ -27,7 +27,7 @@ var _ = Describe("Anthropic Schema", func() { Expect(req.Model).To(Equal("claude-3-sonnet-20240229")) Expect(req.MaxTokens).To(Equal(1024)) Expect(len(req.Messages)).To(Equal(1)) - Expect(req.System).To(Equal("You are a helpful assistant.")) + Expect(string(req.System)).To(Equal("You are a helpful assistant.")) Expect(*req.Temperature).To(Equal(0.7)) }) diff --git a/tests/e2e-fixtures/gpu.yaml b/tests/e2e-fixtures/gpu.yaml deleted file mode 100644 index 78d6d4edb..000000000 --- a/tests/e2e-fixtures/gpu.yaml +++ /dev/null @@ -1,17 +0,0 @@ -context_size: 2048 -mirostat: 2 -mirostat_tau: 5.0 -mirostat_eta: 0.1 -f16: true -threads: 1 -gpu_layers: 90 -name: gpt-4 -mmap: true -parameters: - model: ggllm-test-model.bin - rope_freq_base: 10000 - max_tokens: 20 - rope_freq_scale: 1 - temperature: 0.2 - top_k: 40 - top_p: 0.95 diff --git a/tests/e2e/e2e_anthropic_test.go b/tests/e2e/e2e_anthropic_test.go index c4646cf14..bfebddd22 100644 --- a/tests/e2e/e2e_anthropic_test.go +++ b/tests/e2e/e2e_anthropic_test.go @@ -2,9 +2,11 @@ package e2e_test import ( "context" + "encoding/json" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/shared/constant" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -14,29 +16,16 @@ var _ = Describe("Anthropic API E2E test", func() { Context("API with Anthropic SDK", func() { BeforeEach(func() { - // Create Anthropic client pointing to LocalAI client = anthropic.NewClient( - option.WithBaseURL(localAIURL), - option.WithAPIKey("test-api-key"), // LocalAI doesn't require a real API key + option.WithBaseURL(anthropicBaseURL), + option.WithAPIKey("test-api-key"), ) - - // Wait for API to be ready by attempting a simple request - Eventually(func() error { - _, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", - MaxTokens: 10, - Messages: []anthropic.MessageParam{ - anthropic.NewUserMessage(anthropic.NewTextBlock("Hi")), - }, - }) - return err - }, "2m").ShouldNot(HaveOccurred()) }) Context("Non-streaming responses", func() { It("generates a response for a simple message", func() { message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("How much is 2+2? Reply with just the number.")), @@ -44,21 +33,19 @@ var _ = Describe("Anthropic API E2E test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(message.Content).ToNot(BeEmpty()) - // Role is a constant type that defaults to "assistant" Expect(string(message.Role)).To(Equal("assistant")) - Expect(message.StopReason).To(Equal(anthropic.MessageStopReasonEndTurn)) + Expect(string(message.StopReason)).To(Equal("end_turn")) Expect(string(message.Type)).To(Equal("message")) - // Check that content contains text block with expected answer Expect(len(message.Content)).To(BeNumerically(">=", 1)) textBlock := message.Content[0] Expect(string(textBlock.Type)).To(Equal("text")) - Expect(textBlock.Text).To(Or(ContainSubstring("4"), ContainSubstring("four"))) + Expect(textBlock.Text).To(ContainSubstring("mocked")) }) It("handles system prompts", func() { message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, System: []anthropic.TextBlockParam{ {Text: "You are a helpful assistant. Always respond in uppercase letters."}, @@ -74,7 +61,7 @@ var _ = Describe("Anthropic API E2E test", func() { It("returns usage information", func() { message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 100, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")), @@ -89,7 +76,7 @@ var _ = Describe("Anthropic API E2E test", func() { Context("Streaming responses", func() { It("streams tokens for a simple message", func() { stream := client.Messages.NewStreaming(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("Count from 1 to 5")), @@ -125,7 +112,7 @@ var _ = Describe("Anthropic API E2E test", func() { It("streams with system prompt", func() { stream := client.Messages.NewStreaming(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, System: []anthropic.TextBlockParam{ {Text: "You are a helpful assistant."}, @@ -150,25 +137,27 @@ var _ = Describe("Anthropic API E2E test", func() { Context("Tool calling", func() { It("handles tool calls in non-streaming mode", func() { message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather like in San Francisco?")), }, - Tools: []anthropic.ToolParam{ - { - Name: "get_weather", - Description: anthropic.F("Get the current weather in a given location"), - InputSchema: anthropic.F(map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{ - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + Tools: []anthropic.ToolUnionParam{ + anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.Opt("Get the current weather in a given location"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: constant.ValueOf[constant.Object](), + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, }, + Required: []string{"location"}, }, - "required": []string{"location"}, - }), + }, }, }, }) @@ -179,13 +168,14 @@ var _ = Describe("Anthropic API E2E test", func() { // The model must use tools - find the tool use in the response hasToolUse := false for _, block := range message.Content { - if block.Type == anthropic.ContentBlockTypeToolUse { + if block.Type == "tool_use" { hasToolUse = true Expect(block.Name).To(Equal("get_weather")) Expect(block.ID).ToNot(BeEmpty()) // Verify that input contains location - inputMap, ok := block.Input.(map[string]interface{}) - Expect(ok).To(BeTrue()) + var inputMap map[string]interface{} + err := json.Unmarshal(block.Input, &inputMap) + Expect(err).ToNot(HaveOccurred()) _, hasLocation := inputMap["location"] Expect(hasLocation).To(BeTrue()) } @@ -193,35 +183,37 @@ var _ = Describe("Anthropic API E2E test", func() { // Model must have called the tool Expect(hasToolUse).To(BeTrue(), "Model should have called the get_weather tool") - Expect(message.StopReason).To(Equal(anthropic.MessageStopReasonToolUse)) + Expect(string(message.StopReason)).To(Equal("tool_use")) }) It("handles tool_choice parameter", func() { message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("Tell me about the weather")), }, - Tools: []anthropic.ToolParam{ - { - Name: "get_weather", - Description: anthropic.F("Get the current weather"), - InputSchema: anthropic.F(map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{ - "type": "string", + Tools: []anthropic.ToolUnionParam{ + anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.Opt("Get the current weather"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: constant.ValueOf[constant.Object](), + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + }, }, }, - }), + }, }, }, - ToolChoice: anthropic.F[anthropic.ToolChoiceUnionParam]( - anthropic.ToolChoiceAutoParam{ - Type: anthropic.F(anthropic.ToolChoiceAutoTypeAuto), + ToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{ + Type: constant.ValueOf[constant.Auto](), }, - ), + }, }) Expect(err).ToNot(HaveOccurred()) @@ -231,21 +223,23 @@ var _ = Describe("Anthropic API E2E test", func() { It("handles tool results in messages", func() { // First, make a request that should trigger a tool call firstMessage, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in SF?")), }, - Tools: []anthropic.ToolParam{ - { - Name: "get_weather", - Description: anthropic.F("Get weather"), - InputSchema: anthropic.F(map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{"type": "string"}, + Tools: []anthropic.ToolUnionParam{ + anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.Opt("Get weather"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: constant.ValueOf[constant.Object](), + Properties: map[string]interface{}{ + "location": map[string]interface{}{"type": "string"}, + }, }, - }), + }, }, }, }) @@ -256,7 +250,7 @@ var _ = Describe("Anthropic API E2E test", func() { var toolUseID string var toolUseName string for _, block := range firstMessage.Content { - if block.Type == anthropic.ContentBlockTypeToolUse { + if block.Type == "tool_use" { toolUseID = block.ID toolUseName = block.Name break @@ -266,27 +260,44 @@ var _ = Describe("Anthropic API E2E test", func() { // Model must have called the tool Expect(toolUseID).ToNot(BeEmpty(), "Model should have called the get_weather tool") + // Convert ContentBlockUnion to ContentBlockParamUnion for NewAssistantMessage + contentBlocks := make([]anthropic.ContentBlockParamUnion, len(firstMessage.Content)) + for i, block := range firstMessage.Content { + if block.Type == "tool_use" { + var inputMap map[string]interface{} + if err := json.Unmarshal(block.Input, &inputMap); err == nil { + contentBlocks[i] = anthropic.NewToolUseBlock(block.ID, inputMap, block.Name) + } else { + contentBlocks[i] = anthropic.NewToolUseBlock(block.ID, block.Input, block.Name) + } + } else if block.Type == "text" { + contentBlocks[i] = anthropic.NewTextBlock(block.Text) + } + } + // Send back a tool result and verify it's handled correctly secondMessage, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in SF?")), - anthropic.NewAssistantMessage(firstMessage.Content...), + anthropic.NewAssistantMessage(contentBlocks...), anthropic.NewUserMessage( anthropic.NewToolResultBlock(toolUseID, "Sunny, 72°F", false), ), }, - Tools: []anthropic.ToolParam{ - { - Name: toolUseName, - Description: anthropic.F("Get weather"), - InputSchema: anthropic.F(map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{"type": "string"}, + Tools: []anthropic.ToolUnionParam{ + anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: toolUseName, + Description: anthropic.Opt("Get weather"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: constant.ValueOf[constant.Object](), + Properties: map[string]interface{}{ + "location": map[string]interface{}{"type": "string"}, + }, }, - }), + }, }, }, }) @@ -297,32 +308,33 @@ var _ = Describe("Anthropic API E2E test", func() { It("handles tool calls in streaming mode", func() { stream := client.Messages.NewStreaming(context.TODO(), anthropic.MessageNewParams{ - Model: "gpt-4", + Model: "mock-model", MaxTokens: 1024, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather like in San Francisco?")), }, - Tools: []anthropic.ToolParam{ - { - Name: "get_weather", - Description: anthropic.F("Get the current weather in a given location"), - InputSchema: anthropic.F(map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{ - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + Tools: []anthropic.ToolUnionParam{ + anthropic.ToolUnionParam{ + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.Opt("Get the current weather in a given location"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: constant.ValueOf[constant.Object](), + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, }, + Required: []string{"location"}, }, - "required": []string{"location"}, - }), + }, }, }, }) message := anthropic.Message{} eventCount := 0 - hasToolUseBlock := false hasContentBlockStart := false hasContentBlockDelta := false hasContentBlockStop := false @@ -337,8 +349,8 @@ var _ = Describe("Anthropic API E2E test", func() { switch e := event.AsAny().(type) { case anthropic.ContentBlockStartEvent: hasContentBlockStart = true - if e.ContentBlock.Type == anthropic.ContentBlockTypeToolUse { - hasToolUseBlock = true + if e.ContentBlock.Type == "tool_use" { + // Tool use block detected } case anthropic.ContentBlockDeltaEvent: hasContentBlockDelta = true @@ -357,18 +369,18 @@ var _ = Describe("Anthropic API E2E test", func() { // Check accumulated message has tool use Expect(message.Content).ToNot(BeEmpty()) - + // Model must have called the tool foundToolUse := false for _, block := range message.Content { - if block.Type == anthropic.ContentBlockTypeToolUse { + if block.Type == "tool_use" { foundToolUse = true Expect(block.Name).To(Equal("get_weather")) Expect(block.ID).ToNot(BeEmpty()) } } Expect(foundToolUse).To(BeTrue(), "Model should have called the get_weather tool in streaming mode") - Expect(message.StopReason).To(Equal(anthropic.MessageStopReasonToolUse)) + Expect(string(message.StopReason)).To(Equal("tool_use")) }) }) }) diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index f6ab238df..85d136d3e 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -1,17 +1,170 @@ package e2e_test import ( + "context" + "fmt" + "net/http" "os" + "path/filepath" "testing" + "time" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + httpapi "github.com/mudler/LocalAI/core/http" + "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/phayes/freeport" + "github.com/sashabaranov/go-openai" + "gopkg.in/yaml.v3" + + "github.com/mudler/xlog" ) var ( - localAIURL = os.Getenv("LOCALAI_API") + localAIURL string + anthropicBaseURL string + tmpDir string + backendPath string + modelsPath string + configPath string + app *echo.Echo + appCtx context.Context + appCancel context.CancelFunc + client *openai.Client + apiPort int + apiURL string + mockBackendPath string ) +var _ = BeforeSuite(func() { + var err error + + // Create temporary directory + tmpDir, err = os.MkdirTemp("", "mock-backend-e2e-*") + Expect(err).ToNot(HaveOccurred()) + + backendPath = filepath.Join(tmpDir, "backends") + modelsPath = filepath.Join(tmpDir, "models") + Expect(os.MkdirAll(backendPath, 0755)).To(Succeed()) + Expect(os.MkdirAll(modelsPath, 0755)).To(Succeed()) + + // Build mock backend + mockBackendDir := filepath.Join("..", "e2e", "mock-backend") + mockBackendPath = filepath.Join(backendPath, "mock-backend") + + // Check if mock-backend binary exists in the mock-backend directory + possiblePaths := []string{ + filepath.Join(mockBackendDir, "mock-backend"), + filepath.Join("tests", "e2e", "mock-backend", "mock-backend"), + filepath.Join("..", "..", "tests", "e2e", "mock-backend", "mock-backend"), + } + + found := false + for _, p := range possiblePaths { + if _, err := os.Stat(p); err == nil { + mockBackendPath = p + found = true + break + } + } + + if !found { + // Try to find it relative to current working directory + wd, _ := os.Getwd() + relPath := filepath.Join(wd, "..", "..", "tests", "e2e", "mock-backend", "mock-backend") + if _, err := os.Stat(relPath); err == nil { + mockBackendPath = relPath + found = true + } + } + + Expect(found).To(BeTrue(), "mock-backend binary not found. Run 'make build-mock-backend' first") + + // Make sure it's executable + Expect(os.Chmod(mockBackendPath, 0755)).To(Succeed()) + + // Create model config YAML + modelConfig := map[string]interface{}{ + "name": "mock-model", + "backend": "mock-backend", + "parameters": map[string]interface{}{ + "model": "mock-model.bin", + }, + } + configPath = filepath.Join(modelsPath, "mock-model.yaml") + configYAML, err := yaml.Marshal(modelConfig) + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed()) + + // Set up system state + systemState, err := system.GetSystemState( + system.WithBackendPath(backendPath), + system.WithModelPath(modelsPath), + ) + Expect(err).ToNot(HaveOccurred()) + + // Create application + appCtx, appCancel = context.WithCancel(context.Background()) + + // Create application instance + application, err := application.New( + config.WithContext(appCtx), + config.WithSystemState(systemState), + config.WithDebug(true), + ) + Expect(err).ToNot(HaveOccurred()) + + // Register backend with application's model loader + application.ModelLoader().SetExternalBackend("mock-backend", mockBackendPath) + + // Create HTTP app + app, err = httpapi.API(application) + Expect(err).ToNot(HaveOccurred()) + + // Get free port + port, err := freeport.GetFreePort() + Expect(err).ToNot(HaveOccurred()) + apiPort = port + apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort) + localAIURL = apiURL + // Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages + anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort) + + // Start server in goroutine + go func() { + if err := app.Start(fmt.Sprintf("127.0.0.1:%d", apiPort)); err != nil && err != http.ErrServerClosed { + xlog.Error("server error", "error", err) + } + }() + + // Wait for server to be ready + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = apiURL + client = openai.NewClientWithConfig(defaultConfig) + + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) +}) + +var _ = AfterSuite(func() { + if appCancel != nil { + appCancel() + } + if app != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + Expect(app.Shutdown(ctx)).To(Succeed()) + } + if tmpDir != "" { + os.RemoveAll(tmpDir) + } +}) + func TestLocalAI(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "LocalAI E2E test suite") diff --git a/tests/e2e/e2e_test.go b/tests/e2e/e2e_test.go deleted file mode 100644 index 7b506e609..000000000 --- a/tests/e2e/e2e_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package e2e_test - -import ( - "context" - "fmt" - "os" - "os/exec" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - - openaigo "github.com/otiai10/openaigo" - "github.com/sashabaranov/go-openai" -) - -var _ = Describe("E2E test", func() { - var client *openai.Client - var client2 *openaigo.Client - - Context("API with ephemeral models", func() { - BeforeEach(func() { - defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = localAIURL - - client2 = openaigo.NewClient("") - client2.BaseURL = defaultConfig.BaseURL - - // Wait for API to be ready - client = openai.NewClientWithConfig(defaultConfig) - Eventually(func() error { - _, err := client.ListModels(context.TODO()) - return err - }, "2m").ShouldNot(HaveOccurred()) - }) - - // Check that the GPU was used - AfterEach(func() { - cmd := exec.Command("/bin/bash", "-xce", "docker logs $(docker ps -q --filter ancestor=localai-tests)") - out, err := cmd.CombinedOutput() - Expect(err).ToNot(HaveOccurred(), string(out)) - // Execute docker logs $$(docker ps -q --filter ancestor=localai-tests) as a command and check the output - if os.Getenv("BUILD_TYPE") == "cublas" { - - Expect(string(out)).To(ContainSubstring("found 1 CUDA devices"), string(out)) - Expect(string(out)).To(ContainSubstring("using CUDA for GPU acceleration"), string(out)) - } else { - fmt.Println("Skipping GPU check") - Expect(string(out)).To(ContainSubstring("[llama-cpp] Loads OK"), string(out)) - Expect(string(out)).To(ContainSubstring("llama_model_loader"), string(out)) - } - }) - - Context("Generates text", func() { - It("streams chat tokens", func() { - model := "gpt-4" - resp, err := client.CreateChatCompletion(context.TODO(), - openai.ChatCompletionRequest{ - Model: model, Messages: []openai.ChatCompletionMessage{ - { - Role: "user", - Content: "How much is 2+2?", - }, - }}) - Expect(err).ToNot(HaveOccurred()) - Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) - Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content)) - }) - }) - }) -}) diff --git a/tests/e2e/mock-backend/.gitignore b/tests/e2e/mock-backend/.gitignore new file mode 100644 index 000000000..32923bce6 --- /dev/null +++ b/tests/e2e/mock-backend/.gitignore @@ -0,0 +1 @@ +mock-backend \ No newline at end of file diff --git a/tests/e2e/mock-backend/Makefile b/tests/e2e/mock-backend/Makefile new file mode 100644 index 000000000..02ef0a46b --- /dev/null +++ b/tests/e2e/mock-backend/Makefile @@ -0,0 +1,7 @@ +.PHONY: build clean + +build: + cd ../../.. && go build -o tests/e2e/mock-backend/mock-backend ./tests/e2e/mock-backend + +clean: + rm -f mock-backend diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go new file mode 100644 index 000000000..824474613 --- /dev/null +++ b/tests/e2e/mock-backend/main.go @@ -0,0 +1,334 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "net" + "os" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" + "google.golang.org/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +// MockBackend implements the Backend gRPC service with mocked responses +type MockBackend struct { + pb.UnimplementedBackendServer +} + +func (m *MockBackend) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { + xlog.Debug("Health check called") + return &pb.Reply{Message: []byte("OK")}, nil +} + +func (m *MockBackend) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { + xlog.Debug("LoadModel called", "model", in.Model) + return &pb.Result{ + Message: "Model loaded successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { + xlog.Debug("Predict called", "prompt", in.Prompt) + var response string + toolName := mockToolNameFromRequest(in) + if toolName != "" { + response = fmt.Sprintf(`{"name": "%s", "arguments": {"location": "San Francisco"}}`, toolName) + } else { + response = "This is a mocked response." + } + return &pb.Reply{ + Message: []byte(response), + Tokens: 10, + PromptTokens: 5, + TimingPromptProcessing: 0.1, + TimingTokenGeneration: 0.2, + }, nil +} + +func (m *MockBackend) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { + xlog.Debug("PredictStream called", "prompt", in.Prompt) + var toStream string + toolName := mockToolNameFromRequest(in) + if toolName != "" { + toStream = fmt.Sprintf(`{"name": "%s", "arguments": {"location": "San Francisco"}}`, toolName) + } else { + toStream = "This is a mocked streaming response." + } + for i, r := range toStream { + if err := stream.Send(&pb.Reply{ + Message: []byte(string(r)), + Tokens: int32(i + 1), + }); err != nil { + return err + } + } + return nil +} + +// mockToolNameFromRequest returns the first tool name from the request's Tools JSON (same as other endpoints). +func mockToolNameFromRequest(in *pb.PredictOptions) string { + if in.Tools == "" { + return "" + } + var tools []struct { + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal([]byte(in.Tools), &tools); err != nil || len(tools) == 0 || tools[0].Function.Name == "" { + return "" + } + return tools[0].Function.Name +} + +func (m *MockBackend) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { + xlog.Debug("Embedding called", "prompt", in.Prompt) + // Return a mock embedding vector of 768 dimensions + embeddings := make([]float32, 768) + for i := range embeddings { + embeddings[i] = float32(i%100) / 100.0 // Pattern: 0.0, 0.01, 0.02, ..., 0.99, 0.0, ... + } + return &pb.EmbeddingResult{Embeddings: embeddings}, nil +} + +func (m *MockBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { + xlog.Debug("GenerateImage called", "prompt", in.PositivePrompt) + return &pb.Result{ + Message: "Image generated successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) { + xlog.Debug("GenerateVideo called", "prompt", in.Prompt) + return &pb.Result{ + Message: "Video generated successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { + xlog.Debug("TTS called", "text", in.Text) + // Return success - actual audio would be in the Result message for real backends + return &pb.Result{ + Message: "TTS audio generated successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error { + xlog.Debug("TTSStream called", "text", in.Text) + // Stream mock audio chunks (simplified - just send a few bytes) + chunks := [][]byte{ + {0x52, 0x49, 0x46, 0x46}, // Mock WAV header start + {0x57, 0x41, 0x56, 0x45}, // Mock WAV header + {0x64, 0x61, 0x74, 0x61}, // Mock data chunk + } + for _, chunk := range chunks { + if err := stream.Send(&pb.Reply{Audio: chunk}); err != nil { + return err + } + } + return nil +} + +func (m *MockBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { + xlog.Debug("SoundGeneration called", "text", in.Text) + return &pb.Result{ + Message: "Sound generated successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { + xlog.Debug("AudioTranscription called") + return &pb.TranscriptResult{ + Text: "This is a mocked transcription.", + Segments: []*pb.TranscriptSegment{ + { + Id: 0, + Start: 0, + End: 3000, + Text: "This is a mocked transcription.", + Tokens: []int32{1, 2, 3, 4, 5, 6}, + }, + }, + }, nil +} + +func (m *MockBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { + xlog.Debug("TokenizeString called", "prompt", in.Prompt) + // Return mock token IDs + tokens := []int32{101, 2023, 2003, 1037, 3231, 1012} + return &pb.TokenizationResponse{ + Length: int32(len(tokens)), + Tokens: tokens, + }, nil +} + +func (m *MockBackend) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) { + xlog.Debug("Status called") + return &pb.StatusResponse{ + State: pb.StatusResponse_READY, + Memory: &pb.MemoryUsageData{ + Total: 1024 * 1024 * 100, // 100MB + Breakdown: map[string]uint64{ + "mock": 1024 * 1024 * 50, + }, + }, + }, nil +} + +func (m *MockBackend) Detect(ctx context.Context, in *pb.DetectOptions) (*pb.DetectResponse, error) { + xlog.Debug("Detect called", "src", in.Src) + return &pb.DetectResponse{ + Detections: []*pb.Detection{ + { + X: 10.0, + Y: 20.0, + Width: 100.0, + Height: 200.0, + Confidence: 0.95, + ClassName: "mocked_object", + }, + }, + }, nil +} + +func (m *MockBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) { + xlog.Debug("StoresSet called", "keys", len(in.Keys)) + return &pb.Result{ + Message: "Keys set successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) { + xlog.Debug("StoresDelete called", "keys", len(in.Keys)) + return &pb.Result{ + Message: "Keys deleted successfully (mocked)", + Success: true, + }, nil +} + +func (m *MockBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) { + xlog.Debug("StoresGet called", "keys", len(in.Keys)) + // Return mock keys and values + keys := make([]*pb.StoresKey, len(in.Keys)) + values := make([]*pb.StoresValue, len(in.Keys)) + for i := range in.Keys { + keys[i] = in.Keys[i] + values[i] = &pb.StoresValue{ + Bytes: []byte(fmt.Sprintf("mocked_value_%d", i)), + } + } + return &pb.StoresGetResult{ + Keys: keys, + Values: values, + }, nil +} + +func (m *MockBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) { + xlog.Debug("StoresFind called", "topK", in.TopK) + // Return mock similar keys + keys := []*pb.StoresKey{ + {Floats: []float32{0.1, 0.2, 0.3}}, + {Floats: []float32{0.4, 0.5, 0.6}}, + } + values := []*pb.StoresValue{ + {Bytes: []byte("mocked_value_1")}, + {Bytes: []byte("mocked_value_2")}, + } + similarities := []float32{0.95, 0.85} + return &pb.StoresFindResult{ + Keys: keys, + Values: values, + Similarities: similarities, + }, nil +} + +func (m *MockBackend) Rerank(ctx context.Context, in *pb.RerankRequest) (*pb.RerankResult, error) { + xlog.Debug("Rerank called", "query", in.Query, "documents", len(in.Documents)) + // Return mock reranking results + results := make([]*pb.DocumentResult, len(in.Documents)) + for i, doc := range in.Documents { + results[i] = &pb.DocumentResult{ + Index: int32(i), + Text: doc, + RelevanceScore: 0.9 - float32(i)*0.1, // Decreasing scores + } + } + return &pb.RerankResult{ + Usage: &pb.Usage{ + TotalTokens: int32(len(in.Documents) * 10), + PromptTokens: int32(len(in.Documents) * 10), + }, + Results: results, + }, nil +} + +func (m *MockBackend) GetMetrics(ctx context.Context, in *pb.MetricsRequest) (*pb.MetricsResponse, error) { + xlog.Debug("GetMetrics called") + return &pb.MetricsResponse{ + SlotId: 0, + PromptJsonForSlot: `{"prompt":"mocked"}`, + TokensPerSecond: 10.0, + TokensGenerated: 100, + PromptTokensProcessed: 50, + }, nil +} + +func (m *MockBackend) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) { + xlog.Debug("VAD called", "audio_length", len(in.Audio)) + return &pb.VADResponse{ + Segments: []*pb.VADSegment{ + { + Start: 0.0, + End: 1.5, + }, + { + Start: 2.0, + End: 3.5, + }, + }, + }, nil +} + +func (m *MockBackend) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) { + xlog.Debug("ModelMetadata called", "model", in.Model) + return &pb.ModelMetadataResponse{ + SupportsThinking: false, + RenderedTemplate: "", + }, nil +} + +func main() { + xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), os.Getenv("LOCALAI_LOG_FORMAT"))) + + flag.Parse() + + lis, err := net.Listen("tcp", *addr) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + s := grpc.NewServer( + grpc.MaxRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxSendMsgSize(50*1024*1024), // 50MB + ) + pb.RegisterBackendServer(s, &MockBackend{}) + + xlog.Info("Mock gRPC Server listening", "address", lis.Addr()) + if err := s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } +} diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go new file mode 100644 index 000000000..0585209dd --- /dev/null +++ b/tests/e2e/mock_backend_test.go @@ -0,0 +1,185 @@ +package e2e_test + +import ( + "context" + "io" + "net/http" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/sashabaranov/go-openai" +) + +var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { + Describe("Text Generation APIs", func() { + Context("Predict (Chat Completions)", func() { + It("should return mocked response", func() { + resp, err := client.CreateChatCompletion( + context.TODO(), + openai.ChatCompletionRequest{ + Model: "mock-model", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "Hello", + }, + }, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).To(ContainSubstring("mocked response")) + }) + }) + + Context("PredictStream (Streaming Chat Completions)", func() { + It("should stream mocked tokens", func() { + stream, err := client.CreateChatCompletionStream( + context.TODO(), + openai.ChatCompletionRequest{ + Model: "mock-model", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "Hello", + }, + }, + }, + ) + Expect(err).ToNot(HaveOccurred()) + defer stream.Close() + + hasContent := false + for { + response, err := stream.Recv() + if err != nil { + break + } + if len(response.Choices) > 0 && response.Choices[0].Delta.Content != "" { + hasContent = true + } + } + Expect(hasContent).To(BeTrue()) + }) + }) + }) + + Describe("Embeddings API", func() { + It("should return mocked embeddings", func() { + resp, err := client.CreateEmbeddings( + context.TODO(), + openai.EmbeddingRequest{ + Model: "mock-model", + Input: []string{"test"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data)).To(Equal(1)) + Expect(len(resp.Data[0].Embedding)).To(Equal(768)) + }) + }) + + Describe("TTS APIs", func() { + Context("TTS", func() { + It("should generate mocked audio", func() { + req, err := http.NewRequest("POST", apiURL+"/audio/speech", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + body := `{"model":"mock-model","input":"Hello world","voice":"default"}` + req.Body = http.NoBody + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(body)), nil + } + + // Use direct HTTP client for TTS endpoint + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + if err == nil { + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + } + }) + }) + }) + + Describe("Image Generation API", func() { + It("should generate mocked image", func() { + req, err := http.NewRequest("POST", apiURL+"/images/generations", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + body := `{"model":"mock-model","prompt":"a cat"}` + req.Body = http.NoBody + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(body)), nil + } + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + if err == nil { + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + } + }) + }) + + Describe("Audio Transcription API", func() { + It("should return mocked transcription", func() { + req, err := http.NewRequest("POST", apiURL+"/audio/transcriptions", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "multipart/form-data") + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + if err == nil { + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + } + }) + }) + + Describe("Rerank API", func() { + It("should return mocked reranking results", func() { + req, err := http.NewRequest("POST", apiURL+"/rerank", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + body := `{"model":"mock-model","query":"test","documents":["doc1","doc2"]}` + req.Body = http.NoBody + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(body)), nil + } + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + if err == nil { + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + } + }) + }) + + Describe("Tokenization API", func() { + It("should return mocked tokens", func() { + req, err := http.NewRequest("POST", apiURL+"/tokenize", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + body := `{"model":"mock-model","text":"Hello world"}` + req.Body = http.NoBody + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(body)), nil + } + + httpClient := &http.Client{Timeout: 30 * time.Second} + resp, err := httpClient.Do(req) + if err == nil { + defer resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 500)) + } + }) + }) +})