mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-08 16:57:08 -04:00
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED is set, and connect workers with scoped credentials from the register response. Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs. Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode. Assisted-by: Grok:grok grok-build Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(distributed): correct NATS JWT scoping and harden client auth The JWT-auth path added in 46467cc7 had several gaps that fail silently under LOCALAI_NATS_REQUIRE_AUTH: - Agent-worker minted JWTs did not allow the subjects the agent worker actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop), so MCP-CI jobs and backend-stop session cleanup were silently dropped. Scope the agent permission set to those subjects. - NATS subscription permission violations were swallowed (Subscribe returned a live-but-dead subscription). Confirm subscriptions with a server round-trip so a denial surfaces synchronously, and log async permission errors. - The backend worker connected anonymously when given a JWT without its paired seed; reject the unpaired credential instead. - The documented service-user permissions in nats-auth-setup.sh omitted prefixcache.>, which the frontend publishes and subscribes; add it. Also: add a credential-provider hook to the messaging client (consumed by the follow-up credential-lifecycle change), drop the always-nil error from NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct), and gofmt the feature's files. Tests: an agent-JWT e2e spec that connects to the enforcing NATS server and exercises every subscription the agent worker makes, plus permission allow-list coverage unit tests. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(distributed): acquire and auto-refresh worker NATS credentials Workers fetched NATS credentials once at startup, which broke two cases under JWT auth: a worker that registered while still pending admin approval never received a minted JWT (it connected unauthenticated and gave up), and a long-running worker's 24h JWT expired with no way to renew it. Introduce workerregistry.NATSCredentialManager, built on idempotent re-registration (the frontend preserves the node row and mints a fresh JWT each call): - Acquire re-registers through admin approval until the node is approved and credentials are minted (or returns the first success when auth is not required, preserving anonymous-NATS behavior). - RefreshLoop re-registers before the JWT expires (~75% of its lifetime), updating the credentials served to the connection. - Both are bounded (default 100 attempts / consecutive failures) and return an error on exhaustion, so an unapprovable or unrenewable worker exits non-zero and surfaces the problem instead of hanging or drifting toward an expired credential. The messaging client gains WithUserJWTProvider, fetching credentials on each (re)connect so the connection transparently adopts a refreshed JWT when the server expires the old one. RegisterFull exposes the approval status and full response; Register delegates to it. Both the backend worker and the agent worker are wired to this: explicit env credentials are used as-is, minted credentials are acquired-with-wait and refreshed, and a permanent refresh failure shuts the worker down so it restarts and re-acquires. Tests cover Acquire (wait-through-pending, bounded give-up, context cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry exit) and jwtExpiry decoding. Docs updated in distributed-mode.md. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
1086 lines
42 KiB
Go
1086 lines
42 KiB
Go
package localai
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/xlog"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/gallery"
|
|
"github.com/mudler/LocalAI/core/http/auth"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
|
"github.com/mudler/LocalAI/pkg/httpclient"
|
|
"github.com/mudler/LocalAI/pkg/natsauth"
|
|
)
|
|
|
|
// nodeError builds a schema.ErrorResponse for node endpoints.
|
|
func nodeError(code int, message string) schema.ErrorResponse {
|
|
return schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Code: code,
|
|
Message: message,
|
|
Type: "node_error",
|
|
},
|
|
}
|
|
}
|
|
|
|
// ListNodesEndpoint returns all registered backend nodes.
|
|
func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeList, err := registry.ListWithExtras(ctx)
|
|
if err != nil {
|
|
xlog.Error("Failed to list nodes", "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes"))
|
|
}
|
|
return c.JSON(http.StatusOK, nodeList)
|
|
}
|
|
}
|
|
|
|
// GetNodeEndpoint returns a single node by ID.
|
|
func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
node, err := registry.Get(ctx, id)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
return c.JSON(http.StatusOK, node)
|
|
}
|
|
}
|
|
|
|
// RegisterNodeRequest is the request body for registering a new worker node.
|
|
type RegisterNodeRequest struct {
|
|
Name string `json:"name"`
|
|
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
|
Address string `json:"address"`
|
|
HTTPAddress string `json:"http_address,omitempty"`
|
|
Token string `json:"token,omitempty"`
|
|
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
|
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
|
TotalRAM uint64 `json:"total_ram,omitempty"`
|
|
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
|
GPUVendor string `json:"gpu_vendor,omitempty"`
|
|
Labels map[string]string `json:"labels,omitempty"`
|
|
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
|
|
// Workers older than this field omit it; we coerce 0 → 1 below to preserve
|
|
// historical single-replica behavior.
|
|
MaxReplicasPerModel int `json:"max_replicas_per_model,omitempty"`
|
|
}
|
|
|
|
// RegisterNodeEndpoint registers a new backend node.
|
|
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
|
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
|
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
var req RegisterNodeRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
|
|
// Validate registration token if one is configured on the frontend
|
|
if expectedToken != "" {
|
|
if req.Token == "" {
|
|
return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "registration token required"))
|
|
}
|
|
expectedHash := sha256.Sum256([]byte(expectedToken))
|
|
providedHash := sha256.Sum256([]byte(req.Token))
|
|
if subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) != 1 {
|
|
return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "invalid registration token"))
|
|
}
|
|
}
|
|
|
|
// Determine node type
|
|
nodeType := req.NodeType
|
|
if nodeType == "" {
|
|
nodeType = nodes.NodeTypeBackend
|
|
}
|
|
if nodeType != nodes.NodeTypeBackend && nodeType != nodes.NodeTypeAgent {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest,
|
|
fmt.Sprintf("invalid node_type %q; must be %q or %q", nodeType, nodes.NodeTypeBackend, nodes.NodeTypeAgent)))
|
|
}
|
|
|
|
// Backend workers require address; agent workers don't serve gRPC
|
|
if req.Name == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name is required"))
|
|
}
|
|
if nodeType == nodes.NodeTypeBackend && req.Address == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address is required for backend workers"))
|
|
}
|
|
if len(req.Name) > 255 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name exceeds 255 characters"))
|
|
}
|
|
if len(req.Address) > 512 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address exceeds 512 characters"))
|
|
}
|
|
if len(req.HTTPAddress) > 512 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "http_address exceeds 512 characters"))
|
|
}
|
|
|
|
// Hash the token for storage (if provided)
|
|
var tokenHash string
|
|
if req.Token != "" {
|
|
h := sha256.Sum256([]byte(req.Token))
|
|
tokenHash = hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// Coerce 0 → 1 for backward compat with workers that don't send the field.
|
|
// GORM's `default:1` only fires for a missing column; once Go zero-values
|
|
// reach the struct field they're written as 0 unless explicitly set here.
|
|
maxReplicasPerModel := req.MaxReplicasPerModel
|
|
if maxReplicasPerModel < 1 {
|
|
maxReplicasPerModel = 1
|
|
}
|
|
|
|
node := &nodes.BackendNode{
|
|
Name: req.Name,
|
|
NodeType: nodeType,
|
|
Address: req.Address,
|
|
HTTPAddress: req.HTTPAddress,
|
|
TokenHash: tokenHash,
|
|
TotalVRAM: req.TotalVRAM,
|
|
AvailableVRAM: req.AvailableVRAM,
|
|
TotalRAM: req.TotalRAM,
|
|
AvailableRAM: req.AvailableRAM,
|
|
GPUVendor: req.GPUVendor,
|
|
MaxReplicasPerModel: maxReplicasPerModel,
|
|
}
|
|
|
|
ctx := c.Request().Context()
|
|
if err := registry.Register(ctx, node, autoApprove); err != nil {
|
|
xlog.Error("Failed to register node", "name", req.Name, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node"))
|
|
}
|
|
|
|
// Merge worker-supplied labels into the node's existing label set,
|
|
// then apply auto-labels. Using SetNodeLabels here would have
|
|
// delete-then-recreate semantics — every worker re-register would
|
|
// wipe any UI-added label (since PR #9583 the worker always sends
|
|
// the auto-mirror `node.replica-slots`, which made the latent bug
|
|
// from PR #9186 trigger universally instead of only for workers
|
|
// with `--node-labels` set).
|
|
for k, v := range req.Labels {
|
|
if err := registry.SetNodeLabel(ctx, node.ID, k, v); err != nil {
|
|
xlog.Warn("Failed to set node label", "node", node.ID, "key", k, "error", err)
|
|
}
|
|
}
|
|
registry.ApplyAutoLabels(ctx, node.ID, node)
|
|
|
|
response := map[string]any{
|
|
"id": node.ID,
|
|
"name": node.Name,
|
|
"node_type": node.NodeType,
|
|
"status": node.Status,
|
|
"created_at": node.CreatedAt,
|
|
}
|
|
|
|
// Provision API key for agent workers that are approved (not pending).
|
|
// On re-registration of a previously approved node, revoke old + provision new.
|
|
if nodeType == nodes.NodeTypeAgent && authDB != nil && node.Status != nodes.StatusPending {
|
|
// Use a transaction so that if provisioning fails after revoking old creds,
|
|
// the old credentials are not lost.
|
|
txErr := authDB.Transaction(func(tx *gorm.DB) error {
|
|
if node.AuthUserID != "" {
|
|
if err := tx.Exec("DELETE FROM users WHERE id = ?", node.AuthUserID).Error; err != nil {
|
|
return fmt.Errorf("revoking old credentials: %w", err)
|
|
}
|
|
node.AuthUserID = ""
|
|
node.APIKeyID = ""
|
|
}
|
|
plaintext, err := provisionAgentWorkerKey(ctx, tx, registry, node, hmacSecret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
response["api_token"] = plaintext
|
|
return nil
|
|
})
|
|
if txErr != nil {
|
|
xlog.Warn("Failed to auto-provision API key for agent worker", "node", node.Name, "error", txErr)
|
|
}
|
|
}
|
|
|
|
attachNatsJWT(response, node, natsCfg)
|
|
|
|
return c.JSON(http.StatusCreated, response)
|
|
}
|
|
}
|
|
|
|
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
|
// For agent workers, it also provisions an API key so they can call the inference API.
|
|
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.ApproveNode(ctx, id); err != nil {
|
|
xlog.Error("Failed to approve node", "id", id, "error", err)
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "failed to approve node"))
|
|
}
|
|
node, err := registry.Get(ctx, id)
|
|
if err != nil {
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node approved"})
|
|
}
|
|
|
|
response := map[string]any{
|
|
"id": node.ID,
|
|
"name": node.Name,
|
|
"node_type": node.NodeType,
|
|
"status": node.Status,
|
|
"message": "node approved",
|
|
}
|
|
|
|
// Provision API key for newly approved agent workers
|
|
if node.NodeType == nodes.NodeTypeAgent && authDB != nil && node.AuthUserID == "" {
|
|
if plaintext, err := provisionAgentWorkerKey(ctx, authDB, registry, node, hmacSecret); err != nil {
|
|
xlog.Warn("Failed to provision API key on approval", "node", node.Name, "error", err)
|
|
} else {
|
|
response["api_token"] = plaintext
|
|
}
|
|
}
|
|
|
|
attachNatsJWT(response, node, natsCfg)
|
|
|
|
return c.JSON(http.StatusOK, response)
|
|
}
|
|
}
|
|
|
|
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
|
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
|
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
|
return
|
|
}
|
|
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
|
if err != nil {
|
|
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
|
return
|
|
}
|
|
response["nats_jwt"] = jwt
|
|
response["nats_user_seed"] = seed
|
|
}
|
|
|
|
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
|
// Returns the plaintext API key on success.
|
|
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
|
workerUser := &auth.User{
|
|
ID: uuid.New().String(),
|
|
Name: "agent-worker:" + node.Name,
|
|
Provider: auth.ProviderAgentWorker,
|
|
Subject: node.ID,
|
|
Role: "user",
|
|
Status: "active",
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := authDB.Create(workerUser).Error; err != nil {
|
|
return "", fmt.Errorf("creating agent worker user: %w", err)
|
|
}
|
|
|
|
plaintext, apiKey, err := auth.CreateAPIKey(authDB, workerUser.ID, "agent-worker:"+node.Name, "user", hmacSecret, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("creating API key: %w", err)
|
|
}
|
|
|
|
node.AuthUserID = workerUser.ID
|
|
node.APIKeyID = apiKey.ID
|
|
if err := registry.UpdateAuthRefs(ctx, node.ID, workerUser.ID, apiKey.ID); err != nil {
|
|
xlog.Warn("Failed to update auth refs on node", "node", node.Name, "error", err)
|
|
}
|
|
|
|
// Grant collections feature so the worker can store/retrieve KB data on behalf of users.
|
|
perm := &auth.UserPermission{
|
|
ID: uuid.New().String(),
|
|
UserID: workerUser.ID,
|
|
Permissions: auth.PermissionMap{auth.FeatureCollections: true},
|
|
}
|
|
if err := authDB.Create(perm).Error; err != nil {
|
|
xlog.Warn("Failed to grant collections permission to agent worker", "node", node.Name, "error", err)
|
|
}
|
|
|
|
xlog.Info("Provisioned API key for agent worker", "node", node.Name, "user", workerUser.ID)
|
|
return plaintext, nil
|
|
}
|
|
|
|
// DeregisterNodeEndpoint removes a backend node permanently (admin use).
|
|
func DeregisterNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.Deregister(ctx, id); err != nil {
|
|
xlog.Error("Failed to deregister node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deregister node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node deregistered"})
|
|
}
|
|
}
|
|
|
|
// DeactivateNodeEndpoint marks a node as offline without deleting it.
|
|
// Used by workers on graceful shutdown to preserve approval status across restarts.
|
|
func DeactivateNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkOffline(ctx, id); err != nil {
|
|
xlog.Error("Failed to deactivate node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deactivate node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node set to offline"})
|
|
}
|
|
}
|
|
|
|
// HeartbeatEndpoint updates the heartbeat for a node.
|
|
func HeartbeatEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
id := c.Param("id")
|
|
|
|
// Parse optional VRAM update from body
|
|
var update nodes.HeartbeatUpdate
|
|
_ = c.Bind(&update) // best-effort — empty body is fine
|
|
|
|
var updatePtr *nodes.HeartbeatUpdate
|
|
if update.AvailableVRAM != nil || update.TotalVRAM != nil || update.AvailableRAM != nil || update.GPUVendor != "" {
|
|
updatePtr = &update
|
|
}
|
|
|
|
ctx := c.Request().Context()
|
|
if err := registry.Heartbeat(ctx, id, updatePtr); err != nil {
|
|
xlog.Warn("Heartbeat failed for node", "id", id, "error", err)
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "heartbeat received"})
|
|
}
|
|
}
|
|
|
|
// GetNodeModelsEndpoint returns the models loaded on a node.
|
|
func GetNodeModelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
models, err := registry.GetNodeModels(ctx, id)
|
|
if err != nil {
|
|
xlog.Error("Failed to get node models", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get node models"))
|
|
}
|
|
return c.JSON(http.StatusOK, models)
|
|
}
|
|
}
|
|
|
|
// DrainNodeEndpoint sets a node to draining status (no new requests).
|
|
func DrainNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkDraining(ctx, id); err != nil {
|
|
xlog.Error("Failed to drain node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to drain node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node set to draining"})
|
|
}
|
|
}
|
|
|
|
// ResumeNodeEndpoint sets a draining node back to healthy status.
|
|
func ResumeNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkHealthy(ctx, id); err != nil {
|
|
xlog.Error("Failed to resume node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to resume node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node resumed"})
|
|
}
|
|
}
|
|
|
|
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node.
|
|
// Async: enqueues a ManagementOp on the gallery service channel and returns a
|
|
// jobID immediately. The gallery service worker goroutine drives the actual
|
|
// install via DistributedBackendManager.InstallBackend, which honors the op's
|
|
// TargetNodeID to scope the fan-out to one node. The UI polls /api/backends/job/:uid
|
|
// for progress, mirroring /api/backends/install/:id.
|
|
//
|
|
// Backend can be either a gallery ID (resolved against BackendGalleries) or a
|
|
// direct URI install (URI + Name + optional Alias) - same shape as the
|
|
// standalone /api/backends/install-external path, just scoped to one node.
|
|
//
|
|
// The legacy unloader argument is retained for signature symmetry with
|
|
// DeleteBackendOnNodeEndpoint / ListBackendsOnNodeEndpoint but is no longer
|
|
// used here - the async path goes through galleryService.
|
|
func InstallBackendOnNodeEndpoint(_ nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if galleryService == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "gallery service not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
Backend string `json:"backend"`
|
|
BackendGalleries string `json:"backend_galleries,omitempty"`
|
|
URI string `json:"uri,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Alias string `json:"alias,omitempty"`
|
|
}
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if req.Backend == "" && req.URI == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
|
|
}
|
|
|
|
jobUUID, err := uuid.NewUUID()
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to generate job id"))
|
|
}
|
|
jobID := jobUUID.String()
|
|
|
|
// Cache key: for gallery installs, use the backend slug; for URI
|
|
// installs prefer the provided Name (falling back to URI). All keys
|
|
// are node-scoped so concurrent installs of the same backend on
|
|
// different nodes do not stomp each other in opcache.
|
|
backendKey := req.Backend
|
|
if backendKey == "" {
|
|
backendKey = req.Name
|
|
if backendKey == "" {
|
|
backendKey = req.URI
|
|
}
|
|
}
|
|
cacheKey := galleryop.NodeScopedKey(nodeID, backendKey)
|
|
opcache.SetBackend(cacheKey, jobID)
|
|
|
|
// Optional caller-supplied galleries override. Mirrors the standalone
|
|
// install path so an admin can point at a private gallery.
|
|
galleries := appConfig.BackendGalleries
|
|
if req.BackendGalleries != "" {
|
|
var custom []config.Gallery
|
|
if err := json.Unmarshal([]byte(req.BackendGalleries), &custom); err != nil {
|
|
xlog.Warn("Ignoring malformed backend_galleries override; falling back to configured galleries", "error", err, "nodeID", nodeID)
|
|
} else if len(custom) > 0 {
|
|
galleries = custom
|
|
}
|
|
}
|
|
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
op := galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
|
ID: jobID,
|
|
GalleryElementName: req.Backend,
|
|
Galleries: galleries,
|
|
TargetNodeID: nodeID,
|
|
ExternalURI: req.URI,
|
|
ExternalName: req.Name,
|
|
ExternalAlias: req.Alias,
|
|
Context: ctx,
|
|
CancelFunc: cancelFunc,
|
|
}
|
|
galleryService.StoreCancellation(jobID, cancelFunc)
|
|
go func() {
|
|
galleryService.BackendGalleryChannel <- op
|
|
}()
|
|
|
|
xlog.Info("Node-scoped backend install dispatched", "node", nodeID, "backend", req.Backend, "uri", req.URI, "jobID", jobID)
|
|
return c.JSON(http.StatusAccepted, map[string]string{
|
|
"jobID": jobID,
|
|
"statusUrl": "/api/backends/job/" + jobID,
|
|
"message": "backend installation started",
|
|
})
|
|
}
|
|
}
|
|
|
|
// DeleteBackendOnNodeEndpoint deletes a backend from a worker node via NATS.
|
|
func DeleteBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
Backend string `json:"backend"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.Backend == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name required"))
|
|
}
|
|
reply, err := unloader.DeleteBackend(nodeID, req.Backend)
|
|
if err != nil {
|
|
xlog.Error("Failed to delete backend on node", "node", nodeID, "backend", req.Backend, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete backend on node"))
|
|
}
|
|
if !reply.Success {
|
|
xlog.Error("Backend delete failed on node", "node", nodeID, "backend", req.Backend, "error", reply.Error)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend deletion failed"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "backend deleted"})
|
|
}
|
|
}
|
|
|
|
// ListBackendsOnNodeEndpoint lists installed backends on a worker node via NATS.
|
|
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
reply, err := unloader.ListBackends(nodeID)
|
|
if err != nil {
|
|
xlog.Error("Failed to list backends on node", "node", nodeID, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node"))
|
|
}
|
|
if reply.Error != "" {
|
|
xlog.Error("List backends failed on node", "node", nodeID, "error", reply.Error)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node"))
|
|
}
|
|
return c.JSON(http.StatusOK, reply.Backends)
|
|
}
|
|
}
|
|
|
|
// UnloadModelOnNodeEndpoint unloads a model from a worker node (gRPC Free) via NATS.
|
|
func UnloadModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
ModelName string `json:"model_name"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.ModelName == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required"))
|
|
}
|
|
if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil {
|
|
xlog.Error("Failed to unload model on node", "node", nodeID, "model", req.ModelName, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model on node"))
|
|
}
|
|
// Also stop the backend process
|
|
if err := unloader.StopBackend(nodeID, req.ModelName); err != nil {
|
|
xlog.Error("Failed to stop backend after model unload", "node", nodeID, "model", req.ModelName, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "model unloaded but backend stop failed"))
|
|
}
|
|
// Remove every replica of this model on the node from the registry.
|
|
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "model unloaded"})
|
|
}
|
|
}
|
|
|
|
// DeleteModelOnNodeEndpoint deletes model files from a worker node via NATS.
|
|
func DeleteModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
ModelName string `json:"model_name"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.ModelName == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required"))
|
|
}
|
|
// Unload model first if loaded
|
|
if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model before deletion"))
|
|
}
|
|
if err := unloader.StopBackend(nodeID, req.ModelName); err != nil {
|
|
// Non-fatal — backend process may not be running
|
|
xlog.Warn("StopBackend failed during model deletion (non-fatal)", "node", nodeID, "model", req.ModelName, "error", err)
|
|
}
|
|
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "model deleted from node"})
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsListEndpoint proxies a request to a worker node's /v1/backend-logs
|
|
// endpoint to list model IDs that have backend logs.
|
|
func NodeBackendLogsListEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
if node.HTTPAddress == "" {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address"))
|
|
}
|
|
|
|
resp, err := proxyHTTPToWorker(node.HTTPAddress, "/v1/backend-logs", registrationToken)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err)))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
c.Response().Header().Set("Content-Type", "application/json")
|
|
c.Response().WriteHeader(resp.StatusCode)
|
|
io.Copy(c.Response(), resp.Body)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsLinesEndpoint proxies a request to a worker node's
|
|
// /v1/backend-logs/{modelId} endpoint to get buffered log lines.
|
|
func NodeBackendLogsLinesEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
modelID := c.Param("modelId")
|
|
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
if node.HTTPAddress == "" {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address"))
|
|
}
|
|
|
|
path := "/v1/backend-logs/" + url.PathEscape(modelID)
|
|
resp, err := proxyHTTPToWorker(node.HTTPAddress, path, registrationToken)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err)))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
c.Response().Header().Set("Content-Type", "application/json")
|
|
c.Response().WriteHeader(resp.StatusCode)
|
|
io.Copy(c.Response(), resp.Body)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsWSEndpoint proxies a WebSocket connection to a worker node's
|
|
// /v1/backend-logs/{modelId}/ws endpoint for real-time log streaming.
|
|
func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
browserUpgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // no origin header = same-origin or non-browser
|
|
}
|
|
// Parse origin URL and compare host with request host
|
|
u, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return u.Host == r.Host
|
|
},
|
|
}
|
|
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
modelID := c.Param("modelId")
|
|
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
// Upgrade browser connection
|
|
browserWS, err := browserUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Dial the worker WebSocket
|
|
workerURL := fmt.Sprintf("ws://%s/v1/backend-logs/%s/ws", node.HTTPAddress, url.PathEscape(modelID))
|
|
workerHeaders := http.Header{}
|
|
if registrationToken != "" {
|
|
workerHeaders.Set("Authorization", "Bearer "+registrationToken)
|
|
}
|
|
|
|
workerDialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
|
workerWS, _, err := workerDialer.Dial(workerURL, workerHeaders)
|
|
if err != nil {
|
|
browserWS.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "failed to connect to worker"))
|
|
browserWS.Close()
|
|
return nil
|
|
}
|
|
|
|
// Use sync.OnceFunc wrappers to avoid double-close and ensure each
|
|
// goroutine can safely close the *other* connection to unblock
|
|
// its peer's ReadMessage call.
|
|
done := make(chan struct{})
|
|
closeWorker := sync.OnceFunc(func() { workerWS.Close() })
|
|
closeBrowser := sync.OnceFunc(func() { browserWS.Close() })
|
|
|
|
// Worker → Browser
|
|
go func() {
|
|
defer close(done)
|
|
defer closeBrowser() // unblock Browser→Worker goroutine
|
|
for {
|
|
msgType, msg, err := workerWS.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := browserWS.WriteMessage(msgType, msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Browser → Worker (mainly for close detection)
|
|
go func() {
|
|
defer closeWorker() // unblock Worker→Browser goroutine
|
|
for {
|
|
msgType, msg, err := browserWS.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := workerWS.WriteMessage(msgType, msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
<-done
|
|
closeWorker()
|
|
closeBrowser()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// GetNodeLabelsEndpoint returns labels for a specific node.
|
|
func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
labels, err := registry.GetNodeLabels(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels"))
|
|
}
|
|
// Convert to map for cleaner API response
|
|
result := make(map[string]string)
|
|
for _, l := range labels {
|
|
result[l.Key] = l.Value
|
|
}
|
|
return c.JSON(http.StatusOK, result)
|
|
}
|
|
}
|
|
|
|
// UpdateMaxReplicasPerModelRequest is the body for the per-node replica cap endpoint.
|
|
type UpdateMaxReplicasPerModelRequest struct {
|
|
// Value is the new per-model replica cap on this node. Must be >= 1.
|
|
Value int `json:"value"`
|
|
}
|
|
|
|
// UpdateMaxReplicasPerModelEndpoint sets the per-node cap on how many replicas
|
|
// of any one model can be loaded concurrently. The corresponding
|
|
// `node.replica-slots` auto-label is refreshed so existing AND-selectors keep
|
|
// matching, and any unsatisfiable scheduling cooldowns are cleared so the
|
|
// reconciler retries on the next tick.
|
|
//
|
|
// This is a transient admin override — a worker re-registration restores the
|
|
// value the worker was started with (--max-replicas-per-model). For permanent
|
|
// fleet changes, change the worker flag.
|
|
//
|
|
// @Summary Update a node's max replicas per model
|
|
// @Tags Nodes
|
|
// @Param id path string true "Node ID"
|
|
// @Param request body UpdateMaxReplicasPerModelRequest true "New value"
|
|
// @Success 200 {object} map[string]int
|
|
// @Failure 400 {object} map[string]any "value must be >= 1"
|
|
// @Failure 404 {object} map[string]any "node not found"
|
|
// @Router /api/nodes/{id}/max-replicas-per-model [put]
|
|
func UpdateMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var req UpdateMaxReplicasPerModelRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if req.Value < 1 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "value must be >= 1"))
|
|
}
|
|
if err := registry.UpdateMaxReplicasPerModel(ctx, nodeID, req.Value); err != nil {
|
|
xlog.Error("Failed to update max_replicas_per_model", "node", nodeID, "value", req.Value, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to update max replicas per model"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]int{"max_replicas_per_model": req.Value})
|
|
}
|
|
}
|
|
|
|
// ResetMaxReplicasPerModelEndpoint clears the admin override on a node, so
|
|
// the next worker re-registration is allowed to update the value from its
|
|
// CLI flag again. The current value is left in place until the worker calls
|
|
// register.
|
|
//
|
|
// @Summary Reset a node's max replicas per model to the worker default
|
|
// @Tags Nodes
|
|
// @Param id path string true "Node ID"
|
|
// @Success 200 {object} map[string]bool
|
|
// @Failure 404 {object} map[string]any "node not found"
|
|
// @Router /api/nodes/{id}/max-replicas-per-model [delete]
|
|
func ResetMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
if err := registry.ResetMaxReplicasPerModel(ctx, nodeID); err != nil {
|
|
xlog.Error("Failed to reset max_replicas_per_model override", "node", nodeID, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to reset override"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]bool{"reset": true})
|
|
}
|
|
}
|
|
|
|
// SetNodeLabelsEndpoint replaces all labels for a node.
|
|
func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var labels map[string]string
|
|
if err := c.Bind(&labels); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels"))
|
|
}
|
|
return c.JSON(http.StatusOK, labels)
|
|
}
|
|
}
|
|
|
|
// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones.
|
|
func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var labels map[string]string
|
|
if err := c.Bind(&labels); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
for k, v := range labels {
|
|
if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels"))
|
|
}
|
|
}
|
|
// Return updated labels
|
|
updated, _ := registry.GetNodeLabels(ctx, nodeID)
|
|
result := make(map[string]string)
|
|
for _, l := range updated {
|
|
result[l.Key] = l.Value
|
|
}
|
|
return c.JSON(http.StatusOK, result)
|
|
}
|
|
}
|
|
|
|
// DeleteNodeLabelEndpoint removes a single label from a node.
|
|
func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
key := c.Param("key")
|
|
if key == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required"))
|
|
}
|
|
if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label"))
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
}
|
|
|
|
// ListSchedulingEndpoint returns all model scheduling configs.
|
|
func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
configs, err := registry.ListModelSchedulings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs"))
|
|
}
|
|
return c.JSON(http.StatusOK, configs)
|
|
}
|
|
}
|
|
|
|
// GetSchedulingEndpoint returns the scheduling config for a specific model.
|
|
func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
modelName := c.Param("model")
|
|
config, err := registry.GetModelScheduling(ctx, modelName)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config"))
|
|
}
|
|
if config == nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model"))
|
|
}
|
|
return c.JSON(http.StatusOK, config)
|
|
}
|
|
}
|
|
|
|
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
|
//
|
|
// The four prefix-cache fields are POINTERS so an omitted field is
|
|
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
|
// field preserves the model's previously-configured value instead of resetting
|
|
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
|
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
|
type SetSchedulingRequest struct {
|
|
ModelName string `json:"model_name"`
|
|
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
|
MinReplicas int `json:"min_replicas"`
|
|
MaxReplicas int `json:"max_replicas"`
|
|
RoutePolicy *string `json:"route_policy,omitempty"`
|
|
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
|
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
|
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
|
}
|
|
|
|
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
|
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
|
// single source of truth), and are checked against the RESOLVED values passed
|
|
// in (provided-or-preserved), so validation only rejects bad values the caller
|
|
// actually supplied. It returns nil when valid, or an error with a user-facing
|
|
// message describing the first violation.
|
|
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
|
if req.ModelName == "" {
|
|
return errors.New("model_name is required")
|
|
}
|
|
if req.MinReplicas < 0 {
|
|
return errors.New("min_replicas must be >= 0")
|
|
}
|
|
if req.MaxReplicas < 0 {
|
|
return errors.New("max_replicas must be >= 0")
|
|
}
|
|
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
|
return errors.New("min_replicas must be <= max_replicas")
|
|
}
|
|
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
|
//
|
|
// The registry upsert full-replaces all columns, so a request that omits the
|
|
// prefix-cache fields would otherwise wipe a model's previously-configured
|
|
// routing settings. To avoid that footgun the four prefix-cache fields are
|
|
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
|
// existing config's value (or the zero default when no config exists yet). The
|
|
// non-prefix fields keep their full-replace PUT behavior.
|
|
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
var req SetSchedulingRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
|
|
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
|
// can fall back to the stored value rather than resetting to zero.
|
|
var existing *nodes.ModelSchedulingConfig
|
|
if req.ModelName != "" {
|
|
var err error
|
|
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
|
}
|
|
}
|
|
|
|
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
|
// the existing value (zero/default when there is no existing config).
|
|
routePolicy := ""
|
|
absThr := 0
|
|
relThr := 0.0
|
|
minMatch := 0.0
|
|
if existing != nil {
|
|
routePolicy = existing.RoutePolicy
|
|
absThr = existing.BalanceAbsThreshold
|
|
relThr = existing.BalanceRelThreshold
|
|
minMatch = existing.MinPrefixMatch
|
|
}
|
|
if req.RoutePolicy != nil {
|
|
routePolicy = *req.RoutePolicy
|
|
}
|
|
if req.BalanceAbsThreshold != nil {
|
|
absThr = *req.BalanceAbsThreshold
|
|
}
|
|
if req.BalanceRelThreshold != nil {
|
|
relThr = *req.BalanceRelThreshold
|
|
}
|
|
if req.MinPrefixMatch != nil {
|
|
minMatch = *req.MinPrefixMatch
|
|
}
|
|
|
|
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
|
}
|
|
|
|
// Serialize node selector to JSON
|
|
var selectorJSON string
|
|
if len(req.NodeSelector) > 0 {
|
|
b, err := json.Marshal(req.NodeSelector)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector"))
|
|
}
|
|
selectorJSON = string(b)
|
|
}
|
|
|
|
config := &nodes.ModelSchedulingConfig{
|
|
ModelName: req.ModelName,
|
|
NodeSelector: selectorJSON,
|
|
MinReplicas: req.MinReplicas,
|
|
MaxReplicas: req.MaxReplicas,
|
|
RoutePolicy: routePolicy,
|
|
BalanceAbsThreshold: absThr,
|
|
BalanceRelThreshold: relThr,
|
|
MinPrefixMatch: minMatch,
|
|
}
|
|
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
|
}
|
|
return c.JSON(http.StatusOK, config)
|
|
}
|
|
}
|
|
|
|
// DeleteSchedulingEndpoint removes a model scheduling config.
|
|
func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
modelName := c.Param("model")
|
|
if err := registry.DeleteModelScheduling(ctx, modelName); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config"))
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
}
|
|
|
|
// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth.
|
|
func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) {
|
|
reqURL := fmt.Sprintf("http://%s%s", httpAddress, path)
|
|
req, err := http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
client := httpclient.NewWithTimeout(15 * time.Second)
|
|
return client.Do(req)
|
|
}
|