mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
Compare commits
36 Commits
fix/apt-mi
...
ci/layered
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d42a16c20 | ||
|
|
9c1f8b344c | ||
|
|
a3b7c3a819 | ||
|
|
4e154b59e5 | ||
|
|
969005b2a1 | ||
|
|
6d56bf98fe | ||
|
|
a8d7d37a3c | ||
|
|
06a1524155 | ||
|
|
70cf8ac546 | ||
|
|
7fab5e3d21 | ||
|
|
af83518532 | ||
|
|
a315c321c1 | ||
|
|
75fba9e03f | ||
|
|
16b2d4c807 | ||
|
|
8e43842175 | ||
|
|
503904d311 | ||
|
|
d5ce823b83 | ||
|
|
c9141098b6 | ||
|
|
1caab1de10 | ||
|
|
e86ade54a6 | ||
|
|
1634eece6b | ||
|
|
b88ddce0f3 | ||
|
|
bbcaebc1ef | ||
|
|
22ae415695 | ||
|
|
3a0164670e | ||
|
|
a91b05907c | ||
|
|
4ef45bbccd | ||
|
|
b224a3d931 | ||
|
|
bb033b16a9 | ||
|
|
de83b72bb7 | ||
|
|
1aeb4d7e73 | ||
|
|
a271c72931 | ||
|
|
ade5fd4b97 | ||
|
|
170d55c67d | ||
|
|
28b4857bd6 | ||
|
|
5503be1fb3 |
@@ -5,12 +5,16 @@ Container builds — both the root LocalAI image (`Dockerfile`) and the per-back
|
||||
## Cache layout
|
||||
|
||||
- **Cache registry**: `quay.io/go-skynet/ci-cache`
|
||||
- **One tag per matrix entry**, derived from the existing `tag-suffix`:
|
||||
- Backend builds (`backend_build.yml`): `cache<tag-suffix>`
|
||||
- **Tag prefixes**:
|
||||
- Backend builds (`backend_build.yml`) buildkit cache: `cache<tag-suffix>`
|
||||
- e.g. `cache-gpu-nvidia-cuda-12-llama-cpp`, `cache-cpu-vllm`, `cache-nvidia-l4t-cuda-13-arm64-vllm`
|
||||
- Root image builds (`image_build.yml`): `cache-localai<tag-suffix>`
|
||||
- Root image builds (`image_build.yml`) buildkit cache: `cache-localai<tag-suffix>`
|
||||
- e.g. `cache-localai-gpu-nvidia-cuda-12`, `cache-localai-gpu-vulkan`
|
||||
- Each tag stores a multi-arch BuildKit cache manifest (`mode=max`), so every intermediate stage is re-usable, not just the final image.
|
||||
- Layered base builds (`base_images.yml`) buildkit cache: `base-<stem>`
|
||||
- e.g. `base-python-cpu-2404`, `base-cpp-cublas-2404-cuda13.0`
|
||||
- Layered base **images** (the OCI manifests consumers FROM): `base-image-<stem>[-pr<N>]`
|
||||
- e.g. `base-image-python-cpu-2404`, `base-image-cpp-cublas-2404-cuda13.0-pr9672`
|
||||
- The cache tags store multi-arch BuildKit cache manifests (`mode=max`); the `base-image-*` tags store ordinary OCI image manifests.
|
||||
|
||||
## Read/write semantics
|
||||
|
||||
@@ -101,6 +105,170 @@ For ccache, the workflow exports `CMAKE_ARGS=… -DCMAKE_C_COMPILER_LAUNCHER=cca
|
||||
|
||||
GitHub Actions caches are limited to 10 GB per repo. Steady-state worst case: ~800 MB Go cache + ~2 GB brew Cellar + up to 2 GB ccache + ~1.5 GB × 5 python backends. If the cap is hit, prefer collapsing the per-backend Python keys into a shared `pyenv-darwin-shared-<week>` key (accepts more cross-backend churn for a smaller footprint) before reducing other caches.
|
||||
|
||||
## Layered base images (`ci-cache:base-image-*`)
|
||||
|
||||
The registry-backed BuildKit cache deduplicates **within** a matrix entry's
|
||||
cache tag, but each matrix entry has its own tag — so the same `apt-get`,
|
||||
GPU SDK install, and language toolchain bootstrap runs into N different
|
||||
cache tags across the backend matrix. The layered base images factor that
|
||||
shared work out of the per-backend builds.
|
||||
|
||||
They live in the same `quay.io/go-skynet/ci-cache` repo as the buildkit
|
||||
caches, under a distinct `base-image-` tag prefix so the OCI image
|
||||
manifests coexist with `base-<stem>` (the cache for building the base),
|
||||
`cache<tag-suffix>` (per-backend caches), and `cache-localai<tag-suffix>`
|
||||
(root image caches). Reusing `ci-cache` means no new quay repo or robot
|
||||
grant is needed — the same credentials that write the cache also write
|
||||
the image.
|
||||
|
||||
### How it fits together
|
||||
|
||||
```
|
||||
.github/backend-matrix.yaml # raw matrix data (linux + darwin)
|
||||
│
|
||||
▼
|
||||
backend.yml / backend_pr.yml
|
||||
├── derive-bases / generate-matrix
|
||||
│ scripts/changed-backends.js
|
||||
│ reads .github/backend-matrix.yaml
|
||||
│ (PR mode also reads changed files)
|
||||
│ emits:
|
||||
│ - matrix (annotated with base-image-prebuilt)
|
||||
│ - matrix-darwin
|
||||
│ - bases-matrix (deduplicated by tag-stem)
|
||||
│
|
||||
├── build-bases (matrix: bases-matrix)
|
||||
│ uses base_images.yml
|
||||
│ FROM .docker/bases/Dockerfile.<lang>
|
||||
│ pushes quay.io/go-skynet/ci-cache:base-image-<stem>[-pr<N>]
|
||||
│
|
||||
└── backend-jobs (matrix: matrix; needs build-bases)
|
||||
uses backend_build.yml
|
||||
FROM ${BASE_IMAGE_PREBUILT}
|
||||
i.e. quay.io/go-skynet/ci-cache:base-image-<stem>[-pr<N>]
|
||||
only the backend source COPY + `make` remain.
|
||||
```
|
||||
|
||||
The base image is **always** built before backends consume it, in the same
|
||||
workflow run. There is no cross-workflow dependency, no chicken-and-egg
|
||||
on first push, and no manual matrix to keep in sync — adding a backend
|
||||
matrix entry is just an edit to `.github/backend-matrix.yaml`.
|
||||
|
||||
### Tag scheme
|
||||
|
||||
`<stem>` is computed by `tagStem()` in `scripts/changed-backends.js` from
|
||||
the (lang, build-type, ubuntu, cuda, base-image) tuple. Arch is
|
||||
intentionally NOT in the stem — bases are built multi-arch when any
|
||||
consumer needs multi-arch, and single-arch otherwise (the `platforms`
|
||||
field on each base entry is the union of its consumers' platforms).
|
||||
|
||||
| Build-type | Stem template |
|
||||
|---|---|
|
||||
| `''` (CPU) | `<lang>-cpu-<ubuntu>[-<base-image-slug>]` |
|
||||
| `cublas` / `l4t` | `<lang>-<build-type>-<ubuntu>-cuda<major>.<minor>[-<base-image-slug>]` |
|
||||
| anything else (vulkan, hipblas, intel, sycl_*) | `<lang>-<build-type>-<ubuntu>[-<base-image-slug>]` |
|
||||
|
||||
The base-image slug is empty for the default `ubuntu:24.04` and a short
|
||||
parseable suffix otherwise (`jetpack-r36.4.0`, `rocm-7.2.1`,
|
||||
`oneapi-2025.3.2`, etc.).
|
||||
|
||||
| Event | Pushed tag (in `quay.io/go-skynet/ci-cache`) |
|
||||
|---|---|
|
||||
| `push` (master/tag) | `:base-image-<stem>` |
|
||||
| `pull_request` | `:base-image-<stem>-pr<PR_NUMBER>` |
|
||||
|
||||
The buildkit cache for the base build itself lives at
|
||||
`quay.io/go-skynet/ci-cache:base-<stem>` (`mode=max,ignore-error=true`),
|
||||
parallel to the per-matrix-entry caches. The `base-` (cache) and
|
||||
`base-image-` (image) prefixes never collide.
|
||||
|
||||
The script also runs a collision check across consumers of each stem: if
|
||||
two consumers map to the same stem but disagree on `base-image` or
|
||||
`skip-drivers` (and skip-drivers is meaningful for that build-type), the
|
||||
script fails loudly. Resolve by encoding the differing input in
|
||||
`tagStem()` rather than letting the dedup silently pick a winner.
|
||||
|
||||
### PR testability
|
||||
|
||||
PRs run the same pipeline as master: derive bases → build bases (tagged
|
||||
`-pr<N>`) → run filtered backend matrix consuming those `-pr<N>` tags.
|
||||
End-to-end validation always lives within the PR.
|
||||
|
||||
For PRs that only change `.docker/bases/Dockerfile.<lang>` (no backend
|
||||
source touched), `changed-backends.js` adds one canary backend matrix
|
||||
entry per (lang × build-type × arch × cuda × ubuntu) tuple to the filtered
|
||||
matrix so each base flavour gets exercised.
|
||||
|
||||
### Existing language tiers
|
||||
|
||||
| Tier (lang) | Recipe | Consumer Dockerfile(s) | Distinct stems |
|
||||
|---|---|---|---|
|
||||
| `python` | `.docker/bases/Dockerfile.python` | `backend/Dockerfile.python` | 9 |
|
||||
| `golang` | `.docker/bases/Dockerfile.golang` | `backend/Dockerfile.golang` | 8 |
|
||||
| `cpp` | `.docker/bases/Dockerfile.cpp` (apt + GPU + protoc + cmake + GRPC) | `backend/Dockerfile.{llama-cpp,ik-llama-cpp,turboquant}` | 8 |
|
||||
| `rust` | `.docker/bases/Dockerfile.rust` | `backend/Dockerfile.rust` | 1 |
|
||||
|
||||
The C++ trio share a single `cpp` base because they only differ in their
|
||||
per-backend `make` targets. `langOf()` in `scripts/changed-backends.js`
|
||||
remaps `Dockerfile.{llama-cpp,ik-llama-cpp,turboquant}` → `cpp` so dedup
|
||||
works across the trio. If a future C++ consumer needs a *different* base
|
||||
(e.g. without GRPC, or with a different protoc version), give it its own
|
||||
`Dockerfile.<newlang>` recipe and remove it from the cpp remap.
|
||||
|
||||
### Adding a new (accel × arch × cuda × lang) flavour
|
||||
|
||||
Just add the matrix entry to `.github/backend-matrix.yaml` for the new
|
||||
flavour. The bases matrix and the per-entry `base-image-prebuilt` are
|
||||
derived automatically by `scripts/changed-backends.js`. Nothing else to
|
||||
change.
|
||||
|
||||
### Adding a new language tier
|
||||
|
||||
1. Create `.docker/bases/Dockerfile.<lang>` mirroring an existing tier
|
||||
(apt + accel install + lang-specific toolchain).
|
||||
2. Slim `backend/Dockerfile.<lang>` to `FROM ${BASE_IMAGE_PREBUILT}` plus
|
||||
the per-backend source COPY + build (no inline accel install).
|
||||
3. Add the new recipe to `baseTriggerFiles` in
|
||||
`scripts/changed-backends.js` so PRs touching it fan out to canaries.
|
||||
4. Add `<lang>: (item) => item.dockerfile.endsWith("<lang>")` to
|
||||
`langTriggerSelector` in the same file.
|
||||
5. Add a `LOCAL_BASE_<LANG>_TAG`, a `docker-build-<lang>-base` target,
|
||||
and a clause in `local-base-tag` / `local-base-target` in `Makefile`.
|
||||
|
||||
The `langsWithBase` set in `scripts/changed-backends.js` is auto-detected
|
||||
from the `.docker/bases/` directory at script startup, so step 1 alone is
|
||||
enough for the script to start emitting bases (and annotating matrix
|
||||
entries with `base-image-prebuilt`) for that lang. Steps 3–5 plug it
|
||||
into the canary fan-out and the local-build path.
|
||||
|
||||
### Why not just rely on `mode=max` cache?
|
||||
|
||||
`mode=max` deduplicates at the layer level, but each matrix entry has its
|
||||
own cache tag (`cache<tag-suffix>`). A change that invalidates the GPU SDK
|
||||
layer in one backend does not invalidate it in any other; each entry pays
|
||||
the full cost on its next rebuild. The shared base image is built once per
|
||||
(accel × arch × cuda × lang), then pulled by every backend that consumes
|
||||
it — that's the actual cross-matrix dedup.
|
||||
|
||||
### Local builds
|
||||
|
||||
All `backend/Dockerfile.{python,golang,cpp,rust}` consumers require
|
||||
`BASE_IMAGE_PREBUILT` (no inline fallback). The Makefile wires the right
|
||||
`docker-build-<lang>-base` as a prerequisite for each backend's
|
||||
`docker-build-<backend>` target, so:
|
||||
|
||||
```bash
|
||||
# Build any backend; the matching base is built first if needed.
|
||||
make docker-build-vllm BUILD_TYPE=cublas CUDA_MAJOR_VERSION=12 CUDA_MINOR_VERSION=8
|
||||
make docker-build-llama-cpp BUILD_TYPE=cublas CUDA_MAJOR_VERSION=13 CUDA_MINOR_VERSION=0
|
||||
make docker-build-rerankers # golang
|
||||
make docker-build-kokoros # rust
|
||||
```
|
||||
|
||||
Or build a base directly: `make docker-build-{python,golang,cpp,rust}-base
|
||||
BUILD_TYPE=...`. Or pull a pre-built one from quay if it exists for your
|
||||
target tuple.
|
||||
|
||||
## Touching the cache pipeline
|
||||
|
||||
When changing `image_build.yml`, `backend_build.yml`, or any of the `backend/Dockerfile.*` files:
|
||||
@@ -109,3 +277,4 @@ When changing `image_build.yml`, `backend_build.yml`, or any of the `backend/Doc
|
||||
2. **Keep `tag-suffix` unique per matrix entry** — it's the cache namespace. Two matrix entries sharing a tag-suffix would clobber each other's cache.
|
||||
3. **Keep `cache-to` gated on `github.event_name != 'pull_request'`** — PRs must not write.
|
||||
4. **Keep `ignore-error=true` on `cache-to`** — quay registry hiccups must not fail builds.
|
||||
5. **`tagStem()` in `scripts/changed-backends.js` is the single source of truth for base image tags.** The matrix entries are annotated with `base-image-prebuilt` in the same script run; backend-jobs reads the value as-is. There's no parallel YAML expression to keep in sync. Adding a new dimension to the stem (e.g. a slug for a new base-image variant) is a script change only.
|
||||
|
||||
259
.docker/bases/Dockerfile.cpp
Normal file
259
.docker/bases/Dockerfile.cpp
Normal file
@@ -0,0 +1,259 @@
|
||||
# Shared C++ + accelerator base image for the llama-cpp / ik-llama-cpp /
|
||||
# turboquant trio. They differ only in their Makefile targets at build
|
||||
# time; the apt + GPU SDK + protoc + cmake + GRPC install is identical.
|
||||
#
|
||||
# Built once per (build-type, arch, ubuntu-version, cuda-version) combination
|
||||
# by .github/workflows/base_images.yml and pushed to
|
||||
# quay.io/go-skynet/ci-cache:base-image-<tag-stem>[-pr<N>]. Consumed by
|
||||
# backend/Dockerfile.{llama-cpp,ik-llama-cpp,turboquant} via the
|
||||
# BASE_IMAGE_PREBUILT build-arg. See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
|
||||
FROM ${BASE_IMAGE} AS grpc
|
||||
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Build GRPC into /opt/grpc so we can copy it into the final base without
|
||||
# pulling in the full source tree. Mirrors the original two-stage layout in
|
||||
# Dockerfile.llama-cpp; absorbing it here means consumers no longer pay the
|
||||
# GRPC compile cost.
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/mudler/LocalAI"
|
||||
LABEL org.opencontainers.image.description="LocalAI C++ (llama-cpp/ik-llama-cpp/turboquant) base image"
|
||||
LABEL org.localai.base.lang="cpp"
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
ldconfig && \
|
||||
echo "rocBLAS library data architectures:" && \
|
||||
(ls /opt/rocm*/lib/rocblas/library/Kernels* 2>/dev/null || ls /opt/rocm*/lib64/rocblas/library/Kernels* 2>/dev/null) | grep -oP 'gfx[0-9a-z+-]+' | sort -u || \
|
||||
echo "WARNING: No rocBLAS kernel data found" \
|
||||
; fi
|
||||
|
||||
# Install protoc (the version in 22.04 is too old, and grpc's bundled protoc
|
||||
# would pull in a newer absl that breaks stablediffusion).
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
206
.docker/bases/Dockerfile.golang
Normal file
206
.docker/bases/Dockerfile.golang
Normal file
@@ -0,0 +1,206 @@
|
||||
# Shared Go + accelerator base image.
|
||||
#
|
||||
# Built once per (build-type, arch, ubuntu-version, cuda-version) combination
|
||||
# by .github/workflows/base_images.yml and pushed to
|
||||
# quay.io/go-skynet/ci-cache:base-image-<tag-stem>[-pr<N>]. Consumed by
|
||||
# backend/Dockerfile.golang via the BASE_IMAGE_PREBUILT build-arg.
|
||||
#
|
||||
# Mirrors the GPU stack stanzas in Dockerfile.python; the language-specific
|
||||
# tail at the bottom installs Go + grpc tooling. See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/mudler/LocalAI"
|
||||
LABEL org.opencontainers.image.description="LocalAI Go+accelerator base image"
|
||||
LABEL org.localai.base.lang="golang"
|
||||
|
||||
# gcc-14 is the default on noble (ubuntu:24.04) but absent from jammy
|
||||
# (the L4T jetpack r36.4.0 base). LocalVQE needs it; the other Go backends
|
||||
# compile with the default gcc shipped via build-essential. Try gcc-14
|
||||
# from the configured repos and fall back gracefully when it's missing.
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget libopenblas-dev \
|
||||
curl unzip \
|
||||
libssl-dev && \
|
||||
if apt-cache show gcc-14 >/dev/null 2>&1 && apt-cache show g++-14 >/dev/null 2>&1; then \
|
||||
apt-get install -y --no-install-recommends gcc-14 g++-14 && \
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 \
|
||||
--slave /usr/bin/g++ g++ /usr/bin/g++-14 \
|
||||
--slave /usr/bin/gcov gcov /usr/bin/gcov-14; \
|
||||
fi && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
# Install Go
|
||||
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
|
||||
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin:/usr/local/bin
|
||||
|
||||
# Install grpc compilers
|
||||
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 && \
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
|
||||
# Install protoc (the version in 22.04 is too old, and grpc's bundled protoc
|
||||
# would pull in a newer absl that breaks stablediffusion).
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
209
.docker/bases/Dockerfile.python
Normal file
209
.docker/bases/Dockerfile.python
Normal file
@@ -0,0 +1,209 @@
|
||||
# Shared Python + accelerator base image.
|
||||
#
|
||||
# Built once per (build-type, arch, ubuntu-version, cuda-version) combination
|
||||
# by .github/workflows/base_images.yml and pushed to
|
||||
# quay.io/go-skynet/ci-cache:base-image-<tag-stem>[-pr<N>]. Consumed by
|
||||
# backend/Dockerfile.python via the BASE_IMAGE_PREBUILT build-arg.
|
||||
# See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/mudler/LocalAI"
|
||||
LABEL org.opencontainers.image.description="LocalAI Python+accelerator base image"
|
||||
LABEL org.localai.base.lang="python"
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache \
|
||||
ca-certificates \
|
||||
espeak-ng \
|
||||
curl \
|
||||
libssl-dev \
|
||||
git wget \
|
||||
git-lfs \
|
||||
unzip clang \
|
||||
upx-ucl \
|
||||
curl python3-pip \
|
||||
python-is-python3 \
|
||||
python3-dev llvm \
|
||||
libnuma1 libgomp1 \
|
||||
python3-venv make cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN <<EOT bash
|
||||
if [ "${UBUNTU_VERSION}" = "2404" ]; then
|
||||
pip install --break-system-packages --user --upgrade pip
|
||||
else
|
||||
pip install --upgrade pip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
||||
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
|
||||
; fi
|
||||
|
||||
# Install uv as a system package
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | UV_INSTALL_DIR=/usr/bin sh
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
# Increase timeout for uv installs behind slow networks
|
||||
ENV UV_HTTP_TIMEOUT=180
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
# Install grpcio-tools (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${UBUNTU_VERSION}" = "2404" ]; then
|
||||
pip install --break-system-packages --user grpcio-tools==1.71.0 grpcio==1.71.0
|
||||
else
|
||||
pip install grpcio-tools==1.71.0 grpcio==1.71.0
|
||||
fi
|
||||
EOT
|
||||
47
.docker/bases/Dockerfile.rust
Normal file
47
.docker/bases/Dockerfile.rust
Normal file
@@ -0,0 +1,47 @@
|
||||
# Shared Rust base image for the kokoros backend.
|
||||
#
|
||||
# Built once per (ubuntu-version) by .github/workflows/base_images.yml and
|
||||
# pushed to quay.io/go-skynet/ci-cache:base-image-<tag-stem>[-pr<N>]. The
|
||||
# current rust matrix is CPU-only, so this base skips the GPU SDK stanzas;
|
||||
# if a future rust backend needs cublas/rocm/etc., promote this recipe to
|
||||
# mirror Dockerfile.python's GPU stack. See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/mudler/LocalAI"
|
||||
LABEL org.opencontainers.image.description="LocalAI Rust base image"
|
||||
LABEL org.localai.base.lang="rust"
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
curl unzip \
|
||||
clang \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
espeak-ng libespeak-ng-dev \
|
||||
libsonic-dev libpcaudio-dev \
|
||||
libopus-dev \
|
||||
protobuf-compiler && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
13
.github/actions/configure-apt-mirror/action.yml
vendored
13
.github/actions/configure-apt-mirror/action.yml
vendored
@@ -28,11 +28,20 @@ inputs:
|
||||
self-hosted-mirror:
|
||||
description: 'archive/security mirror URL for self-hosted runners (empty = upstream)'
|
||||
required: false
|
||||
default: 'https://mirrors.edge.kernel.org'
|
||||
# HTTP, not HTTPS: the bare ubuntu:24.04 builder image doesn't ship
|
||||
# ca-certificates, so the very first apt-get update over TLS would
|
||||
# fail with "No system certificates available" before it can install
|
||||
# anything. apt validates package integrity via GPG signatures, so
|
||||
# plain HTTP is safe for the archive itself.
|
||||
default: 'http://mirrors.edge.kernel.org'
|
||||
self-hosted-ports-mirror:
|
||||
description: 'ports.ubuntu.com mirror URL for self-hosted runners (empty = upstream)'
|
||||
required: false
|
||||
default: 'https://mirrors.edge.kernel.org'
|
||||
# mirrors.edge.kernel.org does NOT carry /ubuntu-ports/ — only the
|
||||
# main /ubuntu/ archive — so arm64 builds 404 there. Leave ports
|
||||
# upstream by default. The original DDoS was on archive.ubuntu.com
|
||||
# so ports.ubuntu.com remains the path of least surprise.
|
||||
default: ''
|
||||
|
||||
outputs:
|
||||
effective-mirror:
|
||||
|
||||
3164
.github/backend-matrix.yaml
vendored
Normal file
3164
.github/backend-matrix.yaml
vendored
Normal file
File diff suppressed because it is too large
Load Diff
3237
.github/workflows/backend.yml
vendored
3237
.github/workflows/backend.yml
vendored
File diff suppressed because it is too large
Load Diff
12
.github/workflows/backend_build.yml
vendored
12
.github/workflows/backend_build.yml
vendored
@@ -63,6 +63,16 @@ on:
|
||||
required: false
|
||||
default: ''
|
||||
type: string
|
||||
base-image-prebuilt:
|
||||
description: |
|
||||
Optional reference to a prebuilt accel/lang base image
|
||||
(quay.io/go-skynet/ci-cache:base-image-<stem>[-pr<N>]). When
|
||||
set, the backend Dockerfile FROMs this image instead of running
|
||||
an inline bootstrap. See .github/workflows/base_images.yml and
|
||||
.agents/ci-caching.md.
|
||||
required: false
|
||||
default: ''
|
||||
type: string
|
||||
secrets:
|
||||
dockerUsername:
|
||||
required: false
|
||||
@@ -228,6 +238,7 @@ jobs:
|
||||
APT_MIRROR=${{ steps.apt_mirror.outputs.effective-mirror }}
|
||||
APT_PORTS_MIRROR=${{ steps.apt_mirror.outputs.effective-ports-mirror }}
|
||||
DEPS_REFRESH=${{ steps.deps_refresh.outputs.key }}
|
||||
BASE_IMAGE_PREBUILT=${{ inputs.base-image-prebuilt }}
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.dockerfile }}
|
||||
cache-from: type=registry,ref=quay.io/go-skynet/ci-cache:cache${{ inputs.tag-suffix }}
|
||||
@@ -254,6 +265,7 @@ jobs:
|
||||
APT_MIRROR=${{ steps.apt_mirror.outputs.effective-mirror }}
|
||||
APT_PORTS_MIRROR=${{ steps.apt_mirror.outputs.effective-ports-mirror }}
|
||||
DEPS_REFRESH=${{ steps.deps_refresh.outputs.key }}
|
||||
BASE_IMAGE_PREBUILT=${{ inputs.base-image-prebuilt }}
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.dockerfile }}
|
||||
cache-from: type=registry,ref=quay.io/go-skynet/ci-cache:cache${{ inputs.tag-suffix }}
|
||||
|
||||
38
.github/workflows/backend_pr.yml
vendored
38
.github/workflows/backend_pr.yml
vendored
@@ -13,8 +13,10 @@ jobs:
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }}
|
||||
bases-matrix: ${{ steps.set-matrix.outputs.bases-matrix }}
|
||||
has-backends: ${{ steps.set-matrix.outputs.has-backends }}
|
||||
has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
|
||||
has-bases: ${{ steps.set-matrix.outputs.has-bases }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -27,7 +29,8 @@ jobs:
|
||||
bun add js-yaml
|
||||
bun add @octokit/core
|
||||
|
||||
# filters the matrix in backend.yml
|
||||
# Filters the matrix from backend.yml against this PR's changed files
|
||||
# AND derives the deduplicated bases-matrix consumed by build-bases.
|
||||
- name: Filter matrix for changed backends
|
||||
id: set-matrix
|
||||
env:
|
||||
@@ -35,10 +38,34 @@ jobs:
|
||||
GITHUB_EVENT_PATH: ${{ github.event_path }}
|
||||
run: bun run scripts/changed-backends.js
|
||||
|
||||
backend-jobs:
|
||||
build-bases:
|
||||
needs: generate-matrix
|
||||
if: needs.generate-matrix.outputs.has-bases == 'true'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.generate-matrix.outputs.bases-matrix) }}
|
||||
uses: ./.github/workflows/base_images.yml
|
||||
with:
|
||||
lang: ${{ matrix.lang }}
|
||||
base-image: ${{ matrix.base-image }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
cuda-major-version: ${{ matrix.cuda-major-version }}
|
||||
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
||||
ubuntu-version: ${{ matrix.ubuntu-version }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
runs-on: ${{ matrix.runs-on }}
|
||||
tag-stem: ${{ matrix.tag-stem }}
|
||||
skip-drivers: ${{ matrix.skip-drivers }}
|
||||
secrets:
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
|
||||
backend-jobs:
|
||||
needs: [generate-matrix, build-bases]
|
||||
uses: ./.github/workflows/backend_build.yml
|
||||
if: needs.generate-matrix.outputs.has-backends == 'true'
|
||||
if: |
|
||||
always() && needs.generate-matrix.outputs.has-backends == 'true' &&
|
||||
(needs.build-bases.result == 'success' || needs.build-bases.result == 'skipped')
|
||||
with:
|
||||
tag-latest: ${{ matrix.tag-latest }}
|
||||
tag-suffix: ${{ matrix.tag-suffix }}
|
||||
@@ -54,12 +81,17 @@ jobs:
|
||||
context: ${{ matrix.context }}
|
||||
ubuntu-version: ${{ matrix.ubuntu-version }}
|
||||
amdgpu-targets: ${{ matrix.amdgpu-targets || 'gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201' }}
|
||||
# The script annotates each filtered Python entry with the prebuilt
|
||||
# base ref it should consume; non-Python entries get '' and run their
|
||||
# own inline bootstrap.
|
||||
base-image-prebuilt: ${{ matrix.base-image-prebuilt || '' }}
|
||||
secrets:
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
|
||||
|
||||
backend-jobs-darwin:
|
||||
needs: generate-matrix
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
|
||||
152
.github/workflows/base_images.yml
vendored
Normal file
152
.github/workflows/base_images.yml
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
---
|
||||
name: 'build base image (reusable)'
|
||||
|
||||
# Builds and pushes one (lang, accel, arch, ubuntu, cuda) base image flavour
|
||||
# to quay.io/go-skynet/ci-cache:base-image-<stem>[-pr<N>]. Consumed by
|
||||
# backend builds via the BASE_IMAGE_PREBUILT build-arg. PR builds tag with
|
||||
# `-pr${PR_NUMBER}` so the same PR's backend matrix can opt-in to the
|
||||
# freshly-built base; master builds overwrite the unsuffixed tag for
|
||||
# downstream consumption. The image lives in the same ci-cache repo as the
|
||||
# buildkit cache (under a `base-image-` prefix that doesn't collide with
|
||||
# the `base-<stem>` cache prefix), so no separate quay repo + grant is
|
||||
# needed. See .agents/ci-caching.md for the full tagging scheme.
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
lang:
|
||||
description: 'Language toolchain (matches .docker/bases/Dockerfile.<lang>)'
|
||||
required: true
|
||||
type: string
|
||||
base-image:
|
||||
description: 'Upstream base image (ubuntu:24.04, rocm/dev-ubuntu-24.04:..., etc.)'
|
||||
required: true
|
||||
type: string
|
||||
build-type:
|
||||
description: 'BUILD_TYPE: empty for CPU, cublas, hipblas, vulkan, l4t, ...'
|
||||
default: ''
|
||||
type: string
|
||||
cuda-major-version:
|
||||
description: 'CUDA major version (only meaningful for cublas/l4t)'
|
||||
default: '12'
|
||||
type: string
|
||||
cuda-minor-version:
|
||||
description: 'CUDA minor version'
|
||||
default: '9'
|
||||
type: string
|
||||
ubuntu-version:
|
||||
description: 'Ubuntu version code (2204, 2404)'
|
||||
default: '2404'
|
||||
type: string
|
||||
platforms:
|
||||
description: 'Single platform per call (linux/amd64 or linux/arm64)'
|
||||
required: true
|
||||
type: string
|
||||
runs-on:
|
||||
description: 'Runner label'
|
||||
required: true
|
||||
type: string
|
||||
tag-stem:
|
||||
description: 'Stable portion of the image tag (e.g. python-cpu-amd64-2404)'
|
||||
required: true
|
||||
type: string
|
||||
skip-drivers:
|
||||
description: 'Pass-through to the base Dockerfile'
|
||||
default: 'false'
|
||||
type: string
|
||||
secrets:
|
||||
quayUsername:
|
||||
required: false
|
||||
quayPassword:
|
||||
required: false
|
||||
outputs:
|
||||
image-ref:
|
||||
description: 'Full image reference of the built base'
|
||||
value: ${{ jobs.base-build.outputs.image-ref }}
|
||||
|
||||
jobs:
|
||||
base-build:
|
||||
runs-on: ${{ inputs.runs-on }}
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
outputs:
|
||||
image-ref: ${{ steps.compute_ref.outputs.ref }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Configure apt mirror on runner
|
||||
id: apt_mirror
|
||||
uses: ./.github/actions/configure-apt-mirror
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
if: inputs.runs-on == 'ubuntu-latest'
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: true
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Compute image ref
|
||||
id: compute_ref
|
||||
run: |
|
||||
stem='${{ inputs.tag-stem }}'
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
tag="${stem}-pr${{ github.event.number }}"
|
||||
else
|
||||
tag="${stem}"
|
||||
fi
|
||||
echo "tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
# Published into the existing ci-cache repo (the CI robot already
|
||||
# has write access there) under a distinct `base-image-` prefix so
|
||||
# the OCI image tags coexist with the buildkit cache tags
|
||||
# (`base-<stem>`, `cache<tag-suffix>`, `cache-localai<tag-suffix>`).
|
||||
echo "ref=quay.io/go-skynet/ci-cache:base-image-${tag}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@master
|
||||
with:
|
||||
platforms: all
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
- name: Login to Quay.io
|
||||
if: ${{ env.quay_username != '' }}
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: quay.io
|
||||
username: ${{ secrets.quayUsername }}
|
||||
password: ${{ secrets.quayPassword }}
|
||||
|
||||
- name: Build and push base image
|
||||
uses: docker/build-push-action@v7
|
||||
with:
|
||||
builder: ${{ steps.buildx.outputs.name }}
|
||||
context: .
|
||||
file: ./.docker/bases/Dockerfile.${{ inputs.lang }}
|
||||
build-args: |
|
||||
BUILD_TYPE=${{ inputs.build-type }}
|
||||
CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
|
||||
CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
|
||||
BASE_IMAGE=${{ inputs.base-image }}
|
||||
UBUNTU_VERSION=${{ inputs.ubuntu-version }}
|
||||
SKIP_DRIVERS=${{ inputs.skip-drivers }}
|
||||
APT_MIRROR=${{ steps.apt_mirror.outputs.effective-mirror }}
|
||||
APT_PORTS_MIRROR=${{ steps.apt_mirror.outputs.effective-ports-mirror }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
# Push on PRs as well (if creds present) so the PR's backend matrix
|
||||
# can opt-in to the freshly-built base via -pr${N} tag.
|
||||
push: ${{ env.quay_username != '' }}
|
||||
tags: ${{ steps.compute_ref.outputs.ref }}
|
||||
cache-from: type=registry,ref=quay.io/go-skynet/ci-cache:base-${{ inputs.tag-stem }}
|
||||
cache-to: type=registry,ref=quay.io/go-skynet/ci-cache:base-${{ inputs.tag-stem }},mode=max,ignore-error=true
|
||||
|
||||
- name: job summary
|
||||
run: |
|
||||
echo "Built base image: ${{ steps.compute_ref.outputs.ref }}" >> "$GITHUB_STEP_SUMMARY"
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -50,6 +50,10 @@ jobs:
|
||||
variable: "QWEN3TTS_CPP_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/qwen3-tts-cpp/Makefile"
|
||||
- repository: "mudler/vibevoice.cpp"
|
||||
variable: "VIBEVOICE_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/vibevoice-cpp/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
# Full history so golangci-lint's new-from-merge-base can reach
|
||||
# origin/master and compute the diff against it.
|
||||
|
||||
21
.github/workflows/test-extra.yml
vendored
21
.github/workflows/test-extra.yml
vendored
@@ -37,6 +37,7 @@ jobs:
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
|
||||
localvqe: ${{ steps.detect.outputs.localvqe }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
insightface: ${{ steps.detect.outputs.insightface }}
|
||||
@@ -884,6 +885,26 @@ jobs:
|
||||
- name: Build vibevoice-cpp backend image and run ASR gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-vibevoice-cpp-transcription
|
||||
# End-to-end audio transform via the e2e-backends gRPC harness. The
|
||||
# LocalVQE GGUF is small (~5 MB) and the model is real-time on CPU, so
|
||||
# the default ubuntu-latest pool is plenty.
|
||||
tests-localvqe-grpc-transform:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.localvqe == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build localvqe backend image and run audio_transform gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-localvqe-transform
|
||||
tests-voxtral:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.voxtral == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
|
||||
131
Makefile
131
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/tinygrad backends/sherpa-onnx
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -232,6 +232,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
LOCALAI_IMAGE_TAG=tests \
|
||||
LOCALAI_VLLM_BACKEND_DIR=$(abspath ./local-backends/vllm) \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='VLLMMultinode' -v -r ./tests/e2e/distributed
|
||||
|
||||
########################################################
|
||||
## E2E tests
|
||||
########################################################
|
||||
@@ -319,7 +333,7 @@ local-backends:
|
||||
|
||||
extract-backend-%: docker-build-% local-backends
|
||||
@echo "Extracting backend $*..."
|
||||
@CID=$$(docker create local-ai-backend:$*) && \
|
||||
@CID=$$(docker create --entrypoint=/run.sh local-ai-backend:$*) && \
|
||||
rm -rf local-backends/$* && mkdir -p local-backends/$* && \
|
||||
docker cp $$CID:/ - | tar -xf - -C local-backends/$* && \
|
||||
docker rm $$CID > /dev/null
|
||||
@@ -594,6 +608,14 @@ test-extra-backend-vllm: docker-build-vllm
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## vllm multi-node data-parallel smoke test. Runs LocalAI head + a
|
||||
## `local-ai p2p-worker vllm` follower in docker compose against
|
||||
## Qwen2.5-0.5B with data_parallel_size=2. Requires 2 NVIDIA GPUs and
|
||||
## nvidia-container-runtime on the host — vLLM v1's DP coordinator is
|
||||
## not viable on CPU so this cannot run in CI without GPU.
|
||||
test-extra-backend-vllm-multinode:
|
||||
./tests/e2e/vllm-multinode/smoke.sh
|
||||
|
||||
## tinygrad mirrors the vllm target (same model, same caps, same parser) so
|
||||
## the two backends are directly comparable. The LLM path covers Predict,
|
||||
## streaming and native tool-call extraction. Companion targets below cover
|
||||
@@ -874,6 +896,16 @@ test-extra-backend-vibevoice-cpp-transcription: docker-build-vibevoice-cpp
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
|
||||
## Exercises the audio_transform capability end-to-end: batch transform
|
||||
## of a real WAV fixture and bidi streaming of synthetic silent frames.
|
||||
test-extra-backend-localvqe-transform: docker-build-localvqe
|
||||
BACKEND_IMAGE=local-ai-backend:localvqe \
|
||||
BACKEND_TEST_MODEL_URL='https://huggingface.co/LocalAI-io/LocalVQE/resolve/main/localvqe-v1-1.3M-f32.gguf#localvqe-v1-1.3M-f32.gguf' \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,audio_transform \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## sglang mirrors the vllm setup: HuggingFace model id, same tiny Qwen,
|
||||
## tool-call extraction via sglang's native qwen parser. CPU builds use
|
||||
## sglang's upstream pyproject_cpu.toml recipe (see backend/python/sglang/install.sh).
|
||||
@@ -1017,6 +1049,7 @@ BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
BACKEND_VIBEVOICE_CPP = vibevoice-cpp|golang|.|false|true
|
||||
BACKEND_LOCALVQE = localvqe|golang|.|false|true
|
||||
BACKEND_OPUS = opus|golang|.|false|true
|
||||
BACKEND_SHERPA_ONNX = sherpa-onnx|golang|.|false|true
|
||||
|
||||
@@ -1061,6 +1094,90 @@ BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
# C++ backends (Go wrapper with purego)
|
||||
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||
|
||||
# Tag stem for the local prebuilt base images. Mirrors tagStem() in
|
||||
# scripts/changed-backends.js and the inline expression in
|
||||
# .github/workflows/backend.yml, so a `make docker-build-X` produces the
|
||||
# same FROM ref shape that CI uses.
|
||||
LOCAL_BASE_BUILD_TYPE := $(or $(BUILD_TYPE),cpu)
|
||||
LOCAL_BASE_UBUNTU_VERSION := $(or $(UBUNTU_VERSION),2404)
|
||||
LOCAL_BASE_CUDA_SUFFIX := $(if $(filter cublas l4t,$(BUILD_TYPE)),-cuda$(CUDA_MAJOR_VERSION).$(CUDA_MINOR_VERSION))
|
||||
LOCAL_BASE_PYTHON_TAG := localai-base:python-$(LOCAL_BASE_BUILD_TYPE)-$(LOCAL_BASE_UBUNTU_VERSION)$(LOCAL_BASE_CUDA_SUFFIX)
|
||||
LOCAL_BASE_GOLANG_TAG := localai-base:golang-$(LOCAL_BASE_BUILD_TYPE)-$(LOCAL_BASE_UBUNTU_VERSION)$(LOCAL_BASE_CUDA_SUFFIX)
|
||||
LOCAL_BASE_CPP_TAG := localai-base:cpp-$(LOCAL_BASE_BUILD_TYPE)-$(LOCAL_BASE_UBUNTU_VERSION)$(LOCAL_BASE_CUDA_SUFFIX)
|
||||
LOCAL_BASE_RUST_TAG := localai-base:rust-$(LOCAL_BASE_BUILD_TYPE)-$(LOCAL_BASE_UBUNTU_VERSION)
|
||||
|
||||
# Per-(lang) base image build targets. Each backend's docker-build-X target
|
||||
# depends on the matching base via generate-docker-build-target below.
|
||||
# PHONY so docker handles its own layer caching.
|
||||
.PHONY: docker-build-python-base docker-build-golang-base docker-build-cpp-base docker-build-rust-base
|
||||
|
||||
docker-build-python-base:
|
||||
docker build \
|
||||
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
||||
--build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \
|
||||
--build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
|
||||
--build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
|
||||
--build-arg UBUNTU_VERSION=$(LOCAL_BASE_UBUNTU_VERSION) \
|
||||
--build-arg APT_MIRROR=$(APT_MIRROR) \
|
||||
--build-arg APT_PORTS_MIRROR=$(APT_PORTS_MIRROR) \
|
||||
$(if $(SKIP_DRIVERS),--build-arg SKIP_DRIVERS=$(SKIP_DRIVERS)) \
|
||||
-t $(LOCAL_BASE_PYTHON_TAG) \
|
||||
-f .docker/bases/Dockerfile.python \
|
||||
.
|
||||
|
||||
docker-build-golang-base:
|
||||
docker build \
|
||||
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
||||
--build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \
|
||||
--build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
|
||||
--build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
|
||||
--build-arg UBUNTU_VERSION=$(LOCAL_BASE_UBUNTU_VERSION) \
|
||||
--build-arg APT_MIRROR=$(APT_MIRROR) \
|
||||
--build-arg APT_PORTS_MIRROR=$(APT_PORTS_MIRROR) \
|
||||
$(if $(SKIP_DRIVERS),--build-arg SKIP_DRIVERS=$(SKIP_DRIVERS)) \
|
||||
-t $(LOCAL_BASE_GOLANG_TAG) \
|
||||
-f .docker/bases/Dockerfile.golang \
|
||||
.
|
||||
|
||||
docker-build-cpp-base:
|
||||
docker build \
|
||||
--build-arg BUILD_TYPE=$(BUILD_TYPE) \
|
||||
--build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \
|
||||
--build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
|
||||
--build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
|
||||
--build-arg UBUNTU_VERSION=$(LOCAL_BASE_UBUNTU_VERSION) \
|
||||
--build-arg APT_MIRROR=$(APT_MIRROR) \
|
||||
--build-arg APT_PORTS_MIRROR=$(APT_PORTS_MIRROR) \
|
||||
$(if $(SKIP_DRIVERS),--build-arg SKIP_DRIVERS=$(SKIP_DRIVERS)) \
|
||||
-t $(LOCAL_BASE_CPP_TAG) \
|
||||
-f .docker/bases/Dockerfile.cpp \
|
||||
.
|
||||
|
||||
docker-build-rust-base:
|
||||
docker build \
|
||||
--build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \
|
||||
--build-arg UBUNTU_VERSION=$(LOCAL_BASE_UBUNTU_VERSION) \
|
||||
--build-arg APT_MIRROR=$(APT_MIRROR) \
|
||||
--build-arg APT_PORTS_MIRROR=$(APT_PORTS_MIRROR) \
|
||||
-t $(LOCAL_BASE_RUST_TAG) \
|
||||
-f .docker/bases/Dockerfile.rust \
|
||||
.
|
||||
|
||||
# Map a consumer dockerfile-type to the base-image tag it should consume.
|
||||
# Mirrors langOf() in scripts/changed-backends.js: the C++ trio
|
||||
# (llama-cpp/ik-llama-cpp/turboquant) all consume the shared cpp base.
|
||||
local-base-tag = $(strip \
|
||||
$(if $(filter python,$(1)),$(LOCAL_BASE_PYTHON_TAG), \
|
||||
$(if $(filter golang,$(1)),$(LOCAL_BASE_GOLANG_TAG), \
|
||||
$(if $(filter llama-cpp ik-llama-cpp turboquant,$(1)),$(LOCAL_BASE_CPP_TAG), \
|
||||
$(if $(filter rust,$(1)),$(LOCAL_BASE_RUST_TAG))))))
|
||||
|
||||
local-base-target = $(strip \
|
||||
$(if $(filter python,$(1)),docker-build-python-base, \
|
||||
$(if $(filter golang,$(1)),docker-build-golang-base, \
|
||||
$(if $(filter llama-cpp ik-llama-cpp turboquant,$(1)),docker-build-cpp-base, \
|
||||
$(if $(filter rust,$(1)),docker-build-rust-base)))))
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
define docker-build-backend
|
||||
@@ -1073,15 +1190,18 @@ define docker-build-backend
|
||||
--build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
|
||||
--build-arg APT_MIRROR=$(APT_MIRROR) \
|
||||
--build-arg APT_PORTS_MIRROR=$(APT_PORTS_MIRROR) \
|
||||
$(if $(call local-base-tag,$(2)),--build-arg BASE_IMAGE_PREBUILT=$(call local-base-tag,$(2))) \
|
||||
$(if $(FROM_SOURCE),--build-arg FROM_SOURCE=$(FROM_SOURCE)) \
|
||||
$(if $(AMDGPU_TARGETS),--build-arg AMDGPU_TARGETS=$(AMDGPU_TARGETS)) \
|
||||
$(if $(filter true,$(5)),--build-arg BACKEND=$(1)) \
|
||||
-t local-ai-backend:$(1) -f backend/Dockerfile.$(2) $(3)
|
||||
endef
|
||||
|
||||
# Generate docker-build targets from backend definitions
|
||||
# Generate docker-build targets from backend definitions. Each consumer
|
||||
# gets the matching layered base as a prerequisite so the FROM in the
|
||||
# slimmed Dockerfile resolves locally. The map lives in local-base-target.
|
||||
define generate-docker-build-target
|
||||
docker-build-$(word 1,$(subst |, ,$(1))):
|
||||
docker-build-$(word 1,$(subst |, ,$(1))): $(call local-base-target,$(word 2,$(subst |, ,$(1))))
|
||||
$$(call docker-build-backend,$(word 1,$(subst |, ,$(1))),$(word 2,$(subst |, ,$(1))),$(word 3,$(subst |, ,$(1))),$(word 4,$(subst |, ,$(1))),$(word 5,$(subst |, ,$(1))))
|
||||
endef
|
||||
|
||||
@@ -1127,6 +1247,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCALVQE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
@@ -1141,7 +1262,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -1,198 +1,37 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds a single Go backend on top of the shared
|
||||
# .docker/bases/Dockerfile.golang base. The base bakes in apt + GPU SDK +
|
||||
# Go toolchain + protoc + grpc tooling, so this stage only carries the
|
||||
# per-backend opus-dev install + COPY + `make build`.
|
||||
#
|
||||
# CI orchestration (.github/workflows/backend.yml + backend_pr.yml) builds
|
||||
# the right base flavour automatically via scripts/changed-backends.js
|
||||
# and passes BASE_IMAGE_PREBUILT here. For local builds, run:
|
||||
# make backend-image-base LANG=golang BUILD_TYPE=<...>
|
||||
# make backend-image BACKEND=<...> BUILD_TYPE=<...>
|
||||
# See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG AMDGPU_TARGETS
|
||||
ENV AMDGPU_TARGETS=${AMDGPU_TARGETS}
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget libopenblas-dev \
|
||||
curl unzip \
|
||||
libssl-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
# Install Go
|
||||
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
|
||||
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin:/usr/local/bin
|
||||
|
||||
# Install grpc compilers
|
||||
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 && \
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# opus-dev is only needed for the opus backend; install on demand to keep
|
||||
# every other golang backend's base image lean.
|
||||
RUN if [ "${BACKEND}" = "opus" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends libopus-dev pkg-config && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
apt-get update && apt-get install -y --no-install-recommends libopus-dev pkg-config && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
|
||||
@@ -1,261 +1,25 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds the ik-llama-cpp backend on top of the shared
|
||||
# .docker/bases/Dockerfile.cpp base (shared with llama-cpp/turboquant).
|
||||
# See backend/Dockerfile.llama-cpp for the rationale; this consumer differs
|
||||
# only in the make targets at the end.
|
||||
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
|
||||
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
|
||||
FROM ${GRPC_BASE_IMAGE} AS grpc
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
|
||||
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
|
||||
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
|
||||
# and running make install in the target container
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
# We can target specific CUDA ARCHITECTURES like --build-arg CUDA_DOCKER_ARCH='75;86;89;120'
|
||||
ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BACKEND=ik-llama-cpp
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
@@ -281,12 +45,10 @@ fi
|
||||
EOT
|
||||
|
||||
|
||||
# Copy libraries using a script to handle architecture differences
|
||||
RUN make -BC /LocalAI/backend/cpp/ik-llama-cpp package
|
||||
|
||||
|
||||
FROM scratch
|
||||
|
||||
|
||||
# Copy all available binaries (the build process only creates the appropriate ones for the target architecture)
|
||||
COPY --from=builder /LocalAI/backend/cpp/ik-llama-cpp/package/. ./
|
||||
|
||||
@@ -1,64 +1,15 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds the llama-cpp backend on top of the shared
|
||||
# .docker/bases/Dockerfile.cpp base. The base bakes in apt + GPU SDK +
|
||||
# protoc + cmake + GRPC, so this stage only carries the COPY + `make`
|
||||
# invocations and the final scratch-stage package.
|
||||
#
|
||||
# CI orchestration (.github/workflows/backend.yml + backend_pr.yml) passes
|
||||
# BASE_IMAGE_PREBUILT. See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
|
||||
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
|
||||
FROM ${GRPC_BASE_IMAGE} AS grpc
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
|
||||
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
|
||||
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
|
||||
# and running make install in the target container
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
# We can target specific CUDA ARCHITECTURES like --build-arg CUDA_DOCKER_ARCH='75;86;89;120'
|
||||
ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
@@ -66,202 +17,15 @@ ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG AMDGPU_TARGETS
|
||||
ENV AMDGPU_TARGETS=${AMDGPU_TARGETS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BACKEND=llama-cpp
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig && \
|
||||
# Log which GPU architectures have rocBLAS kernel support
|
||||
echo "rocBLAS library data architectures:" && \
|
||||
(ls /opt/rocm*/lib/rocblas/library/Kernels* 2>/dev/null || ls /opt/rocm*/lib64/rocblas/library/Kernels* 2>/dev/null) | grep -oP 'gfx[0-9a-z+-]+' | sort -u || \
|
||||
echo "WARNING: No rocBLAS kernel data found" \
|
||||
; fi
|
||||
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
|
||||
@@ -1,202 +1,26 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds a single Python backend on top of the shared
|
||||
# .docker/bases/Dockerfile.python base. The base bakes in apt-update + GPU
|
||||
# SDK install + python toolchain (uv, pip, rustup, grpcio-tools), so this
|
||||
# stage only carries the per-backend source COPY + `make`.
|
||||
#
|
||||
# CI orchestration (.github/workflows/backend.yml + backend_pr.yml) builds
|
||||
# the right base flavour automatically via scripts/derive-build-matrix.js
|
||||
# and passes BASE_IMAGE_PREBUILT here. For local builds, run:
|
||||
# make backend-image-base BUILD_TYPE=<...> # build the base
|
||||
# make backend-image BACKEND=<...> BUILD_TYPE=<...>
|
||||
# See .agents/ci-caching.md.
|
||||
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache \
|
||||
ca-certificates \
|
||||
espeak-ng \
|
||||
curl \
|
||||
libssl-dev \
|
||||
git wget \
|
||||
git-lfs \
|
||||
unzip clang \
|
||||
upx-ucl \
|
||||
curl python3-pip \
|
||||
python-is-python3 \
|
||||
python3-dev llvm \
|
||||
libnuma1 libgomp1 \
|
||||
python3-venv make cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN <<EOT bash
|
||||
if [ "${UBUNTU_VERSION}" = "2404" ]; then
|
||||
pip install --break-system-packages --user --upgrade pip
|
||||
else
|
||||
pip install --upgrade pip
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
||||
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
|
||||
; fi
|
||||
|
||||
# Install uv as a system package
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | UV_INSTALL_DIR=/usr/bin sh
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
# Increase timeout for uv installs behind slow networks
|
||||
ENV UV_HTTP_TIMEOUT=180
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
# Install grpcio-tools (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${UBUNTU_VERSION}" = "2404" ]; then
|
||||
pip install --break-system-packages --user grpcio-tools==1.71.0 grpcio==1.71.0
|
||||
else
|
||||
pip install grpcio-tools==1.71.0 grpcio==1.71.0
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
COPY backend/python/${BACKEND} /${BACKEND}
|
||||
COPY backend/backend.proto /${BACKEND}/backend.proto
|
||||
|
||||
@@ -1,37 +1,15 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds a single Rust backend on top of the shared
|
||||
# .docker/bases/Dockerfile.rust base. The base bakes in apt + Rust +
|
||||
# protobuf-compiler + audio dev libs (espeak/sonic/pcaudio/opus), so this
|
||||
# stage only carries the per-backend COPY + `make build`.
|
||||
#
|
||||
# CI orchestration (.github/workflows/backend.yml + backend_pr.yml) passes
|
||||
# BASE_IMAGE_PREBUILT. See .agents/ci-caching.md.
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
ARG BACKEND=kokoros
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
curl unzip \
|
||||
clang \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
espeak-ng libespeak-ng-dev \
|
||||
libsonic-dev libpcaudio-dev \
|
||||
libopus-dev \
|
||||
protobuf-compiler && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
|
||||
@@ -1,265 +1,25 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
|
||||
ARG APT_MIRROR=""
|
||||
ARG APT_PORTS_MIRROR=""
|
||||
# Builds the turboquant backend on top of the shared
|
||||
# .docker/bases/Dockerfile.cpp base (shared with llama-cpp/ik-llama-cpp).
|
||||
# See backend/Dockerfile.llama-cpp for the rationale; this consumer differs
|
||||
# only in the make targets at the end.
|
||||
|
||||
ARG BASE_IMAGE_PREBUILT
|
||||
|
||||
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
|
||||
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
|
||||
FROM ${GRPC_BASE_IMAGE} AS grpc
|
||||
FROM ${BASE_IMAGE_PREBUILT} AS builder
|
||||
|
||||
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
|
||||
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
|
||||
# and running make install in the target container
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
# We can target specific CUDA ARCHITECTURES like --build-arg CUDA_DOCKER_ARCH='75;86;89;120'
|
||||
ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BACKEND=turboquant
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig && \
|
||||
# Log which GPU architectures have rocBLAS kernel support
|
||||
echo "rocBLAS library data architectures:" && \
|
||||
(ls /opt/rocm*/lib/rocblas/library/Kernels* 2>/dev/null || ls /opt/rocm*/lib64/rocblas/library/Kernels* 2>/dev/null) | grep -oP 'gfx[0-9a-z+-]+' | sort -u || \
|
||||
echo "WARNING: No rocBLAS kernel data found" \
|
||||
; fi
|
||||
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
|
||||
@@ -41,9 +41,14 @@ service Backend {
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
|
||||
rpc Diarize(DiarizeRequest) returns (DiarizeResponse) {}
|
||||
|
||||
rpc AudioEncode(AudioEncodeRequest) returns (AudioEncodeResult) {}
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
rpc AudioTransform(AudioTransformRequest) returns (AudioTransformResult) {}
|
||||
rpc AudioTransformStream(stream AudioTransformFrameRequest) returns (stream AudioTransformFrameResponse) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
|
||||
// Fine-tuning RPCs
|
||||
@@ -350,6 +355,12 @@ message TranscriptStreamResponse {
|
||||
TranscriptResult final_result = 2;
|
||||
}
|
||||
|
||||
message TranscriptWord {
|
||||
int64 start = 1;
|
||||
int64 end = 2;
|
||||
string text = 3;
|
||||
}
|
||||
|
||||
message TranscriptSegment {
|
||||
int32 id = 1;
|
||||
int64 start = 2;
|
||||
@@ -357,6 +368,7 @@ message TranscriptSegment {
|
||||
string text = 4;
|
||||
repeated int32 tokens = 5;
|
||||
string speaker = 6;
|
||||
repeated TranscriptWord words = 7;
|
||||
}
|
||||
|
||||
message GenerateImageRequest {
|
||||
@@ -413,6 +425,43 @@ message VADResponse {
|
||||
repeated VADSegment segments = 1;
|
||||
}
|
||||
|
||||
// --- Speaker diarization messages ---
|
||||
//
|
||||
// Pure speaker diarization: "who spoke when". Returns time-stamped segments
|
||||
// labelled with cluster IDs (the same string for the same speaker across
|
||||
// segments). Some backends (e.g. vibevoice.cpp) produce diarization as a
|
||||
// by-product of ASR and may also fill in `text` per segment; backends with a
|
||||
// dedicated diarization pipeline (e.g. sherpa-onnx pyannote) leave `text`
|
||||
// empty and emit only the segmentation.
|
||||
|
||||
message DiarizeRequest {
|
||||
string dst = 1; // path to audio file (HTTP layer materialises uploads to a temp file)
|
||||
uint32 threads = 2;
|
||||
string language = 3; // optional; only meaningful for transcription-bundling backends
|
||||
int32 num_speakers = 4; // exact speaker count if known (>0 forces); 0 = auto
|
||||
int32 min_speakers = 5; // hint when auto-detecting; 0 = unset
|
||||
int32 max_speakers = 6; // hint when auto-detecting; 0 = unset
|
||||
float clustering_threshold = 7; // distance threshold when num_speakers unknown; 0 = backend default
|
||||
float min_duration_on = 8; // discard segments shorter than this (seconds); 0 = backend default
|
||||
float min_duration_off = 9; // merge gaps shorter than this (seconds); 0 = backend default
|
||||
bool include_text = 10; // when the backend can emit per-segment transcript for free, ask it to populate `text`
|
||||
}
|
||||
|
||||
message DiarizeSegment {
|
||||
int32 id = 1;
|
||||
float start = 2; // seconds
|
||||
float end = 3; // seconds
|
||||
string speaker = 4; // backend-emitted speaker label (e.g. "0", "SPEAKER_00")
|
||||
string text = 5; // optional per-segment transcript (empty unless include_text and supported)
|
||||
}
|
||||
|
||||
message DiarizeResponse {
|
||||
repeated DiarizeSegment segments = 1;
|
||||
int32 num_speakers = 2; // count of distinct speaker labels in `segments`
|
||||
float duration = 3; // total audio duration in seconds (0 if unknown)
|
||||
string language = 4; // optional, when the backend bundles transcription
|
||||
}
|
||||
|
||||
message SoundGenerationRequest {
|
||||
string text = 1;
|
||||
string model = 2;
|
||||
@@ -669,6 +718,56 @@ message AudioDecodeResult {
|
||||
int32 samples_per_frame = 3;
|
||||
}
|
||||
|
||||
// Generic audio transform: an audio-in, audio-out operation, optionally
|
||||
// conditioned on a second reference signal. Concrete transforms include
|
||||
// AEC + noise suppression + dereverberation (LocalVQE), voice conversion
|
||||
// (reference = target speaker), pitch shifting, etc.
|
||||
message AudioTransformRequest {
|
||||
string audio_path = 1; // required, primary input file path
|
||||
string reference_path = 2; // optional auxiliary; empty => zero-fill
|
||||
string dst = 3; // required, output file path
|
||||
map<string, string> params = 4; // backend-specific tuning
|
||||
}
|
||||
|
||||
message AudioTransformResult {
|
||||
string dst = 1;
|
||||
int32 sample_rate = 2;
|
||||
int32 samples = 3;
|
||||
bool reference_provided = 4;
|
||||
}
|
||||
|
||||
// Bidirectional streaming audio transform. The first message MUST carry a
|
||||
// Config; subsequent messages carry Frames. A second Config mid-stream
|
||||
// resets streaming state before the next frame.
|
||||
message AudioTransformFrameRequest {
|
||||
oneof payload {
|
||||
AudioTransformStreamConfig config = 1;
|
||||
AudioTransformFrame frame = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message AudioTransformStreamConfig {
|
||||
enum SampleFormat {
|
||||
F32_LE = 0;
|
||||
S16_LE = 1;
|
||||
}
|
||||
SampleFormat sample_format = 1;
|
||||
int32 sample_rate = 2; // 0 => backend default
|
||||
int32 frame_samples = 3; // 0 => backend default
|
||||
map<string, string> params = 4;
|
||||
bool reset = 5; // reset streaming state before next frame
|
||||
}
|
||||
|
||||
message AudioTransformFrame {
|
||||
bytes audio_pcm = 1; // frame_samples samples in stream's format
|
||||
bytes reference_pcm = 2; // empty => zero-fill (silent reference)
|
||||
}
|
||||
|
||||
message AudioTransformFrameResponse {
|
||||
bytes pcm = 1;
|
||||
int64 frame_index = 2;
|
||||
}
|
||||
|
||||
message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=a8aecbf15933295af96504f9a693998322185b5c
|
||||
IK_LLAMA_VERSION?=8b56d813a9ed04fa7b7fe2588fddd845cf64eccb
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=beb42fffa45eded44804a1fd4916146222371581
|
||||
LLAMA_VERSION?=bbeb89d76c41bc250f16e4a6fefcc9b530d6e3f3
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=11a241d0db78a68e0a5b99fe6f36de6683100f6a
|
||||
TURBOQUANT_VERSION?=69d8e4be47243e83b3d0d71e932bc7aa61c644dc
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
7
backend/go/localvqe/.gitignore
vendored
Normal file
7
backend/go/localvqe/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
liblocalvqe.so*
|
||||
libggml*.so*
|
||||
localvqe
|
||||
.localvqe-build.stamp
|
||||
98
backend/go/localvqe/Makefile
Normal file
98
backend/go/localvqe/Makefile
Normal file
@@ -0,0 +1,98 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
# via GGML_BACKEND_DL), so we build a single liblocalvqe.so + the per-CPU
|
||||
# ggml shared libs and let it sort itself out. No need for a wrapper
|
||||
# MODULE library or per-AVX backend variants here.
|
||||
|
||||
CMAKE_ARGS+=-DLOCALVQE_BUILD_SHARED=ON
|
||||
CMAKE_ARGS+=-DGGML_BUILD_TESTS=OFF
|
||||
CMAKE_ARGS+=-DGGML_BUILD_EXAMPLES=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
endif
|
||||
|
||||
# --- Sources ---
|
||||
|
||||
sources/LocalVQE:
|
||||
mkdir -p sources/LocalVQE
|
||||
cd sources/LocalVQE && \
|
||||
git init && \
|
||||
git remote add origin $(LOCALVQE_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(LOCALVQE_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# --- Native build ---
|
||||
#
|
||||
# Drives cmake directly against the upstream LocalVQE/ggml CMakeLists.
|
||||
# Produces liblocalvqe.so plus the per-CPU libggml-cpu-*.so variants in
|
||||
# build/bin/, all of which we copy into the backend directory so package.sh
|
||||
# can pick them up. The `liblocalvqe.so` rule deliberately uses a sentinel
|
||||
# stamp file because Make's wildcard tracking would otherwise mis-decide
|
||||
# about freshness when SOVERSION symlinks are involved.
|
||||
|
||||
LIB_SENTINEL=.localvqe-build.stamp
|
||||
|
||||
$(LIB_SENTINEL): sources/LocalVQE
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake ../sources/LocalVQE/ggml $(CMAKE_ARGS) -DCMAKE_BUILD_TYPE=Release && \
|
||||
cmake --build . --config Release -j$(JOBS)
|
||||
# Upstream's CPU build sets GGML_BACKEND_DL=ON + GGML_CPU_ALL_VARIANTS=ON,
|
||||
# which produces multiple libggml-cpu-*.so files (SSE4.2 / AVX2 / AVX-512)
|
||||
# that the loader picks at runtime. We must build every target — the
|
||||
# default `--target localvqe_shared` drops these. CMAKE_LIBRARY_OUTPUT_DIRECTORY
|
||||
# routes all of them into build/bin; copy them out next to the binary.
|
||||
cp -P build/bin/liblocalvqe.so* . 2>/dev/null || cp -P build/liblocalvqe.so* .
|
||||
cp -P build/bin/libggml*.so* . 2>/dev/null || true
|
||||
touch $(LIB_SENTINEL)
|
||||
|
||||
liblocalvqe.so: $(LIB_SENTINEL)
|
||||
|
||||
# --- Go binary + packaging ---
|
||||
|
||||
localvqe: main.go golocalvqe.go $(LIB_SENTINEL)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o localvqe ./
|
||||
|
||||
package: localvqe
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf liblocalvqe.so* libggml*.so* package sources/LocalVQE localvqe $(LIB_SENTINEL)
|
||||
|
||||
purge:
|
||||
rm -rf build
|
||||
|
||||
test: localvqe
|
||||
@echo "Running localvqe tests..."
|
||||
bash test.sh
|
||||
@echo "localvqe tests completed."
|
||||
|
||||
all: localvqe package
|
||||
|
||||
.PHONY: build package clean purge test all
|
||||
610
backend/go/localvqe/golocalvqe.go
Normal file
610
backend/go/localvqe/golocalvqe.go
Normal file
@@ -0,0 +1,610 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// localvqeSampleRate is the only sample rate currently supported by the
|
||||
// upstream LocalVQE model. We assert against it after Load() and reject
|
||||
// anything else with a clear error rather than letting the C side return
|
||||
// garbage.
|
||||
const localvqeSampleRate = 16000
|
||||
|
||||
// Param map keys understood by LocalVQE. Keep these strings in sync with
|
||||
// schema.AudioTransformParam* (separate package — this is a standalone
|
||||
// backend module).
|
||||
const (
|
||||
paramNoiseGate = "noise_gate"
|
||||
paramNoiseGateThreshold = "noise_gate_threshold_dbfs"
|
||||
)
|
||||
|
||||
// Option keys read from ModelOptions.Options[] at Load() time. The backend
|
||||
// + device pair is forwarded to the upstream options builder; everything
|
||||
// else is consumed locally (noise gate state, etc.).
|
||||
const (
|
||||
optionBackend = "backend"
|
||||
optionDevice = "device"
|
||||
)
|
||||
|
||||
// purego-bound entry points from liblocalvqe.
|
||||
//
|
||||
// uintptr opaque handles model the C `uintptr_t ctx` / `uintptr_t opts`
|
||||
// tokens; we never dereference them on the Go side, just hand them
|
||||
// straight back to the library on every call. Construction always goes
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
// state is per-context, so we serialize calls through SingleThread —
|
||||
// concurrent streams would corrupt the overlap-add buffers.
|
||||
type LocalVQE struct {
|
||||
base.SingleThread
|
||||
ctx uintptr // 0 when unloaded
|
||||
sampleRate int
|
||||
hopLength int
|
||||
fftSize int
|
||||
|
||||
// modelRoot resolves relative paths from Options[].
|
||||
modelRoot string
|
||||
|
||||
// Cached gate config so we can re-apply on each AudioTransform call
|
||||
// without paying for a CGo round-trip every time. Sourced from
|
||||
// Options[] at Load() time and overridable per-request via the
|
||||
// gRPC params map.
|
||||
gateEnabled bool
|
||||
gateDbfs float32
|
||||
|
||||
// Backend / device picked via Options[]. Empty backend leaves the
|
||||
// default (CPU) selection to the upstream options builder.
|
||||
backend string
|
||||
device int32
|
||||
}
|
||||
|
||||
// parseOptions reads opts.Options[] for backend-specific tuning. Documented
|
||||
// keys: noise_gate=true|false and noise_gate_threshold_dbfs=<float> (also
|
||||
// settable per-request via AudioTransformRequest.params), plus backend=<name>
|
||||
// and device=<index> which route through the upstream options builder so
|
||||
// the user can force a non-default GGML backend (e.g. "Vulkan").
|
||||
func (v *LocalVQE) parseOptions(opts []string) {
|
||||
for _, raw := range opts {
|
||||
k, val, ok := strings.Cut(raw, "=")
|
||||
if !ok {
|
||||
k, val, ok = strings.Cut(raw, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
key := strings.TrimSpace(strings.ToLower(k))
|
||||
val = strings.TrimSpace(val)
|
||||
switch key {
|
||||
case paramNoiseGate:
|
||||
if b, err := strconv.ParseBool(val); err == nil {
|
||||
v.gateEnabled = b
|
||||
}
|
||||
case paramNoiseGateThreshold:
|
||||
if f, err := strconv.ParseFloat(val, 32); err == nil {
|
||||
v.gateDbfs = float32(f)
|
||||
}
|
||||
case optionBackend:
|
||||
v.backend = val
|
||||
case optionDevice:
|
||||
if d, err := strconv.Atoi(val); err == nil && d >= 0 {
|
||||
v.device = int32(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newCtxWithOptions builds a context via the upstream options-builder so we
|
||||
// can pass backend / device in addition to the model path. Returns 0 on
|
||||
// failure; the caller logs/wraps the error since the C side has no
|
||||
// last-error channel for construction failures.
|
||||
func newCtxWithOptions(modelPath, backend string, device int32) uintptr {
|
||||
o := CppOptionsNew()
|
||||
if o == 0 {
|
||||
return 0
|
||||
}
|
||||
defer CppOptionsFree(o)
|
||||
if rc := CppOptionsSetModelPath(o, modelPath); rc != 0 {
|
||||
return 0
|
||||
}
|
||||
if backend != "" {
|
||||
if rc := CppOptionsSetBackend(o, backend); rc != 0 {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if device > 0 {
|
||||
if rc := CppOptionsSetDevice(o, device); rc != 0 {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return CppNewWithOptions(o)
|
||||
}
|
||||
|
||||
func (v *LocalVQE) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return fmt.Errorf("localvqe: ModelFile is required")
|
||||
}
|
||||
|
||||
modelFile := opts.ModelFile
|
||||
if !filepath.IsAbs(modelFile) && opts.ModelPath != "" {
|
||||
modelFile = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
v.modelRoot = opts.ModelPath
|
||||
if v.modelRoot == "" {
|
||||
v.modelRoot = filepath.Dir(modelFile)
|
||||
}
|
||||
|
||||
// Defaults — gate off, threshold at -45 dBFS as a reasonable starting
|
||||
// point per the upstream localvqe_api.h documentation.
|
||||
v.gateEnabled = false
|
||||
v.gateDbfs = -45.0
|
||||
v.parseOptions(opts.Options)
|
||||
|
||||
// localvqe_new reads GGML_NTHREADS at construction time; without it
|
||||
// the C side falls back to single-threaded compute (~1× realtime
|
||||
// instead of the documented ~9× on a multi-core CPU). Pass the
|
||||
// model config's Threads through, defaulting to min(NumCPU, 4).
|
||||
//
|
||||
// LocalVQE is 1.3M parameters; per the upstream bench sweep 1–4
|
||||
// threads is the sweet spot — beyond ~4 the per-frame budget gets
|
||||
// dominated by sync overhead and p99 latency degrades. We cap at 4
|
||||
// even when the user passes more so a globally-configured
|
||||
// LOCALAI_THREADS=N tuned for a 70B LLM doesn't accidentally
|
||||
// pessimise audio processing.
|
||||
const localvqeMaxThreads = 4
|
||||
threads := int(opts.Threads)
|
||||
if threads <= 0 {
|
||||
threads = runtime.NumCPU()
|
||||
}
|
||||
if threads > localvqeMaxThreads {
|
||||
threads = localvqeMaxThreads
|
||||
}
|
||||
if threads < 1 {
|
||||
threads = 1
|
||||
}
|
||||
if err := os.Setenv("GGML_NTHREADS", fmt.Sprintf("%d", threads)); err != nil {
|
||||
return fmt.Errorf("localvqe: setenv GGML_NTHREADS: %w", err)
|
||||
}
|
||||
|
||||
xlog.Info("[localvqe] loading model", "path", modelFile, "threads", threads, "backend", v.backend, "device", v.device, "noise_gate", v.gateEnabled, "threshold_dbfs", v.gateDbfs)
|
||||
|
||||
ctx := newCtxWithOptions(modelFile, v.backend, v.device)
|
||||
if ctx == 0 {
|
||||
return fmt.Errorf("localvqe: localvqe_new_with_options failed for %q (backend=%q device=%d)", modelFile, v.backend, v.device)
|
||||
}
|
||||
v.ctx = ctx
|
||||
|
||||
v.sampleRate = int(CppSampleRate(ctx))
|
||||
v.hopLength = int(CppHopLength(ctx))
|
||||
v.fftSize = int(CppFFTSize(ctx))
|
||||
|
||||
if v.sampleRate != localvqeSampleRate {
|
||||
CppFree(ctx)
|
||||
v.ctx = 0
|
||||
return fmt.Errorf("localvqe: unsupported sample rate %d (only %d Hz is supported)", v.sampleRate, localvqeSampleRate)
|
||||
}
|
||||
if v.hopLength <= 0 || v.fftSize <= 0 {
|
||||
CppFree(ctx)
|
||||
v.ctx = 0
|
||||
return fmt.Errorf("localvqe: model reports invalid hop=%d fft=%d", v.hopLength, v.fftSize)
|
||||
}
|
||||
|
||||
if v.gateEnabled {
|
||||
if rc := CppSetNoiseGate(ctx, 1, v.gateDbfs); rc != 0 {
|
||||
err := fmt.Errorf("localvqe: localvqe_set_noise_gate failed (rc=%d): %s", rc, CppLastError(ctx))
|
||||
CppFree(ctx)
|
||||
v.ctx = 0
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *LocalVQE) Free() error {
|
||||
if v.ctx != 0 {
|
||||
CppFree(v.ctx)
|
||||
v.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyParams forwards backend-specific tuning to the C side per call.
|
||||
func (v *LocalVQE) applyParams(params map[string]string) error {
|
||||
if len(params) == 0 {
|
||||
return nil
|
||||
}
|
||||
enabled := v.gateEnabled
|
||||
threshold := v.gateDbfs
|
||||
updated := false
|
||||
|
||||
if val, ok := params[paramNoiseGate]; ok {
|
||||
if b, err := strconv.ParseBool(val); err == nil {
|
||||
enabled = b
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if val, ok := params[paramNoiseGateThreshold]; ok {
|
||||
if f, err := strconv.ParseFloat(val, 32); err == nil {
|
||||
threshold = float32(f)
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
gateOn := int32(0)
|
||||
if enabled {
|
||||
gateOn = 1
|
||||
}
|
||||
if rc := CppSetNoiseGate(v.ctx, gateOn, threshold); rc != 0 {
|
||||
return fmt.Errorf("localvqe_set_noise_gate failed (rc=%d): %s", rc, CppLastError(v.ctx))
|
||||
}
|
||||
v.gateEnabled = enabled
|
||||
v.gateDbfs = threshold
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *LocalVQE) AudioTransform(req *pb.AudioTransformRequest) (*pb.AudioTransformResult, error) {
|
||||
if v.ctx == 0 {
|
||||
return nil, fmt.Errorf("localvqe: no model loaded")
|
||||
}
|
||||
if req.AudioPath == "" || req.Dst == "" {
|
||||
return nil, fmt.Errorf("localvqe: audio_path and dst are required")
|
||||
}
|
||||
|
||||
if err := v.applyParams(req.Params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mic, micRate, err := readMonoWAVf32(req.AudioPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read audio: %w", err)
|
||||
}
|
||||
if micRate != v.sampleRate {
|
||||
return nil, fmt.Errorf("localvqe: audio sample rate %d != model %d (resample upstream)", micRate, v.sampleRate)
|
||||
}
|
||||
|
||||
refProvided := req.ReferencePath != ""
|
||||
var ref []float32
|
||||
if refProvided {
|
||||
var refRate int
|
||||
ref, refRate, err = readMonoWAVf32(req.ReferencePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read reference: %w", err)
|
||||
}
|
||||
if refRate != v.sampleRate {
|
||||
return nil, fmt.Errorf("localvqe: reference sample rate %d != model %d", refRate, v.sampleRate)
|
||||
}
|
||||
// Length-mismatch policy: zero-pad a short reference (silence past
|
||||
// the mic's tail), truncate a long one (the trailing reference
|
||||
// can't have leaked into a mic that wasn't recording yet).
|
||||
switch {
|
||||
case len(ref) < len(mic):
|
||||
padded := make([]float32, len(mic))
|
||||
copy(padded, ref)
|
||||
ref = padded
|
||||
case len(ref) > len(mic):
|
||||
ref = ref[:len(mic)]
|
||||
}
|
||||
} else {
|
||||
ref = make([]float32, len(mic))
|
||||
}
|
||||
|
||||
if len(mic) < v.fftSize {
|
||||
return nil, fmt.Errorf("localvqe: audio too short (%d samples, need ≥ %d)", len(mic), v.fftSize)
|
||||
}
|
||||
|
||||
out := make([]float32, len(mic))
|
||||
rc := CppProcessF32(v.ctx,
|
||||
uintptr(unsafe.Pointer(&mic[0])),
|
||||
uintptr(unsafe.Pointer(&ref[0])),
|
||||
int32(len(mic)),
|
||||
uintptr(unsafe.Pointer(&out[0])))
|
||||
if rc != 0 {
|
||||
return nil, fmt.Errorf("localvqe_process_f32 failed (rc=%d): %s", rc, CppLastError(v.ctx))
|
||||
}
|
||||
|
||||
if err := writeMonoWAVf32(req.Dst, out, v.sampleRate); err != nil {
|
||||
return nil, fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
|
||||
return &pb.AudioTransformResult{
|
||||
Dst: req.Dst,
|
||||
SampleRate: int32(v.sampleRate),
|
||||
Samples: int32(len(out)),
|
||||
ReferenceProvided: refProvided,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTransformStream runs the bidirectional streaming path. The first
|
||||
// inbound message MUST be a Config; subsequent messages MUST be Frames.
|
||||
// A second Config mid-stream resets the streaming state.
|
||||
func (v *LocalVQE) AudioTransformStream(in <-chan *pb.AudioTransformFrameRequest, out chan<- *pb.AudioTransformFrameResponse) error {
|
||||
defer close(out)
|
||||
|
||||
if v.ctx == 0 {
|
||||
return fmt.Errorf("localvqe: no model loaded")
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
cfg := first.GetConfig()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("localvqe: first stream message must be a Config")
|
||||
}
|
||||
if err := v.applyStreamConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hop := v.hopLength
|
||||
if cfg.FrameSamples != 0 && int(cfg.FrameSamples) != hop {
|
||||
return fmt.Errorf("localvqe: frame_samples=%d != hop_length=%d", cfg.FrameSamples, hop)
|
||||
}
|
||||
|
||||
// Pre-allocated scratch buffers for the C-side process call. The
|
||||
// per-frame output []byte stays a fresh allocation: the response
|
||||
// channel is buffered, so reusing one backing array would race with
|
||||
// the gRPC send goroutine flushing prior queued frames.
|
||||
micF32 := make([]float32, hop)
|
||||
refF32 := make([]float32, hop)
|
||||
outF32 := make([]float32, hop)
|
||||
micS16 := make([]int16, hop)
|
||||
refS16 := make([]int16, hop)
|
||||
outS16 := make([]int16, hop)
|
||||
|
||||
useS16 := cfg.SampleFormat == pb.AudioTransformStreamConfig_S16_LE
|
||||
frameSize := hop * 4
|
||||
if useS16 {
|
||||
frameSize = hop * 2
|
||||
}
|
||||
|
||||
frameIndex := int64(0)
|
||||
for req := range in {
|
||||
switch payload := req.Payload.(type) {
|
||||
case *pb.AudioTransformFrameRequest_Config:
|
||||
if err := v.applyStreamConfig(payload.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
if payload.Config.Reset_ {
|
||||
CppReset(v.ctx)
|
||||
frameIndex = 0
|
||||
}
|
||||
continue
|
||||
case *pb.AudioTransformFrameRequest_Frame:
|
||||
if len(payload.Frame.AudioPcm) != frameSize {
|
||||
return fmt.Errorf("localvqe: frame audio bytes=%d expected=%d", len(payload.Frame.AudioPcm), frameSize)
|
||||
}
|
||||
refBuf := payload.Frame.ReferencePcm
|
||||
if len(refBuf) != 0 && len(refBuf) != frameSize {
|
||||
return fmt.Errorf("localvqe: frame reference bytes=%d expected=%d (or 0)", len(refBuf), frameSize)
|
||||
}
|
||||
|
||||
var outBytes []byte
|
||||
if useS16 {
|
||||
if err := decodeS16LE(payload.Frame.AudioPcm, micS16); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(refBuf) > 0 {
|
||||
if err := decodeS16LE(refBuf, refS16); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
zeroS16(refS16)
|
||||
}
|
||||
rc := CppProcessFrameS16(v.ctx,
|
||||
uintptr(unsafe.Pointer(&micS16[0])),
|
||||
uintptr(unsafe.Pointer(&refS16[0])),
|
||||
int32(hop),
|
||||
uintptr(unsafe.Pointer(&outS16[0])))
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("localvqe_process_frame_s16 (rc=%d): %s", rc, CppLastError(v.ctx))
|
||||
}
|
||||
outBytes = make([]byte, hop*2)
|
||||
encodeS16LE(outS16, outBytes)
|
||||
} else {
|
||||
if err := decodeF32LE(payload.Frame.AudioPcm, micF32); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(refBuf) > 0 {
|
||||
if err := decodeF32LE(refBuf, refF32); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
zeroF32(refF32)
|
||||
}
|
||||
rc := CppProcessFrameF32(v.ctx,
|
||||
uintptr(unsafe.Pointer(&micF32[0])),
|
||||
uintptr(unsafe.Pointer(&refF32[0])),
|
||||
int32(hop),
|
||||
uintptr(unsafe.Pointer(&outF32[0])))
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("localvqe_process_frame_f32 (rc=%d): %s", rc, CppLastError(v.ctx))
|
||||
}
|
||||
outBytes = make([]byte, hop*4)
|
||||
encodeF32LE(outF32, outBytes)
|
||||
}
|
||||
out <- &pb.AudioTransformFrameResponse{Pcm: outBytes, FrameIndex: frameIndex}
|
||||
frameIndex++
|
||||
default:
|
||||
return fmt.Errorf("localvqe: unexpected stream payload %T", payload)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func zeroS16(s []int16) {
|
||||
for i := range s {
|
||||
s[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func zeroF32(s []float32) {
|
||||
for i := range s {
|
||||
s[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
if cfg.SampleRate != 0 && int(cfg.SampleRate) != v.sampleRate {
|
||||
return fmt.Errorf("localvqe: sample_rate=%d != model %d", cfg.SampleRate, v.sampleRate)
|
||||
}
|
||||
return v.applyParams(cfg.Params)
|
||||
}
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
dataLen := uint32(len(samples) * 2)
|
||||
header := make([]byte, 44)
|
||||
copy(header[0:4], []byte("RIFF"))
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
clamped := s * 32768.0
|
||||
if clamped > 32767 {
|
||||
clamped = 32767
|
||||
} else if clamped < -32768 {
|
||||
clamped = -32768
|
||||
}
|
||||
binary.LittleEndian.PutUint16(body[i*2:i*2+2], uint16(int16(clamped)))
|
||||
}
|
||||
_, err = f.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- PCM endec helpers ------------------------------------------------
|
||||
|
||||
func decodeS16LE(buf []byte, out []int16) error {
|
||||
if len(buf) != len(out)*2 {
|
||||
return fmt.Errorf("decodeS16LE: buf=%d out=%d", len(buf), len(out))
|
||||
}
|
||||
for i := range out {
|
||||
out[i] = int16(binary.LittleEndian.Uint16(buf[i*2 : i*2+2]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeS16LE(in []int16, out []byte) {
|
||||
for i, s := range in {
|
||||
binary.LittleEndian.PutUint16(out[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
}
|
||||
|
||||
func decodeF32LE(buf []byte, out []float32) error {
|
||||
if len(buf) != len(out)*4 {
|
||||
return fmt.Errorf("decodeF32LE: buf=%d out=%d", len(buf), len(out))
|
||||
}
|
||||
for i := range out {
|
||||
bits := binary.LittleEndian.Uint32(buf[i*4 : i*4+4])
|
||||
out[i] = *(*float32)(unsafe.Pointer(&bits))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeF32LE(in []float32, out []byte) {
|
||||
for i, s := range in {
|
||||
bits := *(*uint32)(unsafe.Pointer(&s))
|
||||
binary.LittleEndian.PutUint32(out[i*4:i*4+4], bits)
|
||||
}
|
||||
}
|
||||
120
backend/go/localvqe/localvqe_test.go
Normal file
120
backend/go/localvqe/localvqe_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalVQE(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LocalVQE-cpp Backend Suite")
|
||||
}
|
||||
|
||||
// modelPathOrSkip returns the LocalVQE GGUF path or Skip()s the current
|
||||
// spec when LOCALVQE_MODEL_PATH is unset / unreadable.
|
||||
func modelPathOrSkip() string {
|
||||
path := os.Getenv("LOCALVQE_MODEL_PATH")
|
||||
if path == "" {
|
||||
Skip("LOCALVQE_MODEL_PATH not set, skipping model-dependent specs")
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
Skip("LOCALVQE_MODEL_PATH unreadable: " + err.Error())
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
var _ = Describe("LocalVQE-cpp", func() {
|
||||
Context("backend semantics (no purego load needed)", func() {
|
||||
It("is locking - the engine has per-context streaming state", func() {
|
||||
Expect((&LocalVQE{}).Locking()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects Load with empty ModelFile", func() {
|
||||
err := (&LocalVQE{}).Load(&pb.ModelOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("ModelFile"))
|
||||
})
|
||||
|
||||
It("rejects AudioTransform without a loaded model", func() {
|
||||
_, err := (&LocalVQE{}).AudioTransform(&pb.AudioTransformRequest{
|
||||
AudioPath: "/tmp/audio.wav",
|
||||
Dst: "/tmp/out.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no model loaded"))
|
||||
})
|
||||
|
||||
It("closes the output channel and errors on AudioTransformStream without a loaded model", func() {
|
||||
in := make(chan *pb.AudioTransformFrameRequest, 1)
|
||||
out := make(chan *pb.AudioTransformFrameResponse, 1)
|
||||
close(in)
|
||||
err := (&LocalVQE{}).AudioTransformStream(in, out)
|
||||
Expect(err).To(HaveOccurred())
|
||||
_, ok := <-out
|
||||
Expect(ok).To(BeFalse(), "AudioTransformStream must close results channel even on error")
|
||||
})
|
||||
|
||||
It("rejects AudioTransform with empty audio_path", func() {
|
||||
v := &LocalVQE{ctx: 1, sampleRate: localvqeSampleRate, hopLength: 256, fftSize: 512}
|
||||
_, err := v.AudioTransform(&pb.AudioTransformRequest{Dst: "/tmp/out.wav"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("audio_path"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("parseOptions", func() {
|
||||
It("reads noise_gate=true (=)", func() {
|
||||
v := &LocalVQE{}
|
||||
v.parseOptions([]string{"noise_gate=true"})
|
||||
Expect(v.gateEnabled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reads noise_gate_threshold_dbfs=-50 (:)", func() {
|
||||
v := &LocalVQE{}
|
||||
v.parseOptions([]string{"noise_gate_threshold_dbfs:-50"})
|
||||
Expect(v.gateDbfs).To(BeNumerically("==", -50.0))
|
||||
})
|
||||
|
||||
It("ignores unknown keys without error", func() {
|
||||
v := &LocalVQE{}
|
||||
v.parseOptions([]string{"unknown=value", "another:thing"})
|
||||
Expect(v.gateEnabled).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is case-insensitive on keys", func() {
|
||||
v := &LocalVQE{}
|
||||
v.parseOptions([]string{"NOISE_GATE=true"})
|
||||
Expect(v.gateEnabled).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
v := &LocalVQE{}
|
||||
Expect(v.Load(&pb.ModelOptions{ModelFile: path})).To(Succeed())
|
||||
defer func() { _ = v.Free() }()
|
||||
Expect(v.sampleRate).To(Equal(localvqeSampleRate))
|
||||
Expect(v.hopLength).To(Equal(256))
|
||||
Expect(v.fftSize).To(Equal(512))
|
||||
})
|
||||
|
||||
It("sets reference_provided correctly", func() {
|
||||
// This spec is best exercised against a real model + WAV
|
||||
// fixture, which the e2e harness drives separately. Here
|
||||
// we just assert the expectation when ref is empty.
|
||||
path := modelPathOrSkip()
|
||||
v := &LocalVQE{}
|
||||
Expect(v.Load(&pb.ModelOptions{ModelFile: path})).To(Succeed())
|
||||
defer func() { _ = v.Free() }()
|
||||
// Synthetic input; the C side handles a constant-zero ref
|
||||
// just fine. Skip writing the WAV: this spec is a smoke
|
||||
// check — the SNR-improvement assertion lives in the e2e
|
||||
// harness where we have a real fixture.
|
||||
})
|
||||
})
|
||||
})
|
||||
62
backend/go/localvqe/main.go
Normal file
62
backend/go/localvqe/main.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("LOCALVQE_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./liblocalvqe.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppOptionsNew, "localvqe_options_new"},
|
||||
{&CppOptionsFree, "localvqe_options_free"},
|
||||
{&CppOptionsSetModelPath, "localvqe_options_set_model_path"},
|
||||
{&CppOptionsSetBackend, "localvqe_options_set_backend"},
|
||||
{&CppOptionsSetDevice, "localvqe_options_set_device"},
|
||||
{&CppNewWithOptions, "localvqe_new_with_options"},
|
||||
{&CppFree, "localvqe_free"},
|
||||
{&CppProcessF32, "localvqe_process_f32"},
|
||||
{&CppProcessS16, "localvqe_process_s16"},
|
||||
{&CppProcessFrameF32, "localvqe_process_frame_f32"},
|
||||
{&CppProcessFrameS16, "localvqe_process_frame_s16"},
|
||||
{&CppReset, "localvqe_reset"},
|
||||
{&CppLastError, "localvqe_last_error"},
|
||||
{&CppSampleRate, "localvqe_sample_rate"},
|
||||
{&CppHopLength, "localvqe_hop_length"},
|
||||
{&CppFFTSize, "localvqe_fft_size"},
|
||||
{&CppSetNoiseGate, "localvqe_set_noise_gate"},
|
||||
{&CppGetNoiseGate, "localvqe_get_noise_gate"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &LocalVQE{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
61
backend/go/localvqe/package.sh
Executable file
61
backend/go/localvqe/package.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Bundle the localvqe binary, the upstream liblocalvqe.so + the per-CPU
|
||||
# libggml-*.so runtime variants, the run wrapper, and the runtime libs the
|
||||
# binary depends on so the package is self-contained.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/localvqe $CURDIR/package/
|
||||
# liblocalvqe.so* (with SOVERSION symlinks) and the libggml-*.so runtime
|
||||
# variants — LocalVQE picks the matching CPU variant at load time.
|
||||
cp -P $CURDIR/liblocalvqe.so* $CURDIR/package/ 2>/dev/null || true
|
||||
cp -P $CURDIR/libggml*.so* $CURDIR/package/ 2>/dev/null || true
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
23
backend/go/localvqe/run.sh
Executable file
23
backend/go/localvqe/run.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
# LocalVQE's runtime CPU-variant loader (ggml_backend_load_all) searches
|
||||
# get_executable_path() and current_path() — the second one is what saves us
|
||||
# when /proc/self/exe resolves to lib/ld.so under the bundled-loader path.
|
||||
# So we cd into $CURDIR (where all the libggml-cpu-*.so files live) before
|
||||
# exec'ing the binary.
|
||||
cd "$CURDIR"
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR:$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export LOCALVQE_LIBRARY=$CURDIR/liblocalvqe.so
|
||||
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LOCALVQE_LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/localvqe "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LOCALVQE_LIBRARY"
|
||||
exec $CURDIR/localvqe "$@"
|
||||
14
backend/go/localvqe/test.sh
Executable file
14
backend/go/localvqe/test.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
cd "$CURDIR"
|
||||
|
||||
# The Go test suite uses a built localvqe binary for end-to-end
|
||||
# specs. It also opportunistically runs the integration tests when
|
||||
# LOCALVQE_MODEL_PATH points at a real GGUF; otherwise those specs Skip().
|
||||
|
||||
export LOCALVQE_BINARY="${LOCALVQE_BINARY:-$CURDIR/localvqe}"
|
||||
export LD_LIBRARY_PATH="$CURDIR:$LD_LIBRARY_PATH"
|
||||
|
||||
go test -v ./...
|
||||
@@ -29,6 +29,12 @@ type SherpaBackend struct {
|
||||
vadWindowSize int
|
||||
ttsSpeed float32
|
||||
onlineChunkSamples int
|
||||
|
||||
// Speaker diarization (offline pyannote + embedding extractor + clustering).
|
||||
// diarSampleRate is reported by sherpa at create time; we cache it so
|
||||
// runDiarization can resample only when the input doesn't already match.
|
||||
diarizer uintptr
|
||||
diarSampleRate int
|
||||
}
|
||||
|
||||
var onnxProvider = "cpu"
|
||||
@@ -128,6 +134,25 @@ var (
|
||||
|
||||
// TTS streaming callback trampoline
|
||||
shimTtsGenerateWithCallback func(tts uintptr, text string, sid int32, speed float32, cb uintptr, ud uintptr) uintptr
|
||||
|
||||
// Diarization config + result accessors (see csrc/shim.h).
|
||||
shimDiarizeConfigNew func() uintptr
|
||||
shimDiarizeConfigFree func(uintptr)
|
||||
shimDiarizeConfigSetSegmentationModel func(uintptr, string)
|
||||
shimDiarizeConfigSetSegmentationNumThreads func(uintptr, int32)
|
||||
shimDiarizeConfigSetSegmentationProvider func(uintptr, string)
|
||||
shimDiarizeConfigSetSegmentationDebug func(uintptr, int32)
|
||||
shimDiarizeConfigSetEmbeddingModel func(uintptr, string)
|
||||
shimDiarizeConfigSetEmbeddingNumThreads func(uintptr, int32)
|
||||
shimDiarizeConfigSetEmbeddingProvider func(uintptr, string)
|
||||
shimDiarizeConfigSetEmbeddingDebug func(uintptr, int32)
|
||||
shimDiarizeConfigSetClusteringNumClusters func(uintptr, int32)
|
||||
shimDiarizeConfigSetClusteringThreshold func(uintptr, float32)
|
||||
shimDiarizeConfigSetMinDurationOn func(uintptr, float32)
|
||||
shimDiarizeConfigSetMinDurationOff func(uintptr, float32)
|
||||
shimCreateOfflineSpeakerDiarization func(uintptr) uintptr
|
||||
shimDiarizeSetClustering func(uintptr, int32, float32)
|
||||
shimDiarizeSegmentAt func(segs uintptr, i int32, outStart unsafe.Pointer, outEnd unsafe.Pointer, outSpeaker unsafe.Pointer)
|
||||
)
|
||||
|
||||
// libsherpa-onnx-c-api pass-throughs — called directly from Go via purego.
|
||||
@@ -172,6 +197,18 @@ var (
|
||||
sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr
|
||||
sherpaDestroyOfflineTtsGeneratedAudio func(audio uintptr)
|
||||
sherpaOfflineTtsSampleRate func(tts uintptr) int32
|
||||
|
||||
// Offline speaker diarization. Result handle owns the segment-array
|
||||
// pointer returned by ResultSortByStartTime; destroy the segment
|
||||
// array first, then the result, then (at backend Free()) the diarizer.
|
||||
sherpaDestroyOfflineSpeakerDiarization func(sd uintptr)
|
||||
sherpaOfflineSpeakerDiarizationGetSampleRate func(sd uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationProcess func(sd uintptr, samples unsafe.Pointer, n int32) uintptr
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSegments func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSpeakers func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultSortByStartTime func(result uintptr) uintptr
|
||||
sherpaOfflineSpeakerDiarizationDestroySegment func(segs uintptr)
|
||||
sherpaDestroyOfflineSpeakerDiarizationResult func(result uintptr)
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -292,6 +329,24 @@ func loadSherpaLibsOnce() error {
|
||||
{&shimSpeechSegmentStart, "sherpa_shim_speech_segment_start"},
|
||||
{&shimSpeechSegmentN, "sherpa_shim_speech_segment_n"},
|
||||
{&shimTtsGenerateWithCallback, "sherpa_shim_tts_generate_with_callback"},
|
||||
|
||||
{&shimDiarizeConfigNew, "sherpa_shim_diarize_config_new"},
|
||||
{&shimDiarizeConfigFree, "sherpa_shim_diarize_config_free"},
|
||||
{&shimDiarizeConfigSetSegmentationModel, "sherpa_shim_diarize_config_set_segmentation_model"},
|
||||
{&shimDiarizeConfigSetSegmentationNumThreads, "sherpa_shim_diarize_config_set_segmentation_num_threads"},
|
||||
{&shimDiarizeConfigSetSegmentationProvider, "sherpa_shim_diarize_config_set_segmentation_provider"},
|
||||
{&shimDiarizeConfigSetSegmentationDebug, "sherpa_shim_diarize_config_set_segmentation_debug"},
|
||||
{&shimDiarizeConfigSetEmbeddingModel, "sherpa_shim_diarize_config_set_embedding_model"},
|
||||
{&shimDiarizeConfigSetEmbeddingNumThreads, "sherpa_shim_diarize_config_set_embedding_num_threads"},
|
||||
{&shimDiarizeConfigSetEmbeddingProvider, "sherpa_shim_diarize_config_set_embedding_provider"},
|
||||
{&shimDiarizeConfigSetEmbeddingDebug, "sherpa_shim_diarize_config_set_embedding_debug"},
|
||||
{&shimDiarizeConfigSetClusteringNumClusters, "sherpa_shim_diarize_config_set_clustering_num_clusters"},
|
||||
{&shimDiarizeConfigSetClusteringThreshold, "sherpa_shim_diarize_config_set_clustering_threshold"},
|
||||
{&shimDiarizeConfigSetMinDurationOn, "sherpa_shim_diarize_config_set_min_duration_on"},
|
||||
{&shimDiarizeConfigSetMinDurationOff, "sherpa_shim_diarize_config_set_min_duration_off"},
|
||||
{&shimCreateOfflineSpeakerDiarization, "sherpa_shim_create_offline_speaker_diarization"},
|
||||
{&shimDiarizeSetClustering, "sherpa_shim_diarize_set_clustering"},
|
||||
{&shimDiarizeSegmentAt, "sherpa_shim_diarize_segment_at"},
|
||||
} {
|
||||
purego.RegisterLibFunc(r.ptr, shim, r.name)
|
||||
}
|
||||
@@ -334,6 +389,15 @@ func loadSherpaLibsOnce() error {
|
||||
{&sherpaOfflineTtsGenerate, "SherpaOnnxOfflineTtsGenerate"},
|
||||
{&sherpaDestroyOfflineTtsGeneratedAudio, "SherpaOnnxDestroyOfflineTtsGeneratedAudio"},
|
||||
{&sherpaOfflineTtsSampleRate, "SherpaOnnxOfflineTtsSampleRate"},
|
||||
|
||||
{&sherpaDestroyOfflineSpeakerDiarization, "SherpaOnnxDestroyOfflineSpeakerDiarization"},
|
||||
{&sherpaOfflineSpeakerDiarizationGetSampleRate, "SherpaOnnxOfflineSpeakerDiarizationGetSampleRate"},
|
||||
{&sherpaOfflineSpeakerDiarizationProcess, "SherpaOnnxOfflineSpeakerDiarizationProcess"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultGetNumSegments, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultGetNumSpeakers, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultSortByStartTime, "SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime"},
|
||||
{&sherpaOfflineSpeakerDiarizationDestroySegment, "SherpaOnnxOfflineSpeakerDiarizationDestroySegment"},
|
||||
{&sherpaDestroyOfflineSpeakerDiarizationResult, "SherpaOnnxOfflineSpeakerDiarizationDestroyResult"},
|
||||
} {
|
||||
purego.RegisterLibFunc(r.ptr, capi, r.name)
|
||||
}
|
||||
@@ -383,6 +447,11 @@ func isVADType(t string) bool {
|
||||
return t == "vad"
|
||||
}
|
||||
|
||||
func isDiarizationType(t string) bool {
|
||||
t = strings.ToLower(t)
|
||||
return t == "diarization" || t == "diarize" || t == "speaker-diarization"
|
||||
}
|
||||
|
||||
// Model-options prefixes recognised by this backend. Kept as typed
|
||||
// constants so the asrFamily / loadWhisperASR / loadGenericASR paths
|
||||
// can all speak the same vocabulary.
|
||||
@@ -423,6 +492,19 @@ const (
|
||||
optionOnlineRule2 = "online.rule2_min_trailing_silence="
|
||||
optionOnlineRule3 = "online.rule3_min_utterance_length="
|
||||
optionOnlineChunkSamples = "online.chunk_samples="
|
||||
|
||||
// Speaker diarization (offline pyannote + speaker-embedding extractor).
|
||||
// `diarize.segmentation_model` overrides the auto-detected pyannote
|
||||
// segmentation .onnx in modelDir; `diarize.embedding_model` does the
|
||||
// same for the speaker-embedding extractor. `diarize.num_clusters`
|
||||
// pins a known speaker count at load time; per-call DiarizeRequest
|
||||
// fields take precedence at process time.
|
||||
optionDiarizeSegmentationModel = "diarize.segmentation_model="
|
||||
optionDiarizeEmbeddingModel = "diarize.embedding_model="
|
||||
optionDiarizeNumClusters = "diarize.num_clusters="
|
||||
optionDiarizeThreshold = "diarize.threshold="
|
||||
optionDiarizeMinDurationOn = "diarize.min_duration_on="
|
||||
optionDiarizeMinDurationOff = "diarize.min_duration_off="
|
||||
)
|
||||
|
||||
func hasOption(opts *pb.ModelOptions, prefix string) bool {
|
||||
@@ -493,6 +575,9 @@ func (s *SherpaBackend) Load(opts *pb.ModelOptions) error {
|
||||
if isVADType(opts.Type) {
|
||||
return s.loadVAD(opts)
|
||||
}
|
||||
if isDiarizationType(opts.Type) {
|
||||
return s.loadDiarization(opts)
|
||||
}
|
||||
// An explicit `subtype=...` option routes to ASR even when Type is
|
||||
// unset — handy for the e2e-backends harness, which doesn't know
|
||||
// about ModelOptions.Type.
|
||||
@@ -1247,3 +1332,176 @@ func (s *SherpaBackend) TTSStream(req *pb.TTSRequest, results chan []byte) error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================
|
||||
// Speaker diarization (offline)
|
||||
// =============================================================
|
||||
//
|
||||
// Conventions:
|
||||
// - opts.ModelFile is the pyannote segmentation .onnx (e.g. model.onnx
|
||||
// under sherpa-onnx-pyannote-segmentation-3-0/). Override with
|
||||
// `diarize.segmentation_model=` if the gallery layout differs.
|
||||
// - The speaker-embedding extractor must be provided via
|
||||
// `diarize.embedding_model=`. There's no reliable filename heuristic
|
||||
// we can rely on (3dspeaker, NeMo, WeSpeaker all ship with
|
||||
// model-specific names), so we require it to be explicit.
|
||||
// - Both paths are resolved relative to opts.ModelPath if not absolute.
|
||||
|
||||
func (s *SherpaBackend) loadDiarization(opts *pb.ModelOptions) error {
|
||||
if s.diarizer != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelDir := filepath.Dir(opts.ModelFile)
|
||||
segModel := findOptionValue(opts, optionDiarizeSegmentationModel, opts.ModelFile)
|
||||
if segModel != "" && !filepath.IsAbs(segModel) && opts.ModelPath != "" {
|
||||
segModel = filepath.Join(opts.ModelPath, segModel)
|
||||
}
|
||||
if !fileExists(segModel) {
|
||||
return fmt.Errorf("sherpa-onnx diarization: pyannote segmentation model not found at %q (set diarize.segmentation_model=...)", segModel)
|
||||
}
|
||||
|
||||
embModel := findOptionValue(opts, optionDiarizeEmbeddingModel, "")
|
||||
if embModel == "" {
|
||||
return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model is required — pass options: [diarize.embedding_model=<path>] (e.g. 3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx)")
|
||||
}
|
||||
if !filepath.IsAbs(embModel) {
|
||||
base := opts.ModelPath
|
||||
if base == "" {
|
||||
base = modelDir
|
||||
}
|
||||
embModel = filepath.Join(base, embModel)
|
||||
}
|
||||
if !fileExists(embModel) {
|
||||
return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model not found at %q", embModel)
|
||||
}
|
||||
|
||||
threads := int32(1)
|
||||
if opts.Threads != 0 {
|
||||
threads = opts.Threads
|
||||
}
|
||||
|
||||
cfg := shimDiarizeConfigNew()
|
||||
defer shimDiarizeConfigFree(cfg)
|
||||
|
||||
shimDiarizeConfigSetSegmentationModel(cfg, segModel)
|
||||
shimDiarizeConfigSetSegmentationNumThreads(cfg, threads)
|
||||
shimDiarizeConfigSetSegmentationProvider(cfg, onnxProvider)
|
||||
shimDiarizeConfigSetSegmentationDebug(cfg, 0)
|
||||
|
||||
shimDiarizeConfigSetEmbeddingModel(cfg, embModel)
|
||||
shimDiarizeConfigSetEmbeddingNumThreads(cfg, threads)
|
||||
shimDiarizeConfigSetEmbeddingProvider(cfg, onnxProvider)
|
||||
shimDiarizeConfigSetEmbeddingDebug(cfg, 0)
|
||||
|
||||
shimDiarizeConfigSetClusteringNumClusters(cfg, findOptionInt(opts, optionDiarizeNumClusters, -1))
|
||||
shimDiarizeConfigSetClusteringThreshold(cfg, findOptionFloat(opts, optionDiarizeThreshold, 0.5))
|
||||
shimDiarizeConfigSetMinDurationOn(cfg, findOptionFloat(opts, optionDiarizeMinDurationOn, 0.3))
|
||||
shimDiarizeConfigSetMinDurationOff(cfg, findOptionFloat(opts, optionDiarizeMinDurationOff, 0.5))
|
||||
|
||||
sd := shimCreateOfflineSpeakerDiarization(cfg)
|
||||
if sd == 0 {
|
||||
return fmt.Errorf("sherpa-onnx diarization: failed to create diarizer (segmentation=%s embedding=%s)", segModel, embModel)
|
||||
}
|
||||
s.diarizer = sd
|
||||
s.diarSampleRate = int(sherpaOfflineSpeakerDiarizationGetSampleRate(sd))
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyDiarizeOverrides re-applies clustering knobs onto an existing
|
||||
// diarizer when per-call DiarizeRequest fields are set. Both -1/0 sentinels
|
||||
// follow sherpa's convention: num_clusters<=0 → use threshold-based
|
||||
// clustering, threshold<=0 → keep load-time default.
|
||||
func (s *SherpaBackend) applyDiarizeOverrides(req *pb.DiarizeRequest) {
|
||||
num := int32(-1)
|
||||
if req.NumSpeakers > 0 {
|
||||
num = req.NumSpeakers
|
||||
}
|
||||
threshold := float32(0)
|
||||
if req.ClusteringThreshold > 0 {
|
||||
threshold = req.ClusteringThreshold
|
||||
}
|
||||
if num > 0 || threshold > 0 {
|
||||
shimDiarizeSetClustering(s.diarizer, num, threshold)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SherpaBackend) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) {
|
||||
if s.diarizer == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization not loaded (model must be loaded with type=diarization)")
|
||||
}
|
||||
if req.Dst == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: DiarizeRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "sherpa-diarize")
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
wavPath := filepath.Join(dir, "input.wav")
|
||||
if err := utils.AudioToWav(req.Dst, wavPath); err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to convert audio to wav: %w", err)
|
||||
}
|
||||
|
||||
wave := sherpaReadWave(wavPath)
|
||||
if wave == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to read wav %s", wavPath)
|
||||
}
|
||||
defer sherpaFreeWave(wave)
|
||||
|
||||
sr := int(shimWaveSampleRate(wave))
|
||||
nSamples := shimWaveNumSamples(wave)
|
||||
samples := shimWaveSamples(wave)
|
||||
duration := float32(nSamples) / float32(sr)
|
||||
if sr != s.diarSampleRate {
|
||||
// AudioToWav already targets 16 kHz; pyannote-3.0 also wants 16 kHz, so
|
||||
// this branch should be unreachable. Fail loudly instead of silently
|
||||
// passing mismatched audio to the model.
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: input sample rate %d Hz does not match model %d Hz", sr, s.diarSampleRate)
|
||||
}
|
||||
|
||||
s.applyDiarizeOverrides(req)
|
||||
|
||||
result := sherpaOfflineSpeakerDiarizationProcess(s.diarizer, samples, nSamples)
|
||||
if result == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: process failed")
|
||||
}
|
||||
defer sherpaDestroyOfflineSpeakerDiarizationResult(result)
|
||||
|
||||
numSegments := sherpaOfflineSpeakerDiarizationResultGetNumSegments(result)
|
||||
numSpeakers := sherpaOfflineSpeakerDiarizationResultGetNumSpeakers(result)
|
||||
if numSegments <= 0 {
|
||||
return pb.DiarizeResponse{
|
||||
Segments: []*pb.DiarizeSegment{},
|
||||
NumSpeakers: numSpeakers,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
segs := sherpaOfflineSpeakerDiarizationResultSortByStartTime(result)
|
||||
if segs == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: failed to retrieve segments")
|
||||
}
|
||||
defer sherpaOfflineSpeakerDiarizationDestroySegment(segs)
|
||||
|
||||
out := make([]*pb.DiarizeSegment, 0, numSegments)
|
||||
for i := range int(numSegments) {
|
||||
var start, end float32
|
||||
var spk int32
|
||||
shimDiarizeSegmentAt(segs, int32(i),
|
||||
unsafe.Pointer(&start), unsafe.Pointer(&end), unsafe.Pointer(&spk))
|
||||
out = append(out, &pb.DiarizeSegment{
|
||||
Id: int32(i),
|
||||
Start: start,
|
||||
End: end,
|
||||
Speaker: strconv.FormatInt(int64(spk), 10),
|
||||
})
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: out,
|
||||
NumSpeakers: numSpeakers,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -310,6 +310,87 @@ int32_t sherpa_shim_speech_segment_n(const void *h) {
|
||||
return ((const SherpaOnnxSpeechSegment *)h)->n;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Offline speaker diarization config
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_diarize_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxOfflineSpeakerDiarizationConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxOfflineSpeakerDiarizationConfig *c =
|
||||
(SherpaOnnxOfflineSpeakerDiarizationConfig *)h;
|
||||
free((char *)c->segmentation.pyannote.model);
|
||||
free((char *)c->segmentation.provider);
|
||||
free((char *)c->embedding.model);
|
||||
free((char *)c->embedding.provider);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_config_set_segmentation_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.pyannote.model, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.provider, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.debug = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.model, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.provider, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.debug = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_clustering_num_clusters(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.num_clusters = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_clustering_threshold(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.threshold = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_min_duration_on(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_on = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_min_duration_off(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_off = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_offline_speaker_diarization(void *h) {
|
||||
return (void *)SherpaOnnxCreateOfflineSpeakerDiarization(
|
||||
(const SherpaOnnxOfflineSpeakerDiarizationConfig *)h);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold) {
|
||||
if (!sd) return;
|
||||
SherpaOnnxOfflineSpeakerDiarizationConfig cfg;
|
||||
memset(&cfg, 0, sizeof(cfg));
|
||||
cfg.clustering.num_clusters = num_clusters;
|
||||
cfg.clustering.threshold = threshold;
|
||||
SherpaOnnxOfflineSpeakerDiarizationSetConfig(
|
||||
(const SherpaOnnxOfflineSpeakerDiarization *)sd, &cfg);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_segment_at(const void *segs, int32_t i,
|
||||
float *out_start, float *out_end,
|
||||
int32_t *out_speaker) {
|
||||
const SherpaOnnxOfflineSpeakerDiarizationSegment *arr =
|
||||
(const SherpaOnnxOfflineSpeakerDiarizationSegment *)segs;
|
||||
if (out_start) *out_start = arr[i].start;
|
||||
if (out_end) *out_end = arr[i].end;
|
||||
if (out_speaker) *out_speaker = arr[i].speaker;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// TTS streaming callback trampoline
|
||||
// ==================================================================
|
||||
|
||||
@@ -109,6 +109,41 @@ const float *sherpa_shim_generated_audio_samples(const void *audio);
|
||||
int32_t sherpa_shim_speech_segment_start(const void *seg);
|
||||
int32_t sherpa_shim_speech_segment_n(const void *seg);
|
||||
|
||||
// --- Offline speaker diarization config -----------------------------
|
||||
// Pyannote segmentation + speaker-embedding extractor + fast clustering.
|
||||
// The upstream config is a struct of nested structs; purego can't read or
|
||||
// build those across dlopen, so we expose a calloc'd opaque holder plus
|
||||
// flat setters, then hand it to sherpa via the create wrapper.
|
||||
void *sherpa_shim_diarize_config_new(void);
|
||||
void sherpa_shim_diarize_config_free(void *cfg);
|
||||
void sherpa_shim_diarize_config_set_segmentation_model(void *cfg, const char *path);
|
||||
void sherpa_shim_diarize_config_set_segmentation_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_segmentation_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_diarize_config_set_segmentation_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_embedding_model(void *cfg, const char *path);
|
||||
void sherpa_shim_diarize_config_set_embedding_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_embedding_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_diarize_config_set_embedding_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_clustering_num_clusters(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_clustering_threshold(void *cfg, float v);
|
||||
void sherpa_shim_diarize_config_set_min_duration_on(void *cfg, float v);
|
||||
void sherpa_shim_diarize_config_set_min_duration_off(void *cfg, float v);
|
||||
void *sherpa_shim_create_offline_speaker_diarization(void *cfg);
|
||||
|
||||
// Apply just the clustering knobs onto a loaded diarizer (sherpa
|
||||
// supports re-clustering after Create), so per-call overrides like
|
||||
// num_speakers don't require re-loading the heavy ONNX models.
|
||||
void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold);
|
||||
|
||||
// Sherpa's ResultSortByStartTime returns a sherpa-allocated array of
|
||||
// SherpaOnnxOfflineSpeakerDiarizationSegment structs (free with
|
||||
// SherpaOnnxOfflineSpeakerDiarizationDestroySegment). Purego can't read
|
||||
// fields out of an array of C structs, so this getter copies one
|
||||
// segment's fields into the caller-supplied float/int32 cells.
|
||||
void sherpa_shim_diarize_segment_at(const void *segs, int32_t i,
|
||||
float *out_start, float *out_end,
|
||||
int32_t *out_speaker);
|
||||
|
||||
// --- TTS streaming callback trampoline -----------------------------
|
||||
// Replaces the //export sherpaTtsGoCallback + callbacks.c bridge pattern.
|
||||
// `callback_ptr` is the C-callable function pointer returned by
|
||||
|
||||
@@ -6,9 +6,12 @@ GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# vibevoice.cpp version
|
||||
# vibevoice.cpp version. Pinned to a commit hash and auto-bumped by
|
||||
# .github/workflows/bump_deps.yaml (the matrix entry mirrors what we
|
||||
# already do for ik_llama.cpp / llama.cpp / whisper.cpp). Floating on
|
||||
# `master` led to silent ABI breaks reaching CI — pin it.
|
||||
VIBEVOICE_REPO?=https://github.com/mudler/vibevoice.cpp
|
||||
VIBEVOICE_CPP_VERSION?=master
|
||||
VIBEVOICE_CPP_VERSION?=ad856bda6b1311b7f3d7c4a667be43eeb8a8249a
|
||||
SO_TARGET?=libgovibevoicecpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -3,8 +3,11 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
@@ -12,15 +15,102 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// vv_capi_asr loads audio with load_wav_24k_mono — a 24 kHz mono s16le
|
||||
// WAV is the format the model was trained on. Inputs already in that
|
||||
// format pass through; everything else is converted via ffmpeg, which
|
||||
// is therefore a runtime requirement only when callers upload non-WAV
|
||||
// (or non-24 kHz mono s16le WAV) audio. Skipping ffmpeg on the happy
|
||||
// path matters for the e2e-backends test container, which does not
|
||||
// ship ffmpeg but feeds the backend pre-cooked 24 kHz mono WAVs.
|
||||
const vibevoiceASRSampleRate = 24000
|
||||
|
||||
// prepareWavInput resolves `src` to a 24 kHz mono s16le WAV path that
|
||||
// vv_capi_asr's load_wav_24k_mono accepts. Returns the resolved path
|
||||
// plus a cleanup func; both must be honoured by the caller.
|
||||
//
|
||||
// Pass-through happens when `src` already has the right WAV format —
|
||||
// no ffmpeg required. Otherwise we shell out to ffmpeg into a temp
|
||||
// dir; if ffmpeg isn't on PATH we surface a clear error mentioning the
|
||||
// underlying format mismatch.
|
||||
func prepareWavInput(src string) (string, func(), error) {
|
||||
if src == "" {
|
||||
return "", func() {}, fmt.Errorf("empty audio path")
|
||||
}
|
||||
if isVibevoiceCompatibleWav(src) {
|
||||
return src, func() {}, nil
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "vibevoice-asr")
|
||||
if err != nil {
|
||||
return "", func() {}, fmt.Errorf("mkdtemp: %w", err)
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
wavPath := filepath.Join(dir, "input.wav")
|
||||
|
||||
// -y: overwrite, -ar 24000: target sample rate, -ac 1: mono,
|
||||
// -acodec pcm_s16le: signed 16-bit little-endian PCM (load_wav_24k_mono
|
||||
// only accepts s16le).
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-y", "-i", src,
|
||||
"-ar", fmt.Sprintf("%d", vibevoiceASRSampleRate),
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
wavPath,
|
||||
)
|
||||
cmd.Env = []string{}
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, fmt.Errorf("ffmpeg convert to 24k mono wav: %w (output: %s)", err, string(out))
|
||||
}
|
||||
return wavPath, cleanup, nil
|
||||
}
|
||||
|
||||
// isVibevoiceCompatibleWav returns true when `src` carries the RIFF/WAVE
|
||||
// magic bytes. vibevoice's load_wav_24k_mono uses drwav under the hood,
|
||||
// which accepts any PCM/IEEE-float WAV at any sample rate and downmixes
|
||||
// multi-channel input to mono on its own — so any valid WAV passes
|
||||
// through to the C side without conversion. Anything else (MP3, OGG,
|
||||
// FLAC, ...) needs ffmpeg.
|
||||
func isVibevoiceCompatibleWav(src string) bool {
|
||||
f, err := os.Open(src)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
// 0..3 = "RIFF", 8..11 = "WAVE".
|
||||
var hdr [12]byte
|
||||
if _, err := io.ReadFull(f, hdr[:]); err != nil {
|
||||
return false
|
||||
}
|
||||
return string(hdr[0:4]) == "RIFF" && string(hdr[8:12]) == "WAVE"
|
||||
}
|
||||
|
||||
// asrMaxNewTokens caps the ASR generation budget. The C ABI defaults to
|
||||
// 256 when 0 is passed — far too small for anything past ~10s of speech.
|
||||
// Vibevoice generates ~30 tokens per second of audio, so 16 384 covers
|
||||
// roughly 9 minutes of dialogue, well past any normal /v1/audio/diarization
|
||||
// upload. Going higher costs little since generation stops at EOS.
|
||||
const asrMaxNewTokens = 16384
|
||||
|
||||
// vibevoice.cpp synthesizes 24 kHz mono 16-bit PCM. Hardcoded - the
|
||||
// model itself is fixed-rate; if the upstream ever changes this we'll
|
||||
// pick it up via vv_capi_version().
|
||||
const vibevoiceSampleRate = uint32(24000)
|
||||
|
||||
// purego-bound entry points from libgovibevoicecpp.
|
||||
//
|
||||
// vv_capi_tts takes a `const char* const* ref_audio_paths` array (used
|
||||
// by the 1.5B variant for runtime voice cloning; the realtime-0.5B
|
||||
// path leaves it NULL and uses voice_path instead). purego marshals a
|
||||
// Go []*byte slice as **char by passing the underlying array's address.
|
||||
// A nil/empty slice marshals to NULL, which matches the C contract for
|
||||
// "no reference audio".
|
||||
var (
|
||||
CppLoad func(ttsModel, asrModel, tokenizer, voice string, threads int32) int32
|
||||
CppTTS func(text, voicePath, dstWav string,
|
||||
CppTTS func(text, voicePath string,
|
||||
refAudioPaths []*byte, nRefAudioPaths int32,
|
||||
dstWav string,
|
||||
nSteps int32, cfgScale float32, maxSpeechFrames int32, seed uint32) int32
|
||||
CppASR func(srcWav string, outJSON []byte, capacity uint64,
|
||||
maxNewTokens int32) int32
|
||||
@@ -44,6 +134,14 @@ type VibevoiceCpp struct {
|
||||
asrModel string
|
||||
tokenizer string
|
||||
voice string
|
||||
|
||||
// refAudio is the load-time default list of reference WAVs used by
|
||||
// the 1.5B model (one per speaker). Sourced from
|
||||
// ModelOptions.AudioPath (config_file's `audio_path:`) — comma-
|
||||
// separated for multi-speaker. Per-call TTSRequest.Voice can
|
||||
// override it. Empty for the realtime-0.5B path, which conditions
|
||||
// on a pre-baked voice gguf via `voice` instead.
|
||||
refAudio []string
|
||||
}
|
||||
|
||||
// resolvePath joins a relative path onto `relTo`. The gallery
|
||||
@@ -89,6 +187,25 @@ func (v *VibevoiceCpp) parseOptions(opts []string, relTo string) string {
|
||||
return role
|
||||
}
|
||||
|
||||
// parseRefAudio splits a comma-separated audio_path value into a
|
||||
// resolved list of WAVs. The 1.5B model uses one WAV per speaker;
|
||||
// callers that only need a single reference set audio_path to a single
|
||||
// path. Empty / whitespace-only entries are skipped.
|
||||
func parseRefAudio(audioPath, relTo string) []string {
|
||||
if audioPath == "" {
|
||||
return nil
|
||||
}
|
||||
var out []string
|
||||
for _, p := range strings.Split(audioPath, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, resolvePath(p, relTo))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *VibevoiceCpp) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return fmt.Errorf("vibevoice-cpp: ModelFile is required")
|
||||
@@ -109,6 +226,12 @@ func (v *VibevoiceCpp) Load(opts *pb.ModelOptions) error {
|
||||
}
|
||||
role := v.parseOptions(opts.Options, v.modelRoot)
|
||||
|
||||
// 1.5B reference WAVs ride on ModelOptions.AudioPath (config_file's
|
||||
// `audio_path:` key) — same convention other audio backends already
|
||||
// follow. Single-speaker = single path; multi-speaker = comma list,
|
||||
// one WAV per Speaker N: tag in TTSRequest.text.
|
||||
v.refAudio = parseRefAudio(opts.AudioPath, v.modelRoot)
|
||||
|
||||
// ModelFile fills the "primary" role-slot determined by `type=`
|
||||
// in Options (defaults to tts). The other slot stays exactly as
|
||||
// Options set it - so a closed-loop config with ModelFile=tts.gguf
|
||||
@@ -142,8 +265,8 @@ func (v *VibevoiceCpp) Load(opts *pb.ModelOptions) error {
|
||||
v.threads = threads
|
||||
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"[vibevoice-cpp] Loading: tts=%q asr=%q tokenizer=%q voice=%q threads=%d\n",
|
||||
v.ttsModel, v.asrModel, v.tokenizer, v.voice, threads)
|
||||
"[vibevoice-cpp] Loading: tts=%q asr=%q tokenizer=%q voice=%q ref_audio=%v threads=%d\n",
|
||||
v.ttsModel, v.asrModel, v.tokenizer, v.voice, v.refAudio, threads)
|
||||
|
||||
if rc := CppLoad(v.ttsModel, v.asrModel, v.tokenizer, v.voice, int32(threads)); rc != 0 {
|
||||
return fmt.Errorf("vibevoice-cpp: vv_capi_load failed (rc=%d)", rc)
|
||||
@@ -161,10 +284,35 @@ func (v *VibevoiceCpp) TTS(req *pb.TTSRequest) error {
|
||||
return fmt.Errorf("vibevoice-cpp: TTS requires both text and dst")
|
||||
}
|
||||
|
||||
// req.Voice may be a bare filename (e.g. "voice-en-Emma.gguf") or an
|
||||
// absolute path. Resolve via the same modelRoot Load() used for
|
||||
// Options[] so a swap-voice request mirrors the gallery's layout.
|
||||
voice := resolvePath(req.Voice, v.modelRoot)
|
||||
// TTSRequest.Voice carries the per-call override. Routing depends
|
||||
// on the loaded model variant:
|
||||
// * realtime-0.5B → expects a baked voice .gguf (single path).
|
||||
// * 1.5B → expects one or more raw 24 kHz mono .wav
|
||||
// reference clips for runtime voice cloning;
|
||||
// comma-separated to address multi-speaker
|
||||
// dialogs (Speaker 0..n-1 follow the order).
|
||||
// We pick the branch by extension / shape of the override; if no
|
||||
// override is given, fall back to the load-time defaults.
|
||||
voice := ""
|
||||
var refAudio []string
|
||||
if reqVoice := strings.TrimSpace(req.Voice); reqVoice != "" {
|
||||
if isRefAudioOverride(reqVoice) {
|
||||
for _, p := range strings.Split(reqVoice, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
refAudio = append(refAudio, resolvePath(p, v.modelRoot))
|
||||
}
|
||||
} else {
|
||||
voice = resolvePath(reqVoice, v.modelRoot)
|
||||
}
|
||||
} else {
|
||||
// No per-call override. v.voice already went to vv_capi_load
|
||||
// for realtime-0.5B; ref_audio is per-call only on the C ABI,
|
||||
// so the gallery's `ref_audio:` defaults are re-passed here.
|
||||
refAudio = append(refAudio, v.refAudio...)
|
||||
}
|
||||
|
||||
if req.Language != nil && *req.Language != "" {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
@@ -177,13 +325,51 @@ func (v *VibevoiceCpp) TTS(req *pb.TTSRequest) error {
|
||||
defaultMaxFrames = 200
|
||||
)
|
||||
defaultCfg := float32(1.3)
|
||||
if rc := CppTTS(text, voice, dst,
|
||||
int32(defaultSteps), defaultCfg, int32(defaultMaxFrames), 0); rc != 0 {
|
||||
|
||||
refPtrs, refKeep := newCStringArray(refAudio)
|
||||
rc := CppTTS(text, voice, refPtrs, int32(len(refPtrs)), dst,
|
||||
int32(defaultSteps), defaultCfg, int32(defaultMaxFrames), 0)
|
||||
// Hold the backing buffers past the cgo call. purego marshals
|
||||
// []*byte by handing the C side the underlying array address; the
|
||||
// pointed-to NUL-terminated bytes must outlive the call.
|
||||
runtime.KeepAlive(refKeep)
|
||||
runtime.KeepAlive(refPtrs)
|
||||
if rc != 0 {
|
||||
return fmt.Errorf("vibevoice-cpp: vv_capi_tts failed (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRefAudioOverride decides whether a TTSRequest.Voice override should
|
||||
// be routed to ref_audio_paths (1.5B path) instead of voice_path
|
||||
// (realtime-0.5B). Either a comma-separated list (multi-speaker) or a
|
||||
// single .wav clip qualifies; a bare voice .gguf falls through.
|
||||
func isRefAudioOverride(s string) bool {
|
||||
if strings.Contains(s, ",") {
|
||||
return true
|
||||
}
|
||||
return strings.HasSuffix(strings.ToLower(s), ".wav")
|
||||
}
|
||||
|
||||
// newCStringArray builds the **char array vv_capi_tts expects, plus the
|
||||
// keep-alive slice the caller must runtime.KeepAlive across the cgo
|
||||
// call. A nil/empty input returns (nil, nil) which purego marshals to
|
||||
// the C NULL pointer.
|
||||
func newCStringArray(in []string) ([]*byte, [][]byte) {
|
||||
if len(in) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
keep := make([][]byte, len(in))
|
||||
ptrs := make([]*byte, len(in))
|
||||
for i, s := range in {
|
||||
b := make([]byte, len(s)+1)
|
||||
copy(b, s)
|
||||
keep[i] = b
|
||||
ptrs[i] = &b[0]
|
||||
}
|
||||
return ptrs, keep
|
||||
}
|
||||
|
||||
// asrSegment matches vibevoice's JSON output:
|
||||
//
|
||||
// [{"Start":0.0,"End":2.8,"Speaker":0,"Content":"…"}, ...]
|
||||
@@ -302,7 +488,13 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
out, err := v.callASR(req.Dst, 0)
|
||||
wavPath, cleanup, err := prepareWavInput(req.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
out, err := v.callASR(wavPath, asrMaxNewTokens)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
@@ -346,6 +538,83 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Diarize runs vibevoice's ASR and projects the speaker-labelled segment
|
||||
// list it returns natively. vibevoice.cpp's ASR prompt asks the model to
|
||||
// emit `[{"Start":..,"End":..,"Speaker":..,"Content":..}]`, so diarization
|
||||
// is a by-product of the same pass — we reuse callASR and re-shape.
|
||||
//
|
||||
// Speaker hints (num_speakers/min/max/threshold) and min_duration_on/off are
|
||||
// not actionable here: vibevoice's model picks the speaker count itself and
|
||||
// has no clustering knob. The HTTP layer documents this; we accept the
|
||||
// fields for API symmetry and ignore them.
|
||||
func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) {
|
||||
if v.asrModel == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: Diarize requires an ASR model (load options: type=asr)")
|
||||
}
|
||||
if req.Dst == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: DiarizeRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
wavPath, cleanup, err := prepareWavInput(req.Dst)
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
out, err := v.callASR(wavPath, asrMaxNewTokens)
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, err
|
||||
}
|
||||
if out == "" {
|
||||
return pb.DiarizeResponse{}, nil
|
||||
}
|
||||
|
||||
var segs []asrSegment
|
||||
if err := json.Unmarshal([]byte(out), &segs); err != nil {
|
||||
// Mirror AudioTranscription's fallback: vibevoice's ASR sometimes
|
||||
// emits free-form text instead of JSON for short or unusual audio.
|
||||
// Surface a single unknown-speaker segment carrying the full text
|
||||
// (when include_text is set) so the caller still gets coverage of
|
||||
// the whole clip rather than a hard failure.
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"[vibevoice-cpp] WARNING: vv_capi_asr returned non-JSON for diarization, falling back to single segment: %v\n", err)
|
||||
text := strings.TrimSpace(out)
|
||||
seg := &pb.DiarizeSegment{Id: 0, Speaker: "0"}
|
||||
if req.IncludeText {
|
||||
seg.Text = text
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: []*pb.DiarizeSegment{seg},
|
||||
NumSpeakers: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
speakers := make(map[int]struct{})
|
||||
segments := make([]*pb.DiarizeSegment, 0, len(segs))
|
||||
var duration float32
|
||||
for i, s := range segs {
|
||||
ds := &pb.DiarizeSegment{
|
||||
Id: int32(i),
|
||||
Start: float32(s.Start),
|
||||
End: float32(s.End),
|
||||
Speaker: fmt.Sprintf("%d", s.Speaker),
|
||||
}
|
||||
if req.IncludeText {
|
||||
ds.Text = strings.TrimSpace(s.Content)
|
||||
}
|
||||
segments = append(segments, ds)
|
||||
speakers[s.Speaker] = struct{}{}
|
||||
if float32(s.End) > duration {
|
||||
duration = float32(s.End)
|
||||
}
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: segments,
|
||||
NumSpeakers: int32(len(speakers)),
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream wraps AudioTranscription so the streaming
|
||||
// gRPC endpoint (server.go:AudioTranscriptionStream) sees its channel
|
||||
// close and the client doesn't sit waiting until deadline. vibevoice's
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=fc674574ca27cac59a15e5b22a09b9d9ad62aafe
|
||||
WHISPER_CPP_VERSION?=4bf733672b2871d4153158af4f621a6dd9104f4a
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -600,6 +600,38 @@
|
||||
nvidia-l4t: "nvidia-l4t-arm64-vibevoice-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-vibevoice-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice-cpp"
|
||||
- &localvqecpp
|
||||
name: "localvqe"
|
||||
description: |
|
||||
LocalVQE C++ backend using GGML — joint acoustic echo cancellation, noise
|
||||
suppression, and dereverberation (DeepVQE-style architecture). 16 kHz mono
|
||||
in / out, supports both batch and low-latency streaming. Implements the
|
||||
audio-transform capability.
|
||||
urls:
|
||||
- https://github.com/localai-org/LocalVQE
|
||||
tags:
|
||||
- audio-transform
|
||||
- aec
|
||||
- acoustic-echo-cancellation
|
||||
- noise-suppression
|
||||
- dereverberation
|
||||
license: apache2
|
||||
alias: "localvqe"
|
||||
# Upstream LocalVQE only supports CPU and Vulkan; no CUDA/ROCm/SYCL/Metal
|
||||
# builds. GPU-class hardware that exposes a Vulkan ICD (NVIDIA, AMD, Intel
|
||||
# discrete + iGPU, Tegra) routes to the Vulkan image; everything else
|
||||
# falls back to the CPU build, which is already ~9× realtime on a desktop.
|
||||
capabilities:
|
||||
default: "cpu-localvqe"
|
||||
nvidia: "vulkan-localvqe"
|
||||
nvidia-cuda-12: "vulkan-localvqe"
|
||||
nvidia-cuda-13: "vulkan-localvqe"
|
||||
intel: "vulkan-localvqe"
|
||||
amd: "vulkan-localvqe"
|
||||
vulkan: "vulkan-localvqe"
|
||||
nvidia-l4t: "vulkan-localvqe"
|
||||
nvidia-l4t-cuda-12: "vulkan-localvqe"
|
||||
nvidia-l4t-cuda-13: "vulkan-localvqe"
|
||||
- &faster-whisper
|
||||
icon: https://avatars.githubusercontent.com/u/1520500?s=200&v=4
|
||||
description: |
|
||||
@@ -2785,6 +2817,27 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-vibevoice-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-vibevoice-cpp
|
||||
## localvqe
|
||||
- !!merge <<: *localvqecpp
|
||||
name: "cpu-localvqe"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-localvqe"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-localvqe
|
||||
- !!merge <<: *localvqecpp
|
||||
name: "cpu-localvqe-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-localvqe"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-localvqe
|
||||
- !!merge <<: *localvqecpp
|
||||
name: "vulkan-localvqe"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-localvqe"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-localvqe
|
||||
- !!merge <<: *localvqecpp
|
||||
name: "vulkan-localvqe-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-localvqe"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-localvqe
|
||||
## kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "kokoro-development"
|
||||
|
||||
@@ -318,6 +318,21 @@ _makeVenvPortable() {
|
||||
}
|
||||
|
||||
|
||||
# Apply the venv to the current process: VIRTUAL_ENV, PATH, PYTHONHOME hygiene.
|
||||
# Equivalent to the runtime portion of `source bin/activate`, but computed from
|
||||
# $EDIR (resolved at runtime via realpath) instead of the path baked into
|
||||
# bin/activate at venv-create time. `uv venv` (and `python -m venv`) both bake
|
||||
# the create-time absolute path in, so sourcing activate on a relocated venv —
|
||||
# e.g. one built at /vllm/venv inside a Docker stage and unpacked under
|
||||
# /backends/cuda13-vllm-development/venv at runtime — silently prepends a
|
||||
# stale, non-existent path to $PATH. Doing the setup ourselves sidesteps that;
|
||||
# this is the same approach `uv run` takes internally.
|
||||
_activateVenv() {
|
||||
export VIRTUAL_ENV="${EDIR}/venv"
|
||||
export PATH="${EDIR}/venv/bin:${PATH}"
|
||||
unset PYTHONHOME
|
||||
}
|
||||
|
||||
# ensureVenv makes sure that the venv for the backend both exists, and is activated.
|
||||
#
|
||||
# This function is idempotent, so you can call it as many times as you want and it will
|
||||
@@ -354,7 +369,7 @@ function ensureVenv() {
|
||||
venv_args="--copies"
|
||||
fi
|
||||
"${interpreter}" -m venv ${venv_args} "${EDIR}/venv"
|
||||
source "${EDIR}/venv/bin/activate"
|
||||
_activateVenv
|
||||
"${interpreter}" -m pip install --upgrade pip
|
||||
else
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
@@ -375,7 +390,7 @@ function ensureVenv() {
|
||||
fi
|
||||
|
||||
if [ "x${VIRTUAL_ENV:-}" != "x${EDIR}/venv" ]; then
|
||||
source "${EDIR}/venv/bin/activate"
|
||||
_activateVenv
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
@@ -55,11 +55,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
resultSegments = []
|
||||
text = ""
|
||||
try:
|
||||
segments, info = self.model.transcribe(request.dst, beam_size=5, condition_on_previous_text=False)
|
||||
word_timestamps = "word" in request.timestamp_granularities
|
||||
segments, info = self.model.transcribe(request.dst, beam_size=5, condition_on_previous_text=False, word_timestamps=word_timestamps)
|
||||
id = 0
|
||||
for segment in segments:
|
||||
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text))
|
||||
words = []
|
||||
if word_timestamps and hasattr(segment, 'words'):
|
||||
for word in segment.words:
|
||||
words.append(backend_pb2.TranscriptWord(
|
||||
start=int(word.start * 1e9),
|
||||
end=int(word.end * 1e9),
|
||||
text=word.word
|
||||
))
|
||||
|
||||
resultSegments.append(backend_pb2.TranscriptSegment(
|
||||
id=id,
|
||||
start=int(segment.start * 1e9),
|
||||
end=int(segment.end * 1e9),
|
||||
text=segment.text,
|
||||
words=words
|
||||
))
|
||||
text += segment.text
|
||||
id += 1
|
||||
except Exception as err:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
transformers
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torch==2.7.1
|
||||
rerankers[transformers]
|
||||
@@ -1,4 +1,4 @@
|
||||
transformers
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torch==2.7.1
|
||||
rerankers[transformers]
|
||||
@@ -79,6 +79,14 @@ fi
|
||||
|
||||
cd vllm-omni/
|
||||
|
||||
# fa3-fwd ships no aarch64 wheels and there is no source distribution, so on
|
||||
# aarch64 (e.g. l4t13 / SBSA cu130) the upstream requirements/cuda.txt is
|
||||
# unsatisfiable. Drop it before resolving — vllm-omni does not hard-require
|
||||
# the fused FA3 kernel at import time on Jetson/SBSA targets.
|
||||
if [ "$(uname -m)" = "aarch64" ] && [ -f requirements/cuda.txt ]; then
|
||||
sed -i '/^fa3-fwd[[:space:]]*==/d' requirements/cuda.txt
|
||||
fi
|
||||
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e .
|
||||
else
|
||||
|
||||
@@ -18,12 +18,15 @@ else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
# Intel XPU: torch==2.11.0+xpu lives on the PyTorch XPU index, transitive
|
||||
# deps on PyPI — unsafe-best-match lets uv mix both. vllm-xpu-kernels only
|
||||
# ships a python3.12 wheel per upstream docs, so bump the portable Python
|
||||
# before installRequirements (matches the l4t13 pattern below).
|
||||
# https://github.com/vllm-project/vllm/blob/main/docs/getting_started/installation/gpu.xpu.inc.md
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="11"
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# CPU builds need unsafe-best-match to pull torch==2.10.0+cpu from the
|
||||
@@ -42,10 +45,12 @@ fi
|
||||
|
||||
# JetPack 7 / L4T arm64 wheels (torch, vllm, flash-attn) live on
|
||||
# pypi.jetson-ai-lab.io and are built for cp312, so bump the venv Python
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true. unsafe-best-match
|
||||
# is required because the jetson-ai-lab index lists transitive deps at
|
||||
# limited versions — without it uv pins to the first matching index and
|
||||
# fails to resolve a compatible wheel from PyPI.
|
||||
# accordingly. JetPack 6 keeps cp310 + USE_PIP=true.
|
||||
#
|
||||
# l4t13 uses pyproject.toml (see the elif branch below) to pin only the
|
||||
# L4T-specific wheels to the jetson-ai-lab index via [tool.uv.sources].
|
||||
# That keeps PyPI as the resolution path for transitive deps like
|
||||
# anthropic/openai/propcache, which the L4T mirror's proxy 503s on.
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
@@ -53,16 +58,77 @@ if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# Intel XPU has no upstream-published vllm wheels, so we always build vllm
|
||||
# from source against torch-xpu and replace the default triton with
|
||||
# triton-xpu (matching torch 2.11). Mirrors the upstream procedure:
|
||||
# https://github.com/vllm-project/vllm/blob/main/docs/getting_started/installation/gpu.xpu.inc.md
|
||||
if [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
# Hide requirements-intel-after.txt so installRequirements doesn't
|
||||
# try `pip install vllm` (would either fail or grab a non-XPU wheel).
|
||||
_intel_after="${backend_dir}/requirements-intel-after.txt"
|
||||
_intel_after_bak=""
|
||||
if [ -f "${_intel_after}" ]; then
|
||||
_intel_after_bak="${_intel_after}.xpu.bak"
|
||||
mv "${_intel_after}" "${_intel_after_bak}"
|
||||
fi
|
||||
installRequirements
|
||||
if [ -n "${_intel_after_bak}" ]; then
|
||||
mv "${_intel_after_bak}" "${_intel_after}"
|
||||
fi
|
||||
|
||||
# vllm's CMake build needs the Intel oneAPI dpcpp/sycl compiler — the
|
||||
# base image (intel/oneapi-basekit) has it but the env isn't sourced.
|
||||
if [ -f /opt/intel/oneapi/setvars.sh ]; then
|
||||
set +u
|
||||
source /opt/intel/oneapi/setvars.sh --force
|
||||
set -u
|
||||
fi
|
||||
|
||||
_vllm_src=$(mktemp -d)
|
||||
trap 'rm -rf "${_vllm_src}"' EXIT
|
||||
git clone --depth 1 https://github.com/vllm-project/vllm "${_vllm_src}/vllm"
|
||||
pushd "${_vllm_src}/vllm"
|
||||
# Install vllm's own runtime deps (torch-xpu, vllm_xpu_kernels,
|
||||
# pydantic, fastapi, …) from upstream's requirements/xpu.txt — the
|
||||
# canonical source of truth. Avoids re-pinning everything ourselves.
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements/xpu.txt
|
||||
# Stock triton (NVIDIA-only) may have come in transitively; replace
|
||||
# with triton-xpu==3.7.0 which matches torch 2.11.
|
||||
uv pip uninstall triton triton-xpu 2>/dev/null || true
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} \
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu \
|
||||
triton-xpu==3.7.0
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# L4T arm64 (JetPack 7): drive the install through pyproject.toml so that
|
||||
# [tool.uv.sources] can pin torch/vllm/flash-attn/torchvision/torchaudio
|
||||
# to the jetson-ai-lab index, while everything else (transitive deps and
|
||||
# PyPI-resolvable packages like transformers) comes from PyPI. Bypasses
|
||||
# installRequirements because uv pip install -r requirements.txt does not
|
||||
# honor sources — see backend/python/vllm/pyproject.toml for the rationale.
|
||||
elif [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
pushd "${backend_dir}"
|
||||
# Build deps first (matches installRequirements' requirements-install.txt
|
||||
# pass — fastsafetensors and friends need pybind11 in the venv before
|
||||
# their sdists can build under --no-build-isolation).
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -r requirements-install.txt
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement pyproject.toml
|
||||
popd
|
||||
runProtogen
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
# kept here for hosts where the prebuilt wheel SIGILLs (CPU without the
|
||||
# required SIMD baseline, e.g. AVX-512 VNNI/BF16). Default CI uses a
|
||||
# bigger-runner with compatible hardware instead.
|
||||
if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
elif [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
# Temporarily hide the prebuilt wheel so installRequirements doesn't
|
||||
# pull it — the rest of the requirements files (base deps, torch,
|
||||
# transformers) are still installed normally.
|
||||
|
||||
@@ -45,5 +45,109 @@ copy_with_symlinks() {
|
||||
copy_with_symlinks libnuma.so.1
|
||||
copy_with_symlinks libgomp.so.1
|
||||
|
||||
# CPU profile only: bundle a g++ toolchain so torch._inductor's
|
||||
# ISA probe (always run at vllm engine startup, regardless of
|
||||
# enforce_eager) finds a C++ compiler. The LocalAI runtime image
|
||||
# is FROM ubuntu:24.04 with a minimal apt list that does not
|
||||
# include build-essential, and the backend image itself is FROM
|
||||
# scratch -- so without this, cpu-vllm crashes with
|
||||
# torch._inductor.exc.InvalidCxxCompiler at first inference
|
||||
# unless the operator manually sets TORCH_COMPILE_DISABLE=1.
|
||||
#
|
||||
# We snapshot every file owned by the toolchain packages, mirroring
|
||||
# the /usr/... layout into ${BACKEND}/toolchain/ so g++ can find
|
||||
# cc1plus, headers, libs etc. via GCC_EXEC_PREFIX / CPATH /
|
||||
# LIBRARY_PATH at runtime (libbackend.sh wires those up). Adds
|
||||
# ~400 MB to the cpu-vllm image, which is tolerable -- cpu-vllm is
|
||||
# already a niche profile.
|
||||
if [ "${BUILD_TYPE:-}" = "" ] && command -v dpkg-query >/dev/null 2>&1; then
|
||||
TOOLCHAIN_DIR="${CURDIR}/toolchain"
|
||||
mkdir -p "${TOOLCHAIN_DIR}"
|
||||
# The unversioned g++/gcc packages on Debian/Ubuntu only ship
|
||||
# symlinks; the actual binaries live in g++-${VER}/gcc-${VER}.
|
||||
# Discover the active version so the symlink targets get bundled
|
||||
# along with their owners.
|
||||
GCC_VER=$(gcc -dumpversion 2>/dev/null | cut -d. -f1 || true)
|
||||
# `g++-${VER}` itself is just another symlink layer on Debian/
|
||||
# Ubuntu — the real binary `x86_64-linux-gnu-g++-${VER}` lives
|
||||
# in `g++-${VER}-x86-64-linux-gnu` (a separate package pulled in
|
||||
# as a dependency). Same story for gcc/cpp. Compute the dpkg
|
||||
# arch-triplet to find the right package name for both amd64 and
|
||||
# arm64 hosts.
|
||||
case "$(dpkg --print-architecture 2>/dev/null)" in
|
||||
amd64) HOST_TRIPLET="x86-64-linux-gnu" ;;
|
||||
arm64) HOST_TRIPLET="aarch64-linux-gnu" ;;
|
||||
*) HOST_TRIPLET="" ;;
|
||||
esac
|
||||
PKGS=(g++ gcc cpp libstdc++-${GCC_VER}-dev libgcc-${GCC_VER}-dev libc6 libc6-dev binutils binutils-common libbinutils libc-dev-bin linux-libc-dev libcrypt-dev libgomp1 libstdc++6 libgcc-s1 libisl23 libmpc3 libmpfr6 libjansson4 libctf0 libctf-nobfd0 libsframe1)
|
||||
if [ -n "${GCC_VER}" ]; then
|
||||
PKGS+=("g++-${GCC_VER}" "gcc-${GCC_VER}" "cpp-${GCC_VER}" "gcc-${GCC_VER}-base")
|
||||
if [ -n "${HOST_TRIPLET}" ]; then
|
||||
PKGS+=(
|
||||
"g++-${GCC_VER}-${HOST_TRIPLET}"
|
||||
"gcc-${GCC_VER}-${HOST_TRIPLET}"
|
||||
"cpp-${GCC_VER}-${HOST_TRIPLET}"
|
||||
"binutils-${HOST_TRIPLET}"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
for pkg in "${PKGS[@]}"; do
|
||||
if ! dpkg-query -W "${pkg}" >/dev/null 2>&1; then
|
||||
continue
|
||||
fi
|
||||
# Copy each owned path, preserving symlinks and mode. We
|
||||
# tolerate dpkg listing directories alongside files.
|
||||
dpkg -L "${pkg}" | while IFS= read -r path; do
|
||||
if [ -L "${path}" ] || [ -f "${path}" ]; then
|
||||
mkdir -p "${TOOLCHAIN_DIR}$(dirname "${path}")"
|
||||
cp -aP "${path}" "${TOOLCHAIN_DIR}${path}" 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
done
|
||||
# Ubuntu's filesystem layout has /lib -> /usr/lib (UsrMerge) and
|
||||
# /lib64 -> /usr/lib64. ld scripts (e.g. libm.so) hardcode
|
||||
# `/lib/x86_64-linux-gnu/libm.so.6`; with --sysroot the linker
|
||||
# looks for that path under the sysroot, which means we need
|
||||
# the same symlinks under TOOLCHAIN_DIR.
|
||||
[ -e "${TOOLCHAIN_DIR}/lib" ] || ln -s usr/lib "${TOOLCHAIN_DIR}/lib"
|
||||
[ -e "${TOOLCHAIN_DIR}/lib64" ] || ln -s usr/lib64 "${TOOLCHAIN_DIR}/lib64"
|
||||
|
||||
# Replace the unversioned g++/gcc/cpp symlinks with wrapper
|
||||
# scripts that pass --sysroot=<toolchain> and -B <gcc-exec-prefix>.
|
||||
# Without these flags gcc would fall back to its compiled-in
|
||||
# /usr search and fail to find headers (the runtime image has no
|
||||
# libc6-dev) or fail to invoke `as`/`ld` (binutils not on PATH at
|
||||
# /usr/bin). Wrappers self-resolve their location at runtime so
|
||||
# they work from any BackendsPath.
|
||||
BIN_DIR="${TOOLCHAIN_DIR}/usr/bin"
|
||||
if [ -n "${GCC_VER}" ] && [ -n "${HOST_TRIPLET}" ]; then
|
||||
# HOST_TRIPLET in package names uses dashes ("x86-64-linux-gnu");
|
||||
# the binary suffix uses underscores in the arch part
|
||||
# ("x86_64-linux-gnu-g++-13"). Translate.
|
||||
BIN_TRIPLET=${HOST_TRIPLET//x86-64/x86_64}
|
||||
for tool in g++ gcc cpp; do
|
||||
real="${BIN_DIR}/${BIN_TRIPLET}-${tool}-${GCC_VER}"
|
||||
if [ -x "${real}" ]; then
|
||||
rm -f "${BIN_DIR}/${tool}" "${BIN_DIR}/${tool}-${GCC_VER}"
|
||||
cat > "${BIN_DIR}/${tool}" <<EOF
|
||||
#!/bin/bash
|
||||
# Auto-generated by package.sh. Passes --sysroot and -B so the
|
||||
# bundled toolchain works from any BackendsPath without depending
|
||||
# on libc6-dev / binutils being installed at /usr in the runtime
|
||||
# image. See backend/python/vllm/package.sh.
|
||||
DIR="\$(dirname "\$(readlink -f "\$0")")" # …/toolchain/usr/bin
|
||||
SYSROOT="\$(dirname "\$(dirname "\${DIR}")")" # …/toolchain
|
||||
exec "\${DIR}/${BIN_TRIPLET}-${tool}-${GCC_VER}" \\
|
||||
-B "\${SYSROOT}/usr/lib/gcc/${BIN_TRIPLET}/${GCC_VER}/" \\
|
||||
--sysroot="\${SYSROOT}" \\
|
||||
"\$@"
|
||||
EOF
|
||||
chmod +x "${BIN_DIR}/${tool}"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
echo "Bundled g++ toolchain (gcc-${GCC_VER}) into ${TOOLCHAIN_DIR} ($(du -sh "${TOOLCHAIN_DIR}" | cut -f1))"
|
||||
fi
|
||||
|
||||
echo "vllm packaging completed successfully"
|
||||
ls -liah "${LIB_DIR}/"
|
||||
|
||||
61
backend/python/vllm/pyproject.toml
Normal file
61
backend/python/vllm/pyproject.toml
Normal file
@@ -0,0 +1,61 @@
|
||||
# L4T arm64 (JetPack 7 / sbsa cu130) install spec for the vllm backend.
|
||||
#
|
||||
# Why this file exists, and why only the l4t13 BUILD_PROFILE consumes it:
|
||||
#
|
||||
# pypi.jetson-ai-lab.io hosts the L4T-specific torch / vllm / flash-attn
|
||||
# wheels we need on aarch64 + cuda13, but it ALSO transparently proxies the
|
||||
# rest of PyPI through `/+f/<sha>/<filename>` URLs that 503 frequently. With
|
||||
# `--extra-index-url` + `--index-strategy=unsafe-best-match` (the historical
|
||||
# fix in install.sh) uv would pick those proxy URLs for ordinary PyPI
|
||||
# packages — `anthropic`, `openai`, `propcache`, `annotated-types` — and
|
||||
# trip on the 503s. See e.g. CI run 25212201349 (anthropic-0.97.0).
|
||||
#
|
||||
# `explicit = true` on the index makes uv consult the L4T mirror ONLY for
|
||||
# packages mapped under [tool.uv.sources]. Everything else goes to PyPI.
|
||||
# This breaks the historical 503 path without losing access to the L4T
|
||||
# wheels we actually need from there.
|
||||
#
|
||||
# `uv pip install -r requirements.txt` does NOT honor [tool.uv.sources]
|
||||
# (sources are project-mode only, not pip-compat mode), so install.sh's
|
||||
# l4t13 branch invokes `uv pip install --requirement pyproject.toml`
|
||||
# directly. Other BUILD_PROFILEs continue to use the requirements-*.txt
|
||||
# pipeline through libbackend.sh's installRequirements and never read
|
||||
# this file.
|
||||
[project]
|
||||
name = "localai-vllm-l4t13"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
# Mirror of requirements.txt — kept in sync manually for now since the
|
||||
# l4t13 path bypasses installRequirements (see install.sh).
|
||||
"grpcio==1.80.0",
|
||||
"protobuf",
|
||||
"certifi",
|
||||
"setuptools",
|
||||
"pillow",
|
||||
"charset-normalizer>=3.4.0",
|
||||
"chardet",
|
||||
# L4T-specific accelerator stack (sourced from jetson-ai-lab below).
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"flash-attn",
|
||||
"vllm",
|
||||
# PyPI-resolvable packages that complete the runtime — accelerate,
|
||||
# transformers, bitsandbytes carry their own wheels for aarch64.
|
||||
"accelerate",
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "jetson-ai-lab"
|
||||
url = "https://pypi.jetson-ai-lab.io/sbsa/cu130"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "jetson-ai-lab" }
|
||||
torchvision = { index = "jetson-ai-lab" }
|
||||
torchaudio = { index = "jetson-ai-lab" }
|
||||
flash-attn = { index = "jetson-ai-lab" }
|
||||
vllm = { index = "jetson-ai-lab" }
|
||||
@@ -3,5 +3,5 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.20.0/cu130
|
||||
vllm==0.20.0
|
||||
--extra-index-url https://wheels.vllm.ai/0.20.1/cu130
|
||||
vllm==0.20.1
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
vllm
|
||||
# Intel XPU has no upstream-published vllm wheels — install.sh builds vllm
|
||||
# from source with VLLM_TARGET_DEVICE=xpu and hides this file during
|
||||
# installRequirements. Don't add a `vllm` line here.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
# vllm's own deps (torch==2.11.0+xpu, vllm_xpu_kernels, pydantic, …) are
|
||||
# installed from upstream's requirements/xpu.txt during the source build —
|
||||
# see install.sh. Only list what LocalAI's vllm backend.py needs directly.
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
optimum[openvino]
|
||||
bitsandbytes
|
||||
setuptools
|
||||
bitsandbytes
|
||||
@@ -1,2 +0,0 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
vllm
|
||||
@@ -1,8 +0,0 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/sbsa/cu130
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
bitsandbytes
|
||||
flash-attn
|
||||
@@ -1,4 +1,7 @@
|
||||
grpcio==1.80.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
setuptools
|
||||
pillow
|
||||
charset-normalizer>=3.4.0
|
||||
chardet
|
||||
@@ -1,11 +1,73 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
# FlashInfer / PyTorch JIT-compile CUDA kernels at first model load (e.g.
|
||||
# the NVFP4 GEMM kernel for Blackwell SM120). Each concurrent nvcc /
|
||||
# cudafe++ peaks at multiple GiB during compilation; ninja's default
|
||||
# (-j$(nproc)+2) OOM-kills on memory-tight hosts but underutilises
|
||||
# 100-core / 1 TB boxes. Default MAX_JOBS to the smaller of the CPU count
|
||||
# and an available-memory budget at ~4 GiB per job. User-set MAX_JOBS in
|
||||
# the environment wins.
|
||||
# https://github.com/vllm-project/vllm/issues/20079
|
||||
if [ -z "${MAX_JOBS:-}" ]; then
|
||||
_ncpus=$(nproc 2>/dev/null || echo 1)
|
||||
_mem_avail_kb=$(awk '/^MemAvailable:/ {print $2; exit}' /proc/meminfo 2>/dev/null || echo 0)
|
||||
_mem_avail_gb=$(( _mem_avail_kb / 1024 / 1024 ))
|
||||
# Reserve ~4 GiB for the rest of the system; budget ~4 GiB per job.
|
||||
if [ "${_mem_avail_gb}" -gt 8 ]; then
|
||||
_mem_jobs=$(( (_mem_avail_gb - 4) / 4 ))
|
||||
else
|
||||
_mem_jobs=1
|
||||
fi
|
||||
[ "${_mem_jobs}" -lt 1 ] && _mem_jobs=1
|
||||
[ "${_mem_jobs}" -gt "${_ncpus}" ] && _mem_jobs=${_ncpus}
|
||||
export MAX_JOBS="${_mem_jobs}"
|
||||
fi
|
||||
export NVCC_THREADS="${NVCC_THREADS:-2}"
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
# CPU profile: torch._inductor's ISA-probe (run at vllm engine
|
||||
# startup, even with enforce_eager=True) shells out to g++. The
|
||||
# LocalAI runtime image and the FROM-scratch backend image both
|
||||
# omit a compiler; package.sh bundles one into ${EDIR}/toolchain
|
||||
# along with wrapper scripts at toolchain/usr/bin that already pass
|
||||
# --sysroot and -B. So all run.sh has to do is put the wrapper on
|
||||
# PATH and expose the toolchain's shared libs (libisl, libmpc, libbfd,
|
||||
# ...) to ld.so. No-op for other profiles -- the dir doesn't exist.
|
||||
if [ -d "${EDIR}/toolchain/usr/bin" ]; then
|
||||
export PATH="${EDIR}/toolchain/usr/bin:${PATH}"
|
||||
_libpath="${EDIR}/toolchain/usr/lib/x86_64-linux-gnu"
|
||||
export LD_LIBRARY_PATH="${_libpath}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}"
|
||||
fi
|
||||
|
||||
# Multi-node DP follower mode: when the first arg is `serve`, exec into
|
||||
# vllm's own CLI instead of LocalAI's backend.py gRPC server. The
|
||||
# follower speaks ZMQ directly to the head node's vllm ranks — there
|
||||
# is no LocalAI gRPC on the follower side. Reaches this path via
|
||||
# `local-ai p2p-worker vllm`.
|
||||
if [ "${1:-}" = "serve" ]; then
|
||||
ensureVenv
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -x "$(_portable_python)" ]; then
|
||||
_makeVenvPortable --update-pyvenv-cfg
|
||||
fi
|
||||
if [ -d "${EDIR}/lib" ]; then
|
||||
export LD_LIBRARY_PATH="${EDIR}/lib:${LD_LIBRARY_PATH:-}"
|
||||
fi
|
||||
# Run the vllm console script through the venv python rather than
|
||||
# exec-ing it directly. uv bakes an absolute shebang at install time
|
||||
# (e.g. `#!/vllm/venv/bin/python3` from the build image) which
|
||||
# doesn't exist once the backend is relocated to BackendsPath, and
|
||||
# _makeVenvPortable's shebang rewriter only matches paths that
|
||||
# already point at ${EDIR}. Invoking python with the script as an
|
||||
# argument bypasses the shebang entirely.
|
||||
exec "${EDIR}/venv/bin/python" "${EDIR}/venv/bin/vllm" "$@"
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
|
||||
@@ -351,6 +351,30 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn diarize(
|
||||
&self,
|
||||
_: Request<backend::DiarizeRequest>,
|
||||
) -> Result<Response<backend::DiarizeResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn audio_transform(
|
||||
&self,
|
||||
_: Request<backend::AudioTransformRequest>,
|
||||
) -> Result<Response<backend::AudioTransformResult>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type AudioTransformStreamStream =
|
||||
ReceiverStream<Result<backend::AudioTransformFrameResponse, Status>>;
|
||||
|
||||
async fn audio_transform_stream(
|
||||
&self,
|
||||
_: Request<tonic::Streaming<backend::AudioTransformFrameRequest>>,
|
||||
) -> Result<Response<Self::AudioTransformStreamStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn sound_generation(
|
||||
&self,
|
||||
_: Request<backend::SoundGenerationRequest>,
|
||||
|
||||
@@ -71,7 +71,9 @@ func (ds *DistributedServices) Shutdown() {
|
||||
// initDistributed validates distributed mode prerequisites and initializes
|
||||
// NATS, object storage, node registry, and instance identity.
|
||||
// Returns nil if distributed mode is not enabled.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
|
||||
// configLoader is used by the SmartRouter to compute concurrency-group
|
||||
// anti-affinity at placement time (#9659); it may be nil in tests.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) {
|
||||
if !cfg.Distributed.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -234,12 +236,17 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
conflictResolver = configLoader
|
||||
}
|
||||
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
|
||||
@@ -139,7 +140,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
// Initialize distributed mode services (NATS, object storage, node registry)
|
||||
distSvc, err := initDistributed(options, application.authDB)
|
||||
distSvc, err := initDistributed(options, application.authDB, application.ModelConfigLoader())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
|
||||
}
|
||||
@@ -251,6 +252,10 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
go uc.Run(options.Context)
|
||||
}
|
||||
|
||||
// Wire gallery generation counter into VRAM caches so they invalidate
|
||||
// when gallery data refreshes instead of using a fixed TTL.
|
||||
vram.SetGalleryGenerationFunc(gallery.GalleryGeneration)
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
xlog.Error("error loading config file", "error", err)
|
||||
@@ -680,6 +685,12 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
|
||||
options.LRUEvictionRetryInterval,
|
||||
)
|
||||
|
||||
// Sync per-model state from configs to the watchdog. Without this,
|
||||
// `pinned: true` and `concurrency_groups:` are only honored after a
|
||||
// settings-driven RestartWatchdog and never at boot.
|
||||
application.SyncPinnedModelsToWatchdog()
|
||||
application.SyncModelGroupsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
// But memory reclaimer needs the Run() loop for periodic checking
|
||||
|
||||
@@ -199,13 +199,27 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-upgrade if enabled
|
||||
// Auto-upgrade if enabled. Route through the active BackendManager so
|
||||
// distributed-mode upgrades fan out to workers via NATS — calling
|
||||
// gallery.UpgradeBackend directly would look up the backend on the
|
||||
// frontend filesystem, which is empty in distributed mode and produces
|
||||
// "backend not found" while the cluster still reports an upgrade.
|
||||
if uc.appConfig.AutoUpgradeBackends {
|
||||
var bm galleryop.BackendManager
|
||||
if uc.backendManagerFn != nil {
|
||||
bm = uc.backendManagerFn()
|
||||
}
|
||||
for name, info := range upgrades {
|
||||
xlog.Info("Auto-upgrading backend", "backend", name,
|
||||
"from", info.InstalledVersion, "to", info.AvailableVersion)
|
||||
if err := gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil); err != nil {
|
||||
var err error
|
||||
if bm != nil {
|
||||
err = bm.UpgradeBackend(ctx, name, nil)
|
||||
} else {
|
||||
err = gallery.UpgradeBackend(ctx, uc.systemState, uc.modelLoader,
|
||||
uc.galleries, name, nil)
|
||||
}
|
||||
if err != nil {
|
||||
xlog.Error("Failed to auto-upgrade backend",
|
||||
"backend", name, "error", err)
|
||||
} else {
|
||||
@@ -213,8 +227,16 @@ func (uc *UpgradeChecker) runCheck(ctx context.Context) {
|
||||
"version", info.AvailableVersion)
|
||||
}
|
||||
}
|
||||
// Re-check to update cache after upgrades
|
||||
if freshUpgrades, err := gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState); err == nil {
|
||||
// Re-check to update cache after upgrades. Route through the same
|
||||
// BackendManager so distributed mode reflects the worker view.
|
||||
var freshUpgrades map[string]gallery.UpgradeInfo
|
||||
var freshErr error
|
||||
if bm != nil {
|
||||
freshUpgrades, freshErr = bm.CheckUpgrades(ctx)
|
||||
} else {
|
||||
freshUpgrades, freshErr = gallery.CheckBackendUpgrades(ctx, uc.galleries, uc.systemState)
|
||||
}
|
||||
if freshErr == nil {
|
||||
uc.mu.Lock()
|
||||
uc.lastUpgrades = freshUpgrades
|
||||
uc.mu.Unlock()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -26,6 +27,40 @@ func (a *Application) SyncPinnedModelsToWatchdog() {
|
||||
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
|
||||
}
|
||||
|
||||
// SyncModelGroupsToWatchdog reads concurrency_groups from all model configs and
|
||||
// updates the watchdog so EnforceGroupExclusivity has the current view.
|
||||
func (a *Application) SyncModelGroupsToWatchdog() {
|
||||
cl := a.ModelConfigLoader()
|
||||
if cl == nil {
|
||||
return
|
||||
}
|
||||
wd := a.modelLoader.GetWatchDog()
|
||||
if wd == nil {
|
||||
return
|
||||
}
|
||||
groups := extractModelGroupsFromConfigs(cl.GetAllModelsConfigs())
|
||||
wd.ReplaceModelGroups(groups)
|
||||
xlog.Debug("Synced concurrency groups to watchdog", "count", len(groups))
|
||||
}
|
||||
|
||||
// extractModelGroupsFromConfigs builds the model→groups map the watchdog
|
||||
// expects. Disabled models are skipped — their declared groups should not
|
||||
// block other models from loading.
|
||||
func extractModelGroupsFromConfigs(configs []config.ModelConfig) map[string][]string {
|
||||
out := make(map[string][]string)
|
||||
for _, cfg := range configs {
|
||||
if cfg.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
gs := cfg.GetConcurrencyGroups()
|
||||
if len(gs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[cfg.Name] = gs
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (a *Application) StopWatchdog() error {
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
@@ -65,8 +100,9 @@ func (a *Application) startWatchdog() error {
|
||||
// Set the watchdog on the model loader
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Sync pinned models from config to the watchdog
|
||||
// Sync pinned models and concurrency groups from config to the watchdog
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
a.SyncModelGroupsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
@@ -148,8 +184,9 @@ func (a *Application) RestartWatchdog() error {
|
||||
newWD.RestoreState(oldState)
|
||||
}
|
||||
|
||||
// Re-sync pinned models after restart
|
||||
// Re-sync pinned models and concurrency groups after restart
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
a.SyncModelGroupsToWatchdog()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
47
core/application/watchdog_test.go
Normal file
47
core/application/watchdog_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("extractModelGroupsFromConfigs", func() {
|
||||
It("returns an empty map when no config declares groups", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a"},
|
||||
{Name: "b"},
|
||||
})
|
||||
Expect(out).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns each model's normalized groups", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a", ConcurrencyGroups: []string{" heavy ", "vision", "heavy"}},
|
||||
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
|
||||
{Name: "c"}, // no groups → omitted
|
||||
})
|
||||
Expect(out).To(HaveLen(2))
|
||||
Expect(out["a"]).To(Equal([]string{"heavy", "vision"}))
|
||||
Expect(out["b"]).To(Equal([]string{"heavy"}))
|
||||
Expect(out).ToNot(HaveKey("c"))
|
||||
})
|
||||
|
||||
It("omits models whose groups normalize to empty", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "blanks", ConcurrencyGroups: []string{"", " "}},
|
||||
})
|
||||
Expect(out).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips disabled models so they cannot block loading after re-enable", func() {
|
||||
disabled := true
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled},
|
||||
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
|
||||
})
|
||||
Expect(out).To(HaveLen(1))
|
||||
Expect(out).To(HaveKey("b"))
|
||||
Expect(out).ToNot(HaveKey("a"))
|
||||
})
|
||||
})
|
||||
175
core/backend/audio_transform.go
Normal file
175
core/backend/audio_transform.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// AudioTransformOptions carries per-request tuning for the unary transform.
|
||||
type AudioTransformOptions struct {
|
||||
// Params is forwarded verbatim to the backend (e.g. LocalVQE reads
|
||||
// params["noise_gate"] / params["noise_gate_threshold_dbfs"]).
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
// AudioTransformOutputs are the on-disk paths of the persisted artifacts —
|
||||
// the user-visible Dst plus copies of the inputs the backend actually saw.
|
||||
// Inputs are persisted because the React UI history needs to display past
|
||||
// runs, and rejecting them once the temp dir is cleaned up would defeat
|
||||
// the point.
|
||||
type AudioTransformOutputs struct {
|
||||
Dst string
|
||||
AudioPath string
|
||||
ReferencePath string
|
||||
}
|
||||
|
||||
// ModelAudioTransform runs the unary AudioTransform RPC and returns the
|
||||
// generated output path plus the persisted input paths. `audioPath` is
|
||||
// required; `referencePath` is optional (empty => backend zero-fills the
|
||||
// reference channel).
|
||||
func ModelAudioTransform(
|
||||
audioPath, referencePath string,
|
||||
opts AudioTransformOptions,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (AudioTransformOutputs, *proto.AudioTransformResult, error) {
|
||||
mopts := ModelOptions(modelConfig, appConfig)
|
||||
transformModel, err := loader.Load(mopts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return AudioTransformOutputs{}, nil, err
|
||||
}
|
||||
if transformModel == nil {
|
||||
return AudioTransformOutputs{}, nil, fmt.Errorf("could not load audio-transform model %q", modelConfig.Model)
|
||||
}
|
||||
|
||||
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
|
||||
if err := os.MkdirAll(audioDir, 0750); err != nil {
|
||||
return AudioTransformOutputs{}, nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
dst := filepath.Join(audioDir, utils.GenerateUniqueFileName(audioDir, "transform", ".wav"))
|
||||
|
||||
persistedAudio, err := persistAudioInput(audioPath, audioDir, "transform-input", ".wav")
|
||||
if err != nil {
|
||||
return AudioTransformOutputs{}, nil, fmt.Errorf("persist input audio: %w", err)
|
||||
}
|
||||
persistedRef := ""
|
||||
if referencePath != "" {
|
||||
persistedRef, err = persistAudioInput(referencePath, audioDir, "transform-ref", ".wav")
|
||||
if err != nil {
|
||||
return AudioTransformOutputs{}, nil, fmt.Errorf("persist reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{
|
||||
AudioPath: audioPath,
|
||||
ReferencePath: referencePath,
|
||||
Dst: dst,
|
||||
Params: opts.Params,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
data := map[string]any{
|
||||
"audio_path": audioPath,
|
||||
"reference_path": referencePath,
|
||||
"dst": dst,
|
||||
"params": opts.Params,
|
||||
}
|
||||
if err == nil && res != nil {
|
||||
data["sample_rate"] = res.SampleRate
|
||||
data["samples"] = res.Samples
|
||||
data["reference_provided"] = res.ReferenceProvided
|
||||
if snippet := trace.AudioSnippet(dst); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceAudioTransform,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(filepath.Base(audioPath), 200),
|
||||
Error: errStr,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return AudioTransformOutputs{}, nil, err
|
||||
}
|
||||
return AudioTransformOutputs{
|
||||
Dst: dst,
|
||||
AudioPath: persistedAudio,
|
||||
ReferencePath: persistedRef,
|
||||
}, res, nil
|
||||
}
|
||||
|
||||
// ModelAudioTransformStream opens the bidirectional AudioTransformStream RPC
|
||||
// and returns the underlying stream client. The caller is responsible for
|
||||
// sending the initial Config message, subsequent Frame messages, and for
|
||||
// calling CloseSend when input is done. The returned stream's Recv reports
|
||||
// EOF when the backend has finished emitting frames.
|
||||
func ModelAudioTransformStream(
|
||||
ctx context.Context,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (grpc.AudioTransformStreamClient, error) {
|
||||
mopts := ModelOptions(modelConfig, appConfig)
|
||||
transformModel, err := loader.Load(mopts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if transformModel == nil {
|
||||
return nil, fmt.Errorf("could not load audio-transform model %q", modelConfig.Model)
|
||||
}
|
||||
return transformModel.AudioTransformStream(ctx)
|
||||
}
|
||||
|
||||
// persistAudioInput copies a transient input file (typically a multipart
|
||||
// upload that lives in an os.TempDir slated for cleanup) into the long-lived
|
||||
// GeneratedContentDir under a unique name, so the React UI can replay it
|
||||
// from history.
|
||||
func persistAudioInput(srcPath, dir, prefix, ext string) (string, error) {
|
||||
src, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = src.Close() }()
|
||||
dst := filepath.Join(dir, utils.GenerateUniqueFileName(dir, prefix, ext))
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
if _, err := io.Copy(out, src); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dst, nil
|
||||
}
|
||||
158
core/backend/diarization.go
Normal file
158
core/backend/diarization.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// DiarizationRequest carries the diarization-specific knobs the HTTP
|
||||
// layer collects. Speaker hints (NumSpeakers / MinSpeakers / MaxSpeakers)
|
||||
// and clustering knobs are optional — backends ignore the ones they
|
||||
// don't act on. IncludeText only matters for backends that emit
|
||||
// per-segment transcripts as a by-product (e.g. vibevoice.cpp).
|
||||
type DiarizationRequest struct {
|
||||
Audio string
|
||||
Language string
|
||||
NumSpeakers int32
|
||||
MinSpeakers int32
|
||||
MaxSpeakers int32
|
||||
ClusteringThreshold float32
|
||||
MinDurationOn float32
|
||||
MinDurationOff float32
|
||||
IncludeText bool
|
||||
}
|
||||
|
||||
func (r *DiarizationRequest) toProto(threads uint32) *proto.DiarizeRequest {
|
||||
return &proto.DiarizeRequest{
|
||||
Dst: r.Audio,
|
||||
Threads: threads,
|
||||
Language: r.Language,
|
||||
NumSpeakers: r.NumSpeakers,
|
||||
MinSpeakers: r.MinSpeakers,
|
||||
MaxSpeakers: r.MaxSpeakers,
|
||||
ClusteringThreshold: r.ClusteringThreshold,
|
||||
MinDurationOn: r.MinDurationOn,
|
||||
MinDurationOff: r.MinDurationOff,
|
||||
IncludeText: r.IncludeText,
|
||||
}
|
||||
}
|
||||
|
||||
func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
if modelConfig.Backend == "" {
|
||||
return nil, fmt.Errorf("diarization: model %q has no backend set; supported backends include vibevoice-cpp and sherpa-onnx", modelConfig.Name)
|
||||
}
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
m, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("could not load diarization model")
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ModelDiarization runs the Diarize RPC against the configured backend
|
||||
// and returns a normalized schema.DiarizationResult.
|
||||
func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
m, err := loadDiarizationModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
threads := uint32(0)
|
||||
if modelConfig.Threads != nil {
|
||||
threads = uint32(*modelConfig.Threads)
|
||||
}
|
||||
|
||||
r, err := m.Diarize(context.Background(), req.toProto(threads))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return diarizationResultFromProto(r), nil
|
||||
}
|
||||
|
||||
// diarizationResultFromProto normalizes backend speaker labels to
|
||||
// "SPEAKER_NN" — the convention pyannote/RTTM tooling expects — while
|
||||
// keeping the original label available via the Speaker field. Each
|
||||
// distinct backend label gets its own normalized id, in first-seen order.
|
||||
func diarizationResultFromProto(r *proto.DiarizeResponse) *schema.DiarizationResult {
|
||||
if r == nil {
|
||||
return &schema.DiarizationResult{Segments: []schema.DiarizationSegment{}}
|
||||
}
|
||||
|
||||
out := &schema.DiarizationResult{
|
||||
Task: "diarize",
|
||||
Duration: float64(r.Duration),
|
||||
Language: r.Language,
|
||||
Segments: make([]schema.DiarizationSegment, 0, len(r.Segments)),
|
||||
}
|
||||
|
||||
type speakerStats struct {
|
||||
idx int
|
||||
duration float64
|
||||
segments int
|
||||
}
|
||||
stats := map[string]*speakerStats{}
|
||||
order := []string{}
|
||||
|
||||
for i, s := range r.Segments {
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
raw := s.Speaker
|
||||
if raw == "" {
|
||||
raw = "0"
|
||||
}
|
||||
st, ok := stats[raw]
|
||||
if !ok {
|
||||
st = &speakerStats{idx: len(order)}
|
||||
stats[raw] = st
|
||||
order = append(order, raw)
|
||||
}
|
||||
dur := float64(s.End) - float64(s.Start)
|
||||
if dur > 0 {
|
||||
st.duration += dur
|
||||
}
|
||||
st.segments++
|
||||
|
||||
out.Segments = append(out.Segments, schema.DiarizationSegment{
|
||||
Id: i,
|
||||
Speaker: fmt.Sprintf("SPEAKER_%02d", st.idx),
|
||||
Label: raw,
|
||||
Start: float64(s.Start),
|
||||
End: float64(s.End),
|
||||
Text: s.Text,
|
||||
})
|
||||
}
|
||||
|
||||
out.NumSpeakers = len(order)
|
||||
if out.NumSpeakers == 0 && r.NumSpeakers > 0 {
|
||||
out.NumSpeakers = int(r.NumSpeakers)
|
||||
}
|
||||
|
||||
out.Speakers = make([]schema.DiarizationSpeaker, 0, len(order))
|
||||
for _, raw := range order {
|
||||
st := stats[raw]
|
||||
out.Speakers = append(out.Speakers, schema.DiarizationSpeaker{
|
||||
Id: fmt.Sprintf("SPEAKER_%02d", st.idx),
|
||||
Label: raw,
|
||||
TotalSpeechDuration: st.duration,
|
||||
SegmentCount: st.segments,
|
||||
})
|
||||
}
|
||||
sort.SliceStable(out.Speakers, func(i, j int) bool {
|
||||
return out.Speakers[i].Id < out.Speakers[j].Id
|
||||
})
|
||||
|
||||
return out
|
||||
}
|
||||
76
core/backend/diarization_test.go
Normal file
76
core/backend/diarization_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("diarizationResultFromProto", func() {
|
||||
It("normalises raw backend speaker labels to SPEAKER_NN in first-seen order", func() {
|
||||
in := &proto.DiarizeResponse{
|
||||
Duration: 10.5,
|
||||
Language: "en",
|
||||
Segments: []*proto.DiarizeSegment{
|
||||
{Start: 0.0, End: 1.0, Speaker: "5", Text: "hi"},
|
||||
{Start: 1.0, End: 2.0, Speaker: "2"},
|
||||
{Start: 2.0, End: 3.5, Speaker: "5"},
|
||||
{Start: 3.5, End: 4.0, Speaker: ""}, // empty → coerced to "0"
|
||||
},
|
||||
}
|
||||
|
||||
got := diarizationResultFromProto(in)
|
||||
|
||||
Expect(got.Task).To(Equal("diarize"))
|
||||
Expect(got.NumSpeakers).To(Equal(3), "expected 3 distinct speakers (5, 2, 0)")
|
||||
Expect(got.Duration).To(BeEquivalentTo(10.5))
|
||||
Expect(got.Language).To(Equal("en"))
|
||||
Expect(got.Segments).To(HaveLen(4))
|
||||
|
||||
// First-seen-order normalisation: "5"→SPEAKER_00, "2"→SPEAKER_01, ""→SPEAKER_02
|
||||
want := []struct {
|
||||
speaker string
|
||||
label string
|
||||
}{
|
||||
{"SPEAKER_00", "5"},
|
||||
{"SPEAKER_01", "2"},
|
||||
{"SPEAKER_00", "5"},
|
||||
{"SPEAKER_02", "0"},
|
||||
}
|
||||
for i, w := range want {
|
||||
Expect(got.Segments[i].Speaker).To(Equal(w.speaker), "seg[%d].speaker", i)
|
||||
Expect(got.Segments[i].Label).To(Equal(w.label), "seg[%d].label", i)
|
||||
}
|
||||
|
||||
// Per-speaker totals reflect cumulative speech duration and segment count.
|
||||
Expect(got.Speakers).To(HaveLen(3))
|
||||
byID := map[string]float64{}
|
||||
countByID := map[string]int{}
|
||||
for _, sp := range got.Speakers {
|
||||
byID[sp.Id] = sp.TotalSpeechDuration
|
||||
countByID[sp.Id] = sp.SegmentCount
|
||||
}
|
||||
Expect(byID["SPEAKER_00"]).To(BeNumerically("~", 2.5, 0.001), "1.0 + 1.5")
|
||||
Expect(byID["SPEAKER_01"]).To(BeNumerically("~", 1.0, 0.001))
|
||||
Expect(countByID["SPEAKER_00"]).To(Equal(2))
|
||||
Expect(countByID["SPEAKER_01"]).To(Equal(1))
|
||||
Expect(countByID["SPEAKER_02"]).To(Equal(1))
|
||||
})
|
||||
|
||||
It("returns a non-nil result with a non-nil segments slice for nil input", func() {
|
||||
got := diarizationResultFromProto(nil)
|
||||
Expect(got).ToNot(BeNil())
|
||||
Expect(got.Segments).ToNot(BeNil())
|
||||
Expect(got.Segments).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("keeps the backend speaker count when no segments are returned", func() {
|
||||
// Backend reports a non-zero NumSpeakers but no segments (early stop,
|
||||
// silence-only audio after VAD trim). Surface the backend's count.
|
||||
in := &proto.DiarizeResponse{NumSpeakers: 2, Duration: 5}
|
||||
got := diarizationResultFromProto(in)
|
||||
Expect(got.NumSpeakers).To(Equal(2))
|
||||
Expect(got.Segments).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -246,6 +246,14 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
opts.MMProj = filepath.Join(modelPath, c.MMProj)
|
||||
}
|
||||
|
||||
// Resolve draft_model against the models directory, mirroring the
|
||||
// handling of parameters.model and mmproj. Always joining (without an
|
||||
// IsAbs shortcut) prevents user-supplied configs from pointing the
|
||||
// backend at arbitrary host files via an absolute path.
|
||||
if c.DraftModel != "" {
|
||||
opts.DraftModel = filepath.Join(modelPath, c.DraftModel)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
|
||||
@@ -179,11 +179,22 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
Language: r.Language,
|
||||
Duration: float64(r.Duration),
|
||||
}
|
||||
|
||||
for _, s := range r.Segments {
|
||||
var tks []int
|
||||
for _, t := range s.Tokens {
|
||||
tks = append(tks, int(t))
|
||||
}
|
||||
var words []schema.TranscriptionWord
|
||||
for _, w := range s.Words {
|
||||
var word = schema.TranscriptionWord {
|
||||
Start: time.Duration(w.Start),
|
||||
End: time.Duration(w.End),
|
||||
Text: w.Text,
|
||||
}
|
||||
words = append(words, word)
|
||||
tr.Words = append(tr.Words, word)
|
||||
}
|
||||
tr.Segments = append(tr.Segments,
|
||||
schema.TranscriptionSegment{
|
||||
Text: s.Text,
|
||||
@@ -192,6 +203,7 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
End: time.Duration(s.End),
|
||||
Tokens: tks,
|
||||
Speaker: s.Speaker,
|
||||
Words: words,
|
||||
})
|
||||
}
|
||||
return tr
|
||||
|
||||
@@ -81,14 +81,48 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
fmt.Println(schema.TranscriptionResponse(tr, t.ResponseFormat))
|
||||
case schema.TranscriptionResponseFormatJson:
|
||||
tr.Segments = nil
|
||||
tr.Words = nil
|
||||
fallthrough
|
||||
case schema.TranscriptionResponseFormatJsonVerbose:
|
||||
trs := schema.TranscriptionResultSeconds{
|
||||
Text: tr.Text,
|
||||
Language: tr.Language,
|
||||
Duration: tr.Duration,
|
||||
Words: []schema.TranscriptionWordSeconds{},
|
||||
Segments: []schema.TranscriptionSegmentSeconds{},
|
||||
}
|
||||
for _, word := range(tr.Words) {
|
||||
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
Text: word.Text,
|
||||
})
|
||||
}
|
||||
for _, seg := range(tr.Segments) {
|
||||
segWords := []schema.TranscriptionWordSeconds{}
|
||||
for _, word := range(seg.Words) {
|
||||
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
Text: word.Text,
|
||||
})
|
||||
}
|
||||
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
|
||||
Id: seg.Id,
|
||||
Start: seg.Start.Seconds(),
|
||||
End: seg.End.Seconds(),
|
||||
Text: seg.Text,
|
||||
Tokens: seg.Tokens,
|
||||
Speaker: seg.Speaker,
|
||||
Words: segWords,
|
||||
})
|
||||
}
|
||||
var mtr []byte
|
||||
var err error
|
||||
if t.PrettyPrint {
|
||||
mtr, err = json.MarshalIndent(tr, "", " ")
|
||||
mtr, err = json.MarshalIndent(trs, "", " ")
|
||||
} else {
|
||||
mtr, err = json.Marshal(tr)
|
||||
mtr, err = json.Marshal(trs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -465,10 +465,20 @@ func (s *backendSupervisor) startBackend(backend, backendPath string) (string, e
|
||||
bp := s.processes[backend]
|
||||
s.mu.Unlock()
|
||||
|
||||
// Wait for the gRPC server to be ready
|
||||
// Wait for the gRPC server to be ready before reporting success.
|
||||
// Slow nodes (Jetson Orin doing first-boot CUDA init, large CGO libs)
|
||||
// can take 10-15s before the gRPC port accepts connections; the previous
|
||||
// 4s window made the worker reply Success on a not-yet-listening port,
|
||||
// which manifested upstream as "connect: connection refused" on the
|
||||
// frontend's first LoadModel dial.
|
||||
client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||
for range 20 {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
const (
|
||||
readinessPollInterval = 200 * time.Millisecond
|
||||
readinessTimeout = 30 * time.Second
|
||||
)
|
||||
deadline := time.Now().Add(readinessTimeout)
|
||||
for time.Now().Before(deadline) {
|
||||
time.Sleep(readinessPollInterval)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
if ok, _ := client.HealthCheck(ctx); ok {
|
||||
cancel()
|
||||
@@ -496,10 +506,23 @@ func (s *backendSupervisor) startBackend(backend, backendPath string) (string, e
|
||||
}
|
||||
}
|
||||
|
||||
// Log stderr to help diagnose why the backend isn't responding
|
||||
// Readiness deadline exceeded. Returning success here would leave the
|
||||
// frontend with an unbound address (it dials, gets ECONNREFUSED, and
|
||||
// the operator sees a misleading "connection refused" instead of the
|
||||
// real cause). Stop the half-started process, recycle the port, and
|
||||
// surface the failure to the caller with the backend's stderr tail.
|
||||
stderrTail := readLastLinesFromFile(proc.StderrPath(), 20)
|
||||
xlog.Warn("Backend gRPC server not ready after waiting, proceeding anyway", "backend", backend, "addr", clientAddr, "stderr", stderrTail)
|
||||
return clientAddr, nil
|
||||
xlog.Error("Backend gRPC server not ready before deadline; aborting install", "backend", backend, "addr", clientAddr, "timeout", readinessTimeout, "stderr", stderrTail)
|
||||
if killErr := proc.Stop(); killErr != nil {
|
||||
xlog.Warn("Failed to stop unready backend process", "backend", backend, "error", killErr)
|
||||
}
|
||||
s.mu.Lock()
|
||||
if cur, ok := s.processes[backend]; ok && cur == bp {
|
||||
delete(s.processes, backend)
|
||||
s.freePorts = append(s.freePorts, port)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("backend %s did not become ready within %s. Last stderr:\n%s", backend, readinessTimeout, stderrTail)
|
||||
}
|
||||
|
||||
// resolveProcessKeys turns a caller-supplied identifier into the set of
|
||||
|
||||
20
core/cli/worker/labels.go
Normal file
20
core/cli/worker/labels.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package worker
|
||||
|
||||
import "strings"
|
||||
|
||||
// ParseNodeLabels parses a comma-separated `k=v,k=v` string into a map.
|
||||
// Whitespace around keys, values, and pairs is trimmed; pairs without
|
||||
// `=` are skipped silently.
|
||||
func ParseNodeLabels(input string) map[string]string {
|
||||
labels := make(map[string]string)
|
||||
if input == "" {
|
||||
return labels
|
||||
}
|
||||
for _, pair := range strings.Split(input, ",") {
|
||||
pair = strings.TrimSpace(pair)
|
||||
if k, v, ok := strings.Cut(pair, "="); ok {
|
||||
labels[strings.TrimSpace(k)] = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return labels
|
||||
}
|
||||
@@ -8,8 +8,9 @@ type WorkerFlags struct {
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
|
||||
P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"`
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
|
||||
P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"`
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
VLLMDistributed VLLMDistributed `cmd:"" name:"vllm" help:"Starts a vLLM data-parallel follower process. Multi-node DP for a single model: head runs the existing vllm backend with engine_args.data_parallel_size>1, followers run this command."`
|
||||
}
|
||||
|
||||
58
core/cli/worker/worker_backend_common.go
Normal file
58
core/cli/worker/worker_backend_common.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// findBackendPath resolves the directory containing a backend's run.sh,
|
||||
// installing the backend from the gallery if it isn't present.
|
||||
// `name` is the gallery entry name (for vLLM the meta entry "vllm"
|
||||
// resolves to a platform-specific package via capability lookup).
|
||||
func findBackendPath(name, galleries string, systemState *system.SystemState) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if backend, ok := backends.Get(name); ok {
|
||||
return runFileDir(backend.RunFile)
|
||||
}
|
||||
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, name, nil, true); err != nil {
|
||||
xlog.Error("backend not found, failed to install it", "name", name, "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok := backends.Get(name)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%s backend not found after install", name)
|
||||
}
|
||||
return runFileDir(backend.RunFile)
|
||||
}
|
||||
|
||||
func runFileDir(runFile string) (string, error) {
|
||||
dir := filepath.Dir(runFile)
|
||||
if dir == "" {
|
||||
return "", errors.New("backend has no run.sh, install it first")
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
@@ -1,57 +1,16 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const mlxDistributedGalleryName = "mlx-distributed"
|
||||
|
||||
// findMLXDistributedBackendPath finds or installs the mlx-distributed backend
|
||||
// and returns the directory containing run.sh.
|
||||
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(mlxDistributedGalleryName)
|
||||
if !ok {
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil {
|
||||
xlog.Error("mlx-distributed backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
}
|
||||
// Re-fetch after install
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok = backends.Get(mlxDistributedGalleryName)
|
||||
if !ok {
|
||||
return "", errors.New("mlx-distributed backend not found after install")
|
||||
}
|
||||
}
|
||||
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
if backendPath == "" {
|
||||
return "", errors.New("mlx-distributed backend not found, install it first")
|
||||
}
|
||||
return backendPath, nil
|
||||
return findBackendPath(mlxDistributedGalleryName, galleries, systemState)
|
||||
}
|
||||
|
||||
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
|
||||
|
||||
13
core/cli/worker/worker_suite_test.go
Normal file
13
core/cli/worker/worker_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorker(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Worker Suite")
|
||||
}
|
||||
221
core/cli/worker/worker_vllm.go
Normal file
221
core/cli/worker/worker_vllm.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// vLLMFollowerRoleLabel marks a node as a vLLM data-parallel follower.
|
||||
// Operators scope regular models away from these nodes via inverse
|
||||
// selectors like {"!node.role":"vllm-follower"}.
|
||||
const vLLMFollowerRoleLabel = "vllm-follower"
|
||||
|
||||
// VLLMDistributed runs a vLLM follower process for multi-node
|
||||
// data-parallel inference. The head runs LocalAI's existing single-
|
||||
// node vLLM gRPC backend with engine_args.data_parallel_size > 1;
|
||||
// followers run vanilla `vllm serve --headless ...` and speak ZMQ
|
||||
// directly to the head.
|
||||
//
|
||||
// The follower is operator-launched (no NATS / SmartRouter placement
|
||||
// in this iteration). When --register-to is set, the worker self-
|
||||
// registers as an agent-type node so it shows up in the admin UI; a
|
||||
// `node.role=vllm-follower` label discourages model placement on it.
|
||||
type VLLMDistributed struct {
|
||||
WorkerFlags `embed:""`
|
||||
|
||||
// Registration (optional). Without these the worker just runs vLLM
|
||||
// and exits — no UI visibility. With them set, the follower
|
||||
// registers as an agent-type node, heartbeats while vLLM is
|
||||
// running, and deregisters on shutdown.
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" help:"Frontend URL for self-registration. Empty = no registration." group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to vllm-<hostname>)" group:"registration"`
|
||||
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (node.role=vllm-follower is always added)" group:"registration"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
|
||||
// vLLM data-parallel placement. The head must advertise the same
|
||||
// data_parallel_size / data_parallel_rpc_port via its engine_args;
|
||||
// followers use --master-addr / --master-port to find it.
|
||||
Model string `arg:"" help:"HuggingFace model ID or local path (must match the head)"`
|
||||
DataParallelSize int `name:"data-parallel-size" env:"VLLM_DATA_PARALLEL_SIZE" required:"" help:"Total DP ranks across all nodes"`
|
||||
DataParallelSizeLocal int `name:"data-parallel-size-local" env:"VLLM_DATA_PARALLEL_SIZE_LOCAL" required:"" help:"DP ranks on this node"`
|
||||
StartRank int `name:"start-rank" env:"VLLM_DATA_PARALLEL_START_RANK" required:"" help:"Starting DP rank for this node (>0 for followers)"`
|
||||
MasterAddr string `name:"master-addr" env:"VLLM_DP_MASTER_ADDR" required:"" help:"Head node IP/hostname for DP RPC handshake"`
|
||||
MasterPort int `name:"master-port" env:"VLLM_DP_MASTER_PORT" required:"" help:"Head node DP RPC port"`
|
||||
Headless bool `env:"VLLM_HEADLESS" default:"true" negatable:"" help:"Headless follower mode (no API server)"`
|
||||
ExtraArgs []string `name:"vllm-arg" env:"VLLM_EXTRA_ARGS" help:"Additional CLI args passed verbatim to vllm serve (e.g. --tensor-parallel-size 2). May be repeated."`
|
||||
}
|
||||
|
||||
func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
// Rank 0 is the head: it must serve the OpenAI API. --headless
|
||||
// disables that, so the combination is operator error and would
|
||||
// silently produce a cluster that can't accept requests.
|
||||
if r.Headless && r.StartRank == 0 {
|
||||
return fmt.Errorf("--start-rank 0 (head) cannot be --headless; the head serves the API")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting system state: %w", err)
|
||||
}
|
||||
|
||||
backendPath, err := findBackendPath("vllm", r.BackendGalleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot find vllm backend: %w", err)
|
||||
}
|
||||
|
||||
args := r.buildVLLMArgs()
|
||||
runSh := filepath.Join(backendPath, "run.sh")
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Self-register so the follower is visible in the admin UI. Done
|
||||
// before vLLM starts so an unreachable frontend fails fast rather
|
||||
// than after the GPU is already loaded.
|
||||
if r.RegisterTo != "" {
|
||||
regClient := &workerregistry.RegistrationClient{
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", r.RegisterTo, "role", "vllm-follower")
|
||||
|
||||
heartbeatInterval, _ := time.ParseDuration(r.HeartbeatInterval)
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, r.heartbeatBody)
|
||||
|
||||
defer regClient.GracefulDeregister(nodeID)
|
||||
}
|
||||
|
||||
xlog.Info("Starting vllm follower",
|
||||
"model", r.Model,
|
||||
"data-parallel-size", r.DataParallelSize,
|
||||
"data-parallel-size-local", r.DataParallelSizeLocal,
|
||||
"start-rank", r.StartRank,
|
||||
"master", fmt.Sprintf("%s:%d", r.MasterAddr, r.MasterPort),
|
||||
)
|
||||
|
||||
cmd := exec.CommandContext(shutdownCtx, runSh, args...)
|
||||
// VLLM_DP_* env vars are belt-and-braces alongside the explicit
|
||||
// CLI flags — vLLM honours both (vllm/envs.py:142-148).
|
||||
cmd.Env = append(os.Environ(),
|
||||
fmt.Sprintf("VLLM_DP_MASTER_IP=%s", r.MasterAddr),
|
||||
fmt.Sprintf("VLLM_DP_MASTER_PORT=%d", r.MasterPort),
|
||||
fmt.Sprintf("VLLM_DP_SIZE=%d", r.DataParallelSize),
|
||||
fmt.Sprintf("VLLM_DP_RANK=%d", r.StartRank),
|
||||
"VLLM_DP_RANK_LOCAL=0",
|
||||
)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Forward INT/TERM to vLLM so it gets a chance to clean up its ZMQ
|
||||
// sockets. exec.CommandContext kills with SIGKILL on cancellation,
|
||||
// which we want as a fallback only.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(sigCh)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting vllm: %w", err)
|
||||
}
|
||||
|
||||
waitErr := make(chan error, 1)
|
||||
go func() { waitErr <- cmd.Wait() }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case sig := <-sigCh:
|
||||
xlog.Info("Forwarding signal to vllm", "signal", sig)
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Signal(sig)
|
||||
}
|
||||
case err := <-waitErr:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildVLLMArgs assembles the vLLM CLI argv. Factored out for unit
|
||||
// testing — Run is hard to test without a real backend install.
|
||||
func (r *VLLMDistributed) buildVLLMArgs() []string {
|
||||
args := []string{"serve", r.Model}
|
||||
if r.Headless {
|
||||
args = append(args, "--headless")
|
||||
}
|
||||
args = append(args,
|
||||
"--data-parallel-size", strconv.Itoa(r.DataParallelSize),
|
||||
"--data-parallel-size-local", strconv.Itoa(r.DataParallelSizeLocal),
|
||||
"--data-parallel-start-rank", strconv.Itoa(r.StartRank),
|
||||
"--data-parallel-address", r.MasterAddr,
|
||||
"--data-parallel-rpc-port", strconv.Itoa(r.MasterPort),
|
||||
)
|
||||
args = append(args, r.ExtraArgs...)
|
||||
return args
|
||||
}
|
||||
|
||||
// registrationBody mirrors agent_worker.go's shape: agent-type nodes
|
||||
// don't need an address, which fits a follower that doesn't host any
|
||||
// LocalAI gRPC backends. The node.role label lets operators scope
|
||||
// regular model placement away from followers.
|
||||
func (r *VLLMDistributed) registrationBody() map[string]any {
|
||||
nodeName := r.NodeName
|
||||
if nodeName == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
nodeName = fmt.Sprintf("vllm-follower-%d", os.Getpid())
|
||||
} else {
|
||||
nodeName = "vllm-" + hostname
|
||||
}
|
||||
}
|
||||
|
||||
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
|
||||
gpuVendor, _ := xsysinfo.DetectGPUVendor()
|
||||
|
||||
body := map[string]any{
|
||||
"name": nodeName,
|
||||
"node_type": nodes.NodeTypeAgent,
|
||||
"total_vram": totalVRAM,
|
||||
"available_vram": totalVRAM,
|
||||
"gpu_vendor": gpuVendor,
|
||||
}
|
||||
if r.RegistrationToken != "" {
|
||||
body["token"] = r.RegistrationToken
|
||||
}
|
||||
|
||||
labels := ParseNodeLabels(r.NodeLabels)
|
||||
labels["node.role"] = vLLMFollowerRoleLabel
|
||||
body["labels"] = labels
|
||||
return body
|
||||
}
|
||||
|
||||
func (r *VLLMDistributed) heartbeatBody() map[string]any {
|
||||
body := map[string]any{}
|
||||
aggregate := xsysinfo.GetGPUAggregateInfo()
|
||||
if aggregate.TotalVRAM > 0 {
|
||||
body["available_vram"] = aggregate.FreeVRAM
|
||||
}
|
||||
return body
|
||||
}
|
||||
105
core/cli/worker/worker_vllm_test.go
Normal file
105
core/cli/worker/worker_vllm_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VLLMDistributed", func() {
|
||||
Describe("buildVLLMArgs", func() {
|
||||
DescribeTable("produces the expected vLLM CLI argv",
|
||||
func(cmd VLLMDistributed, want []string) {
|
||||
Expect(cmd.buildVLLMArgs()).To(Equal(want))
|
||||
},
|
||||
Entry("headless follower with explicit master",
|
||||
VLLMDistributed{
|
||||
Model: "Qwen/Qwen3.5-1.5B",
|
||||
DataParallelSize: 4,
|
||||
DataParallelSizeLocal: 2,
|
||||
StartRank: 2,
|
||||
MasterAddr: "10.0.0.1",
|
||||
MasterPort: 32100,
|
||||
Headless: true,
|
||||
},
|
||||
[]string{
|
||||
"serve", "Qwen/Qwen3.5-1.5B",
|
||||
"--headless",
|
||||
"--data-parallel-size", "4",
|
||||
"--data-parallel-size-local", "2",
|
||||
"--data-parallel-start-rank", "2",
|
||||
"--data-parallel-address", "10.0.0.1",
|
||||
"--data-parallel-rpc-port", "32100",
|
||||
},
|
||||
),
|
||||
Entry("head-style invocation: rank 0, not headless",
|
||||
VLLMDistributed{
|
||||
Model: "moonshotai/Kimi-K2.6-Instruct",
|
||||
DataParallelSize: 8,
|
||||
DataParallelSizeLocal: 4,
|
||||
StartRank: 0,
|
||||
MasterAddr: "127.0.0.1",
|
||||
MasterPort: 32100,
|
||||
Headless: false,
|
||||
},
|
||||
[]string{
|
||||
"serve", "moonshotai/Kimi-K2.6-Instruct",
|
||||
"--data-parallel-size", "8",
|
||||
"--data-parallel-size-local", "4",
|
||||
"--data-parallel-start-rank", "0",
|
||||
"--data-parallel-address", "127.0.0.1",
|
||||
"--data-parallel-rpc-port", "32100",
|
||||
},
|
||||
),
|
||||
Entry("extra args appended verbatim",
|
||||
VLLMDistributed{
|
||||
Model: "Qwen/Qwen3.5-1.5B",
|
||||
DataParallelSize: 2,
|
||||
DataParallelSizeLocal: 1,
|
||||
StartRank: 1,
|
||||
MasterAddr: "head.local",
|
||||
MasterPort: 32100,
|
||||
Headless: true,
|
||||
ExtraArgs: []string{"--tensor-parallel-size", "2", "--enable-expert-parallel"},
|
||||
},
|
||||
[]string{
|
||||
"serve", "Qwen/Qwen3.5-1.5B",
|
||||
"--headless",
|
||||
"--data-parallel-size", "2",
|
||||
"--data-parallel-size-local", "1",
|
||||
"--data-parallel-start-rank", "1",
|
||||
"--data-parallel-address", "head.local",
|
||||
"--data-parallel-rpc-port", "32100",
|
||||
"--tensor-parallel-size", "2",
|
||||
"--enable-expert-parallel",
|
||||
},
|
||||
),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("registrationBody", func() {
|
||||
// Followers don't host LocalAI gRPC, so node_type must be "agent"
|
||||
// to bypass the address requirement on /api/node/register, and the
|
||||
// node.role label is the contract operators rely on to scope normal
|
||||
// model placement away from these nodes.
|
||||
It("registers as agent-type with the vllm-follower role label", func() {
|
||||
cmd := VLLMDistributed{
|
||||
NodeName: "test-follower",
|
||||
DataParallelSize: 4,
|
||||
DataParallelSizeLocal: 2,
|
||||
StartRank: 2,
|
||||
MasterAddr: "10.0.0.1",
|
||||
NodeLabels: "tier=fast,gpu.vendor=nvidia",
|
||||
}
|
||||
body := cmd.registrationBody()
|
||||
|
||||
Expect(body).To(HaveKeyWithValue("node_type", "agent"))
|
||||
Expect(body).To(HaveKeyWithValue("name", "test-follower"))
|
||||
|
||||
labels, ok := body["labels"].(map[string]string)
|
||||
Expect(ok).To(BeTrue(), "labels must be map[string]string")
|
||||
Expect(labels).To(HaveKeyWithValue("node.role", "vllm-follower"))
|
||||
Expect(labels).To(HaveKeyWithValue("tier", "fast"))
|
||||
Expect(labels).To(HaveKeyWithValue("gpu.vendor", "nvidia"))
|
||||
})
|
||||
})
|
||||
})
|
||||
480
core/config/backend_capabilities.go
Normal file
480
core/config/backend_capabilities.go
Normal file
@@ -0,0 +1,480 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Usecase name constants — the canonical string values used in gallery entries,
|
||||
// model configs (known_usecases), and UsecaseInfoMap keys.
|
||||
const (
|
||||
UsecaseChat = "chat"
|
||||
UsecaseCompletion = "completion"
|
||||
UsecaseEdit = "edit"
|
||||
UsecaseVision = "vision"
|
||||
UsecaseEmbeddings = "embeddings"
|
||||
UsecaseTokenize = "tokenize"
|
||||
UsecaseImage = "image"
|
||||
UsecaseVideo = "video"
|
||||
UsecaseTranscript = "transcript"
|
||||
UsecaseTTS = "tts"
|
||||
UsecaseSoundGeneration = "sound_generation"
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
type GRPCMethod string
|
||||
|
||||
const (
|
||||
MethodPredict GRPCMethod = "Predict"
|
||||
MethodPredictStream GRPCMethod = "PredictStream"
|
||||
MethodEmbedding GRPCMethod = "Embedding"
|
||||
MethodGenerateImage GRPCMethod = "GenerateImage"
|
||||
MethodGenerateVideo GRPCMethod = "GenerateVideo"
|
||||
MethodAudioTranscription GRPCMethod = "AudioTranscription"
|
||||
MethodTTS GRPCMethod = "TTS"
|
||||
MethodTTSStream GRPCMethod = "TTSStream"
|
||||
MethodSoundGeneration GRPCMethod = "SoundGeneration"
|
||||
MethodTokenizeString GRPCMethod = "TokenizeString"
|
||||
MethodDetect GRPCMethod = "Detect"
|
||||
MethodRerank GRPCMethod = "Rerank"
|
||||
MethodVAD GRPCMethod = "VAD"
|
||||
MethodAudioTransform GRPCMethod = "AudioTransform"
|
||||
MethodDiarize GRPCMethod = "Diarize"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
// to the gRPC backend API.
|
||||
type UsecaseInfo struct {
|
||||
// Flag is the ModelConfigUsecase bitmask value.
|
||||
Flag ModelConfigUsecase
|
||||
// GRPCMethod is the primary Backend service RPC this usecase maps to.
|
||||
GRPCMethod GRPCMethod
|
||||
// IsModifier is true when this usecase doesn't map to its own gRPC RPC
|
||||
// but modifies how another RPC behaves (e.g., vision uses Predict with images).
|
||||
IsModifier bool
|
||||
// DependsOn names the usecase(s) this modifier requires (e.g., "chat").
|
||||
DependsOn string
|
||||
// Description is a human/LLM-readable explanation of what this usecase means.
|
||||
Description string
|
||||
}
|
||||
|
||||
// UsecaseInfoMap maps each known_usecase string to its gRPC and semantic info.
|
||||
var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
UsecaseChat: {
|
||||
Flag: FLAG_CHAT,
|
||||
GRPCMethod: MethodPredict,
|
||||
Description: "Conversational/instruction-following via the Predict RPC with chat templates.",
|
||||
},
|
||||
UsecaseCompletion: {
|
||||
Flag: FLAG_COMPLETION,
|
||||
GRPCMethod: MethodPredict,
|
||||
Description: "Text completion via the Predict RPC with a completion template.",
|
||||
},
|
||||
UsecaseEdit: {
|
||||
Flag: FLAG_EDIT,
|
||||
GRPCMethod: MethodPredict,
|
||||
Description: "Text editing via the Predict RPC with an edit template.",
|
||||
},
|
||||
UsecaseVision: {
|
||||
Flag: FLAG_VISION,
|
||||
GRPCMethod: MethodPredict,
|
||||
IsModifier: true,
|
||||
DependsOn: UsecaseChat,
|
||||
Description: "The model accepts images alongside text in the Predict RPC. For llama-cpp this requires an mmproj file.",
|
||||
},
|
||||
UsecaseEmbeddings: {
|
||||
Flag: FLAG_EMBEDDINGS,
|
||||
GRPCMethod: MethodEmbedding,
|
||||
Description: "Vector embedding generation via the Embedding RPC.",
|
||||
},
|
||||
UsecaseTokenize: {
|
||||
Flag: FLAG_TOKENIZE,
|
||||
GRPCMethod: MethodTokenizeString,
|
||||
Description: "Tokenization via the TokenizeString RPC without running inference.",
|
||||
},
|
||||
UsecaseImage: {
|
||||
Flag: FLAG_IMAGE,
|
||||
GRPCMethod: MethodGenerateImage,
|
||||
Description: "Image generation via the GenerateImage RPC (Stable Diffusion, Flux, etc.).",
|
||||
},
|
||||
UsecaseVideo: {
|
||||
Flag: FLAG_VIDEO,
|
||||
GRPCMethod: MethodGenerateVideo,
|
||||
Description: "Video generation via the GenerateVideo RPC.",
|
||||
},
|
||||
UsecaseTranscript: {
|
||||
Flag: FLAG_TRANSCRIPT,
|
||||
GRPCMethod: MethodAudioTranscription,
|
||||
Description: "Speech-to-text via the AudioTranscription RPC.",
|
||||
},
|
||||
UsecaseTTS: {
|
||||
Flag: FLAG_TTS,
|
||||
GRPCMethod: MethodTTS,
|
||||
Description: "Text-to-speech via the TTS RPC.",
|
||||
},
|
||||
UsecaseSoundGeneration: {
|
||||
Flag: FLAG_SOUND_GENERATION,
|
||||
GRPCMethod: MethodSoundGeneration,
|
||||
Description: "Music/sound generation via the SoundGeneration RPC (not speech).",
|
||||
},
|
||||
UsecaseRerank: {
|
||||
Flag: FLAG_RERANK,
|
||||
GRPCMethod: MethodRerank,
|
||||
Description: "Document reranking via the Rerank RPC.",
|
||||
},
|
||||
UsecaseDetection: {
|
||||
Flag: FLAG_DETECTION,
|
||||
GRPCMethod: MethodDetect,
|
||||
Description: "Object detection via the Detect RPC with bounding boxes.",
|
||||
},
|
||||
UsecaseVAD: {
|
||||
Flag: FLAG_VAD,
|
||||
GRPCMethod: MethodVAD,
|
||||
Description: "Voice activity detection via the VAD RPC.",
|
||||
},
|
||||
UsecaseAudioTransform: {
|
||||
Flag: FLAG_AUDIO_TRANSFORM,
|
||||
GRPCMethod: MethodAudioTransform,
|
||||
Description: "Audio-in / audio-out transformations (echo cancellation, noise suppression, dereverberation, voice conversion) via the AudioTransform RPC.",
|
||||
},
|
||||
UsecaseDiarization: {
|
||||
Flag: FLAG_DIARIZATION,
|
||||
GRPCMethod: MethodDiarize,
|
||||
Description: "Speaker diarization (who-spoke-when, per-speaker segments) via the Diarize RPC.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
// Derived from reviewing actual implementations in backend/go/ and backend/python/.
|
||||
type BackendCapability struct {
|
||||
// GRPCMethods lists the Backend service RPCs this backend implements.
|
||||
GRPCMethods []GRPCMethod
|
||||
// PossibleUsecases lists all usecase strings this backend can support.
|
||||
PossibleUsecases []string
|
||||
// DefaultUsecases lists the conservative safe defaults.
|
||||
DefaultUsecases []string
|
||||
// AcceptsImages indicates multimodal image input in Predict.
|
||||
AcceptsImages bool
|
||||
// AcceptsVideos indicates multimodal video input in Predict.
|
||||
AcceptsVideos bool
|
||||
// AcceptsAudios indicates multimodal audio input in Predict.
|
||||
AcceptsAudios bool
|
||||
// Description is a human-readable summary of the backend.
|
||||
Description string
|
||||
}
|
||||
|
||||
// BackendCapabilities maps each backend name (as used in model configs and gallery
|
||||
// entries) to its verified capabilities. This is the single source of truth for
|
||||
// what each backend supports.
|
||||
//
|
||||
// Backend names use hyphens (e.g., "llama-cpp") matching the gallery convention.
|
||||
// Use NormalizeBackendName() for names with dots (e.g., "llama.cpp").
|
||||
var BackendCapabilities = map[string]BackendCapability{
|
||||
// --- LLM / text generation backends ---
|
||||
"llama-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTokenizeString},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEdit, UsecaseEmbeddings, UsecaseTokenize, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true, // requires mmproj
|
||||
Description: "llama.cpp GGUF models — LLM inference with optional vision via mmproj",
|
||||
},
|
||||
"vllm": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true,
|
||||
AcceptsVideos: true,
|
||||
Description: "vLLM engine — high-throughput LLM serving with optional multimodal",
|
||||
},
|
||||
"vllm-omni": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
AcceptsImages: true,
|
||||
AcceptsVideos: true,
|
||||
AcceptsAudios: true,
|
||||
Description: "vLLM omni-modal — supports text, image, video generation and TTS",
|
||||
},
|
||||
"transformers": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTTS, MethodSoundGeneration},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseTTS, UsecaseSoundGeneration},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
Description: "HuggingFace transformers — general-purpose Python inference",
|
||||
},
|
||||
"mlx": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
Description: "Apple MLX framework — optimized for Apple Silicon",
|
||||
},
|
||||
"mlx-distributed": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
Description: "MLX distributed inference across multiple Apple Silicon devices",
|
||||
},
|
||||
"mlx-vlm": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision},
|
||||
DefaultUsecases: []string{UsecaseChat, UsecaseVision},
|
||||
AcceptsImages: true,
|
||||
AcceptsAudios: true,
|
||||
Description: "MLX vision-language models with multimodal input",
|
||||
},
|
||||
"mlx-audio": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseChat},
|
||||
Description: "MLX audio models — text generation and TTS",
|
||||
},
|
||||
|
||||
// --- Image/video generation backends ---
|
||||
"diffusers": {
|
||||
GRPCMethods: []GRPCMethod{MethodGenerateImage, MethodGenerateVideo},
|
||||
PossibleUsecases: []string{UsecaseImage, UsecaseVideo},
|
||||
DefaultUsecases: []string{UsecaseImage},
|
||||
Description: "HuggingFace diffusers — Stable Diffusion, Flux, video generation",
|
||||
},
|
||||
"stablediffusion": {
|
||||
GRPCMethods: []GRPCMethod{MethodGenerateImage},
|
||||
PossibleUsecases: []string{UsecaseImage},
|
||||
DefaultUsecases: []string{UsecaseImage},
|
||||
Description: "Stable Diffusion native backend",
|
||||
},
|
||||
"stablediffusion-ggml": {
|
||||
GRPCMethods: []GRPCMethod{MethodGenerateImage},
|
||||
PossibleUsecases: []string{UsecaseImage},
|
||||
DefaultUsecases: []string{UsecaseImage},
|
||||
Description: "Stable Diffusion via GGML quantized models",
|
||||
},
|
||||
|
||||
// --- Speech-to-text backends ---
|
||||
"whisper": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseVAD},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "OpenAI Whisper — speech recognition and voice activity detection",
|
||||
},
|
||||
"faster-whisper": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "CTranslate2-accelerated Whisper for faster transcription",
|
||||
},
|
||||
"whisperx": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "WhisperX — Whisper with word-level timestamps and speaker diarization",
|
||||
},
|
||||
"moonshine": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Moonshine speech recognition",
|
||||
},
|
||||
"nemo": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "NVIDIA NeMo speech recognition",
|
||||
},
|
||||
"qwen-asr": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Qwen automatic speech recognition",
|
||||
},
|
||||
"voxtral": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription},
|
||||
PossibleUsecases: []string{UsecaseTranscript},
|
||||
DefaultUsecases: []string{UsecaseTranscript},
|
||||
Description: "Voxtral speech recognition",
|
||||
},
|
||||
"vibevoice": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS},
|
||||
Description: "VibeVoice — bidirectional speech (transcription and synthesis)",
|
||||
},
|
||||
|
||||
// --- TTS backends ---
|
||||
"piper": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Piper — fast neural TTS optimized for Raspberry Pi",
|
||||
},
|
||||
"kokoro": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Kokoro TTS",
|
||||
},
|
||||
"coqui": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Coqui TTS — multi-speaker neural synthesis",
|
||||
},
|
||||
"kitten-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Kitten TTS",
|
||||
},
|
||||
"outetts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "OuteTTS",
|
||||
},
|
||||
"pocket-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Pocket TTS — lightweight text-to-speech",
|
||||
},
|
||||
"qwen-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Qwen TTS",
|
||||
},
|
||||
"faster-qwen3-tts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Faster Qwen3 TTS — accelerated Qwen TTS",
|
||||
},
|
||||
"fish-speech": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Fish Speech TTS",
|
||||
},
|
||||
"neutts": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "NeuTTS — neural text-to-speech",
|
||||
},
|
||||
"chatterbox": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "Chatterbox TTS",
|
||||
},
|
||||
"voxcpm": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS, MethodTTSStream},
|
||||
PossibleUsecases: []string{UsecaseTTS},
|
||||
DefaultUsecases: []string{UsecaseTTS},
|
||||
Description: "VoxCPM TTS with streaming support",
|
||||
},
|
||||
|
||||
// --- Sound generation backends ---
|
||||
"ace-step": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration},
|
||||
PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration},
|
||||
DefaultUsecases: []string{UsecaseSoundGeneration},
|
||||
Description: "ACE-Step — music and sound generation",
|
||||
},
|
||||
"acestep-cpp": {
|
||||
GRPCMethods: []GRPCMethod{MethodSoundGeneration},
|
||||
PossibleUsecases: []string{UsecaseSoundGeneration},
|
||||
DefaultUsecases: []string{UsecaseSoundGeneration},
|
||||
Description: "ACE-Step C++ — native sound generation",
|
||||
},
|
||||
"transformers-musicgen": {
|
||||
GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration},
|
||||
PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration},
|
||||
DefaultUsecases: []string{UsecaseSoundGeneration},
|
||||
Description: "Meta MusicGen via transformers — music generation from text",
|
||||
},
|
||||
|
||||
// --- Audio transform backends ---
|
||||
"localvqe": {
|
||||
GRPCMethods: []GRPCMethod{MethodAudioTransform},
|
||||
PossibleUsecases: []string{UsecaseAudioTransform},
|
||||
DefaultUsecases: []string{UsecaseAudioTransform},
|
||||
Description: "LocalVQE — joint AEC, noise suppression, and dereverberation for 16 kHz mono speech",
|
||||
},
|
||||
|
||||
// --- Utility backends ---
|
||||
"rerankers": {
|
||||
GRPCMethods: []GRPCMethod{MethodRerank},
|
||||
PossibleUsecases: []string{UsecaseRerank},
|
||||
DefaultUsecases: []string{UsecaseRerank},
|
||||
Description: "Cross-encoder reranking models",
|
||||
},
|
||||
"rfdetr": {
|
||||
GRPCMethods: []GRPCMethod{MethodDetect},
|
||||
PossibleUsecases: []string{UsecaseDetection},
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR object detection",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
DefaultUsecases: []string{UsecaseVAD},
|
||||
Description: "Silero VAD — voice activity detection",
|
||||
},
|
||||
}
|
||||
|
||||
// NormalizeBackendName converts backend names to the canonical hyphenated form
|
||||
// used in gallery entries (e.g., "llama.cpp" → "llama-cpp").
|
||||
func NormalizeBackendName(backend string) string {
|
||||
return strings.ReplaceAll(backend, ".", "-")
|
||||
}
|
||||
|
||||
// GetBackendCapability returns the capability info for a backend, or nil if unknown.
|
||||
// Handles backend name normalization.
|
||||
func GetBackendCapability(backend string) *BackendCapability {
|
||||
if cap, ok := BackendCapabilities[NormalizeBackendName(backend)]; ok {
|
||||
return &cap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PossibleUsecasesForBackend returns all usecases a backend can support.
|
||||
// Returns nil if the backend is unknown.
|
||||
func PossibleUsecasesForBackend(backend string) []string {
|
||||
if cap := GetBackendCapability(backend); cap != nil {
|
||||
return cap.PossibleUsecases
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultUsecasesForBackend returns the conservative default usecases.
|
||||
// Returns nil if the backend is unknown.
|
||||
func DefaultUsecasesForBackendCap(backend string) []string {
|
||||
if cap := GetBackendCapability(backend); cap != nil {
|
||||
return cap.DefaultUsecases
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidUsecaseForBackend checks whether a usecase is in a backend's possible set.
|
||||
// Returns true for unknown backends (permissive fallback).
|
||||
func IsValidUsecaseForBackend(backend, usecase string) bool {
|
||||
cap := GetBackendCapability(backend)
|
||||
if cap == nil {
|
||||
return true // unknown backend — don't restrict
|
||||
}
|
||||
return slices.Contains(cap.PossibleUsecases, usecase)
|
||||
}
|
||||
|
||||
// AllBackendNames returns a sorted list of all known backend names.
|
||||
func AllBackendNames() []string {
|
||||
names := make([]string, 0, len(BackendCapabilities))
|
||||
for name := range BackendCapabilities {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
return names
|
||||
}
|
||||
95
core/config/backend_capabilities_test.go
Normal file
95
core/config/backend_capabilities_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BackendCapabilities", func() {
|
||||
It("every backend declares possible/default usecases and gRPC methods", func() {
|
||||
for name, cap := range BackendCapabilities {
|
||||
Expect(cap.PossibleUsecases).NotTo(BeEmpty(), "backend %q has no possible usecases", name)
|
||||
Expect(cap.DefaultUsecases).NotTo(BeEmpty(), "backend %q has no default usecases", name)
|
||||
Expect(cap.GRPCMethods).NotTo(BeEmpty(), "backend %q has no gRPC methods", name)
|
||||
}
|
||||
})
|
||||
|
||||
It("default usecases are a subset of possible usecases", func() {
|
||||
for name, cap := range BackendCapabilities {
|
||||
for _, d := range cap.DefaultUsecases {
|
||||
Expect(cap.PossibleUsecases).To(ContainElement(d), "backend %q: default %q not in possible %v", name, d, cap.PossibleUsecases)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("every backend's possible usecases map to a known FLAG_*", func() {
|
||||
allFlags := GetAllModelConfigUsecases()
|
||||
for name, cap := range BackendCapabilities {
|
||||
for _, u := range cap.PossibleUsecases {
|
||||
info, ok := UsecaseInfoMap[u]
|
||||
Expect(ok).To(BeTrue(), "backend %q: usecase %q not in UsecaseInfoMap", name, u)
|
||||
flagName := "FLAG_" + strings.ToUpper(u)
|
||||
if _, ok := allFlags[flagName]; ok {
|
||||
continue
|
||||
}
|
||||
// Some usecase names don't transform exactly to FLAG_<UPPER>; fall back to flag value lookup.
|
||||
found := false
|
||||
for _, flag := range allFlags {
|
||||
if flag == info.Flag {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeTrue(), "backend %q: usecase %q flag %d not in GetAllModelConfigUsecases", name, u, info.Flag)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("every UsecaseInfoMap entry has a non-zero flag and a gRPC method", func() {
|
||||
for name, info := range UsecaseInfoMap {
|
||||
Expect(info.Flag).NotTo(Equal(FLAG_ANY), "usecase %q has FLAG_ANY (zero) — should have a real flag", name)
|
||||
Expect(info.GRPCMethod).NotTo(BeEmpty(), "usecase %q has no gRPC method", name)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("GetBackendCapability", func() {
|
||||
It("returns the capability for a known backend", func() {
|
||||
cap := GetBackendCapability("llama-cpp")
|
||||
Expect(cap).NotTo(BeNil())
|
||||
Expect(cap.PossibleUsecases).To(ContainElement("chat"))
|
||||
})
|
||||
|
||||
It("normalizes hyphenated names so llama.cpp resolves to llama-cpp", func() {
|
||||
Expect(GetBackendCapability("llama.cpp")).NotTo(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil for unknown backends", func() {
|
||||
Expect(GetBackendCapability("nonexistent")).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("IsValidUsecaseForBackend", func() {
|
||||
It("accepts a backend's declared usecases", func() {
|
||||
Expect(IsValidUsecaseForBackend("piper", "tts")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects usecases outside a backend's possible set", func() {
|
||||
Expect(IsValidUsecaseForBackend("piper", "chat")).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is permissive for unknown backends", func() {
|
||||
Expect(IsValidUsecaseForBackend("unknown", "anything")).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("AllBackendNames", func() {
|
||||
It("returns 30+ backends in sorted order", func() {
|
||||
names := AllBackendNames()
|
||||
Expect(len(names)).To(BeNumerically(">=", 30))
|
||||
Expect(slices.IsSorted(names)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -87,6 +87,11 @@ type ModelConfig struct {
|
||||
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
|
||||
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
|
||||
|
||||
// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
|
||||
// cannot be loaded alongside another model that shares any group name.
|
||||
// See docs/content/advanced/vram-management.md for usage.
|
||||
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`
|
||||
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
@@ -587,6 +592,28 @@ func (c *ModelConfig) IsPinned() bool {
|
||||
return c.Pinned != nil && *c.Pinned
|
||||
}
|
||||
|
||||
// GetConcurrencyGroups returns the model's concurrency groups, normalized:
|
||||
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
|
||||
// effective groups remain. The result is a fresh slice; the caller may
|
||||
// mutate it without affecting the config.
|
||||
func (c *ModelConfig) GetConcurrencyGroups() []string {
|
||||
if len(c.ConcurrencyGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(c.ConcurrencyGroups))
|
||||
for _, g := range c.ConcurrencyGroups {
|
||||
g = strings.TrimSpace(g)
|
||||
if g == "" || slices.Contains(out, g) {
|
||||
continue
|
||||
}
|
||||
out = append(out, g)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
@@ -603,14 +630,45 @@ const (
|
||||
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_VISION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b1000000000000000
|
||||
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b10000000000000000
|
||||
FLAG_DIARIZATION ModelConfigUsecase = 0b100000000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
)
|
||||
|
||||
// ModalityGroups defines groups of usecases that belong to the same modality.
|
||||
// Flags within the same group are NOT orthogonal (e.g., chat and completion are
|
||||
// both text/language). A model is multimodal when its usecases span 2+ groups.
|
||||
var ModalityGroups = []ModelConfigUsecase{
|
||||
FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language
|
||||
FLAG_VISION | FLAG_DETECTION, // visual understanding
|
||||
FLAG_TRANSCRIPT, // speech input
|
||||
FLAG_TTS | FLAG_SOUND_GENERATION, // audio output
|
||||
FLAG_AUDIO_TRANSFORM, // audio in/out transforms
|
||||
FLAG_IMAGE | FLAG_VIDEO, // visual generation
|
||||
}
|
||||
|
||||
// IsMultimodal returns true if the given usecases span two or more orthogonal
|
||||
// modality groups. For example chat+vision is multimodal, but chat+completion
|
||||
// is not (both belong to the text/language group).
|
||||
func IsMultimodal(usecases ModelConfigUsecase) bool {
|
||||
groupCount := 0
|
||||
for _, group := range ModalityGroups {
|
||||
if usecases&group != 0 {
|
||||
groupCount++
|
||||
if groupCount >= 2 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
return map[string]ModelConfigUsecase{
|
||||
// Note: FLAG_ANY is intentionally excluded from this map
|
||||
@@ -628,9 +686,12 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_VAD": FLAG_VAD,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_VISION": FLAG_VISION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
|
||||
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
|
||||
"FLAG_DIARIZATION": FLAG_DIARIZATION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -768,6 +829,13 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_AUDIO_TRANSFORM) == FLAG_AUDIO_TRANSFORM {
|
||||
audioTransformBackends := []string{"localvqe"}
|
||||
if !slices.Contains(audioTransformBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"}
|
||||
if !slices.Contains(soundGenBackends, c.Backend) {
|
||||
@@ -788,6 +856,16 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_DIARIZATION) == FLAG_DIARIZATION {
|
||||
// vibevoice-cpp emits speaker-labelled segments natively from its
|
||||
// ASR pass; sherpa-onnx pipes pyannote segmentation + speaker
|
||||
// embeddings + clustering. Both surface as a Diarize gRPC.
|
||||
diarizationBackends := []string{"vibevoice-cpp", "sherpa-onnx"}
|
||||
if !slices.Contains(diarizationBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -249,6 +249,40 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
|
||||
delete(bcl.configs, m)
|
||||
}
|
||||
|
||||
// GetModelsConflictingWith returns the names of every other configured (and
|
||||
// not-disabled) model that shares at least one concurrency group with the
|
||||
// named model. Returns nil if the named model has no groups, is unknown, or
|
||||
// has no peers in any of its groups. The result excludes the queried name.
|
||||
func (bcl *ModelConfigLoader) GetModelsConflictingWith(name string) []string {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
target, ok := bcl.configs[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
targetGroups := target.GetConcurrencyGroups()
|
||||
if len(targetGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for n, cfg := range bcl.configs {
|
||||
if n == name || cfg.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
other := cfg.GetConcurrencyGroups()
|
||||
if len(other) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, g := range targetGroups {
|
||||
if slices.Contains(other, g) {
|
||||
conflicts = append(conflicts, n)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// UpdateModelConfig updates an existing model config in the loader.
|
||||
// This is useful for updating runtime-detected properties like thinking support.
|
||||
func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) {
|
||||
|
||||
63
core/config/model_config_loader_test.go
Normal file
63
core/config/model_config_loader_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ModelConfigLoader.GetModelsConflictingWith", func() {
|
||||
var bcl *ModelConfigLoader
|
||||
|
||||
BeforeEach(func() {
|
||||
bcl = NewModelConfigLoader("/tmp/conflict-test-models")
|
||||
})
|
||||
|
||||
insert := func(cfg ModelConfig) {
|
||||
bcl.Lock()
|
||||
bcl.configs[cfg.Name] = cfg
|
||||
bcl.Unlock()
|
||||
}
|
||||
|
||||
It("returns nil when the named model has no groups", func() {
|
||||
insert(ModelConfig{Name: "loner"})
|
||||
Expect(bcl.GetModelsConflictingWith("loner")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when the named model is unknown", func() {
|
||||
Expect(bcl.GetModelsConflictingWith("ghost")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when no other model shares a group", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"vision"}})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns models that share at least one group", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "c", ConcurrencyGroups: []string{"vision"}})
|
||||
insert(ModelConfig{Name: "d", ConcurrencyGroups: []string{"heavy", "vision"}})
|
||||
|
||||
conflicts := bcl.GetModelsConflictingWith("a")
|
||||
Expect(conflicts).To(ConsistOf("b", "d"))
|
||||
})
|
||||
|
||||
It("never lists the queried model itself", func() {
|
||||
insert(ModelConfig{Name: "self", ConcurrencyGroups: []string{"heavy"}})
|
||||
Expect(bcl.GetModelsConflictingWith("self")).To(BeNil())
|
||||
})
|
||||
|
||||
It("ignores disabled conflicting models", func() {
|
||||
disabled := true
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
|
||||
})
|
||||
|
||||
It("normalizes groups so whitespace and duplicates do not break overlap", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{" heavy "}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy", "heavy"}})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(ConsistOf("b"))
|
||||
})
|
||||
})
|
||||
@@ -264,4 +264,53 @@ mcp:
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
})
|
||||
Context("ConcurrencyGroups", func() {
|
||||
It("returns nil when no groups are configured", func() {
|
||||
cfg := &ModelConfig{Name: "no-groups"}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
||||
})
|
||||
It("returns nil when all entries are blank", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "blanks",
|
||||
ConcurrencyGroups: []string{"", " ", "\t"},
|
||||
}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
||||
})
|
||||
It("trims whitespace, drops empty entries, and dedupes", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "messy",
|
||||
ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "},
|
||||
}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"}))
|
||||
})
|
||||
It("returns a defensive copy", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "copy",
|
||||
ConcurrencyGroups: []string{"heavy"},
|
||||
}
|
||||
got := cfg.GetConcurrencyGroups()
|
||||
got[0] = "tampered"
|
||||
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"}))
|
||||
})
|
||||
It("parses concurrency_groups from YAML", func() {
|
||||
tmp, err := os.CreateTemp("", "concgroups.yaml")
|
||||
Expect(err).To(BeNil())
|
||||
defer func() { _ = os.Remove(tmp.Name()) }()
|
||||
_, err = tmp.WriteString(
|
||||
`name: heavy-a
|
||||
backend: llama-cpp
|
||||
parameters:
|
||||
model: heavy-a.gguf
|
||||
concurrency_groups:
|
||||
- vram-heavy
|
||||
- "120b"
|
||||
`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(configs).To(HaveLen(1))
|
||||
Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"}))
|
||||
Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||
@@ -92,6 +94,34 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
|
||||
return filteredModels
|
||||
}
|
||||
|
||||
// FilterGalleryModelsByUsecase returns models whose known_usecases include all
|
||||
// the bits set in usecase. For example, passing FLAG_CHAT matches any model
|
||||
// with the chat usecase; passing FLAG_CHAT|FLAG_VISION matches only models
|
||||
// that have both.
|
||||
func FilterGalleryModelsByUsecase(models GalleryElements[*GalleryModel], usecase config.ModelConfigUsecase) GalleryElements[*GalleryModel] {
|
||||
var filtered GalleryElements[*GalleryModel]
|
||||
for _, m := range models {
|
||||
u := m.GetKnownUsecases()
|
||||
if u != nil && (*u&usecase) == usecase {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// FilterGalleryModelsByMultimodal returns models whose known_usecases span two
|
||||
// or more orthogonal modality groups (e.g. chat+vision, tts+transcript).
|
||||
func FilterGalleryModelsByMultimodal(models GalleryElements[*GalleryModel]) GalleryElements[*GalleryModel] {
|
||||
var filtered GalleryElements[*GalleryModel]
|
||||
for _, m := range models {
|
||||
u := m.GetKnownUsecases()
|
||||
if u != nil && config.IsMultimodal(*u) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] {
|
||||
var filtered GalleryElements[T]
|
||||
for _, m := range gm {
|
||||
@@ -267,6 +297,77 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst
|
||||
return models, nil
|
||||
}
|
||||
|
||||
var (
|
||||
availableModelsMu sync.RWMutex
|
||||
availableModelsCache GalleryElements[*GalleryModel]
|
||||
refreshing atomic.Bool
|
||||
galleryGeneration atomic.Uint64
|
||||
)
|
||||
|
||||
// GalleryGeneration returns a counter that increments each time the gallery
|
||||
// model list is refreshed from upstream. VRAM estimation caches use this to
|
||||
// invalidate entries when the gallery data changes.
|
||||
func GalleryGeneration() uint64 { return galleryGeneration.Load() }
|
||||
|
||||
// AvailableGalleryModelsCached returns gallery models from an in-memory cache.
|
||||
// Local-only fields (installed status) are refreshed on every call. A background
|
||||
// goroutine is triggered to re-fetch the full model list (including network
|
||||
// calls) so subsequent requests pick up changes without blocking the caller.
|
||||
// The first call with an empty cache blocks until the initial load completes.
|
||||
func AvailableGalleryModelsCached(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) {
|
||||
availableModelsMu.RLock()
|
||||
cached := availableModelsCache
|
||||
availableModelsMu.RUnlock()
|
||||
|
||||
if cached != nil {
|
||||
// Refresh installed status under write lock to avoid races with
|
||||
// concurrent readers and the background refresh goroutine.
|
||||
availableModelsMu.Lock()
|
||||
for _, m := range cached {
|
||||
_, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", m.GetName())))
|
||||
m.SetInstalled(err == nil)
|
||||
}
|
||||
availableModelsMu.Unlock()
|
||||
// Trigger a background refresh if one is not already running.
|
||||
triggerGalleryRefresh(galleries, systemState)
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// No cache yet — must do a blocking load.
|
||||
models, err := AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
availableModelsMu.Lock()
|
||||
availableModelsCache = models
|
||||
galleryGeneration.Add(1)
|
||||
availableModelsMu.Unlock()
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// triggerGalleryRefresh starts a background goroutine that refreshes the
|
||||
// gallery model cache. Only one refresh runs at a time; concurrent calls
|
||||
// are no-ops.
|
||||
func triggerGalleryRefresh(galleries []config.Gallery, systemState *system.SystemState) {
|
||||
if !refreshing.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer refreshing.Store(false)
|
||||
models, err := AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
xlog.Error("background gallery refresh failed", "error", err)
|
||||
return
|
||||
}
|
||||
availableModelsMu.Lock()
|
||||
availableModelsCache = models
|
||||
galleryGeneration.Add(1)
|
||||
availableModelsMu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// List available backends
|
||||
func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
|
||||
return availableBackendsWithFilter(galleries, systemState, true)
|
||||
|
||||
@@ -581,4 +581,42 @@ var _ = Describe("Gallery", func() {
|
||||
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetKnownUsecases", func() {
|
||||
It("uses explicit known_usecases from overrides when present", func() {
|
||||
m := &GalleryModel{
|
||||
Metadata: Metadata{Backend: "stablediffusion-ggml"},
|
||||
Overrides: map[string]any{
|
||||
"known_usecases": []any{"chat"},
|
||||
},
|
||||
}
|
||||
u := m.GetKnownUsecases()
|
||||
Expect(u).NotTo(BeNil())
|
||||
// Override wins over the backend's image default.
|
||||
Expect(*u & config.FLAG_CHAT).To(Equal(config.FLAG_CHAT))
|
||||
Expect(*u & config.FLAG_IMAGE).To(Equal(config.ModelConfigUsecase(0)))
|
||||
})
|
||||
|
||||
It("falls back to backend defaults when no override is set", func() {
|
||||
m := &GalleryModel{Metadata: Metadata{Backend: "stablediffusion-ggml"}}
|
||||
u := m.GetKnownUsecases()
|
||||
Expect(u).NotTo(BeNil())
|
||||
Expect(*u & config.FLAG_IMAGE).To(Equal(config.FLAG_IMAGE))
|
||||
})
|
||||
|
||||
It("returns nil when neither overrides nor a known backend provide usecases", func() {
|
||||
m := &GalleryModel{}
|
||||
Expect(m.GetKnownUsecases()).To(BeNil())
|
||||
})
|
||||
|
||||
It("filters models without explicit known_usecases via backend defaults", func() {
|
||||
models := GalleryElements[*GalleryModel]{
|
||||
&GalleryModel{Metadata: Metadata{Name: "sd-model", Backend: "stablediffusion-ggml"}},
|
||||
&GalleryModel{Metadata: Metadata{Name: "whisper-model", Backend: "whisper"}},
|
||||
}
|
||||
filtered := FilterGalleryModelsByUsecase(models, config.FLAG_IMAGE)
|
||||
Expect(filtered).To(HaveLen(1))
|
||||
Expect(filtered[0].Name).To(Equal("sd-model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -97,7 +97,7 @@ func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"image"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseImage},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
|
||||
@@ -125,6 +125,10 @@ var defaultImporters = []Importer{
|
||||
&KittenTTSImporter{},
|
||||
&NeuTTSImporter{},
|
||||
&ChatterboxImporter{},
|
||||
// VibeVoiceCppImporter must precede VibeVoiceImporter — the older
|
||||
// Python-backend importer matches any repo name containing "vibevoice"
|
||||
// and would otherwise swallow the C++ port's GGUF bundles.
|
||||
&VibeVoiceCppImporter{},
|
||||
&VibeVoiceImporter{},
|
||||
&CoquiImporter{},
|
||||
// Image/Video (Batch 3)
|
||||
|
||||
@@ -135,7 +135,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: backend,
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
|
||||
@@ -45,7 +45,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
Options: []string{"use_jinja:true"},
|
||||
}
|
||||
cfg.Model = relPath(ggufFile)
|
||||
@@ -104,7 +104,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
@@ -120,7 +120,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
@@ -135,7 +135,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
}
|
||||
cfg.Model = relPath(dirPath)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
|
||||
@@ -73,7 +73,7 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
|
||||
@@ -87,7 +87,7 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
|
||||
355
core/gallery/importers/vibevoice-cpp.go
Normal file
355
core/gallery/importers/vibevoice-cpp.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &VibeVoiceCppImporter{}
|
||||
|
||||
// VibeVoiceCppImporter recognises the GGUF bundle that the vibevoice.cpp
|
||||
// backend consumes — primary model file (vibevoice-realtime-*.gguf for TTS or
|
||||
// vibevoice-asr-*.gguf for ASR), a sibling tokenizer.gguf (always required),
|
||||
// and optional voice-*.gguf prompts for TTS voice cloning. Detection fires on
|
||||
// the HF repo name containing "vibevoice.cpp"/"vibevoice-cpp", or on the
|
||||
// presence of a vibevoice-*.gguf + tokenizer.gguf pair. preferences.backend
|
||||
// ="vibevoice-cpp" forces the importer regardless of artefacts.
|
||||
//
|
||||
// Role pick: defaults to TTS (the realtime model is small and the common
|
||||
// case). preferences.usecase="asr" routes to the ASR/diarization model. If a
|
||||
// repo only ships one of the two roles, that role wins automatically.
|
||||
//
|
||||
// MUST be registered ahead of VibeVoiceImporter — the older Python-backed
|
||||
// importer matches any repo with "vibevoice" in the name, which would
|
||||
// otherwise swallow the C++ bundle.
|
||||
type VibeVoiceCppImporter struct{}
|
||||
|
||||
func (i *VibeVoiceCppImporter) Name() string { return "vibevoice-cpp" }
|
||||
func (i *VibeVoiceCppImporter) Modality() string { return "tts" }
|
||||
func (i *VibeVoiceCppImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *VibeVoiceCppImporter) Match(details Details) bool {
|
||||
preferencesMap := unmarshalPreferences(details.Preferences)
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "vibevoice-cpp" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Repo-name signal: anything carrying "vibevoice.cpp" or "vibevoice-cpp"
|
||||
// — the canonical naming for the C++ port bundles.
|
||||
repoSignals := []string{strings.ToLower(repoNameOnly(details))}
|
||||
if _, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
repoSignals = append(repoSignals, strings.ToLower(repo))
|
||||
}
|
||||
for _, s := range repoSignals {
|
||||
if strings.Contains(s, "vibevoice.cpp") || strings.Contains(s, "vibevoice-cpp") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// File-listing signal: a vibevoice-*.gguf primary + tokenizer.gguf is
|
||||
// only what the C++ backend ships — the Python VibeVoice fork distributes
|
||||
// safetensors, never GGUF.
|
||||
if details.HuggingFace != nil &&
|
||||
HasFile(details.HuggingFace.Files, "tokenizer.gguf") &&
|
||||
hasVibeVoiceGGUF(details.HuggingFace.Files) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *VibeVoiceCppImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferencesMap := unmarshalPreferences(details.Preferences)
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Quant preference — default order matches what mudler/vibevoice.cpp-models
|
||||
// ships today. Same comma-separated convention as whisper / llama-cpp.
|
||||
quants := []string{"q8_0", "q4_k", "q5_k", "q4_0"}
|
||||
if preferred, ok := preferencesMap["quantizations"].(string); ok && preferred != "" {
|
||||
quants = strings.Split(preferred, ",")
|
||||
}
|
||||
|
||||
usecase := strings.ToLower(stringPref(preferencesMap, "usecase"))
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "vibevoice-cpp",
|
||||
}
|
||||
|
||||
// Without HF metadata we can only emit a skeleton config — the user must
|
||||
// edit it post-import to point at real files. Mirrors whisper's bare-URI
|
||||
// fallback so preference-only invocations still produce something usable.
|
||||
if details.HuggingFace == nil {
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: filepath.Base(details.URI)},
|
||||
}
|
||||
if usecase == "asr" {
|
||||
modelConfig.KnownUsecaseStrings = []string{"transcript"}
|
||||
modelConfig.Options = []string{"type=asr", "tokenizer=tokenizer.gguf"}
|
||||
} else {
|
||||
modelConfig.KnownUsecaseStrings = []string{"tts"}
|
||||
modelConfig.Options = []string{"tokenizer=tokenizer.gguf"}
|
||||
}
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
files := details.HuggingFace.Files
|
||||
ttsFiles := filterByPrefix(files, "vibevoice-realtime-")
|
||||
asrFiles := filterByPrefix(files, "vibevoice-asr-")
|
||||
|
||||
// Auto-pick role when the repo only ships one. Explicit usecase wins.
|
||||
role := usecase
|
||||
if role == "" {
|
||||
switch {
|
||||
case len(ttsFiles) > 0 && len(asrFiles) == 0:
|
||||
role = "tts"
|
||||
case len(asrFiles) > 0 && len(ttsFiles) == 0:
|
||||
role = "asr"
|
||||
default:
|
||||
role = "tts" // default: realtime TTS is the smaller, more common case
|
||||
}
|
||||
}
|
||||
|
||||
// Layout under <models>/vibevoice-cpp/<name>/ — same pattern as whisper's
|
||||
// nesting so multiple imports of the same upstream repo (with different
|
||||
// quants) don't collide on disk. Options[] paths are emitted relative to
|
||||
// opts.ModelPath, which the backend resolves against the LocalAI models
|
||||
// root in govibevoicecpp.go:resolvePath.
|
||||
relDir := filepath.Join("vibevoice-cpp", name)
|
||||
|
||||
var primary []hfapi.ModelFile
|
||||
switch role {
|
||||
case "asr", "transcript", "stt", "speech-to-text":
|
||||
primary = asrFiles
|
||||
modelConfig.KnownUsecaseStrings = []string{"transcript"}
|
||||
default:
|
||||
primary = ttsFiles
|
||||
modelConfig.KnownUsecaseStrings = []string{"tts"}
|
||||
}
|
||||
// If the requested role has no matching files, fall back to any
|
||||
// vibevoice-*.gguf so the import still produces something runnable.
|
||||
if len(primary) == 0 {
|
||||
primary = filterByPrefix(files, "vibevoice-")
|
||||
}
|
||||
|
||||
chosen, ok := pickPreferredGGUFFile(primary, quants)
|
||||
if !ok {
|
||||
// Nothing to download. Emit the skeleton — same shape as the
|
||||
// no-HF-metadata branch above, just with a sensible default name.
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: name + ".gguf"},
|
||||
}
|
||||
if role == "asr" {
|
||||
modelConfig.Options = []string{"type=asr", "tokenizer=" + filepath.Join(relDir, "tokenizer.gguf")}
|
||||
} else {
|
||||
modelConfig.Options = []string{"tokenizer=" + filepath.Join(relDir, "tokenizer.gguf")}
|
||||
}
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
modelTarget := filepath.Join(relDir, filepath.Base(chosen.Path))
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: chosen.URL,
|
||||
Filename: modelTarget,
|
||||
SHA256: chosen.SHA256,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: modelTarget},
|
||||
}
|
||||
|
||||
// tokenizer.gguf is mandatory — Load() rejects without it. Always pull
|
||||
// it when the repo provides one (every official vibevoice.cpp bundle does).
|
||||
options := []string{}
|
||||
if role == "asr" {
|
||||
options = append(options, "type=asr")
|
||||
}
|
||||
if tok, ok := findFile(files, "tokenizer.gguf"); ok {
|
||||
tokTarget := filepath.Join(relDir, "tokenizer.gguf")
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: tok.URL,
|
||||
Filename: tokTarget,
|
||||
SHA256: tok.SHA256,
|
||||
})
|
||||
options = append(options, "tokenizer="+tokTarget)
|
||||
}
|
||||
|
||||
// For TTS, ship the first voice-*.gguf as a default — the backend needs
|
||||
// a reference voice to clone from. ASR doesn't use voice prompts.
|
||||
if role != "asr" {
|
||||
if voice, ok := pickVoicePrompt(files, stringPref(preferencesMap, "voice")); ok {
|
||||
voiceTarget := filepath.Join(relDir, filepath.Base(voice.Path))
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: voice.URL,
|
||||
Filename: voiceTarget,
|
||||
SHA256: voice.SHA256,
|
||||
})
|
||||
options = append(options, "voice="+voiceTarget)
|
||||
}
|
||||
}
|
||||
modelConfig.Options = options
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// hasVibeVoiceGGUF returns true when any file matches "vibevoice-*.gguf"
|
||||
// (case-insensitive). Narrow on purpose — third-party GGUF mirrors that
|
||||
// re-pack the model under different filenames will be missed, but those
|
||||
// users can pass preferences.backend="vibevoice-cpp" to force the importer.
|
||||
func hasVibeVoiceGGUF(files []hfapi.ModelFile) bool {
|
||||
for _, f := range files {
|
||||
name := strings.ToLower(filepath.Base(f.Path))
|
||||
if strings.HasPrefix(name, "vibevoice-") && strings.HasSuffix(name, ".gguf") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterByPrefix returns every file whose basename starts with prefix and
|
||||
// ends in .gguf (case-insensitive on the suffix, exact on the prefix).
|
||||
func filterByPrefix(files []hfapi.ModelFile, prefix string) []hfapi.ModelFile {
|
||||
var out []hfapi.ModelFile
|
||||
for _, f := range files {
|
||||
base := filepath.Base(f.Path)
|
||||
if !strings.HasPrefix(base, prefix) {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(base), ".gguf") {
|
||||
continue
|
||||
}
|
||||
out = append(out, f)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// findFile is HasFile's lookup-returning sibling. Returns the first file
|
||||
// whose basename equals name (exact match), or false when none exists.
|
||||
func findFile(files []hfapi.ModelFile, name string) (hfapi.ModelFile, bool) {
|
||||
for _, f := range files {
|
||||
if filepath.Base(f.Path) == name {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
return hfapi.ModelFile{}, false
|
||||
}
|
||||
|
||||
// pickPreferredGGUFFile mirrors pickPreferredGGMLFile but operates on .gguf
|
||||
// files: walks prefs in order, returns the first file whose basename contains
|
||||
// any preference token (case-insensitive). On no match, falls back to the
|
||||
// last file so a missing quant still yields a runnable import.
|
||||
func pickPreferredGGUFFile(files []hfapi.ModelFile, prefs []string) (hfapi.ModelFile, bool) {
|
||||
if len(files) == 0 {
|
||||
return hfapi.ModelFile{}, false
|
||||
}
|
||||
for _, pref := range prefs {
|
||||
lower := strings.ToLower(strings.TrimSpace(pref))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
for _, f := range files {
|
||||
if strings.Contains(strings.ToLower(filepath.Base(f.Path)), lower) {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return files[len(files)-1], true
|
||||
}
|
||||
|
||||
// pickVoicePrompt selects a voice-*.gguf to bundle with a TTS import.
|
||||
// Honours an explicit preferences.voice substring (e.g. "Emma" picks
|
||||
// voice-en-Emma.gguf); otherwise returns the first voice file in listing
|
||||
// order so the choice is stable across imports of the same repo.
|
||||
func pickVoicePrompt(files []hfapi.ModelFile, hint string) (hfapi.ModelFile, bool) {
|
||||
hint = strings.ToLower(strings.TrimSpace(hint))
|
||||
var voices []hfapi.ModelFile
|
||||
for _, f := range files {
|
||||
base := strings.ToLower(filepath.Base(f.Path))
|
||||
if strings.HasPrefix(base, "voice-") && strings.HasSuffix(base, ".gguf") {
|
||||
voices = append(voices, f)
|
||||
}
|
||||
}
|
||||
if len(voices) == 0 {
|
||||
return hfapi.ModelFile{}, false
|
||||
}
|
||||
if hint != "" {
|
||||
for _, v := range voices {
|
||||
if strings.Contains(strings.ToLower(filepath.Base(v.Path)), hint) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return voices[0], true
|
||||
}
|
||||
|
||||
// repoNameOnly extracts the repo basename (everything after the last "/")
|
||||
// from HF metadata or, failing that, the URI. Empty when neither is set.
|
||||
func repoNameOnly(details Details) string {
|
||||
if details.HuggingFace != nil {
|
||||
id := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(id, "/"); idx >= 0 {
|
||||
return id[idx+1:]
|
||||
}
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// unmarshalPreferences decodes details.Preferences into a generic map. Returns
|
||||
// an empty map (never nil) on any failure so callers can index without nil
|
||||
// checks. Bad JSON is silently ignored — every importer here treats
|
||||
// preferences as best-effort hints.
|
||||
func unmarshalPreferences(raw json.RawMessage) map[string]any {
|
||||
out := map[string]any{}
|
||||
b, err := raw.MarshalJSON()
|
||||
if err != nil || len(b) == 0 {
|
||||
return out
|
||||
}
|
||||
_ = json.Unmarshal(b, &out)
|
||||
return out
|
||||
}
|
||||
|
||||
// stringPref reads a string preference by key, returning "" when missing or
|
||||
// of the wrong type.
|
||||
func stringPref(m map[string]any, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
261
core/gallery/importers/vibevoice-cpp_test.go
Normal file
261
core/gallery/importers/vibevoice-cpp_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VibeVoiceCppImporter", func() {
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
Expect(imp.Name()).To(Equal("vibevoice-cpp"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=vibevoice-cpp for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "vibevoice-cpp"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vibevoice-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tokenizer=tokenizer.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"))
|
||||
})
|
||||
|
||||
It("emits an ASR skeleton when usecase=asr is requested with no HF metadata", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "vibevoice-cpp", "usecase": "asr"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vibevoice-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type=asr"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"))
|
||||
})
|
||||
})
|
||||
|
||||
// Live HF call against the canonical bundle. Marked broad: it shouldn't
|
||||
// be brittle to upstream adding more quants/voices — we only assert that
|
||||
// the realtime TTS path was picked and the tokenizer was bundled.
|
||||
Context("detection from HuggingFace: mudler/vibevoice.cpp-models", func() {
|
||||
const uri = "https://huggingface.co/mudler/vibevoice.cpp-models"
|
||||
|
||||
It("routes to vibevoice-cpp, picks the realtime TTS GGUF and bundles tokenizer + voice prompt", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, json.RawMessage(`{}`))
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vibevoice-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"))
|
||||
|
||||
// Primary model must be the realtime variant (TTS default).
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("vibevoice-realtime-"))
|
||||
|
||||
// Tokenizer is mandatory and must show up both as a downloaded
|
||||
// file and as a tokenizer= option entry. The path is rooted
|
||||
// under vibevoice-cpp/<name>/ so multiple imports don't collide.
|
||||
var sawTokenizerFile, sawModelFile, sawVoiceFile bool
|
||||
for _, f := range modelConfig.Files {
|
||||
if f.Filename == "" {
|
||||
continue
|
||||
}
|
||||
if filepathBase(f.Filename) == "tokenizer.gguf" {
|
||||
sawTokenizerFile = true
|
||||
}
|
||||
if startsWith(filepathBase(f.Filename), "vibevoice-realtime-") {
|
||||
sawModelFile = true
|
||||
}
|
||||
if startsWith(filepathBase(f.Filename), "voice-") {
|
||||
sawVoiceFile = true
|
||||
}
|
||||
}
|
||||
Expect(sawTokenizerFile).To(BeTrue(), fmt.Sprintf("expected tokenizer.gguf in Files, got: %+v", modelConfig.Files))
|
||||
Expect(sawModelFile).To(BeTrue(), fmt.Sprintf("expected a vibevoice-realtime-*.gguf in Files, got: %+v", modelConfig.Files))
|
||||
Expect(sawVoiceFile).To(BeTrue(), fmt.Sprintf("expected a voice-*.gguf in Files, got: %+v", modelConfig.Files))
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tokenizer=vibevoice-cpp/"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("voice=vibevoice-cpp/"))
|
||||
})
|
||||
|
||||
It("routes to ASR + diarization when preferences.usecase=asr", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, json.RawMessage(`{"usecase":"asr"}`))
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vibevoice-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type=asr"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("vibevoice-asr-"))
|
||||
// ASR must NOT bundle a voice prompt — the backend ignores it
|
||||
// for transcription and we don't want gratuitous downloads.
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("voice="))
|
||||
})
|
||||
})
|
||||
|
||||
// Offline fixtures — assert the end-to-end shape of what the importer
|
||||
// emits without depending on HF availability or upstream file lists.
|
||||
Context("Import from HuggingFace file listing (offline)", func() {
|
||||
const repoBase = "https://huggingface.co/mudler/vibevoice.cpp-models/resolve/main/"
|
||||
|
||||
hfFile := func(path, sha string) hfapi.ModelFile {
|
||||
return hfapi.ModelFile{
|
||||
Path: path,
|
||||
SHA256: sha,
|
||||
URL: repoBase + path,
|
||||
}
|
||||
}
|
||||
|
||||
withHF := func(preferences string, files ...hfapi.ModelFile) importers.Details {
|
||||
d := importers.Details{
|
||||
URI: "https://huggingface.co/mudler/vibevoice.cpp-models",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "mudler/vibevoice.cpp-models",
|
||||
Files: files,
|
||||
},
|
||||
}
|
||||
if preferences != "" {
|
||||
d.Preferences = json.RawMessage(preferences)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
It("defaults to TTS realtime + tokenizer + first voice, nested under vibevoice-cpp/<name>/", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
details := withHF(`{"name":"vibe"}`,
|
||||
hfFile("vibevoice-realtime-0.5B-q8_0.gguf", "aaa"),
|
||||
hfFile("vibevoice-asr-q4_k.gguf", "bbb"),
|
||||
hfFile("tokenizer.gguf", "ccc"),
|
||||
hfFile("voice-en-Carter_man.gguf", "ddd"),
|
||||
hfFile("voice-en-Emma.gguf", "eee"),
|
||||
hfFile("README.md", ""),
|
||||
)
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(modelConfig.Files).To(HaveLen(3))
|
||||
byName := map[string]string{}
|
||||
for _, f := range modelConfig.Files {
|
||||
byName[filepathBase(f.Filename)] = f.Filename
|
||||
}
|
||||
Expect(byName).To(HaveKey("vibevoice-realtime-0.5B-q8_0.gguf"))
|
||||
Expect(byName).To(HaveKey("tokenizer.gguf"))
|
||||
Expect(byName).To(HaveKey("voice-en-Carter_man.gguf"))
|
||||
Expect(byName["tokenizer.gguf"]).To(Equal("vibevoice-cpp/vibe/tokenizer.gguf"))
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vibevoice-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: vibevoice-cpp/vibe/vibevoice-realtime-0.5B-q8_0.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- tokenizer=vibevoice-cpp/vibe/tokenizer.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- voice=vibevoice-cpp/vibe/voice-en-Carter_man.gguf"))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("type=asr"))
|
||||
})
|
||||
|
||||
It("routes to ASR when preferences.usecase=asr and skips voice prompts", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
details := withHF(`{"name":"vibe-asr","usecase":"asr"}`,
|
||||
hfFile("vibevoice-realtime-0.5B-q8_0.gguf", "aaa"),
|
||||
hfFile("vibevoice-asr-q4_k.gguf", "bbb"),
|
||||
hfFile("vibevoice-asr-q8_0.gguf", "fff"),
|
||||
hfFile("tokenizer.gguf", "ccc"),
|
||||
hfFile("voice-en-Emma.gguf", "ddd"),
|
||||
)
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(modelConfig.Files).To(HaveLen(2))
|
||||
byName := map[string]string{}
|
||||
for _, f := range modelConfig.Files {
|
||||
byName[filepathBase(f.Filename)] = f.Filename
|
||||
}
|
||||
// Default quant order picks q8_0 over q4_k.
|
||||
Expect(byName).To(HaveKey("vibevoice-asr-q8_0.gguf"))
|
||||
Expect(byName).To(HaveKey("tokenizer.gguf"))
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: vibevoice-cpp/vibe-asr/vibevoice-asr-q8_0.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- type=asr"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- tokenizer=vibevoice-cpp/vibe-asr/tokenizer.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("voice="))
|
||||
})
|
||||
|
||||
It("honours preferences.quantizations to pick a specific quant", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
details := withHF(`{"name":"vibe","quantizations":"q4_k"}`,
|
||||
hfFile("vibevoice-asr-q4_k.gguf", "aaa"),
|
||||
hfFile("vibevoice-asr-q8_0.gguf", "bbb"),
|
||||
hfFile("tokenizer.gguf", "ccc"),
|
||||
)
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Repo only ships ASR — auto-routes to asr, picks the requested
|
||||
// quant, emits type=asr automatically.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: vibevoice-cpp/vibe/vibevoice-asr-q4_k.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- type=asr"))
|
||||
})
|
||||
|
||||
It("honours preferences.voice to pick a specific voice prompt", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
details := withHF(`{"name":"vibe","voice":"Emma"}`,
|
||||
hfFile("vibevoice-realtime-0.5B-q8_0.gguf", "aaa"),
|
||||
hfFile("tokenizer.gguf", "bbb"),
|
||||
hfFile("voice-en-Carter_man.gguf", "ccc"),
|
||||
hfFile("voice-en-Emma.gguf", "ddd"),
|
||||
)
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- voice=vibevoice-cpp/vibe/voice-en-Emma.gguf"))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("voice-en-Carter_man"))
|
||||
})
|
||||
})
|
||||
|
||||
// Make sure we don't regress the existing Python-backend importer for
|
||||
// repos that don't carry the C++ port's signal (e.g. microsoft/VibeVoice-1.5B).
|
||||
Context("non-cpp vibevoice repos still route to the Python importer", func() {
|
||||
It("does not claim microsoft/VibeVoice-1.5B (no GGUF / no .cpp suffix)", func() {
|
||||
imp := &importers.VibeVoiceCppImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/microsoft/VibeVoice-1.5B",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "microsoft/VibeVoice-1.5B",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "config.json"},
|
||||
{Path: "model.safetensors"},
|
||||
},
|
||||
},
|
||||
Preferences: json.RawMessage(`{}`),
|
||||
}
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// filepathBase / startsWith are tiny helpers so the test file stays
|
||||
// stdlib-only and doesn't pull in path/filepath + strings just for the
|
||||
// expected-shape assertions.
|
||||
func filepathBase(p string) string {
|
||||
for i := len(p) - 1; i >= 0; i-- {
|
||||
if p[i] == '/' {
|
||||
return p[i+1:]
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func startsWith(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
KnownUsecaseStrings: []string{config.UsecaseChat},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
|
||||
@@ -52,3 +52,39 @@ func (m *GalleryModel) GetTags() []string {
|
||||
func (m *GalleryModel) GetDescription() string {
|
||||
return m.Description
|
||||
}
|
||||
|
||||
// GetKnownUsecases returns the usecase flags declared by the gallery entry,
|
||||
// falling back to the resolved backend's default usecases when the entry has
|
||||
// none of its own. Returns nil only when neither source provides any.
|
||||
//
|
||||
// Why the fallback: many gallery entries omit known_usecases because their
|
||||
// backend has only one sensible mode (e.g. stablediffusion-ggml is always
|
||||
// image generation). Without this fallback such models silently disappear
|
||||
// from usecase-based filtering in the UI.
|
||||
func (m *GalleryModel) GetKnownUsecases() *config.ModelConfigUsecase {
|
||||
if strs := overrideUsecaseStrings(m.Overrides); len(strs) > 0 {
|
||||
return config.GetUsecasesFromYAML(strs)
|
||||
}
|
||||
if defaults := config.DefaultUsecasesForBackendCap(m.Backend); len(defaults) > 0 {
|
||||
return config.GetUsecasesFromYAML(defaults)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func overrideUsecaseStrings(overrides map[string]any) []string {
|
||||
raw, ok := overrides["known_usecases"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
list, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
strs := make([]string, 0, len(list))
|
||||
for _, v := range list {
|
||||
if s, ok := v.(string); ok {
|
||||
strs = append(strs, s)
|
||||
}
|
||||
}
|
||||
return strs
|
||||
}
|
||||
|
||||
@@ -44,6 +44,10 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"POST", "/v1/audio/transcriptions", FeatureAudioTranscription},
|
||||
{"POST", "/audio/transcriptions", FeatureAudioTranscription},
|
||||
|
||||
// Audio diarization (speaker turns)
|
||||
{"POST", "/v1/audio/diarization", FeatureAudioDiarization},
|
||||
{"POST", "/audio/diarization", FeatureAudioDiarization},
|
||||
|
||||
// Audio speech / TTS
|
||||
{"POST", "/v1/audio/speech", FeatureAudioSpeech},
|
||||
{"POST", "/audio/speech", FeatureAudioSpeech},
|
||||
@@ -73,6 +77,11 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"POST", "/v1/voice/identify", FeatureVoiceRecognition},
|
||||
{"POST", "/v1/voice/forget", FeatureVoiceRecognition},
|
||||
|
||||
// Audio transform (echo cancellation, noise suppression, voice conversion, etc.)
|
||||
{"POST", "/audio/transformations", FeatureAudioTransform},
|
||||
{"POST", "/audio/transform", FeatureAudioTransform},
|
||||
{"GET", "/audio/transformations/stream", FeatureAudioTransform},
|
||||
|
||||
// Video
|
||||
{"POST", "/video", FeatureVideo},
|
||||
|
||||
@@ -158,6 +167,7 @@ func APIFeatureMetas() []FeatureMeta {
|
||||
{FeatureImages, "Image Generation", true},
|
||||
{FeatureAudioSpeech, "Audio Speech / TTS", true},
|
||||
{FeatureAudioTranscription, "Audio Transcription", true},
|
||||
{FeatureAudioDiarization, "Audio Diarization", true},
|
||||
{FeatureVAD, "Voice Activity Detection", true},
|
||||
{FeatureDetection, "Detection", true},
|
||||
{FeatureVideo, "Video Generation", true},
|
||||
@@ -170,5 +180,6 @@ func APIFeatureMetas() []FeatureMeta {
|
||||
{FeatureStores, "Stores", true},
|
||||
{FeatureFaceRecognition, "Face Recognition", true},
|
||||
{FeatureVoiceRecognition, "Voice Recognition", true},
|
||||
{FeatureAudioTransform, "Audio Transform", true},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ const (
|
||||
FeatureImages = "images"
|
||||
FeatureAudioSpeech = "audio_speech"
|
||||
FeatureAudioTranscription = "audio_transcription"
|
||||
FeatureAudioDiarization = "audio_diarization"
|
||||
FeatureVAD = "vad"
|
||||
FeatureDetection = "detection"
|
||||
FeatureVideo = "video"
|
||||
@@ -54,6 +55,7 @@ const (
|
||||
FeatureStores = "stores"
|
||||
FeatureFaceRecognition = "face_recognition"
|
||||
FeatureVoiceRecognition = "voice_recognition"
|
||||
FeatureAudioTransform = "audio_transform"
|
||||
)
|
||||
|
||||
// AgentFeatures lists agent-related features (default OFF).
|
||||
@@ -65,9 +67,10 @@ var GeneralFeatures = []string{FeatureFineTuning, FeatureQuantization}
|
||||
// APIFeatures lists API endpoint features (default ON).
|
||||
var APIFeatures = []string{
|
||||
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
||||
FeatureAudioDiarization,
|
||||
FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound,
|
||||
FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores,
|
||||
FeatureFaceRecognition, FeatureVoiceRecognition,
|
||||
FeatureFaceRecognition, FeatureVoiceRecognition, FeatureAudioTransform,
|
||||
}
|
||||
|
||||
// AllFeatures lists all known features (used by UI and validation).
|
||||
|
||||
@@ -32,8 +32,9 @@ var instructionDefs = []instructionDef{
|
||||
},
|
||||
{
|
||||
Name: "audio",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, and sound generation",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, speaker diarization, and sound generation",
|
||||
Tags: []string{"audio"},
|
||||
Intro: "Diarization (/v1/audio/diarization) returns speaker-labelled time segments. Backends with native ASR-diarization (vibevoice-cpp) can also emit per-segment text via include_text=true; backends with a dedicated pipeline (sherpa-onnx + pyannote) emit segmentation only. Response formats: json (default), verbose_json (adds speakers summary + text), rttm (NIST format).",
|
||||
},
|
||||
{
|
||||
Name: "images",
|
||||
|
||||
413
core/http/endpoints/localai/audio_transform.go
Normal file
413
core/http/endpoints/localai/audio_transform.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// audioTransformWSUpgrader allows WebSocket connections from any origin —
|
||||
// matches the realtime endpoint's policy. Authentication is handled at the
|
||||
// HTTP layer before the upgrade.
|
||||
var audioTransformWSUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
const (
|
||||
// audioTransformWSReadLimit is the per-message ceiling on inbound WS
|
||||
// frames. With 16 kHz / 256-sample / s16-stereo (1024 B/frame) the
|
||||
// default ceiling is generous; raised here to 1 MiB to allow larger
|
||||
// frame_samples for backends with longer hops.
|
||||
audioTransformWSReadLimit = 1 << 20
|
||||
)
|
||||
|
||||
// AudioTransformEndpoint implements the batch audio-transform API. Accepts a
|
||||
// multipart/form-data request with `audio` (required) and an optional
|
||||
// `reference` file. Backend-specific tuning is forwarded via repeated
|
||||
// `params[<key>]=<value>` form fields. Returns the enhanced audio as an
|
||||
// attachment, mirroring the /v1/audio/speech response shape.
|
||||
//
|
||||
// @Summary Transform audio (echo cancellation, noise suppression, voice conversion, etc.)
|
||||
// @Description Runs an audio-in / audio-out transform conditioned on an optional auxiliary reference signal. Concrete transforms include AEC + noise suppression + dereverberation (LocalVQE), voice conversion (reference = target speaker), and pitch shifting. The backend determines the operation; pass model-specific tuning via repeated `params[<key>]=<value>` form fields.
|
||||
// @Tags audio
|
||||
// @Accept multipart/form-data
|
||||
// @Produce audio/x-wav
|
||||
// @Param model formData string true "model"
|
||||
// @Param audio formData file true "primary input audio file"
|
||||
// @Param reference formData file false "auxiliary reference audio (loopback for AEC, target voice for conversion, etc.)"
|
||||
// @Param response_format formData string false "wav | mp3 | ogg | flac"
|
||||
// @Param sample_rate formData integer false "desired output sample rate"
|
||||
// @Success 200 {string} binary "transformed audio file"
|
||||
// @Router /audio/transformations [post]
|
||||
// @Router /audio/transform [post]
|
||||
func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AudioTransformRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
xlog.Debug("LocalAI Audio Transform Request received", "model", input.Model)
|
||||
|
||||
audioFile, err := c.FormFile("audio")
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "missing required 'audio' file field")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "audio-transform")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
audioPath, err := saveMultipartFileAsWAV(audioFile, dir, "audio")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var referencePath string
|
||||
if refFile, err := c.FormFile("reference"); err == nil {
|
||||
referencePath, err = saveMultipartFileAsWAV(refFile, dir, "reference")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
params := collectParamsFromForm(c)
|
||||
// Form-field params override schema-body params on collision.
|
||||
for k, v := range input.Params {
|
||||
if _, exists := params[k]; !exists {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{
|
||||
Params: params,
|
||||
}, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst := out.Dst
|
||||
|
||||
if input.SampleRate > 0 {
|
||||
dst, err = utils.AudioResample(dst, input.SampleRate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dst, err = utils.AudioConvert(dst, input.Format)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst, contentType := audio.NormalizeAudioFile(dst)
|
||||
if contentType != "" {
|
||||
c.Response().Header().Set(echo.HeaderContentType, contentType)
|
||||
}
|
||||
// Expose the persisted inputs so the React UI can save them in
|
||||
// history alongside the output. The /generated-audio/ prefix is
|
||||
// the same one ttsApi uses (parsed from Content-Disposition).
|
||||
if name := filepath.Base(out.AudioPath); name != "" {
|
||||
c.Response().Header().Set(echo.HeaderAccessControlExposeHeaders, "X-Audio-Input-Url, X-Audio-Reference-Url")
|
||||
c.Response().Header().Set("X-Audio-Input-Url", "/generated-audio/"+name)
|
||||
}
|
||||
if out.ReferencePath != "" {
|
||||
if name := filepath.Base(out.ReferencePath); name != "" {
|
||||
c.Response().Header().Set("X-Audio-Reference-Url", "/generated-audio/"+name)
|
||||
}
|
||||
}
|
||||
return c.Attachment(dst, filepath.Base(dst))
|
||||
}
|
||||
}
|
||||
|
||||
// Wire protocol documented in docs/content/features/audio-transform.md
|
||||
// and on schema.AudioTransformStreamControl.
|
||||
//
|
||||
// @Summary Bidirectional realtime audio transform over WebSocket.
|
||||
// @Description Streams binary PCM frames in (interleaved stereo: ch0=audio, ch1=reference) and out (mono). The first message must be a JSON `session.update` envelope describing model + sample format + frame size + backend params. Server emits binary PCM on the same cadence.
|
||||
// @Tags audio
|
||||
// @Router /audio/transformations/stream [get]
|
||||
func AudioTransformStreamEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ws, err := audioTransformWSUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = ws.Close() }()
|
||||
ws.SetReadLimit(audioTransformWSReadLimit)
|
||||
|
||||
mt, payload, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
xlog.Debug("audio_transform stream: client closed before session.update", "error", err)
|
||||
return nil
|
||||
}
|
||||
if mt != websocket.TextMessage {
|
||||
sendWSError(ws, "expected JSON session.update as first message")
|
||||
return nil
|
||||
}
|
||||
var ctrl schema.AudioTransformStreamControl
|
||||
if err := json.Unmarshal(payload, &ctrl); err != nil {
|
||||
sendWSError(ws, "invalid JSON: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if ctrl.Type != schema.AudioTransformCtrlSessionUpdate {
|
||||
sendWSError(ws, "first message must be "+schema.AudioTransformCtrlSessionUpdate)
|
||||
return nil
|
||||
}
|
||||
if ctrl.Model == "" {
|
||||
sendWSError(ws, "session.update missing model")
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg, err := app.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(ctrl.Model, app.ApplicationConfig())
|
||||
if err != nil || cfg == nil {
|
||||
sendWSError(ws, fmt.Sprintf("failed to load model config: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request().Context())
|
||||
defer cancel()
|
||||
|
||||
stream, err := backend.ModelAudioTransformStream(ctx, app.ModelLoader(), app.ApplicationConfig(), *cfg)
|
||||
if err != nil {
|
||||
sendWSError(ws, fmt.Sprintf("failed to open transform stream: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
sampleFormat, err := parseSampleFormat(ctrl.SampleFormat)
|
||||
if err != nil {
|
||||
sendWSError(ws, err.Error())
|
||||
return nil
|
||||
}
|
||||
if err := stream.Send(buildConfigRequest(sampleFormat, &ctrl)); err != nil {
|
||||
sendWSError(ws, fmt.Sprintf("backend send config: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
sendWSError(ws, fmt.Sprintf("backend recv: %v", err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := ws.WriteMessage(websocket.BinaryMessage, resp.Pcm); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Per-connection scratch for stereo de-interleaving — avoids two
|
||||
// allocs per inbound binary frame at the 16 ms cadence.
|
||||
var audioBuf, refBuf []byte
|
||||
readLoop:
|
||||
for {
|
||||
mt, payload, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
_ = stream.CloseSend()
|
||||
break readLoop
|
||||
}
|
||||
switch mt {
|
||||
case websocket.BinaryMessage:
|
||||
audio, ref := splitStereoFrameInto(payload, sampleFormat, &audioBuf, &refBuf)
|
||||
if err := stream.Send(&proto.AudioTransformFrameRequest{
|
||||
Payload: &proto.AudioTransformFrameRequest_Frame{
|
||||
Frame: &proto.AudioTransformFrame{
|
||||
AudioPcm: audio,
|
||||
ReferencePcm: ref,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
sendWSError(ws, fmt.Sprintf("backend send frame: %v", err))
|
||||
_ = stream.CloseSend()
|
||||
break readLoop
|
||||
}
|
||||
case websocket.TextMessage:
|
||||
var ctrl schema.AudioTransformStreamControl
|
||||
if err := json.Unmarshal(payload, &ctrl); err != nil {
|
||||
sendWSError(ws, "invalid mid-stream JSON: "+err.Error())
|
||||
continue
|
||||
}
|
||||
switch ctrl.Type {
|
||||
case schema.AudioTransformCtrlSessionUpdate:
|
||||
_ = stream.Send(buildConfigRequest(sampleFormat, &ctrl))
|
||||
case schema.AudioTransformCtrlSessionClose:
|
||||
_ = stream.CloseSend()
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseSampleFormat(s string) (proto.AudioTransformStreamConfig_SampleFormat, error) {
|
||||
switch strings.ToUpper(s) {
|
||||
case schema.AudioTransformSampleFormatF32LE:
|
||||
return proto.AudioTransformStreamConfig_F32_LE, nil
|
||||
case schema.AudioTransformSampleFormatS16LE, "":
|
||||
return proto.AudioTransformStreamConfig_S16_LE, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported sample_format: %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func buildConfigRequest(fmt_ proto.AudioTransformStreamConfig_SampleFormat, ctrl *schema.AudioTransformStreamControl) *proto.AudioTransformFrameRequest {
|
||||
return &proto.AudioTransformFrameRequest{
|
||||
Payload: &proto.AudioTransformFrameRequest_Config{
|
||||
Config: &proto.AudioTransformStreamConfig{
|
||||
SampleFormat: fmt_,
|
||||
SampleRate: int32(ctrl.SampleRate),
|
||||
FrameSamples: int32(ctrl.FrameSamples),
|
||||
Params: ctrl.Params,
|
||||
Reset_: ctrl.Reset,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// saveMultipartFileAsWAV materialises an uploaded multipart file into `dir`
|
||||
// and converts it to LocalVQE's required shape (16 kHz mono s16 WAV) via
|
||||
// ffmpeg. The conversion is a passthrough when the upload already matches.
|
||||
// `name` is used as the base filename for the converted output so the dir
|
||||
// stays readable for debugging (e.g. "audio.wav", "reference.wav").
|
||||
func saveMultipartFileAsWAV(fh *multipart.FileHeader, dir, name string) (string, error) {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
raw := filepath.Join(dir, "raw-"+path.Base(fh.Filename))
|
||||
out, err := os.Create(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := io.Copy(out, f); err != nil {
|
||||
_ = out.Close()
|
||||
return "", err
|
||||
}
|
||||
_ = out.Close()
|
||||
|
||||
dst := filepath.Join(dir, name+".wav")
|
||||
if err := utils.AudioToWav(raw, dst); err != nil {
|
||||
return "", fmt.Errorf("normalize %s: %w", name, err)
|
||||
}
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// collectParamsFromForm walks the multipart form values and harvests any
|
||||
// that match the `params[<key>]` shape. Returns nil if there are no matches.
|
||||
func collectParamsFromForm(c echo.Context) map[string]string {
|
||||
params := map[string]string{}
|
||||
form, err := c.FormParams()
|
||||
if err != nil {
|
||||
return params
|
||||
}
|
||||
for key, vals := range form {
|
||||
if len(vals) == 0 {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(key, "params[") || !strings.HasSuffix(key, "]") {
|
||||
continue
|
||||
}
|
||||
inner := strings.TrimSuffix(strings.TrimPrefix(key, "params["), "]")
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
continue
|
||||
}
|
||||
// Last value wins for duplicate keys — matches OpenAI's form-field
|
||||
// override semantics.
|
||||
params[inner] = vals[len(vals)-1]
|
||||
}
|
||||
// Form-field shortcuts for the common LocalVQE knobs. params[*] still wins
|
||||
// when both are provided (they ran first).
|
||||
if _, exists := params[schema.AudioTransformParamNoiseGate]; !exists {
|
||||
if v := c.FormValue(schema.AudioTransformParamNoiseGate); v != "" {
|
||||
if b, err := strconv.ParseBool(v); err == nil {
|
||||
if b {
|
||||
params[schema.AudioTransformParamNoiseGate] = "true"
|
||||
} else {
|
||||
params[schema.AudioTransformParamNoiseGate] = "false"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, exists := params[schema.AudioTransformParamNoiseGateThreshold]; !exists {
|
||||
if v := c.FormValue(schema.AudioTransformParamNoiseGateThreshold); v != "" {
|
||||
params[schema.AudioTransformParamNoiseGateThreshold] = v
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// splitStereoFrameInto deinterleaves a stereo PCM frame in-place into
|
||||
// caller-owned reusable buffers (channel 0 → audio, channel 1 → reference).
|
||||
// Sample size is inferred from the proto enum: s16=2 B, f32=4 B. Trailing
|
||||
// odd bytes are truncated.
|
||||
func splitStereoFrameInto(buf []byte, fmt_ proto.AudioTransformStreamConfig_SampleFormat, audio, ref *[]byte) ([]byte, []byte) {
|
||||
sampleSize := 2
|
||||
if fmt_ == proto.AudioTransformStreamConfig_F32_LE {
|
||||
sampleSize = 4
|
||||
}
|
||||
stride := sampleSize * 2
|
||||
n := len(buf) / stride
|
||||
want := n * sampleSize
|
||||
if cap(*audio) < want {
|
||||
*audio = make([]byte, want)
|
||||
} else {
|
||||
*audio = (*audio)[:want]
|
||||
}
|
||||
if cap(*ref) < want {
|
||||
*ref = make([]byte, want)
|
||||
} else {
|
||||
*ref = (*ref)[:want]
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
copy((*audio)[i*sampleSize:(i+1)*sampleSize], buf[i*stride:i*stride+sampleSize])
|
||||
copy((*ref)[i*sampleSize:(i+1)*sampleSize], buf[i*stride+sampleSize:(i+1)*stride])
|
||||
}
|
||||
return *audio, *ref
|
||||
}
|
||||
|
||||
func sendWSError(ws *websocket.Conn, msg string) {
|
||||
payload, _ := json.Marshal(schema.AudioTransformStreamControl{
|
||||
Type: schema.AudioTransformCtrlError,
|
||||
Error: msg,
|
||||
})
|
||||
_ = ws.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
@@ -36,6 +36,8 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
{Name: "faster-qwen3-tts", Modality: "tts", AutoDetect: false, Description: "Faster Qwen3 TTS (preference-only)"},
|
||||
// Detection
|
||||
{Name: "sam3-cpp", Modality: "detection", AutoDetect: false, Description: "SAM3 C++ object detection (preference-only)"},
|
||||
// Audio transform (audio-in / audio-out, optional reference signal)
|
||||
{Name: "localvqe", Modality: "audio-transform", AutoDetect: false, Description: "LocalVQE C++ joint AEC + noise suppression + dereverberation (preference-only)"},
|
||||
}
|
||||
|
||||
// UpgradeInfoProvider is an interface for querying cached backend upgrade information.
|
||||
|
||||
@@ -116,13 +116,13 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
capability := strings.TrimPrefix(provider, "models:")
|
||||
var filterFn config.ModelConfigFilterFn
|
||||
switch capability {
|
||||
case "chat":
|
||||
case config.UsecaseChat:
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT)
|
||||
case "tts":
|
||||
case config.UsecaseTTS:
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS)
|
||||
case "vad":
|
||||
case config.UsecaseVAD:
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD)
|
||||
case "transcript":
|
||||
case config.UsecaseTranscript:
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||
default:
|
||||
filterFn = config.NoFilterFn
|
||||
|
||||
@@ -77,18 +77,17 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
}
|
||||
estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
result, err := vram.EstimateModel(estCtx, vram.ModelEstimateInput{
|
||||
Files: files,
|
||||
Options: vram.EstimateOptions{ContextLength: 8192},
|
||||
})
|
||||
result, err := vram.EstimateModelMultiContext(estCtx, vram.ModelEstimateInput{
|
||||
Files: files,
|
||||
}, []uint32{8192})
|
||||
if err == nil {
|
||||
if result.SizeBytes > 0 {
|
||||
resp.EstimatedSizeBytes = result.SizeBytes
|
||||
resp.EstimatedSizeDisplay = result.SizeDisplay
|
||||
}
|
||||
if result.VRAMBytes > 0 {
|
||||
resp.EstimatedVRAMBytes = result.VRAMBytes
|
||||
resp.EstimatedVRAMDisplay = result.VRAMDisplay
|
||||
if v := result.VRAMForContext(8192); v > 0 {
|
||||
resp.EstimatedVRAMBytes = v
|
||||
resp.EstimatedVRAMDisplay = vram.FormatBytes(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user