Compare commits

...

4 Commits

Author SHA1 Message Date
ParthSareen
b91c1f6749 update tests 2025-10-03 14:49:49 -07:00
ParthSareen
4f45f39bc6 remove auth for tests 2025-10-03 14:14:28 -07:00
ParthSareen
03e1d64aac add tests 2025-10-01 13:26:52 -07:00
ParthSareen
f88174c55d routes/client: add web search and fetch 2025-10-01 13:08:57 -07:00
5 changed files with 405 additions and 4 deletions

View File

@@ -414,6 +414,24 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse,
return &resp, nil
}
// WebSearch performs a web search via the Ollama server.
func (c *Client) WebSearch(ctx context.Context, req *WebSearchRequest) (*WebSearchResponse, error) {
var resp WebSearchResponse
if err := c.do(ctx, http.MethodPost, "/api/web_search", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// WebFetch fetches the contents of a web page via the Ollama server.
func (c *Client) WebFetch(ctx context.Context, req *WebFetchRequest) (*WebFetchResponse, error) {
var resp WebFetchResponse
if err := c.do(ctx, http.MethodPost, "/api/web_fetch", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates an embedding from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse

View File

@@ -262,3 +262,135 @@ func TestClientDo(t *testing.T) {
})
}
}
func TestClientWebSearch(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/api/web_search") {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var req WebSearchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("failed to decode request: %v", err)
}
if req.Query != "what is ollama" {
t.Fatalf("unexpected query: %s", req.Query)
}
if req.MaxResults != 3 {
t.Fatalf("unexpected max_results: %d", req.MaxResults)
}
resp := WebSearchResponse{
Results: []WebSearchResult{{
Title: "Ollama",
URL: "https://ollama.com",
Content: "Cloud models are now available...",
}},
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
u, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("parse server URL: %v", err)
}
client := NewClient(u, ts.Client())
resp, err := client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama", MaxResults: 3})
if err != nil {
t.Fatalf("WebSearch returned error: %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 result, got %d", len(resp.Results))
}
if resp.Results[0].Title != "Ollama" {
t.Fatalf("unexpected title: %s", resp.Results[0].Title)
}
}
func TestClientWebSearchUnauthorized(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]string{
"signin_url": "https://ollama.com/connect",
})
}))
defer ts.Close()
u, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("parse server URL: %v", err)
}
client := NewClient(u, ts.Client())
_, err = client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama"})
if err == nil {
t.Fatal("expected error, got nil")
}
if _, ok := err.(AuthorizationError); !ok {
t.Fatalf("expected AuthorizationError, got %T", err)
}
}
func TestClientWebFetch(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/api/web_fetch") {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var req WebFetchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("failed to decode request: %v", err)
}
if req.URL != "https://ollama.com" {
t.Fatalf("unexpected url: %s", req.URL)
}
resp := WebFetchResponse{
Title: "Ollama",
Content: "Cloud models are now available...",
Links: []string{"https://ollama.com/models"},
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
u, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("parse server URL: %v", err)
}
client := NewClient(u, ts.Client())
resp, err := client.WebFetch(t.Context(), &WebFetchRequest{URL: "https://ollama.com"})
if err != nil {
t.Fatalf("WebFetch returned error: %v", err)
}
if resp.Title != "Ollama" {
t.Fatalf("unexpected title: %s", resp.Title)
}
if len(resp.Links) != 1 || resp.Links[0] != "https://ollama.com/models" {
t.Fatalf("unexpected links: %v", resp.Links)
}
}

View File

@@ -453,6 +453,40 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
// WebSearchRequest is the request passed to [Client.WebSearch].
type WebSearchRequest struct {
// Query is the search query string.
Query string `json:"query"`
// MaxResults is the optional maximum number of results to return (default 5, max 10).
MaxResults int `json:"max_results,omitempty"`
}
// WebSearchResult represents a single web search result.
type WebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
// WebSearchResponse is the response from [Client.WebSearch].
type WebSearchResponse struct {
Results []WebSearchResult `json:"results"`
}
// WebFetchRequest is the request passed to [Client.WebFetch].
type WebFetchRequest struct {
// URL is the address of the page to fetch.
URL string `json:"url"`
}
// WebFetchResponse is the response from [Client.WebFetch].
type WebFetchResponse struct {
Title string `json:"title"`
Content string `json:"content"`
Links []string `json:"links"`
}
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
// Model is the model name to create.

