mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
903 lines
34 KiB
Go
903 lines
34 KiB
Go
package nodes
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/mudler/LocalAI/core/services/advisorylock"
|
|
"github.com/mudler/xlog"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
// BackendNode represents a remote worker node.
|
|
// Workers are generic — they don't have a fixed backend type.
|
|
// The SmartRouter dynamically installs backends via NATS backend.install events.
|
|
type BackendNode struct {
|
|
ID string `gorm:"primaryKey;size:36" json:"id"`
|
|
Name string `gorm:"uniqueIndex;size:255" json:"name"`
|
|
NodeType string `gorm:"size:32;default:backend" json:"node_type"` // backend, agent
|
|
Address string `gorm:"size:255" json:"address"` // host:port for gRPC
|
|
HTTPAddress string `gorm:"size:255" json:"http_address"` // host:port for HTTP file transfer
|
|
Status string `gorm:"size:32;default:registering" json:"status"` // registering, healthy, unhealthy, draining, pending
|
|
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
|
|
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
|
|
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
|
|
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
|
|
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
|
|
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
|
|
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
|
|
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
|
|
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
const (
|
|
NodeTypeBackend = "backend"
|
|
NodeTypeAgent = "agent"
|
|
|
|
StatusHealthy = "healthy"
|
|
StatusPending = "pending"
|
|
StatusOffline = "offline"
|
|
StatusDraining = "draining"
|
|
StatusUnhealthy = "unhealthy"
|
|
|
|
// Column names (must match gorm:"column:" tags on BackendNode)
|
|
ColAvailableVRAM = "available_vram"
|
|
ColTotalVRAM = "total_vram"
|
|
ColAvailableRAM = "available_ram"
|
|
ColGPUVendor = "gpu_vendor"
|
|
ColLastHeartbeat = "last_heartbeat"
|
|
)
|
|
|
|
// NodeModel tracks which models are loaded on which nodes.
|
|
type NodeModel struct {
|
|
ID string `gorm:"primaryKey;size:36" json:"id"`
|
|
NodeID string `gorm:"index;size:36" json:"node_id"`
|
|
ModelName string `gorm:"index;size:255" json:"model_name"`
|
|
Address string `gorm:"size:255" json:"address"` // gRPC address for this model's backend process
|
|
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
|
|
InFlight int `json:"in_flight"` // number of active requests
|
|
LastUsed time.Time `json:"last_used"`
|
|
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// NodeLabel is a key-value label on a node (like K8s labels).
|
|
type NodeLabel struct {
|
|
ID string `gorm:"primaryKey;size:36" json:"id"`
|
|
NodeID string `gorm:"uniqueIndex:idx_node_label;size:36" json:"node_id"`
|
|
Key string `gorm:"uniqueIndex:idx_node_label;size:128" json:"key"`
|
|
Value string `gorm:"size:255" json:"value"`
|
|
}
|
|
|
|
// ModelSchedulingConfig defines how a model should be scheduled across the cluster.
|
|
// All fields are optional:
|
|
// - NodeSelector only → constrain nodes, single replica
|
|
// - MinReplicas/MaxReplicas only → auto-scale on any node
|
|
// - Both → auto-scale on matching nodes
|
|
// - Neither → no-op (default behavior)
|
|
//
|
|
// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
|
|
type ModelSchedulingConfig struct {
|
|
ID string `gorm:"primaryKey;size:36" json:"id"`
|
|
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
|
|
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
|
|
MinReplicas int `gorm:"default:0" json:"min_replicas"`
|
|
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// NodeWithExtras extends BackendNode with computed fields for list views.
|
|
type NodeWithExtras struct {
|
|
BackendNode
|
|
ModelCount int `json:"model_count"`
|
|
Labels map[string]string `json:"labels,omitempty"`
|
|
}
|
|
|
|
// NodeRegistry manages backend node registration and lookup in PostgreSQL.
|
|
type NodeRegistry struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewNodeRegistry creates a NodeRegistry and auto-migrates the schema.
|
|
// Uses a PostgreSQL advisory lock to prevent concurrent migration races
|
|
// when multiple instances (frontend + workers) start at the same time.
|
|
func NewNodeRegistry(db *gorm.DB) (*NodeRegistry, error) {
|
|
if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error {
|
|
return db.AutoMigrate(&BackendNode{}, &NodeModel{}, &NodeLabel{}, &ModelSchedulingConfig{})
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("migrating node tables: %w", err)
|
|
}
|
|
return &NodeRegistry{db: db}, nil
|
|
}
|
|
|
|
// Register adds or updates a backend node.
|
|
// If autoApprove is true, the node goes directly to "healthy" status.
|
|
// If false, new nodes start in "pending" status and must be approved by an admin.
|
|
// On re-registration (same name), previously approved nodes return to "healthy";
|
|
// nodes that were never approved stay in "pending".
|
|
func (r *NodeRegistry) Register(ctx context.Context, node *BackendNode, autoApprove bool) error {
|
|
node.LastHeartbeat = time.Now()
|
|
|
|
// Try to find existing node by name
|
|
var existing BackendNode
|
|
err := r.db.WithContext(ctx).Where("name = ?", node.Name).First(&existing).Error
|
|
if err == nil {
|
|
// Re-registration (node restart): preserve ID, respect approval history
|
|
node.ID = existing.ID
|
|
if autoApprove || existing.Status != StatusPending {
|
|
// Auto-approve enabled, or node was previously approved — restore healthy
|
|
node.Status = StatusHealthy
|
|
} else {
|
|
// Node was never approved — keep pending
|
|
node.Status = StatusPending
|
|
}
|
|
if err := r.db.WithContext(ctx).Model(&existing).Updates(node).Error; err != nil {
|
|
return fmt.Errorf("updating node %s: %w", node.Name, err)
|
|
}
|
|
// Preserve auth references from existing record.
|
|
// GORM Updates(struct) skips zero-value fields, so the DB retains
|
|
// the old auth_user_id/api_key_id but the caller's struct is empty.
|
|
// Copy them back so the caller can revoke old credentials on re-registration.
|
|
if node.AuthUserID == "" {
|
|
node.AuthUserID = existing.AuthUserID
|
|
}
|
|
if node.APIKeyID == "" {
|
|
node.APIKeyID = existing.APIKeyID
|
|
}
|
|
// Clear stale model records — the node restarted and has nothing loaded
|
|
if err := r.db.WithContext(ctx).Where("node_id = ?", existing.ID).Delete(&NodeModel{}).Error; err != nil {
|
|
xlog.Warn("Failed to clear stale model records on re-register", "node", node.Name, "error", err)
|
|
}
|
|
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Create new node
|
|
if node.ID == "" {
|
|
node.ID = uuid.New().String()
|
|
}
|
|
if autoApprove {
|
|
node.Status = StatusHealthy
|
|
} else {
|
|
node.Status = StatusPending
|
|
}
|
|
if err := r.db.WithContext(ctx).Create(node).Error; err != nil {
|
|
return fmt.Errorf("creating node %s: %w", node.Name, err)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("looking up node %s: %w", node.Name, err)
|
|
}
|
|
|
|
xlog.Info("Node registered", "name", node.Name, "address", node.Address, "status", node.Status)
|
|
return nil
|
|
}
|
|
|
|
// UpdateAuthRefs stores the auto-provisioned user and API key IDs on a node.
|
|
func (r *NodeRegistry) UpdateAuthRefs(ctx context.Context, nodeID, authUserID, apiKeyID string) error {
|
|
return r.db.WithContext(ctx).Model(&BackendNode{}).Where("id = ?", nodeID).Updates(map[string]any{
|
|
"auth_user_id": authUserID,
|
|
"api_key_id": apiKeyID,
|
|
}).Error
|
|
}
|
|
|
|
// ApproveNode sets a pending node's status to healthy.
|
|
func (r *NodeRegistry) ApproveNode(ctx context.Context, nodeID string) error {
|
|
result := r.db.WithContext(ctx).Model(&BackendNode{}).
|
|
Where("id = ? AND status = ?", nodeID, StatusPending).
|
|
Update("status", StatusHealthy)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("approving node %s: %w", nodeID, result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("node %s not found or not in pending status", nodeID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// setStatus updates a node's status column in the database.
|
|
func (r *NodeRegistry) setStatus(ctx context.Context, nodeID, status string) error {
|
|
result := r.db.WithContext(ctx).Model(&BackendNode{}).
|
|
Where("id = ?", nodeID).Update("status", status)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("setting node %s to %s: %w", nodeID, status, result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("node %s not found", nodeID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MarkOffline sets a node to offline status and clears its model records.
|
|
// Used on graceful shutdown — preserves the node row so re-registration
|
|
// can restore the previous approval status.
|
|
func (r *NodeRegistry) MarkOffline(ctx context.Context, nodeID string) error {
|
|
if err := r.setStatus(ctx, nodeID, StatusOffline); err != nil {
|
|
return err
|
|
}
|
|
// Clear model records — node is shutting down
|
|
if err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
|
|
xlog.Warn("Failed to clear model records on offline", "node", nodeID, "error", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// FindNodeWithVRAM returns healthy nodes with at least minBytes available VRAM,
|
|
// ordered idle-first then least-loaded.
|
|
func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
loadedModels := db.Model(&NodeModel{}).
|
|
Select("node_id").
|
|
Where("state = ?", "loaded").
|
|
Group("node_id")
|
|
|
|
subquery := db.Model(&NodeModel{}).
|
|
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
|
Group("node_id")
|
|
|
|
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
|
|
var node BackendNode
|
|
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
|
|
Order("available_vram DESC").
|
|
First(&node).Error
|
|
if err == nil {
|
|
return &node, nil
|
|
}
|
|
|
|
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
|
|
err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes).
|
|
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
|
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// Deregister removes a backend node, its model associations, and any auto-provisioned auth credentials.
|
|
func (r *NodeRegistry) Deregister(ctx context.Context, nodeID string) error {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
var node BackendNode
|
|
if err := db.Where("id = ?", nodeID).First(&node).Error; err != nil {
|
|
return fmt.Errorf("node %s not found: %w", nodeID, err)
|
|
}
|
|
|
|
return db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
|
|
return fmt.Errorf("deleting node models for %s: %w", nodeID, err)
|
|
}
|
|
if err := tx.Where("id = ?", nodeID).Delete(&BackendNode{}).Error; err != nil {
|
|
return fmt.Errorf("deleting node %s: %w", nodeID, err)
|
|
}
|
|
// Clean up auto-provisioned auth user (cascades to API keys via FK)
|
|
if node.AuthUserID != "" {
|
|
if err := tx.Exec("DELETE FROM users WHERE id = ?", node.AuthUserID).Error; err != nil {
|
|
xlog.Warn("Failed to clean up agent worker user", "node", node.Name, "userID", node.AuthUserID, "error", err)
|
|
// non-fatal: don't rollback the whole deregistration for auth cleanup
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// HeartbeatUpdate contains optional fields to update on heartbeat.
|
|
type HeartbeatUpdate struct {
|
|
AvailableVRAM *uint64 `json:"available_vram,omitempty"`
|
|
TotalVRAM *uint64 `json:"total_vram,omitempty"`
|
|
AvailableRAM *uint64 `json:"available_ram,omitempty"`
|
|
GPUVendor string `json:"gpu_vendor,omitempty"`
|
|
}
|
|
|
|
// Heartbeat updates the heartbeat timestamp and status for a node.
|
|
// Nodes in "pending" or "offline" status stay in their current status —
|
|
// they must be approved or re-register respectively.
|
|
func (r *NodeRegistry) Heartbeat(ctx context.Context, nodeID string, update *HeartbeatUpdate) error {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
updates := map[string]any{
|
|
ColLastHeartbeat: time.Now(),
|
|
}
|
|
|
|
if update != nil {
|
|
if update.AvailableVRAM != nil {
|
|
updates[ColAvailableVRAM] = *update.AvailableVRAM
|
|
}
|
|
if update.TotalVRAM != nil {
|
|
updates[ColTotalVRAM] = *update.TotalVRAM
|
|
}
|
|
if update.AvailableRAM != nil {
|
|
updates[ColAvailableRAM] = *update.AvailableRAM
|
|
}
|
|
if update.GPUVendor != "" {
|
|
updates[ColGPUVendor] = update.GPUVendor
|
|
}
|
|
}
|
|
|
|
// Only update all fields (including status promotion) for active nodes.
|
|
// Pending and offline nodes must go through approval or re-registration.
|
|
result := db.Model(&BackendNode{}).
|
|
Where("id = ? AND status NOT IN ?", nodeID, []string{StatusPending, StatusOffline}).
|
|
Updates(updates)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("heartbeat for %s: %w", nodeID, result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
// May be pending or offline — still update heartbeat timestamp
|
|
result = db.Model(&BackendNode{}).Where("id = ?", nodeID).Update(ColLastHeartbeat, time.Now())
|
|
if result.Error != nil {
|
|
return fmt.Errorf("heartbeat for %s: %w", nodeID, result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("node %s not found", nodeID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List returns all registered nodes.
|
|
func (r *NodeRegistry) List(ctx context.Context) ([]BackendNode, error) {
|
|
var nodes []BackendNode
|
|
if err := r.db.WithContext(ctx).Order("name").Find(&nodes).Error; err != nil {
|
|
return nil, fmt.Errorf("listing nodes: %w", err)
|
|
}
|
|
return nodes, nil
|
|
}
|
|
|
|
// Get returns a single node by ID.
|
|
func (r *NodeRegistry) Get(ctx context.Context, nodeID string) (*BackendNode, error) {
|
|
var node BackendNode
|
|
if err := r.db.WithContext(ctx).First(&node, "id = ?", nodeID).Error; err != nil {
|
|
return nil, fmt.Errorf("getting node %s: %w", nodeID, err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// GetByName returns a single node by name.
|
|
func (r *NodeRegistry) GetByName(ctx context.Context, name string) (*BackendNode, error) {
|
|
var node BackendNode
|
|
if err := r.db.WithContext(ctx).First(&node, "name = ?", name).Error; err != nil {
|
|
return nil, fmt.Errorf("getting node by name %s: %w", name, err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// MarkUnhealthy sets a node status to unhealthy.
|
|
func (r *NodeRegistry) MarkUnhealthy(ctx context.Context, nodeID string) error {
|
|
return r.setStatus(ctx, nodeID, StatusUnhealthy)
|
|
}
|
|
|
|
// MarkHealthy sets a node status to healthy.
|
|
func (r *NodeRegistry) MarkHealthy(ctx context.Context, nodeID string) error {
|
|
return r.setStatus(ctx, nodeID, StatusHealthy)
|
|
}
|
|
|
|
// MarkDraining sets a node status to draining (no new requests).
|
|
func (r *NodeRegistry) MarkDraining(ctx context.Context, nodeID string) error {
|
|
return r.setStatus(ctx, nodeID, StatusDraining)
|
|
}
|
|
|
|
// FindStaleNodes returns nodes that haven't sent a heartbeat within the given threshold.
|
|
// Excludes unhealthy, offline, and pending nodes since they're not actively participating.
|
|
func (r *NodeRegistry) FindStaleNodes(ctx context.Context, threshold time.Duration) ([]BackendNode, error) {
|
|
var nodes []BackendNode
|
|
cutoff := time.Now().Add(-threshold)
|
|
if err := r.db.WithContext(ctx).Where("last_heartbeat < ? AND status NOT IN ?", cutoff,
|
|
[]string{StatusUnhealthy, StatusOffline, StatusPending}).
|
|
Find(&nodes).Error; err != nil {
|
|
return nil, fmt.Errorf("finding stale nodes: %w", err)
|
|
}
|
|
return nodes, nil
|
|
}
|
|
|
|
// --- NodeModel operations ---
|
|
|
|
// SetNodeModel records that a model is loaded on a node.
|
|
func (r *NodeRegistry) SetNodeModel(ctx context.Context, nodeID, modelName, state, address string, initialInFlight int) error {
|
|
now := time.Now()
|
|
// Use Attrs for creation-only fields (ID) and Assign for update-only fields.
|
|
// Attrs is applied only when creating a new record. Assign is applied on
|
|
// both create and update. This prevents overwriting the primary key on
|
|
// subsequent calls for the same node+model.
|
|
var nm NodeModel
|
|
result := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
|
Attrs(NodeModel{ID: uuid.New().String(), NodeID: nodeID, ModelName: modelName}).
|
|
Assign(map[string]any{"address": address, "state": state, "last_used": now, "in_flight": initialInFlight}).
|
|
FirstOrCreate(&nm)
|
|
return result.Error
|
|
}
|
|
|
|
// RemoveNodeModel removes a model association from a node.
|
|
func (r *NodeRegistry) RemoveNodeModel(ctx context.Context, nodeID, modelName string) error {
|
|
return r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
|
Delete(&NodeModel{}).Error
|
|
}
|
|
|
|
// FindNodesWithModel returns nodes that have the given model loaded.
|
|
func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error) {
|
|
var nodes []BackendNode
|
|
if err := r.db.WithContext(ctx).Joins("JOIN node_models ON node_models.node_id = backend_nodes.id").
|
|
Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
|
|
modelName, "loaded", StatusHealthy).
|
|
Order("node_models.in_flight ASC").
|
|
Find(&nodes).Error; err != nil {
|
|
return nil, fmt.Errorf("finding nodes with model %s: %w", modelName, err)
|
|
}
|
|
return nodes, nil
|
|
}
|
|
|
|
// FindAndLockNodeWithModel atomically finds the least-loaded node with the given
|
|
// model loaded and increments its in-flight counter within a single transaction.
|
|
// The SELECT FOR UPDATE row lock prevents concurrent eviction from removing the
|
|
// NodeModel row between the find and increment operations.
|
|
func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string) (*BackendNode, *NodeModel, error) {
|
|
var nm NodeModel
|
|
var node BackendNode
|
|
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// Order by in_flight ASC (least busy replica), then by available_vram DESC
|
|
// (prefer nodes with more free VRAM to spread load across the cluster).
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
|
Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded").
|
|
Order("node_models.in_flight ASC, backend_nodes.available_vram DESC").
|
|
First(&nm).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.Model(&nm).Updates(map[string]any{
|
|
"in_flight": gorm.Expr("in_flight + 1"),
|
|
"last_used": time.Now(),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.Where("id = ? AND status = ?", nm.NodeID, StatusHealthy).
|
|
First(&node).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return &node, &nm, nil
|
|
}
|
|
|
|
// TouchNodeModel updates the last_used timestamp for LRU tracking.
|
|
func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string) {
|
|
r.db.WithContext(ctx).Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
|
Update("last_used", time.Now())
|
|
}
|
|
|
|
// GetNodeModel returns the NodeModel record for a specific node+model combination.
|
|
func (r *NodeRegistry) GetNodeModel(ctx context.Context, nodeID, modelName string) (*NodeModel, error) {
|
|
var nm NodeModel
|
|
err := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).First(&nm).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &nm, nil
|
|
}
|
|
|
|
// FindLeastLoadedNode returns the healthy node with the fewest in-flight requests.
|
|
func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
var node BackendNode
|
|
query := db.Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
|
|
// Order by total in-flight across all models on the node
|
|
subquery := db.Model(&NodeModel{}).
|
|
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
|
Group("node_id")
|
|
|
|
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
|
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("finding least loaded node: %w", err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// FindIdleNode returns a healthy node with zero in-flight requests and zero loaded models.
|
|
// Used by the scheduler to prefer truly idle nodes for new backend assignments.
|
|
func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
var node BackendNode
|
|
loadedModels := db.Model(&NodeModel{}).
|
|
Select("node_id").
|
|
Where("state = ?", "loaded").
|
|
Group("node_id")
|
|
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, loadedModels).
|
|
Order("available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// IncrementInFlight atomically increments the in-flight counter for a model on a node.
|
|
func (r *NodeRegistry) IncrementInFlight(ctx context.Context, nodeID, modelName string) error {
|
|
result := r.db.WithContext(ctx).Model(&NodeModel{}).
|
|
Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
|
Updates(map[string]any{
|
|
"in_flight": gorm.Expr("in_flight + 1"),
|
|
"last_used": time.Now(),
|
|
})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("node model %s/%s not found", nodeID, modelName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DecrementInFlight atomically decrements the in-flight counter for a model on a node.
|
|
func (r *NodeRegistry) DecrementInFlight(ctx context.Context, nodeID, modelName string) error {
|
|
result := r.db.WithContext(ctx).Model(&NodeModel{}).
|
|
Where("node_id = ? AND model_name = ? AND in_flight > 0", nodeID, modelName).
|
|
UpdateColumn("in_flight", gorm.Expr("in_flight - 1"))
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
xlog.Warn("DecrementInFlight: no matching row or already zero", "node", nodeID, "model", modelName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetNodeModels returns all models loaded on a given node.
|
|
func (r *NodeRegistry) GetNodeModels(ctx context.Context, nodeID string) ([]NodeModel, error) {
|
|
var models []NodeModel
|
|
if err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&models).Error; err != nil {
|
|
return nil, fmt.Errorf("getting models for node %s: %w", nodeID, err)
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
// ListAllLoadedModels returns all models that are loaded on healthy nodes.
|
|
// Used by DistributedModelStore.Range() to discover models not in local cache.
|
|
func (r *NodeRegistry) ListAllLoadedModels(ctx context.Context) ([]NodeModel, error) {
|
|
var models []NodeModel
|
|
err := r.db.WithContext(ctx).Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
|
Where("node_models.state = ? AND backend_nodes.status = ?", "loaded", StatusHealthy).
|
|
Find(&models).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listing all loaded models: %w", err)
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
// FindNodeForModel returns the first healthy node that has the given model loaded.
|
|
// Returns the node and true if found, nil and false otherwise.
|
|
func (r *NodeRegistry) FindNodeForModel(ctx context.Context, modelName string) (*BackendNode, bool) {
|
|
nodes, err := r.FindNodesWithModel(ctx, modelName)
|
|
if err != nil || len(nodes) == 0 {
|
|
return nil, false
|
|
}
|
|
return &nodes[0], true
|
|
}
|
|
|
|
// FindLRUModel returns the least-recently-used model on a node.
|
|
func (r *NodeRegistry) FindLRUModel(ctx context.Context, nodeID string) (*NodeModel, error) {
|
|
var nm NodeModel
|
|
err := r.db.WithContext(ctx).Where("node_id = ? AND state = ? AND in_flight = 0", nodeID, "loaded").
|
|
Order("last_used ASC").First(&nm).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("finding LRU model on node %s: %w", nodeID, err)
|
|
}
|
|
return &nm, nil
|
|
}
|
|
|
|
// FindGlobalLRUModelWithZeroInFlight returns the least-recently-used model
|
|
// across all healthy backend nodes that has zero in-flight requests.
|
|
// Used by the router for preemptive eviction when no node has free VRAM.
|
|
func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (*NodeModel, error) {
|
|
var nm NodeModel
|
|
err := r.db.WithContext(ctx).Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
|
Where("node_models.state = ? AND node_models.in_flight = 0 AND backend_nodes.status = ? AND backend_nodes.node_type = ?",
|
|
"loaded", StatusHealthy, NodeTypeBackend).
|
|
Order("node_models.last_used ASC").
|
|
First(&nm).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no evictable model found: %w", err)
|
|
}
|
|
return &nm, nil
|
|
}
|
|
|
|
// --- NodeLabel operations ---
|
|
|
|
// SetNodeLabel upserts a single label on a node.
|
|
func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error {
|
|
label := NodeLabel{
|
|
ID: uuid.New().String(),
|
|
NodeID: nodeID,
|
|
Key: key,
|
|
Value: value,
|
|
}
|
|
return r.db.WithContext(ctx).
|
|
Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}},
|
|
DoUpdates: clause.AssignmentColumns([]string{"value"}),
|
|
}).
|
|
Create(&label).Error
|
|
}
|
|
|
|
// SetNodeLabels replaces all labels for a node with the given map.
|
|
func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil {
|
|
return err
|
|
}
|
|
for k, v := range labels {
|
|
label := NodeLabel{ID: uuid.New().String(), NodeID: nodeID, Key: k, Value: v}
|
|
if err := tx.Create(&label).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// RemoveNodeLabel removes a single label from a node.
|
|
func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error {
|
|
return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error
|
|
}
|
|
|
|
// GetNodeLabels returns all labels for a node.
|
|
func (r *NodeRegistry) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) {
|
|
var labels []NodeLabel
|
|
err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&labels).Error
|
|
return labels, err
|
|
}
|
|
|
|
// GetAllNodeLabelsMap returns all labels grouped by node ID.
|
|
func (r *NodeRegistry) GetAllNodeLabelsMap(ctx context.Context) (map[string]map[string]string, error) {
|
|
var labels []NodeLabel
|
|
if err := r.db.WithContext(ctx).Find(&labels).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
result := make(map[string]map[string]string)
|
|
for _, l := range labels {
|
|
if result[l.NodeID] == nil {
|
|
result[l.NodeID] = make(map[string]string)
|
|
}
|
|
result[l.NodeID][l.Key] = l.Value
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// --- Selector-based queries ---
|
|
|
|
// FindNodesBySelector returns healthy backend nodes matching ALL key-value pairs in the selector.
|
|
func (r *NodeRegistry) FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) {
|
|
if len(selector) == 0 {
|
|
// Empty selector matches all healthy backend nodes
|
|
var nodes []BackendNode
|
|
err := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend).Find(&nodes).Error
|
|
return nodes, err
|
|
}
|
|
|
|
db := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
|
|
for k, v := range selector {
|
|
db = db.Where("EXISTS (SELECT 1 FROM node_labels WHERE node_labels.node_id = backend_nodes.id AND node_labels.key = ? AND node_labels.value = ?)", k, v)
|
|
}
|
|
|
|
var nodes []BackendNode
|
|
err := db.Find(&nodes).Error
|
|
return nodes, err
|
|
}
|
|
|
|
// FindNodeWithVRAMFromSet is like FindNodeWithVRAM but restricted to the given node IDs.
|
|
func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
loadedModels := db.Model(&NodeModel{}).
|
|
Select("node_id").
|
|
Where("state = ?", "loaded").
|
|
Group("node_id")
|
|
|
|
subquery := db.Model(&NodeModel{}).
|
|
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
|
Group("node_id")
|
|
|
|
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
|
|
var node BackendNode
|
|
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
|
|
Order("available_vram DESC").
|
|
First(&node).Error
|
|
if err == nil {
|
|
return &node, nil
|
|
}
|
|
|
|
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
|
|
err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
|
|
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
|
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// FindIdleNodeFromSet is like FindIdleNode but restricted to the given node IDs.
|
|
func (r *NodeRegistry) FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
var node BackendNode
|
|
loadedModels := db.Model(&NodeModel{}).
|
|
Select("node_id").
|
|
Where("state = ?", "loaded").
|
|
Group("node_id")
|
|
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, loadedModels, nodeIDs).
|
|
Order("available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// FindLeastLoadedNodeFromSet is like FindLeastLoadedNode but restricted to the given node IDs.
|
|
func (r *NodeRegistry) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
|
|
db := r.db.WithContext(ctx)
|
|
|
|
var node BackendNode
|
|
query := db.Where("status = ? AND node_type = ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, nodeIDs)
|
|
// Order by total in-flight across all models on the node
|
|
subquery := db.Model(&NodeModel{}).
|
|
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
|
Group("node_id")
|
|
|
|
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
|
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
|
First(&node).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("finding least loaded node in set: %w", err)
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// --- ModelSchedulingConfig operations ---
|
|
|
|
// SetModelScheduling creates or updates a scheduling config for a model.
|
|
func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSchedulingConfig) error {
|
|
if config.ID == "" {
|
|
config.ID = uuid.New().String()
|
|
}
|
|
return r.db.WithContext(ctx).
|
|
Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "model_name"}},
|
|
DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
|
|
}).
|
|
Create(config).Error
|
|
}
|
|
|
|
// GetModelScheduling returns the scheduling config for a model, or nil if none exists.
|
|
func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) {
|
|
var config ModelSchedulingConfig
|
|
err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&config).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &config, nil
|
|
}
|
|
|
|
// ListModelSchedulings returns all scheduling configs.
|
|
func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedulingConfig, error) {
|
|
var configs []ModelSchedulingConfig
|
|
err := r.db.WithContext(ctx).Order("model_name ASC").Find(&configs).Error
|
|
return configs, err
|
|
}
|
|
|
|
// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled.
|
|
func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) {
|
|
var configs []ModelSchedulingConfig
|
|
err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error
|
|
return configs, err
|
|
}
|
|
|
|
// DeleteModelScheduling removes a scheduling config by model name.
|
|
func (r *NodeRegistry) DeleteModelScheduling(ctx context.Context, modelName string) error {
|
|
return r.db.WithContext(ctx).Where("model_name = ?", modelName).Delete(&ModelSchedulingConfig{}).Error
|
|
}
|
|
|
|
// CountLoadedReplicas returns the number of loaded replicas for a model.
|
|
func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string) (int64, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&NodeModel{}).Where("model_name = ? AND state = ?", modelName, "loaded").Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// --- Composite queries ---
|
|
|
|
// ListWithExtras returns all nodes with model counts and labels.
|
|
func (r *NodeRegistry) ListWithExtras(ctx context.Context) ([]NodeWithExtras, error) {
|
|
// Get all nodes
|
|
var nodes []BackendNode
|
|
if err := r.db.WithContext(ctx).Order("name ASC").Find(&nodes).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get model counts per node
|
|
type modelCount struct {
|
|
NodeID string
|
|
Count int
|
|
}
|
|
var counts []modelCount
|
|
if err := r.db.WithContext(ctx).Model(&NodeModel{}).
|
|
Select("node_id, COUNT(*) as count").
|
|
Where("state = ?", "loaded").
|
|
Group("node_id").
|
|
Find(&counts).Error; err != nil {
|
|
xlog.Warn("ListWithExtras: failed to get model counts", "error", err)
|
|
}
|
|
|
|
countMap := make(map[string]int)
|
|
for _, c := range counts {
|
|
countMap[c.NodeID] = c.Count
|
|
}
|
|
|
|
// Get all labels
|
|
labelsMap, err := r.GetAllNodeLabelsMap(ctx)
|
|
if err != nil {
|
|
xlog.Warn("ListWithExtras: failed to get labels", "error", err)
|
|
}
|
|
|
|
// Build result
|
|
result := make([]NodeWithExtras, len(nodes))
|
|
for i, n := range nodes {
|
|
result[i] = NodeWithExtras{
|
|
BackendNode: n,
|
|
ModelCount: countMap[n.ID],
|
|
Labels: labelsMap[n.ID],
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// ApplyAutoLabels sets automatic labels based on node hardware info.
|
|
func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node *BackendNode) {
|
|
if node.GPUVendor != "" {
|
|
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vendor", node.GPUVendor)
|
|
}
|
|
if node.TotalVRAM > 0 {
|
|
gb := node.TotalVRAM / (1024 * 1024 * 1024)
|
|
var bucket string
|
|
switch {
|
|
case gb >= 80:
|
|
bucket = "80GB+"
|
|
case gb >= 48:
|
|
bucket = "48GB"
|
|
case gb >= 24:
|
|
bucket = "24GB"
|
|
case gb >= 16:
|
|
bucket = "16GB"
|
|
case gb >= 8:
|
|
bucket = "8GB"
|
|
default:
|
|
bucket = fmt.Sprintf("%dGB", gb)
|
|
}
|
|
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vram", bucket)
|
|
}
|
|
if node.Name != "" {
|
|
_ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name)
|
|
}
|
|
}
|