mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-17 13:10:23 -04:00
* feat(concurrency-groups): per-model exclusive groups for backend loading Adds `concurrency_groups: [...]` to model YAML configs. Two models that share a group cannot be loaded concurrently on the same node — loading one evicts the others, reusing the existing pinned/busy/retry policy from LRU eviction. Layered design: - Watchdog (pkg/model): per-node correctness floor — on every Load(), evict any loaded model that shares a group with the requested one. Pinned skips surface NeedMore so the loader retries (and ultimately logs a clear warning), instead of silently allowing the rule to be violated. - Distributed scheduler (core/services/nodes): soft anti-affinity hint — scheduleNewModel prefers nodes that don't already host a same-group model, falling back to eviction only if every candidate has a conflict. Composes with NodeSelector at the same point in the candidate pipeline. Per-node, not cluster-wide: VRAM is a node-local resource, and two heavy models running on different nodes is fine. The ConfigLoader is wired into SmartRouter via a small ConcurrencyConflictResolver interface so the nodes package keeps a narrow surface on core/config. Refactors the inner LRU eviction body into a shared collectEvictionsLocked helper and the loader retry loop into retryEnforce(fn, maxRetries, interval), so both LRU and group enforcement share busy/pinned/retry semantics. Closes #9659. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(watchdog): sync pinned + concurrency_groups at startup The startup-time watchdog setup lives in initializeWatchdog (startup.go), not in startWatchdog (watchdog.go). The latter is only invoked from the runtime-settings RestartWatchdog path. As a result, neither SyncPinnedModelsToWatchdog nor SyncModelGroupsToWatchdog ran at boot, so `pinned: true` and `concurrency_groups: [...]` only became effective after a settings-driven watchdog restart. Fix by adding both sync calls to initializeWatchdog. Confirmed end-to-end: loading model A in group "heavy", then C with no group (coexists), then B in group "heavy" now correctly evicts A and leaves [B, C]. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(test): satisfy errcheck on new os.Remove in concurrency_groups spec CI lint runs new-from-merge-base, so the existing pre-existing `defer os.Remove(tmp.Name())` lines are baseline-grandfathered but the one introduced by the concurrency_groups YAML round-trip test is held to errcheck. Wrap the remove in a closure that discards the error. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
317 lines
9.0 KiB
Go
317 lines
9.0 KiB
Go
package config
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Test cases for config related functions", func() {
|
|
Context("Test Read configuration functions", func() {
|
|
It("Test Validate", func() {
|
|
tmp, err := os.CreateTemp("", "config.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer os.Remove(tmp.Name())
|
|
_, err = tmp.WriteString(
|
|
`backend: "../foo-bar"
|
|
name: "foo"
|
|
parameters:
|
|
model: "foo-bar"
|
|
known_usecases:
|
|
- chat
|
|
- COMPLETION
|
|
`)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
configs, err := readModelConfigsFromFile(tmp.Name())
|
|
config := configs[0]
|
|
Expect(err).To(BeNil())
|
|
Expect(config).ToNot(BeNil())
|
|
valid, err := config.Validate()
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(valid).To(BeFalse())
|
|
Expect(config.KnownUsecases).ToNot(BeNil())
|
|
})
|
|
It("Test Validate", func() {
|
|
tmp, err := os.CreateTemp("", "config.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer os.Remove(tmp.Name())
|
|
_, err = tmp.WriteString(
|
|
`name: bar-baz
|
|
backend: "foo-bar"
|
|
parameters:
|
|
model: "foo-bar"`)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
configs, err := readModelConfigsFromFile(tmp.Name())
|
|
config := configs[0]
|
|
Expect(err).To(BeNil())
|
|
Expect(config).ToNot(BeNil())
|
|
// two configs in config.yaml
|
|
Expect(config.Name).To(Equal("bar-baz"))
|
|
valid, err := config.Validate()
|
|
Expect(err).To(BeNil())
|
|
Expect(valid).To(BeTrue())
|
|
|
|
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
|
|
httpClient := http.Client{}
|
|
resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer resp.Body.Close()
|
|
tmp, err = os.CreateTemp("", "config.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer os.Remove(tmp.Name())
|
|
_, err = io.Copy(tmp, resp.Body)
|
|
Expect(err).To(BeNil())
|
|
configs, err = readModelConfigsFromFile(tmp.Name())
|
|
config = configs[0]
|
|
Expect(err).To(BeNil())
|
|
Expect(config).ToNot(BeNil())
|
|
// two configs in config.yaml
|
|
Expect(config.Name).To(Equal("hermes-2-pro-mistral"))
|
|
valid, err = config.Validate()
|
|
Expect(err).To(BeNil())
|
|
Expect(valid).To(BeTrue())
|
|
})
|
|
})
|
|
It("Properly handles backend usecase matching", func() {
|
|
|
|
a := ModelConfig{
|
|
Name: "a",
|
|
}
|
|
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
|
|
|
|
b := ModelConfig{
|
|
Name: "b",
|
|
Backend: "stablediffusion",
|
|
}
|
|
Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
|
|
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
|
|
|
c := ModelConfig{
|
|
Name: "c",
|
|
Backend: "llama-cpp",
|
|
TemplateConfig: TemplateConfig{
|
|
Chat: "chat",
|
|
},
|
|
}
|
|
Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
|
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
|
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
|
|
|
d := ModelConfig{
|
|
Name: "d",
|
|
Backend: "llama-cpp",
|
|
TemplateConfig: TemplateConfig{
|
|
Chat: "chat",
|
|
Completion: "completion",
|
|
},
|
|
}
|
|
Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
|
Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
|
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
|
|
|
trueValue := true
|
|
e := ModelConfig{
|
|
Name: "e",
|
|
Backend: "llama-cpp",
|
|
TemplateConfig: TemplateConfig{
|
|
Completion: "completion",
|
|
},
|
|
Embeddings: &trueValue,
|
|
}
|
|
|
|
Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
|
Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
|
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
|
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
|
|
|
|
f := ModelConfig{
|
|
Name: "f",
|
|
Backend: "piper",
|
|
}
|
|
Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
|
|
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
|
|
|
g := ModelConfig{
|
|
Name: "g",
|
|
Backend: "whisper",
|
|
}
|
|
Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
|
|
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
|
|
|
|
h := ModelConfig{
|
|
Name: "h",
|
|
Backend: "transformers-musicgen",
|
|
}
|
|
Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse())
|
|
Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue())
|
|
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
|
|
|
|
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
|
|
i := ModelConfig{
|
|
Name: "i",
|
|
Backend: "whisper",
|
|
// Earlier test checks parsing, this just needs to set final values
|
|
KnownUsecases: &knownUsecases,
|
|
}
|
|
Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue())
|
|
Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
|
|
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
|
|
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
|
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
|
})
|
|
It("Test Validate with invalid MCP config", func() {
|
|
tmp, err := os.CreateTemp("", "config.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer os.Remove(tmp.Name())
|
|
_, err = tmp.WriteString(
|
|
`name: test-mcp
|
|
backend: "llama-cpp"
|
|
mcp:
|
|
stdio: |
|
|
{
|
|
"mcpServers": {
|
|
"ddg": {
|
|
"command": "/docker/docker",
|
|
"args": ["run", "-i"]
|
|
}
|
|
"weather": {
|
|
"command": "/docker/docker",
|
|
"args": ["run", "-i"]
|
|
}
|
|
}
|
|
}`)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
configs, err := readModelConfigsFromFile(tmp.Name())
|
|
config := configs[0]
|
|
Expect(err).To(BeNil())
|
|
Expect(config).ToNot(BeNil())
|
|
valid, err := config.Validate()
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(valid).To(BeFalse())
|
|
Expect(err.Error()).To(ContainSubstring("invalid MCP configuration"))
|
|
})
|
|
It("Test Validate with valid MCP config", func() {
|
|
tmp, err := os.CreateTemp("", "config.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer os.Remove(tmp.Name())
|
|
_, err = tmp.WriteString(
|
|
`name: test-mcp-valid
|
|
backend: "llama-cpp"
|
|
mcp:
|
|
stdio: |
|
|
{
|
|
"mcpServers": {
|
|
"ddg": {
|
|
"command": "/docker/docker",
|
|
"args": ["run", "-i"]
|
|
},
|
|
"weather": {
|
|
"command": "/docker/docker",
|
|
"args": ["run", "-i"]
|
|
}
|
|
}
|
|
}`)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
configs, err := readModelConfigsFromFile(tmp.Name())
|
|
config := configs[0]
|
|
Expect(err).To(BeNil())
|
|
Expect(config).ToNot(BeNil())
|
|
valid, err := config.Validate()
|
|
Expect(err).To(BeNil())
|
|
Expect(valid).To(BeTrue())
|
|
})
|
|
It("Test Validate rejects unmarshalable engine_args", func() {
|
|
// chan values cannot be JSON-marshalled. A valid YAML config could
|
|
// not produce one, but a Go caller stuffing a bad value would, and
|
|
// silently dropping it would change runtime behaviour.
|
|
cfg := &ModelConfig{
|
|
Backend: "vllm",
|
|
LLMConfig: LLMConfig{
|
|
EngineArgs: map[string]any{
|
|
"speculative_config": make(chan int),
|
|
},
|
|
},
|
|
}
|
|
valid, err := cfg.Validate()
|
|
Expect(valid).To(BeFalse())
|
|
Expect(err).ToNot(BeNil())
|
|
Expect(err.Error()).To(ContainSubstring("engine_args is not JSON-serialisable"))
|
|
})
|
|
It("Test Validate accepts well-formed engine_args", func() {
|
|
cfg := &ModelConfig{
|
|
Backend: "vllm",
|
|
LLMConfig: LLMConfig{
|
|
EngineArgs: map[string]any{
|
|
"data_parallel_size": 8,
|
|
"speculative_config": map[string]any{
|
|
"method": "ngram",
|
|
"num_speculative_tokens": 4,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
valid, err := cfg.Validate()
|
|
Expect(err).To(BeNil())
|
|
Expect(valid).To(BeTrue())
|
|
})
|
|
Context("ConcurrencyGroups", func() {
|
|
It("returns nil when no groups are configured", func() {
|
|
cfg := &ModelConfig{Name: "no-groups"}
|
|
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
|
})
|
|
It("returns nil when all entries are blank", func() {
|
|
cfg := &ModelConfig{
|
|
Name: "blanks",
|
|
ConcurrencyGroups: []string{"", " ", "\t"},
|
|
}
|
|
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
|
})
|
|
It("trims whitespace, drops empty entries, and dedupes", func() {
|
|
cfg := &ModelConfig{
|
|
Name: "messy",
|
|
ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "},
|
|
}
|
|
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"}))
|
|
})
|
|
It("returns a defensive copy", func() {
|
|
cfg := &ModelConfig{
|
|
Name: "copy",
|
|
ConcurrencyGroups: []string{"heavy"},
|
|
}
|
|
got := cfg.GetConcurrencyGroups()
|
|
got[0] = "tampered"
|
|
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"}))
|
|
})
|
|
It("parses concurrency_groups from YAML", func() {
|
|
tmp, err := os.CreateTemp("", "concgroups.yaml")
|
|
Expect(err).To(BeNil())
|
|
defer func() { _ = os.Remove(tmp.Name()) }()
|
|
_, err = tmp.WriteString(
|
|
`name: heavy-a
|
|
backend: llama-cpp
|
|
parameters:
|
|
model: heavy-a.gguf
|
|
concurrency_groups:
|
|
- vram-heavy
|
|
- "120b"
|
|
`)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
configs, err := readModelConfigsFromFile(tmp.Name())
|
|
Expect(err).To(BeNil())
|
|
Expect(configs).To(HaveLen(1))
|
|
Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"}))
|
|
Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"}))
|
|
})
|
|
})
|
|
})
|