View File

@@ -76,9 +76,20 @@ var lowVRAMThreshold uint64 = 20 * format.GibiByte
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
lowVRAM bool
addr net.Addr
sched *Scheduler
lowVRAM bool
cloudBaseURL *url.URL
}
func (s *Server) webServiceBase() *url.URL {
defaultWebServiceURL := url.URL{Scheme: "https", Host: "ollama.com"}
if s != nil && s.cloudBaseURL != nil {
u := *s.cloudBaseURL
return &u
}
u := defaultWebServiceURL
return &u
}
func init() {
@@ -767,6 +778,105 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func (s *Server) WebSearchHandler(c *gin.Context) {
var req api.WebSearchRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Query = strings.TrimSpace(req.Query)
if req.Query == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "query is required"})
return
}
if req.MaxResults != 0 && (req.MaxResults < 1 || req.MaxResults > 10) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "max_results must be between 1 and 10"})
return
}
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if resp == nil {
resp = &api.WebSearchResponse{}
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) WebFetchHandler(c *gin.Context) {
var req api.WebFetchRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if resp == nil {
resp = &api.WebFetchResponse{}
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
@@ -1447,6 +1557,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/web_search", s.WebSearchHandler)
r.POST("/api/web_fetch", s.WebFetchHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)

View File

@@ -13,6 +13,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
@@ -92,6 +93,9 @@ func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
func TestRoutes(t *testing.T) {
// Disable authentication for tests to avoid issues with missing private keys
t.Setenv("OLLAMA_AUTH", "false")
type testCase struct {
Name string
Method string
@@ -139,6 +143,11 @@ func TestRoutes(t *testing.T) {
}
}
var (
searchRequests []api.WebSearchRequest
fetchRequests []api.WebFetchRequest
)
testCases := []testCase{
{
Name: "Version Handler",
@@ -455,6 +464,69 @@ func TestRoutes(t *testing.T) {
}
},
},
{
Name: "Web Search Handler",
Method: http.MethodPost,
Path: "/api/web_search",
Setup: func(t *testing.T, req *http.Request) {
searchRequests = nil
payload := api.WebSearchRequest{Query: "cats", MaxResults: 2}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
}
var out api.WebSearchResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(out.Results) != 1 || out.Results[0].Title != "Result" {
t.Fatalf("unexpected response: %+v", out)
}
if len(searchRequests) != 1 {
t.Fatalf("expected 1 forwarded request, got %d", len(searchRequests))
}
if searchRequests[0].Query != "cats" || searchRequests[0].MaxResults != 2 {
t.Fatalf("unexpected forwarded request: %+v", searchRequests[0])
}
},
},
{
Name: "Web Fetch Handler",
Method: http.MethodPost,
Path: "/api/web_fetch",
Setup: func(t *testing.T, req *http.Request) {
fetchRequests = nil
payload := api.WebFetchRequest{URL: "https://example.com"}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
}
var out api.WebFetchResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if out.Title != "Example" || len(out.Links) != 1 {
t.Fatalf("unexpected response: %+v", out)
}
if len(fetchRequests) != 1 || fetchRequests[0].URL != "https://example.com" {
t.Fatalf("unexpected forwarded request: %+v", fetchRequests)
}
},
},
{
Name: "openai retrieve model handler",
Setup: func(t *testing.T, req *http.Request) {
@@ -513,7 +585,40 @@ func TestRoutes(t *testing.T) {
HTTPClient: panicOnRoundTrip,
}
s := &Server{}
remoteSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/web_search":
var req api.WebSearchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
searchRequests = append(searchRequests, req)
resp := api.WebSearchResponse{Results: []api.WebSearchResult{{Title: "Result", URL: "https://example.com", Content: "snippet"}}}
_ = json.NewEncoder(w).Encode(resp)
case "/api/web_fetch":
var req api.WebFetchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
fetchRequests = append(fetchRequests, req)
resp := api.WebFetchResponse{Title: "Example", Content: "content", Links: []string{"https://example.com"}}
_ = json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer remoteSrv.Close()
remoteURL, err := url.Parse(remoteSrv.URL)
if err != nil {
t.Fatalf("parse remote server URL: %v", err)
}
s := &Server{
cloudBaseURL: remoteURL,
}
router, err := s.GenerateRoutes(rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)