mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-01 19:09:15 -05:00
121 lines
3.3 KiB
Go
121 lines
3.3 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/pkg/signals"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
type sessionCache struct {
|
|
mu sync.Mutex
|
|
cache map[string][]*mcp.ClientSession
|
|
}
|
|
|
|
var (
|
|
cache = sessionCache{
|
|
cache: make(map[string][]*mcp.ClientSession),
|
|
}
|
|
|
|
client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
|
|
)
|
|
|
|
func SessionsFromMCPConfig(
|
|
name string,
|
|
remote config.MCPGenericConfig[config.MCPRemoteServers],
|
|
stdio config.MCPGenericConfig[config.MCPSTDIOServers],
|
|
) ([]*mcp.ClientSession, error) {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
|
|
sessions, exists := cache.cache[name]
|
|
if exists {
|
|
return sessions, nil
|
|
}
|
|
|
|
allSessions := []*mcp.ClientSession{}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Get the list of all the tools that the Agent will be esposed to
|
|
for _, server := range remote.Servers {
|
|
xlog.Debug("[MCP remote server] Configuration", "server", server)
|
|
// Create HTTP client with custom roundtripper for bearer token injection
|
|
httpClient := &http.Client{
|
|
Timeout: 360 * time.Second,
|
|
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
|
}
|
|
|
|
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
|
mcpSession, err := client.Connect(ctx, transport, nil)
|
|
if err != nil {
|
|
xlog.Error("Failed to connect to MCP server", "error", err, "url", server.URL)
|
|
continue
|
|
}
|
|
xlog.Debug("[MCP remote server] Connected to MCP server", "url", server.URL)
|
|
cache.cache[name] = append(cache.cache[name], mcpSession)
|
|
allSessions = append(allSessions, mcpSession)
|
|
}
|
|
|
|
for _, server := range stdio.Servers {
|
|
xlog.Debug("[MCP stdio server] Configuration", "server", server)
|
|
command := exec.Command(server.Command, server.Args...)
|
|
command.Env = os.Environ()
|
|
for key, value := range server.Env {
|
|
command.Env = append(command.Env, key+"="+value)
|
|
}
|
|
transport := &mcp.CommandTransport{Command: command}
|
|
mcpSession, err := client.Connect(ctx, transport, nil)
|
|
if err != nil {
|
|
xlog.Error("Failed to start MCP server", "error", err, "command", command)
|
|
continue
|
|
}
|
|
xlog.Debug("[MCP stdio server] Connected to MCP server", "command", command)
|
|
cache.cache[name] = append(cache.cache[name], mcpSession)
|
|
allSessions = append(allSessions, mcpSession)
|
|
}
|
|
|
|
signals.RegisterGracefulTerminationHandler(func() {
|
|
for _, session := range allSessions {
|
|
session.Close()
|
|
}
|
|
cancel()
|
|
})
|
|
|
|
return allSessions, nil
|
|
}
|
|
|
|
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
|
|
// into HTTP requests
|
|
type bearerTokenRoundTripper struct {
|
|
token string
|
|
base http.RoundTripper
|
|
}
|
|
|
|
// RoundTrip implements the http.RoundTripper interface
|
|
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if rt.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+rt.token)
|
|
}
|
|
return rt.base.RoundTrip(req)
|
|
}
|
|
|
|
// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
|
|
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
|
|
if base == nil {
|
|
base = http.DefaultTransport
|
|
}
|
|
return &bearerTokenRoundTripper{
|
|
token: token,
|
|
base: base,
|
|
}
|
|
}
|