mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-05 12:12:39 -05:00
Correctly handle model dir
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
1
Makefile
1
Makefile
@@ -473,6 +473,7 @@ BACKEND_QWEN_TTS = qwen-tts|python|.|false|true
|
||||
BACKEND_QWEN_ASR = qwen-asr|python|.|false|true
|
||||
BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
|
||||
@@ -370,6 +370,7 @@ def _generate_audio_sync(servicer, payload, dst_path):
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.model_path = None
|
||||
self.model_dir = None
|
||||
self.checkpoint_dir = None
|
||||
self.project_root = None
|
||||
self.options = {}
|
||||
@@ -386,20 +387,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_name = (request.Model or "").strip()
|
||||
model_file = (getattr(request, "ModelFile", None) or "").strip()
|
||||
|
||||
# Model dir: where we store checkpoints (always under LocalAI models path, never backend dir)
|
||||
if model_path and model_name:
|
||||
self.checkpoint_dir = model_path
|
||||
self.project_root = os.path.dirname(model_path)
|
||||
self.model_path = os.path.join(model_path, model_name)
|
||||
model_dir = os.path.join(model_path, model_name)
|
||||
elif model_file:
|
||||
self.model_path = model_file
|
||||
self.checkpoint_dir = os.path.dirname(model_file)
|
||||
self.project_root = os.path.dirname(self.checkpoint_dir)
|
||||
model_dir = model_file
|
||||
else:
|
||||
self.model_path = model_name or "."
|
||||
self.checkpoint_dir = os.path.dirname(self.model_path) if self.model_path else "."
|
||||
self.project_root = os.path.dirname(self.checkpoint_dir) if self.checkpoint_dir else "."
|
||||
model_dir = os.path.abspath(model_name or ".")
|
||||
self.model_dir = model_dir
|
||||
self.checkpoint_dir = os.path.join(model_dir, "checkpoints")
|
||||
self.project_root = model_dir
|
||||
self.model_path = os.path.join(self.checkpoint_dir, model_name or os.path.basename(model_dir.rstrip("/\\")))
|
||||
|
||||
config_path = model_name or os.path.basename(self.model_path.rstrip("/\\"))
|
||||
config_path = model_name or os.path.basename(model_dir.rstrip("/\\"))
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Auto-download DiT model and VAE if missing (same as upstream)
|
||||
if config_path:
|
||||
@@ -413,6 +414,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"[ace-step] Warning: VAE download failed: {e}", file=sys.stderr)
|
||||
|
||||
self.dit_handler = AceStepHandler()
|
||||
# Patch handler so it uses our model dir instead of site-packages/checkpoints
|
||||
self.dit_handler._get_project_root = lambda: self.project_root
|
||||
device = self.options.get("device", "auto")
|
||||
use_flash = self.options.get("use_flash_attention", True)
|
||||
if isinstance(use_flash, str):
|
||||
|
||||
Reference in New Issue
Block a user