Correctly handle model dir

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-02-04 20:04:51 +01:00
parent e1c6b0c2d5
commit 3b27ed6fba
2 changed files with 14 additions and 10 deletions

View File

@@ -370,6 +370,7 @@ def _generate_audio_sync(servicer, payload, dst_path):
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.model_path = None
self.model_dir = None
self.checkpoint_dir = None
self.project_root = None
self.options = {}
@@ -386,20 +387,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_name = (request.Model or "").strip()
model_file = (getattr(request, "ModelFile", None) or "").strip()
# Model dir: where we store checkpoints (always under LocalAI models path, never backend dir)
if model_path and model_name:
self.checkpoint_dir = model_path
self.project_root = os.path.dirname(model_path)
self.model_path = os.path.join(model_path, model_name)
model_dir = os.path.join(model_path, model_name)
elif model_file:
self.model_path = model_file
self.checkpoint_dir = os.path.dirname(model_file)
self.project_root = os.path.dirname(self.checkpoint_dir)
model_dir = model_file
else:
self.model_path = model_name or "."
self.checkpoint_dir = os.path.dirname(self.model_path) if self.model_path else "."
self.project_root = os.path.dirname(self.checkpoint_dir) if self.checkpoint_dir else "."
model_dir = os.path.abspath(model_name or ".")
self.model_dir = model_dir
self.checkpoint_dir = os.path.join(model_dir, "checkpoints")
self.project_root = model_dir
self.model_path = os.path.join(self.checkpoint_dir, model_name or os.path.basename(model_dir.rstrip("/\\")))
config_path = model_name or os.path.basename(self.model_path.rstrip("/\\"))
config_path = model_name or os.path.basename(model_dir.rstrip("/\\"))
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Auto-download DiT model and VAE if missing (same as upstream)
if config_path:
@@ -413,6 +414,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print(f"[ace-step] Warning: VAE download failed: {e}", file=sys.stderr)
self.dit_handler = AceStepHandler()
# Patch handler so it uses our model dir instead of site-packages/checkpoints
self.dit_handler._get_project_root = lambda: self.project_root
device = self.options.get("device", "auto")
use_flash = self.options.get("use_flash_attention", True)
if isinstance(use_flash, str):