Files
LocalAI/core/http/endpoints/mcp/tools.go
2025-12-21 19:33:13 +01:00

121 lines
3.3 KiB
Go

package mcp
import (
"context"
"net/http"
"os"
"os/exec"
"sync"
"time"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/mudler/xlog"
)
type sessionCache struct {
mu sync.Mutex
cache map[string][]*mcp.ClientSession
}
var (
cache = sessionCache{
cache: make(map[string][]*mcp.ClientSession),
}
client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
)
func SessionsFromMCPConfig(
name string,
remote config.MCPGenericConfig[config.MCPRemoteServers],
stdio config.MCPGenericConfig[config.MCPSTDIOServers],
) ([]*mcp.ClientSession, error) {
cache.mu.Lock()
defer cache.mu.Unlock()
sessions, exists := cache.cache[name]
if exists {
return sessions, nil
}
allSessions := []*mcp.ClientSession{}
ctx, cancel := context.WithCancel(context.Background())
// Get the list of all the tools that the Agent will be esposed to
for _, server := range remote.Servers {
xlog.Debug("[MCP remote server] Configuration", "server", server)
// Create HTTP client with custom roundtripper for bearer token injection
httpClient := &http.Client{
Timeout: 360 * time.Second,
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
}
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
mcpSession, err := client.Connect(ctx, transport, nil)
if err != nil {
xlog.Error("Failed to connect to MCP server", "error", err, "url", server.URL)
continue
}
xlog.Debug("[MCP remote server] Connected to MCP server", "url", server.URL)
cache.cache[name] = append(cache.cache[name], mcpSession)
allSessions = append(allSessions, mcpSession)
}
for _, server := range stdio.Servers {
xlog.Debug("[MCP stdio server] Configuration", "server", server)
command := exec.Command(server.Command, server.Args...)
command.Env = os.Environ()
for key, value := range server.Env {
command.Env = append(command.Env, key+"="+value)
}
transport := &mcp.CommandTransport{Command: command}
mcpSession, err := client.Connect(ctx, transport, nil)
if err != nil {
xlog.Error("Failed to start MCP server", "error", err, "command", command)
continue
}
xlog.Debug("[MCP stdio server] Connected to MCP server", "command", command)
cache.cache[name] = append(cache.cache[name], mcpSession)
allSessions = append(allSessions, mcpSession)
}
signals.RegisterGracefulTerminationHandler(func() {
for _, session := range allSessions {
session.Close()
}
cancel()
})
return allSessions, nil
}
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
// into HTTP requests
type bearerTokenRoundTripper struct {
token string
base http.RoundTripper
}
// RoundTrip implements the http.RoundTripper interface
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if rt.token != "" {
req.Header.Set("Authorization", "Bearer "+rt.token)
}
return rt.base.RoundTrip(req)
}
// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
if base == nil {
base = http.DefaultTransport
}
return &bearerTokenRoundTripper{
token: token,
base: base,
}
}