mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-28 10:27:30 -04:00
fix(nodes): return labels in single-node GET so the detail view shows them The node detail view (/app/nodes/:id) reads `node.labels` to render a node's existing labels, but the single-node GET endpoint returned a bare BackendNode whose Labels live in a separate table - so the list was always empty and operators could only add labels, never see what was already set (#10527). The same response also lacked in_flight_count and model_count. Add NodeRegistry.GetWithExtras, mirroring the existing List vs ListWithExtras split: bare Get stays cheap for the routing hot paths and existence checks, while the detail endpoint uses the enriched variant to attach the labels map and live counts. No frontend change is needed - the UI already renders existing labels once the data is present. Closes #10527 Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
1115 lines
44 KiB
Go
1115 lines
44 KiB
Go
package localai
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/xlog"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/gallery"
|
|
"github.com/mudler/LocalAI/core/http/auth"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
|
"github.com/mudler/LocalAI/pkg/httpclient"
|
|
"github.com/mudler/LocalAI/pkg/natsauth"
|
|
)
|
|
|
|
// nodeError builds a schema.ErrorResponse for node endpoints.
|
|
func nodeError(code int, message string) schema.ErrorResponse {
|
|
return schema.ErrorResponse{
|
|
Error: &schema.APIError{
|
|
Code: code,
|
|
Message: message,
|
|
Type: "node_error",
|
|
},
|
|
}
|
|
}
|
|
|
|
// ListNodesEndpoint returns all registered backend nodes.
|
|
func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeList, err := registry.ListWithExtras(ctx)
|
|
if err != nil {
|
|
xlog.Error("Failed to list nodes", "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes"))
|
|
}
|
|
return c.JSON(http.StatusOK, nodeList)
|
|
}
|
|
}
|
|
|
|
// GetNodeEndpoint returns a single node by ID.
|
|
func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
// GetWithExtras (not Get) so the response carries the node's labels,
|
|
// loaded-model count, and in-flight total — the bare BackendNode keeps
|
|
// labels in a separate table, leaving the detail view's label list empty.
|
|
node, err := registry.GetWithExtras(ctx, id)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
return c.JSON(http.StatusOK, node)
|
|
}
|
|
}
|
|
|
|
// RegisterNodeRequest is the request body for registering a new worker node.
|
|
type RegisterNodeRequest struct {
|
|
Name string `json:"name"`
|
|
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
|
Address string `json:"address"`
|
|
HTTPAddress string `json:"http_address,omitempty"`
|
|
Token string `json:"token,omitempty"`
|
|
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
|
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
|
TotalRAM uint64 `json:"total_ram,omitempty"`
|
|
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
|
GPUVendor string `json:"gpu_vendor,omitempty"`
|
|
// GPUComputeCapability is the worker GPU's compute capability ("major.minor",
|
|
// e.g. "12.1" for GB10). Used by the router for per-arch option tuning.
|
|
GPUComputeCapability string `json:"gpu_compute_capability,omitempty"`
|
|
Labels map[string]string `json:"labels,omitempty"`
|
|
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
|
|
// Workers older than this field omit it; we coerce 0 → 1 below to preserve
|
|
// historical single-replica behavior.
|
|
MaxReplicasPerModel int `json:"max_replicas_per_model,omitempty"`
|
|
}
|
|
|
|
// RegisterNodeEndpoint registers a new backend node.
|
|
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
|
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
|
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
var req RegisterNodeRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
|
|
// Validate registration token if one is configured on the frontend
|
|
if expectedToken != "" {
|
|
if req.Token == "" {
|
|
return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "registration token required"))
|
|
}
|
|
expectedHash := sha256.Sum256([]byte(expectedToken))
|
|
providedHash := sha256.Sum256([]byte(req.Token))
|
|
if subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) != 1 {
|
|
return c.JSON(http.StatusUnauthorized, nodeError(http.StatusUnauthorized, "invalid registration token"))
|
|
}
|
|
}
|
|
|
|
// Determine node type
|
|
nodeType := req.NodeType
|
|
if nodeType == "" {
|
|
nodeType = nodes.NodeTypeBackend
|
|
}
|
|
if nodeType != nodes.NodeTypeBackend && nodeType != nodes.NodeTypeAgent {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest,
|
|
fmt.Sprintf("invalid node_type %q; must be %q or %q", nodeType, nodes.NodeTypeBackend, nodes.NodeTypeAgent)))
|
|
}
|
|
|
|
// Backend workers require address; agent workers don't serve gRPC
|
|
if req.Name == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name is required"))
|
|
}
|
|
if nodeType == nodes.NodeTypeBackend && req.Address == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address is required for backend workers"))
|
|
}
|
|
if len(req.Name) > 255 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "name exceeds 255 characters"))
|
|
}
|
|
if len(req.Address) > 512 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "address exceeds 512 characters"))
|
|
}
|
|
if len(req.HTTPAddress) > 512 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "http_address exceeds 512 characters"))
|
|
}
|
|
|
|
// Hash the token for storage (if provided)
|
|
var tokenHash string
|
|
if req.Token != "" {
|
|
h := sha256.Sum256([]byte(req.Token))
|
|
tokenHash = hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// Coerce 0 → 1 for backward compat with workers that don't send the field.
|
|
// GORM's `default:1` only fires for a missing column; once Go zero-values
|
|
// reach the struct field they're written as 0 unless explicitly set here.
|
|
maxReplicasPerModel := req.MaxReplicasPerModel
|
|
if maxReplicasPerModel < 1 {
|
|
maxReplicasPerModel = 1
|
|
}
|
|
|
|
node := &nodes.BackendNode{
|
|
Name: req.Name,
|
|
NodeType: nodeType,
|
|
Address: req.Address,
|
|
HTTPAddress: req.HTTPAddress,
|
|
TokenHash: tokenHash,
|
|
TotalVRAM: req.TotalVRAM,
|
|
AvailableVRAM: req.AvailableVRAM,
|
|
TotalRAM: req.TotalRAM,
|
|
AvailableRAM: req.AvailableRAM,
|
|
GPUVendor: req.GPUVendor,
|
|
GPUComputeCapability: req.GPUComputeCapability,
|
|
MaxReplicasPerModel: maxReplicasPerModel,
|
|
}
|
|
|
|
ctx := c.Request().Context()
|
|
if err := registry.Register(ctx, node, autoApprove); err != nil {
|
|
xlog.Error("Failed to register node", "name", req.Name, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node"))
|
|
}
|
|
|
|
// Merge worker-supplied labels into the node's existing label set,
|
|
// then apply auto-labels. Using SetNodeLabels here would have
|
|
// delete-then-recreate semantics — every worker re-register would
|
|
// wipe any UI-added label (since PR #9583 the worker always sends
|
|
// the auto-mirror `node.replica-slots`, which made the latent bug
|
|
// from PR #9186 trigger universally instead of only for workers
|
|
// with `--node-labels` set).
|
|
for k, v := range req.Labels {
|
|
if err := registry.SetNodeLabel(ctx, node.ID, k, v); err != nil {
|
|
xlog.Warn("Failed to set node label", "node", node.ID, "key", k, "error", err)
|
|
}
|
|
}
|
|
registry.ApplyAutoLabels(ctx, node.ID, node)
|
|
|
|
response := map[string]any{
|
|
"id": node.ID,
|
|
"name": node.Name,
|
|
"node_type": node.NodeType,
|
|
"status": node.Status,
|
|
"created_at": node.CreatedAt,
|
|
}
|
|
|
|
// Provision API key for agent workers that are approved (not pending).
|
|
// On re-registration of a previously approved node, revoke old + provision new.
|
|
if nodeType == nodes.NodeTypeAgent && authDB != nil && node.Status != nodes.StatusPending {
|
|
// Use a transaction so that if provisioning fails after revoking old creds,
|
|
// the old credentials are not lost.
|
|
txErr := authDB.Transaction(func(tx *gorm.DB) error {
|
|
if node.AuthUserID != "" {
|
|
if err := tx.Exec("DELETE FROM users WHERE id = ?", node.AuthUserID).Error; err != nil {
|
|
return fmt.Errorf("revoking old credentials: %w", err)
|
|
}
|
|
node.AuthUserID = ""
|
|
node.APIKeyID = ""
|
|
}
|
|
plaintext, err := provisionAgentWorkerKey(ctx, tx, registry, node, hmacSecret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
response["api_token"] = plaintext
|
|
return nil
|
|
})
|
|
if txErr != nil {
|
|
xlog.Warn("Failed to auto-provision API key for agent worker", "node", node.Name, "error", txErr)
|
|
}
|
|
}
|
|
|
|
attachNatsJWT(response, node, natsCfg)
|
|
|
|
return c.JSON(http.StatusCreated, response)
|
|
}
|
|
}
|
|
|
|
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
|
// For agent workers, it also provisions an API key so they can call the inference API.
|
|
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.ApproveNode(ctx, id); err != nil {
|
|
xlog.Error("Failed to approve node", "id", id, "error", err)
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "failed to approve node"))
|
|
}
|
|
node, err := registry.Get(ctx, id)
|
|
if err != nil {
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node approved"})
|
|
}
|
|
|
|
response := map[string]any{
|
|
"id": node.ID,
|
|
"name": node.Name,
|
|
"node_type": node.NodeType,
|
|
"status": node.Status,
|
|
"message": "node approved",
|
|
}
|
|
|
|
// Provision API key for newly approved agent workers
|
|
if node.NodeType == nodes.NodeTypeAgent && authDB != nil && node.AuthUserID == "" {
|
|
if plaintext, err := provisionAgentWorkerKey(ctx, authDB, registry, node, hmacSecret); err != nil {
|
|
xlog.Warn("Failed to provision API key on approval", "node", node.Name, "error", err)
|
|
} else {
|
|
response["api_token"] = plaintext
|
|
}
|
|
}
|
|
|
|
attachNatsJWT(response, node, natsCfg)
|
|
|
|
return c.JSON(http.StatusOK, response)
|
|
}
|
|
}
|
|
|
|
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
|
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
|
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
|
return
|
|
}
|
|
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
|
if err != nil {
|
|
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
|
return
|
|
}
|
|
response["nats_jwt"] = jwt
|
|
response["nats_user_seed"] = seed
|
|
}
|
|
|
|
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
|
// Returns the plaintext API key on success.
|
|
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
|
workerUser := &auth.User{
|
|
ID: uuid.New().String(),
|
|
Name: "agent-worker:" + node.Name,
|
|
Provider: auth.ProviderAgentWorker,
|
|
Subject: node.ID,
|
|
Role: "user",
|
|
Status: "active",
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := authDB.Create(workerUser).Error; err != nil {
|
|
return "", fmt.Errorf("creating agent worker user: %w", err)
|
|
}
|
|
|
|
plaintext, apiKey, err := auth.CreateAPIKey(authDB, workerUser.ID, "agent-worker:"+node.Name, "user", hmacSecret, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("creating API key: %w", err)
|
|
}
|
|
|
|
node.AuthUserID = workerUser.ID
|
|
node.APIKeyID = apiKey.ID
|
|
if err := registry.UpdateAuthRefs(ctx, node.ID, workerUser.ID, apiKey.ID); err != nil {
|
|
xlog.Warn("Failed to update auth refs on node", "node", node.Name, "error", err)
|
|
}
|
|
|
|
// Grant collections feature so the worker can store/retrieve KB data on behalf of users.
|
|
perm := &auth.UserPermission{
|
|
ID: uuid.New().String(),
|
|
UserID: workerUser.ID,
|
|
Permissions: auth.PermissionMap{auth.FeatureCollections: true},
|
|
}
|
|
if err := authDB.Create(perm).Error; err != nil {
|
|
xlog.Warn("Failed to grant collections permission to agent worker", "node", node.Name, "error", err)
|
|
}
|
|
|
|
xlog.Info("Provisioned API key for agent worker", "node", node.Name, "user", workerUser.ID)
|
|
return plaintext, nil
|
|
}
|
|
|
|
// DeregisterNodeEndpoint removes a backend node permanently (admin use).
|
|
func DeregisterNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.Deregister(ctx, id); err != nil {
|
|
xlog.Error("Failed to deregister node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deregister node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node deregistered"})
|
|
}
|
|
}
|
|
|
|
// DeactivateNodeEndpoint marks a node as offline without deleting it.
|
|
// Used by workers on graceful shutdown to preserve approval status across restarts.
|
|
func DeactivateNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkOffline(ctx, id); err != nil {
|
|
xlog.Error("Failed to deactivate node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to deactivate node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node set to offline"})
|
|
}
|
|
}
|
|
|
|
// HeartbeatEndpoint updates the heartbeat for a node.
|
|
func HeartbeatEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
id := c.Param("id")
|
|
|
|
// Parse optional VRAM update from body
|
|
var update nodes.HeartbeatUpdate
|
|
_ = c.Bind(&update) // best-effort — empty body is fine
|
|
|
|
var updatePtr *nodes.HeartbeatUpdate
|
|
if update.AvailableVRAM != nil || update.TotalVRAM != nil || update.AvailableRAM != nil || update.GPUVendor != "" {
|
|
updatePtr = &update
|
|
}
|
|
|
|
ctx := c.Request().Context()
|
|
if err := registry.Heartbeat(ctx, id, updatePtr); err != nil {
|
|
xlog.Warn("Heartbeat failed for node", "id", id, "error", err)
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "heartbeat received"})
|
|
}
|
|
}
|
|
|
|
// GetNodeModelsEndpoint returns the models loaded on a node.
|
|
func GetNodeModelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
models, err := registry.GetNodeModels(ctx, id)
|
|
if err != nil {
|
|
xlog.Error("Failed to get node models", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get node models"))
|
|
}
|
|
return c.JSON(http.StatusOK, models)
|
|
}
|
|
}
|
|
|
|
// ListAllNodeModelsEndpoint returns all loaded models across all healthy nodes.
|
|
// @Summary List all loaded models cluster-wide
|
|
// @Tags Nodes
|
|
// @Success 200 {array} nodes.NodeModel
|
|
// @Router /api/nodes/models [get]
|
|
func ListAllNodeModelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
models, err := registry.ListAllLoadedModels(ctx)
|
|
if err != nil {
|
|
xlog.Error("Failed to list all node models", "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list node models"))
|
|
}
|
|
return c.JSON(http.StatusOK, models)
|
|
}
|
|
}
|
|
|
|
// DrainNodeEndpoint sets a node to draining status (no new requests).
|
|
func DrainNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkDraining(ctx, id); err != nil {
|
|
xlog.Error("Failed to drain node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to drain node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node set to draining"})
|
|
}
|
|
}
|
|
|
|
// ResumeNodeEndpoint sets a draining node back to healthy status.
|
|
func ResumeNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
id := c.Param("id")
|
|
if err := registry.MarkHealthy(ctx, id); err != nil {
|
|
xlog.Error("Failed to resume node", "id", id, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to resume node"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "node resumed"})
|
|
}
|
|
}
|
|
|
|
// InstallBackendOnNodeEndpoint triggers backend installation on a worker node.
|
|
// Async: enqueues a ManagementOp on the gallery service channel and returns a
|
|
// jobID immediately. The gallery service worker goroutine drives the actual
|
|
// install via DistributedBackendManager.InstallBackend, which honors the op's
|
|
// TargetNodeID to scope the fan-out to one node. The UI polls /api/backends/job/:uid
|
|
// for progress, mirroring /api/backends/install/:id.
|
|
//
|
|
// Backend can be either a gallery ID (resolved against BackendGalleries) or a
|
|
// direct URI install (URI + Name + optional Alias) - same shape as the
|
|
// standalone /api/backends/install-external path, just scoped to one node.
|
|
//
|
|
// The legacy unloader argument is retained for signature symmetry with
|
|
// DeleteBackendOnNodeEndpoint / ListBackendsOnNodeEndpoint but is no longer
|
|
// used here - the async path goes through galleryService.
|
|
func InstallBackendOnNodeEndpoint(_ nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if galleryService == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "gallery service not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
Backend string `json:"backend"`
|
|
BackendGalleries string `json:"backend_galleries,omitempty"`
|
|
URI string `json:"uri,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Alias string `json:"alias,omitempty"`
|
|
}
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if req.Backend == "" && req.URI == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
|
|
}
|
|
|
|
jobUUID, err := uuid.NewUUID()
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to generate job id"))
|
|
}
|
|
jobID := jobUUID.String()
|
|
|
|
// Cache key: for gallery installs, use the backend slug; for URI
|
|
// installs prefer the provided Name (falling back to URI). All keys
|
|
// are node-scoped so concurrent installs of the same backend on
|
|
// different nodes do not stomp each other in opcache.
|
|
backendKey := req.Backend
|
|
if backendKey == "" {
|
|
backendKey = req.Name
|
|
if backendKey == "" {
|
|
backendKey = req.URI
|
|
}
|
|
}
|
|
cacheKey := galleryop.NodeScopedKey(nodeID, backendKey)
|
|
opcache.SetBackend(cacheKey, jobID)
|
|
|
|
// Optional caller-supplied galleries override. Mirrors the standalone
|
|
// install path so an admin can point at a private gallery.
|
|
galleries := appConfig.BackendGalleries
|
|
if req.BackendGalleries != "" {
|
|
var custom []config.Gallery
|
|
if err := json.Unmarshal([]byte(req.BackendGalleries), &custom); err != nil {
|
|
xlog.Warn("Ignoring malformed backend_galleries override; falling back to configured galleries", "error", err, "nodeID", nodeID)
|
|
} else if len(custom) > 0 {
|
|
galleries = custom
|
|
}
|
|
}
|
|
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
op := galleryop.ManagementOp[gallery.GalleryBackend, any]{
|
|
ID: jobID,
|
|
GalleryElementName: req.Backend,
|
|
Galleries: galleries,
|
|
TargetNodeID: nodeID,
|
|
ExternalURI: req.URI,
|
|
ExternalName: req.Name,
|
|
ExternalAlias: req.Alias,
|
|
Context: ctx,
|
|
CancelFunc: cancelFunc,
|
|
}
|
|
galleryService.StoreCancellation(jobID, cancelFunc)
|
|
go func() {
|
|
galleryService.BackendGalleryChannel <- op
|
|
}()
|
|
|
|
xlog.Info("Node-scoped backend install dispatched", "node", nodeID, "backend", req.Backend, "uri", req.URI, "jobID", jobID)
|
|
return c.JSON(http.StatusAccepted, map[string]string{
|
|
"jobID": jobID,
|
|
"statusUrl": "/api/backends/job/" + jobID,
|
|
"message": "backend installation started",
|
|
})
|
|
}
|
|
}
|
|
|
|
// DeleteBackendOnNodeEndpoint deletes a backend from a worker node via NATS.
|
|
func DeleteBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
Backend string `json:"backend"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.Backend == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name required"))
|
|
}
|
|
reply, err := unloader.DeleteBackend(nodeID, req.Backend)
|
|
if err != nil {
|
|
xlog.Error("Failed to delete backend on node", "node", nodeID, "backend", req.Backend, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete backend on node"))
|
|
}
|
|
if !reply.Success {
|
|
xlog.Error("Backend delete failed on node", "node", nodeID, "backend", req.Backend, "error", reply.Error)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "backend deletion failed"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "backend deleted"})
|
|
}
|
|
}
|
|
|
|
// ListBackendsOnNodeEndpoint lists installed backends on a worker node via NATS.
|
|
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
reply, err := unloader.ListBackends(nodeID)
|
|
if err != nil {
|
|
xlog.Error("Failed to list backends on node", "node", nodeID, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node"))
|
|
}
|
|
if reply.Error != "" {
|
|
xlog.Error("List backends failed on node", "node", nodeID, "error", reply.Error)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list backends on node"))
|
|
}
|
|
return c.JSON(http.StatusOK, reply.Backends)
|
|
}
|
|
}
|
|
|
|
// UnloadModelOnNodeEndpoint unloads a model from a worker node (gRPC Free) via NATS.
|
|
func UnloadModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
ModelName string `json:"model_name"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.ModelName == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required"))
|
|
}
|
|
if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil {
|
|
xlog.Error("Failed to unload model on node", "node", nodeID, "model", req.ModelName, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model on node"))
|
|
}
|
|
// Also stop the backend process
|
|
if err := unloader.StopBackend(nodeID, req.ModelName); err != nil {
|
|
xlog.Error("Failed to stop backend after model unload", "node", nodeID, "model", req.ModelName, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "model unloaded but backend stop failed"))
|
|
}
|
|
// Remove every replica of this model on the node from the registry.
|
|
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "model unloaded"})
|
|
}
|
|
}
|
|
|
|
// DeleteModelOnNodeEndpoint deletes model files from a worker node via NATS.
|
|
func DeleteModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if unloader == nil {
|
|
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
|
}
|
|
nodeID := c.Param("id")
|
|
var req struct {
|
|
ModelName string `json:"model_name"`
|
|
}
|
|
if err := c.Bind(&req); err != nil || req.ModelName == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name required"))
|
|
}
|
|
// Unload model first if loaded
|
|
if err := unloader.UnloadModelOnNode(nodeID, req.ModelName); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to unload model before deletion"))
|
|
}
|
|
if err := unloader.StopBackend(nodeID, req.ModelName); err != nil {
|
|
// Non-fatal — backend process may not be running
|
|
xlog.Warn("StopBackend failed during model deletion (non-fatal)", "node", nodeID, "model", req.ModelName, "error", err)
|
|
}
|
|
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
|
|
return c.JSON(http.StatusOK, map[string]string{"message": "model deleted from node"})
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsListEndpoint proxies a request to a worker node's /v1/backend-logs
|
|
// endpoint to list model IDs that have backend logs.
|
|
func NodeBackendLogsListEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
if node.HTTPAddress == "" {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address"))
|
|
}
|
|
|
|
resp, err := proxyHTTPToWorker(node.HTTPAddress, "/v1/backend-logs", registrationToken)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err)))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
c.Response().Header().Set("Content-Type", "application/json")
|
|
c.Response().WriteHeader(resp.StatusCode)
|
|
io.Copy(c.Response(), resp.Body)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsLinesEndpoint proxies a request to a worker node's
|
|
// /v1/backend-logs/{modelId} endpoint to get buffered log lines.
|
|
func NodeBackendLogsLinesEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
modelID := c.Param("modelId")
|
|
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
if node.HTTPAddress == "" {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, "node has no HTTP address"))
|
|
}
|
|
|
|
path := "/v1/backend-logs/" + url.PathEscape(modelID)
|
|
resp, err := proxyHTTPToWorker(node.HTTPAddress, path, registrationToken)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadGateway, nodeError(http.StatusBadGateway, fmt.Sprintf("failed to reach worker: %v", err)))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
c.Response().Header().Set("Content-Type", "application/json")
|
|
c.Response().WriteHeader(resp.StatusCode)
|
|
io.Copy(c.Response(), resp.Body)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// NodeBackendLogsWSEndpoint proxies a WebSocket connection to a worker node's
|
|
// /v1/backend-logs/{modelId}/ws endpoint for real-time log streaming.
|
|
func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken string) echo.HandlerFunc {
|
|
browserUpgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // no origin header = same-origin or non-browser
|
|
}
|
|
// Parse origin URL and compare host with request host
|
|
u, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return u.Host == r.Host
|
|
},
|
|
}
|
|
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
modelID := c.Param("modelId")
|
|
|
|
node, err := registry.Get(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
|
|
// Upgrade browser connection
|
|
browserWS, err := browserUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Dial the worker WebSocket
|
|
workerURL := fmt.Sprintf("ws://%s/v1/backend-logs/%s/ws", node.HTTPAddress, url.PathEscape(modelID))
|
|
workerHeaders := http.Header{}
|
|
if registrationToken != "" {
|
|
workerHeaders.Set("Authorization", "Bearer "+registrationToken)
|
|
}
|
|
|
|
workerDialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
|
workerWS, _, err := workerDialer.Dial(workerURL, workerHeaders)
|
|
if err != nil {
|
|
browserWS.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "failed to connect to worker"))
|
|
browserWS.Close()
|
|
return nil
|
|
}
|
|
|
|
// Use sync.OnceFunc wrappers to avoid double-close and ensure each
|
|
// goroutine can safely close the *other* connection to unblock
|
|
// its peer's ReadMessage call.
|
|
done := make(chan struct{})
|
|
closeWorker := sync.OnceFunc(func() { workerWS.Close() })
|
|
closeBrowser := sync.OnceFunc(func() { browserWS.Close() })
|
|
|
|
// Worker → Browser
|
|
go func() {
|
|
defer close(done)
|
|
defer closeBrowser() // unblock Browser→Worker goroutine
|
|
for {
|
|
msgType, msg, err := workerWS.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := browserWS.WriteMessage(msgType, msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Browser → Worker (mainly for close detection)
|
|
go func() {
|
|
defer closeWorker() // unblock Worker→Browser goroutine
|
|
for {
|
|
msgType, msg, err := browserWS.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := workerWS.WriteMessage(msgType, msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
<-done
|
|
closeWorker()
|
|
closeBrowser()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// GetNodeLabelsEndpoint returns labels for a specific node.
|
|
func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
labels, err := registry.GetNodeLabels(ctx, nodeID)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels"))
|
|
}
|
|
// Convert to map for cleaner API response
|
|
result := make(map[string]string)
|
|
for _, l := range labels {
|
|
result[l.Key] = l.Value
|
|
}
|
|
return c.JSON(http.StatusOK, result)
|
|
}
|
|
}
|
|
|
|
// UpdateMaxReplicasPerModelRequest is the body for the per-node replica cap endpoint.
|
|
type UpdateMaxReplicasPerModelRequest struct {
|
|
// Value is the new per-model replica cap on this node. Must be >= 1.
|
|
Value int `json:"value"`
|
|
}
|
|
|
|
// UpdateMaxReplicasPerModelEndpoint sets the per-node cap on how many replicas
|
|
// of any one model can be loaded concurrently. The corresponding
|
|
// `node.replica-slots` auto-label is refreshed so existing AND-selectors keep
|
|
// matching, and any unsatisfiable scheduling cooldowns are cleared so the
|
|
// reconciler retries on the next tick.
|
|
//
|
|
// This is a transient admin override — a worker re-registration restores the
|
|
// value the worker was started with (--max-replicas-per-model). For permanent
|
|
// fleet changes, change the worker flag.
|
|
//
|
|
// @Summary Update a node's max replicas per model
|
|
// @Tags Nodes
|
|
// @Param id path string true "Node ID"
|
|
// @Param request body UpdateMaxReplicasPerModelRequest true "New value"
|
|
// @Success 200 {object} map[string]int
|
|
// @Failure 400 {object} map[string]any "value must be >= 1"
|
|
// @Failure 404 {object} map[string]any "node not found"
|
|
// @Router /api/nodes/{id}/max-replicas-per-model [put]
|
|
func UpdateMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var req UpdateMaxReplicasPerModelRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if req.Value < 1 {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "value must be >= 1"))
|
|
}
|
|
if err := registry.UpdateMaxReplicasPerModel(ctx, nodeID, req.Value); err != nil {
|
|
xlog.Error("Failed to update max_replicas_per_model", "node", nodeID, "value", req.Value, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to update max replicas per model"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]int{"max_replicas_per_model": req.Value})
|
|
}
|
|
}
|
|
|
|
// ResetMaxReplicasPerModelEndpoint clears the admin override on a node, so
|
|
// the next worker re-registration is allowed to update the value from its
|
|
// CLI flag again. The current value is left in place until the worker calls
|
|
// register.
|
|
//
|
|
// @Summary Reset a node's max replicas per model to the worker default
|
|
// @Tags Nodes
|
|
// @Param id path string true "Node ID"
|
|
// @Success 200 {object} map[string]bool
|
|
// @Failure 404 {object} map[string]any "node not found"
|
|
// @Router /api/nodes/{id}/max-replicas-per-model [delete]
|
|
func ResetMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
if err := registry.ResetMaxReplicasPerModel(ctx, nodeID); err != nil {
|
|
xlog.Error("Failed to reset max_replicas_per_model override", "node", nodeID, "error", err)
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to reset override"))
|
|
}
|
|
return c.JSON(http.StatusOK, map[string]bool{"reset": true})
|
|
}
|
|
}
|
|
|
|
// SetNodeLabelsEndpoint replaces all labels for a node.
|
|
func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var labels map[string]string
|
|
if err := c.Bind(&labels); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels"))
|
|
}
|
|
return c.JSON(http.StatusOK, labels)
|
|
}
|
|
}
|
|
|
|
// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones.
|
|
func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
if _, err := registry.Get(ctx, nodeID); err != nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
|
}
|
|
var labels map[string]string
|
|
if err := c.Bind(&labels); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
for k, v := range labels {
|
|
if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels"))
|
|
}
|
|
}
|
|
// Return updated labels
|
|
updated, _ := registry.GetNodeLabels(ctx, nodeID)
|
|
result := make(map[string]string)
|
|
for _, l := range updated {
|
|
result[l.Key] = l.Value
|
|
}
|
|
return c.JSON(http.StatusOK, result)
|
|
}
|
|
}
|
|
|
|
// DeleteNodeLabelEndpoint removes a single label from a node.
|
|
func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
nodeID := c.Param("id")
|
|
key := c.Param("key")
|
|
if key == "" {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required"))
|
|
}
|
|
if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label"))
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
}
|
|
|
|
// ListSchedulingEndpoint returns all model scheduling configs.
|
|
func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
configs, err := registry.ListModelSchedulings(ctx)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs"))
|
|
}
|
|
return c.JSON(http.StatusOK, configs)
|
|
}
|
|
}
|
|
|
|
// GetSchedulingEndpoint returns the scheduling config for a specific model.
|
|
func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
modelName := c.Param("model")
|
|
config, err := registry.GetModelScheduling(ctx, modelName)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config"))
|
|
}
|
|
if config == nil {
|
|
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model"))
|
|
}
|
|
return c.JSON(http.StatusOK, config)
|
|
}
|
|
}
|
|
|
|
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
|
//
|
|
// The four prefix-cache fields are POINTERS so an omitted field is
|
|
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
|
// field preserves the model's previously-configured value instead of resetting
|
|
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
|
// MinReplicas, MaxReplicas and SpreadAll keep their full-replace PUT semantics.
|
|
type SetSchedulingRequest struct {
|
|
ModelName string `json:"model_name"`
|
|
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
|
MinReplicas int `json:"min_replicas"`
|
|
MaxReplicas int `json:"max_replicas"`
|
|
SpreadAll bool `json:"spread_all,omitempty"`
|
|
RoutePolicy *string `json:"route_policy,omitempty"`
|
|
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
|
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
|
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
|
}
|
|
|
|
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
|
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
|
// single source of truth), and are checked against the RESOLVED values passed
|
|
// in (provided-or-preserved), so validation only rejects bad values the caller
|
|
// actually supplied. It returns nil when valid, or an error with a user-facing
|
|
// message describing the first violation.
|
|
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
|
if req.ModelName == "" {
|
|
return errors.New("model_name is required")
|
|
}
|
|
if req.SpreadAll && (req.MinReplicas != 0 || req.MaxReplicas != 0) {
|
|
return errors.New("spread_all and min_replicas/max_replicas are mutually exclusive")
|
|
}
|
|
if req.MinReplicas < 0 {
|
|
return errors.New("min_replicas must be >= 0")
|
|
}
|
|
if req.MaxReplicas < 0 {
|
|
return errors.New("max_replicas must be >= 0")
|
|
}
|
|
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
|
return errors.New("min_replicas must be <= max_replicas")
|
|
}
|
|
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
|
//
|
|
// The registry upsert full-replaces all columns, so a request that omits the
|
|
// prefix-cache fields would otherwise wipe a model's previously-configured
|
|
// routing settings. To avoid that footgun the four prefix-cache fields are
|
|
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
|
// existing config's value (or the zero default when no config exists yet). The
|
|
// non-prefix fields keep their full-replace PUT behavior.
|
|
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
var req SetSchedulingRequest
|
|
if err := c.Bind(&req); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
|
}
|
|
|
|
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
|
// can fall back to the stored value rather than resetting to zero.
|
|
var existing *nodes.ModelSchedulingConfig
|
|
if req.ModelName != "" {
|
|
var err error
|
|
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
|
if err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
|
}
|
|
}
|
|
|
|
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
|
// the existing value (zero/default when there is no existing config).
|
|
routePolicy := ""
|
|
absThr := 0
|
|
relThr := 0.0
|
|
minMatch := 0.0
|
|
if existing != nil {
|
|
routePolicy = existing.RoutePolicy
|
|
absThr = existing.BalanceAbsThreshold
|
|
relThr = existing.BalanceRelThreshold
|
|
minMatch = existing.MinPrefixMatch
|
|
}
|
|
if req.RoutePolicy != nil {
|
|
routePolicy = *req.RoutePolicy
|
|
}
|
|
if req.BalanceAbsThreshold != nil {
|
|
absThr = *req.BalanceAbsThreshold
|
|
}
|
|
if req.BalanceRelThreshold != nil {
|
|
relThr = *req.BalanceRelThreshold
|
|
}
|
|
if req.MinPrefixMatch != nil {
|
|
minMatch = *req.MinPrefixMatch
|
|
}
|
|
|
|
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
|
}
|
|
|
|
// Serialize node selector to JSON
|
|
var selectorJSON string
|
|
if len(req.NodeSelector) > 0 {
|
|
b, err := json.Marshal(req.NodeSelector)
|
|
if err != nil {
|
|
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector"))
|
|
}
|
|
selectorJSON = string(b)
|
|
}
|
|
|
|
config := &nodes.ModelSchedulingConfig{
|
|
ModelName: req.ModelName,
|
|
NodeSelector: selectorJSON,
|
|
MinReplicas: req.MinReplicas,
|
|
MaxReplicas: req.MaxReplicas,
|
|
SpreadAll: req.SpreadAll,
|
|
RoutePolicy: routePolicy,
|
|
BalanceAbsThreshold: absThr,
|
|
BalanceRelThreshold: relThr,
|
|
MinPrefixMatch: minMatch,
|
|
}
|
|
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
|
}
|
|
return c.JSON(http.StatusOK, config)
|
|
}
|
|
}
|
|
|
|
// DeleteSchedulingEndpoint removes a model scheduling config.
|
|
func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
modelName := c.Param("model")
|
|
if err := registry.DeleteModelScheduling(ctx, modelName); err != nil {
|
|
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config"))
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
}
|
|
|
|
// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth.
|
|
func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) {
|
|
reqURL := fmt.Sprintf("http://%s%s", httpAddress, path)
|
|
req, err := http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
client := httpclient.NewWithTimeout(15 * time.Second)
|
|
return client.Do(req)
|
|
}
|