mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 15:26:14 -04:00
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED is set, and connect workers with scoped credentials from the register response. Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs. Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode. Assisted-by: Grok:grok grok-build Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(distributed): correct NATS JWT scoping and harden client auth The JWT-auth path added in 46467cc7 had several gaps that fail silently under LOCALAI_NATS_REQUIRE_AUTH: - Agent-worker minted JWTs did not allow the subjects the agent worker actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop), so MCP-CI jobs and backend-stop session cleanup were silently dropped. Scope the agent permission set to those subjects. - NATS subscription permission violations were swallowed (Subscribe returned a live-but-dead subscription). Confirm subscriptions with a server round-trip so a denial surfaces synchronously, and log async permission errors. - The backend worker connected anonymously when given a JWT without its paired seed; reject the unpaired credential instead. - The documented service-user permissions in nats-auth-setup.sh omitted prefixcache.>, which the frontend publishes and subscribes; add it. Also: add a credential-provider hook to the messaging client (consumed by the follow-up credential-lifecycle change), drop the always-nil error from NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct), and gofmt the feature's files. Tests: an agent-JWT e2e spec that connects to the enforcing NATS server and exercises every subscription the agent worker makes, plus permission allow-list coverage unit tests. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(distributed): acquire and auto-refresh worker NATS credentials Workers fetched NATS credentials once at startup, which broke two cases under JWT auth: a worker that registered while still pending admin approval never received a minted JWT (it connected unauthenticated and gave up), and a long-running worker's 24h JWT expired with no way to renew it. Introduce workerregistry.NATSCredentialManager, built on idempotent re-registration (the frontend preserves the node row and mints a fresh JWT each call): - Acquire re-registers through admin approval until the node is approved and credentials are minted (or returns the first success when auth is not required, preserving anonymous-NATS behavior). - RefreshLoop re-registers before the JWT expires (~75% of its lifetime), updating the credentials served to the connection. - Both are bounded (default 100 attempts / consecutive failures) and return an error on exhaustion, so an unapprovable or unrenewable worker exits non-zero and surfaces the problem instead of hanging or drifting toward an expired credential. The messaging client gains WithUserJWTProvider, fetching credentials on each (re)connect so the connection transparently adopts a refreshed JWT when the server expires the old one. RegisterFull exposes the approval status and full response; Register delegates to it. Both the backend worker and the agent worker are wired to this: explicit env credentials are used as-is, minted credentials are acquired-with-wait and refreshed, and a permanent refresh failure shuts the worker down so it restarts and re-acquires. Tests cover Acquire (wait-through-pending, bounded give-up, context cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry exit) and jwtExpiry decoding. Docs updated in distributed-mode.md. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
522 lines
20 KiB
Go
522 lines
20 KiB
Go
package cli
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
|
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
|
"github.com/mudler/LocalAI/core/services/agents"
|
|
"github.com/mudler/LocalAI/core/services/jobs"
|
|
mcpRemote "github.com/mudler/LocalAI/core/services/mcp"
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/pkg/sanitize"
|
|
"github.com/mudler/cogito"
|
|
"github.com/mudler/cogito/clients"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
// AgentWorkerCMD starts a dedicated agent worker process for distributed mode.
|
|
// It registers with the frontend, subscribes to the NATS agent execution queue,
|
|
// and executes agent chats using cogito. The worker is a pure executor — it
|
|
// receives the full agent config and skills in the NATS job payload, so it
|
|
// does not need direct database access.
|
|
//
|
|
// Usage:
|
|
//
|
|
// localai agent-worker --nats-url nats://... --register-to http://localai:8080
|
|
type AgentWorkerCMD struct {
|
|
// NATS (required)
|
|
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
|
|
|
// Registration (required)
|
|
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
|
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
|
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
|
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
|
|
|
// API access
|
|
APIURL string `env:"LOCALAI_API_URL" help:"LocalAI API URL for inference (auto-derived from RegisterTo if not set)" group:"api"`
|
|
APIToken string `env:"LOCALAI_API_TOKEN" help:"API token for LocalAI inference (auto-provisioned during registration if not set)" group:"api"`
|
|
|
|
// NATS subjects
|
|
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
|
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
|
|
|
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
|
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
|
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
|
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
|
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
|
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
|
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
|
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
|
|
|
// Timeouts
|
|
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
|
}
|
|
|
|
func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
|
xlog.Info("Starting agent worker", "nats", sanitize.URL(cmd.NatsURL), "register_to", cmd.RegisterTo)
|
|
|
|
// Resolve API URL
|
|
apiURL := cmp.Or(cmd.APIURL, strings.TrimRight(cmd.RegisterTo, "/"))
|
|
|
|
// Register with frontend
|
|
regClient := &workerregistry.RegistrationClient{
|
|
FrontendURL: cmd.RegisterTo,
|
|
RegistrationToken: cmd.RegistrationToken,
|
|
}
|
|
|
|
nodeName := cmd.NodeName
|
|
if nodeName == "" {
|
|
hostname, _ := os.Hostname()
|
|
nodeName = "agent-" + hostname
|
|
}
|
|
registrationBody := map[string]any{
|
|
"name": nodeName,
|
|
"node_type": "agent",
|
|
}
|
|
if cmd.RegistrationToken != "" {
|
|
registrationBody["token"] = cmd.RegistrationToken
|
|
}
|
|
|
|
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
|
// other background goroutines.
|
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
|
defer shutdownCancel()
|
|
|
|
// Acquire credentials via (re)registration. When the bus requires auth and no
|
|
// static fallback is configured, wait through admin approval until the
|
|
// frontend mints credentials rather than starting unauthenticated.
|
|
credMgr := workerregistry.NewNATSCredentialManager(
|
|
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
|
return regClient.RegisterFull(ctx, registrationBody)
|
|
},
|
|
cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
|
)
|
|
res, err := credMgr.Acquire(shutdownCtx)
|
|
if err != nil {
|
|
return fmt.Errorf("registration failed: %w", err)
|
|
}
|
|
nodeID := res.ID
|
|
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
|
|
|
// Use provisioned API token if none was set
|
|
if cmd.APIToken == "" {
|
|
cmd.APIToken = res.APIToken
|
|
}
|
|
|
|
// Start heartbeat
|
|
heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval)
|
|
if err != nil && cmd.HeartbeatInterval != "" {
|
|
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
|
}
|
|
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
|
|
|
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
|
|
|
// Resolve NATS credentials with precedence: explicit env override, then
|
|
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
|
// Each static source must supply JWT and seed together.
|
|
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
|
var natsOpts []messaging.Option
|
|
switch {
|
|
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
|
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
|
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
|
}
|
|
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
|
case credMgr.HasCredentials():
|
|
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
|
go func() {
|
|
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
|
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
|
shutdownCancel()
|
|
}
|
|
}()
|
|
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
|
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
|
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
|
}
|
|
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
|
case cmd.NatsRequireAuth:
|
|
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
|
}
|
|
if natsTLS.Enabled() {
|
|
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
|
}
|
|
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
|
if err != nil {
|
|
return fmt.Errorf("connecting to NATS: %w", err)
|
|
}
|
|
defer natsClient.Close()
|
|
|
|
// Create event bridge for publishing results back via NATS
|
|
eventBridge := agents.NewEventBridge(natsClient, nil, "agent-worker-"+nodeID)
|
|
|
|
// Start cancel listener
|
|
cancelSub, err := eventBridge.StartCancelListener()
|
|
if err != nil {
|
|
xlog.Warn("Failed to start cancel listener", "error", err)
|
|
} else {
|
|
defer cancelSub.Unsubscribe()
|
|
}
|
|
|
|
// Create and start the NATS dispatcher.
|
|
// No ConfigProvider or SkillStore needed — config and skills arrive in the job payload.
|
|
dispatcher := agents.NewNATSDispatcher(
|
|
natsClient,
|
|
eventBridge,
|
|
nil, // no ConfigProvider: config comes in the enriched NATS payload
|
|
apiURL, cmd.APIToken,
|
|
cmd.Subject, cmd.Queue,
|
|
0, // no concurrency limit (CLI worker)
|
|
)
|
|
|
|
if err := dispatcher.Start(shutdownCtx); err != nil {
|
|
return fmt.Errorf("starting dispatcher: %w", err)
|
|
}
|
|
|
|
// Subscribe to MCP tool execution requests (load-balanced across workers).
|
|
// The frontend routes model-level MCP tool calls here via NATS request-reply.
|
|
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
|
handleMCPToolRequest(data, reply)
|
|
}); err != nil {
|
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPToolExecute, err)
|
|
}
|
|
|
|
// Subscribe to MCP discovery requests (load-balanced across workers).
|
|
if _, err := natsClient.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func(data []byte, reply func([]byte)) {
|
|
handleMCPDiscoveryRequest(data, reply)
|
|
}); err != nil {
|
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPDiscovery, err)
|
|
}
|
|
|
|
// Subscribe to MCP CI job execution (load-balanced across agent workers).
|
|
// In distributed mode, MCP CI jobs are routed here because the frontend
|
|
// cannot create MCP sessions (e.g., stdio servers using docker).
|
|
mcpCIJobTimeout, err := time.ParseDuration(cmd.MCPCIJobTimeout)
|
|
if err != nil && cmd.MCPCIJobTimeout != "" {
|
|
xlog.Warn("invalid MCP CI job timeout, using default 10m", "input", cmd.MCPCIJobTimeout, "error", err)
|
|
}
|
|
mcpCIJobTimeout = cmp.Or(mcpCIJobTimeout, config.DefaultMCPCIJobTimeout)
|
|
|
|
if _, err := natsClient.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func(data []byte) {
|
|
handleMCPCIJob(shutdownCtx, data, apiURL, cmd.APIToken, natsClient, mcpCIJobTimeout)
|
|
}); err != nil {
|
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectMCPCIJobsNew, err)
|
|
}
|
|
|
|
// Subscribe to backend stop events to clean up cached MCP sessions.
|
|
// In the main application this is done via ml.OnModelUnload, but the agent
|
|
// worker has no model loader — we listen for the NATS stop event instead.
|
|
if _, err := natsClient.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func(data []byte) {
|
|
var req struct {
|
|
Backend string `json:"backend"`
|
|
}
|
|
if json.Unmarshal(data, &req) == nil && req.Backend != "" {
|
|
mcpTools.CloseMCPSessions(req.Backend)
|
|
}
|
|
}); err != nil {
|
|
return fmt.Errorf("subscribing to %s: %w", messaging.SubjectNodeBackendStop(nodeID), err)
|
|
}
|
|
|
|
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
|
|
|
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
|
// credentials became unrenewable), so the worker restarts and re-acquires
|
|
// rather than lingering unable to serve.
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
var runErr error
|
|
select {
|
|
case <-sigCh:
|
|
case <-shutdownCtx.Done():
|
|
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
|
xlog.Error("Internal shutdown requested", "error", runErr)
|
|
}
|
|
|
|
xlog.Info("Shutting down agent worker")
|
|
shutdownCancel() // stop heartbeat loop immediately
|
|
dispatcher.Stop()
|
|
mcpTools.CloseAllMCPSessions()
|
|
regClient.GracefulDeregister(nodeID)
|
|
return runErr
|
|
}
|
|
|
|
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
|
// The worker creates/caches MCP sessions from the serialized config and executes the tool.
|
|
func handleMCPToolRequest(data []byte, reply func([]byte)) {
|
|
var req mcpRemote.MCPToolRequest
|
|
if err := json.Unmarshal(data, &req); err != nil {
|
|
sendMCPToolReply(reply, "", fmt.Sprintf("unmarshal error: %v", err))
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPToolTimeout)
|
|
defer cancel()
|
|
|
|
// Create/cache named MCP sessions from the provided config
|
|
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
|
if err != nil {
|
|
sendMCPToolReply(reply, "", fmt.Sprintf("session error: %v", err))
|
|
return
|
|
}
|
|
|
|
// Discover tools to find the right session
|
|
tools, err := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
|
if err != nil {
|
|
sendMCPToolReply(reply, "", fmt.Sprintf("discovery error: %v", err))
|
|
return
|
|
}
|
|
|
|
// Execute the tool
|
|
argsJSON, _ := json.Marshal(req.Arguments)
|
|
result, err := mcpTools.ExecuteMCPToolCall(ctx, tools, req.ToolName, string(argsJSON))
|
|
if err != nil {
|
|
sendMCPToolReply(reply, "", err.Error())
|
|
return
|
|
}
|
|
|
|
sendMCPToolReply(reply, result, "")
|
|
}
|
|
|
|
func sendMCPToolReply(reply func([]byte), result, errMsg string) {
|
|
resp := mcpRemote.MCPToolResponse{Result: result, Error: errMsg}
|
|
data, _ := json.Marshal(resp)
|
|
reply(data)
|
|
}
|
|
|
|
// handleMCPDiscoveryRequest handles a NATS request-reply for MCP tool/prompt/resource discovery.
|
|
func handleMCPDiscoveryRequest(data []byte, reply func([]byte)) {
|
|
var req mcpRemote.MCPDiscoveryRequest
|
|
if err := json.Unmarshal(data, &req); err != nil {
|
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("unmarshal error: %v", err))
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), config.DefaultMCPDiscoveryTimeout)
|
|
defer cancel()
|
|
|
|
// Create/cache named MCP sessions
|
|
namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(req.ModelName, req.RemoteServers, req.StdioServers, nil)
|
|
if err != nil {
|
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("session error: %v", err))
|
|
return
|
|
}
|
|
|
|
// List servers with their tools/prompts/resources
|
|
serverInfos, err := mcpTools.ListMCPServers(ctx, namedSessions)
|
|
if err != nil {
|
|
sendMCPDiscoveryReply(reply, nil, nil, fmt.Sprintf("list error: %v", err))
|
|
return
|
|
}
|
|
|
|
// Also get tool function schemas for the frontend
|
|
tools, _ := mcpTools.DiscoverMCPTools(ctx, namedSessions)
|
|
var toolDefs []mcpRemote.MCPToolDef
|
|
for _, t := range tools {
|
|
toolDefs = append(toolDefs, mcpRemote.MCPToolDef{
|
|
ServerName: t.ServerName,
|
|
ToolName: t.ToolName,
|
|
Function: t.Function,
|
|
})
|
|
}
|
|
|
|
// Convert server infos
|
|
var servers []mcpRemote.MCPServerInfo
|
|
for _, s := range serverInfos {
|
|
servers = append(servers, mcpRemote.MCPServerInfo{
|
|
Name: s.Name,
|
|
Type: s.Type,
|
|
Tools: s.Tools,
|
|
Prompts: s.Prompts,
|
|
Resources: s.Resources,
|
|
})
|
|
}
|
|
|
|
sendMCPDiscoveryReply(reply, servers, toolDefs, "")
|
|
}
|
|
|
|
func sendMCPDiscoveryReply(reply func([]byte), servers []mcpRemote.MCPServerInfo, tools []mcpRemote.MCPToolDef, errMsg string) {
|
|
resp := mcpRemote.MCPDiscoveryResponse{Servers: servers, Tools: tools, Error: errMsg}
|
|
data, _ := json.Marshal(resp)
|
|
reply(data)
|
|
}
|
|
|
|
// handleMCPCIJob processes an MCP CI job on the agent worker.
|
|
// The agent worker can create MCP sessions (has docker) and call the LocalAI API for inference.
|
|
func handleMCPCIJob(shutdownCtx context.Context, data []byte, apiURL, apiToken string, natsClient messaging.MessagingClient, jobTimeout time.Duration) {
|
|
var evt jobs.JobEvent
|
|
if err := json.Unmarshal(data, &evt); err != nil {
|
|
xlog.Error("Failed to unmarshal job event", "error", err)
|
|
return
|
|
}
|
|
|
|
job := evt.Job
|
|
task := evt.Task
|
|
if job == nil || task == nil {
|
|
xlog.Error("MCP CI job missing enriched data", "jobID", evt.JobID)
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", "job or task data missing from NATS event")
|
|
return
|
|
}
|
|
|
|
modelCfg := evt.ModelConfig
|
|
if modelCfg == nil {
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", "model config missing from job event")
|
|
return
|
|
}
|
|
|
|
xlog.Info("Processing MCP CI job", "jobID", evt.JobID, "taskID", evt.TaskID, "model", task.Model)
|
|
|
|
// Publish running status
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, Status: "running", Message: "Job started on agent worker",
|
|
})
|
|
|
|
// Parse MCP config
|
|
if modelCfg.MCP.Servers == "" && modelCfg.MCP.Stdio == "" {
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", "no MCP servers configured for model")
|
|
return
|
|
}
|
|
|
|
remote, stdio, err := modelCfg.MCP.MCPConfigFromYAML()
|
|
if err != nil {
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("failed to parse MCP config: %v", err))
|
|
return
|
|
}
|
|
|
|
// Create MCP sessions locally (agent worker has docker)
|
|
sessions, err := mcpTools.SessionsFromMCPConfig(modelCfg.Name, remote, stdio)
|
|
if err != nil || len(sessions) == 0 {
|
|
errMsg := "no working MCP servers found"
|
|
if err != nil {
|
|
errMsg = fmt.Sprintf("failed to create MCP sessions: %v", err)
|
|
}
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", errMsg)
|
|
return
|
|
}
|
|
|
|
// Build prompt from template
|
|
prompt := task.Prompt
|
|
if task.CronParametersJSON != "" {
|
|
var params map[string]string
|
|
if err := json.Unmarshal([]byte(task.CronParametersJSON), ¶ms); err != nil {
|
|
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
|
}
|
|
for k, v := range params {
|
|
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
|
}
|
|
}
|
|
if job.ParametersJSON != "" {
|
|
var params map[string]string
|
|
if err := json.Unmarshal([]byte(job.ParametersJSON), ¶ms); err != nil {
|
|
xlog.Warn("Failed to unmarshal parameters", "error", err)
|
|
}
|
|
for k, v := range params {
|
|
prompt = strings.ReplaceAll(prompt, "{{."+k+"}}", v)
|
|
}
|
|
}
|
|
|
|
// Create LLM client pointing back to the frontend API
|
|
llm := clients.NewLocalAILLM(task.Model, apiToken, apiURL)
|
|
|
|
// Build cogito options
|
|
ctx, cancel := context.WithTimeout(shutdownCtx, jobTimeout)
|
|
defer cancel()
|
|
|
|
// Update job status to running in DB
|
|
publishJobStatus(natsClient, evt.JobID, "running", "")
|
|
|
|
// Buffer stream tokens and flush as complete blocks
|
|
var reasoningBuf, contentBuf strings.Builder
|
|
var lastStreamType cogito.StreamEventType
|
|
|
|
flushStreamBuf := func() {
|
|
if reasoningBuf.Len() > 0 {
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, TraceType: "reasoning", TraceContent: reasoningBuf.String(),
|
|
})
|
|
reasoningBuf.Reset()
|
|
}
|
|
if contentBuf.Len() > 0 {
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, TraceType: "content", TraceContent: contentBuf.String(),
|
|
})
|
|
contentBuf.Reset()
|
|
}
|
|
}
|
|
|
|
cogitoOpts := modelCfg.BuildCogitoOptions()
|
|
cogitoOpts = append(cogitoOpts,
|
|
cogito.WithContext(ctx),
|
|
cogito.WithMCPs(sessions...),
|
|
cogito.WithStatusCallback(func(status string) {
|
|
flushStreamBuf()
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, TraceType: "status", TraceContent: status,
|
|
})
|
|
}),
|
|
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
|
flushStreamBuf()
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, TraceType: "tool_result", TraceContent: fmt.Sprintf("%s: %s", t.Name, t.Result),
|
|
})
|
|
}),
|
|
cogito.WithStreamCallback(func(ev cogito.StreamEvent) {
|
|
// Flush if stream type changed (e.g., reasoning → content)
|
|
if ev.Type != lastStreamType {
|
|
flushStreamBuf()
|
|
lastStreamType = ev.Type
|
|
}
|
|
switch ev.Type {
|
|
case cogito.StreamEventReasoning:
|
|
reasoningBuf.WriteString(ev.Content)
|
|
case cogito.StreamEventContent:
|
|
contentBuf.WriteString(ev.Content)
|
|
case cogito.StreamEventToolCall:
|
|
natsClient.Publish(messaging.SubjectJobProgress(evt.JobID), jobs.ProgressEvent{
|
|
JobID: evt.JobID, TraceType: "tool_call", TraceContent: fmt.Sprintf("%s(%s)", ev.ToolName, ev.ToolArgs),
|
|
})
|
|
}
|
|
}),
|
|
)
|
|
|
|
// Execute via cogito
|
|
fragment := cogito.NewEmptyFragment()
|
|
fragment = fragment.AddMessage("user", prompt)
|
|
|
|
f, err := cogito.ExecuteTools(llm, fragment, cogitoOpts...)
|
|
flushStreamBuf() // flush any remaining buffered tokens
|
|
|
|
if err != nil {
|
|
publishJobResult(natsClient, evt.JobID, "failed", "", fmt.Sprintf("cogito execution failed: %v", err))
|
|
return
|
|
}
|
|
|
|
result := ""
|
|
if msg := f.LastMessage(); msg != nil {
|
|
result = msg.Content
|
|
}
|
|
publishJobResult(natsClient, evt.JobID, "completed", result, "")
|
|
xlog.Info("MCP CI job completed", "jobID", evt.JobID, "resultLen", len(result))
|
|
}
|
|
|
|
func publishJobStatus(nc messaging.MessagingClient, jobID, status, message string) {
|
|
jobs.PublishJobProgress(nc, jobID, status, message)
|
|
}
|
|
|
|
func publishJobResult(nc messaging.MessagingClient, jobID, status, result, errMsg string) {
|
|
jobs.PublishJobResult(nc, jobID, status, result, errMsg)
|
|
}
|