mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-20 06:35:41 -04:00
Compare commits
25 Commits
feat/tq-ik
...
v4.1.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad232fdb1a | ||
|
|
11637b5a1b | ||
|
|
0dda4fe6f0 | ||
|
|
773489eeb1 | ||
|
|
06fbe48b3f | ||
|
|
232e324a68 | ||
|
|
39c954764c | ||
|
|
9b7d5513fc | ||
|
|
84cd8c0e7f | ||
|
|
d990f2790c | ||
|
|
53deeb1107 | ||
|
|
c5a840f6af | ||
|
|
6d9d77d590 | ||
|
|
6f304d1201 | ||
|
|
557d0f0f04 | ||
|
|
b7e3589875 | ||
|
|
716ddd697b | ||
|
|
223deb908d | ||
|
|
9f8821bba8 | ||
|
|
84e51b68ef | ||
|
|
7962dd16f7 | ||
|
|
a1466b305a | ||
|
|
57c0026715 | ||
|
|
1ed6b9e5ed | ||
|
|
e4ee74354f |
92
.github/workflows/backend.yml
vendored
92
.github/workflows/backend.yml
vendored
@@ -1828,98 +1828,6 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# llama-cpp-tq (TurboQuant fork)
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-llama-cpp-tq'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-llama-cpp-tq'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-llama-cpp-tq'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-llama-cpp-tq'
|
||||
base-image: "ubuntu:24.04"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
ubuntu-version: '2404'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-llama-cpp-tq'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-llama-cpp-tq'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-llama-cpp-tq'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "llama-cpp-tq"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# Stablediffusion-ggml
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
|
||||
7
.github/workflows/bump_deps.yaml
vendored
7
.github/workflows/bump_deps.yaml
vendored
@@ -14,11 +14,6 @@ jobs:
|
||||
variable: "LLAMA_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/cpp/llama-cpp/Makefile"
|
||||
- repository: "TheTom/llama-cpp-turboquant"
|
||||
variable: "LLAMA_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/cpp/llama-cpp-tq/Makefile"
|
||||
branch_suffix: "-tq"
|
||||
- repository: "ggml-org/whisper.cpp"
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
@@ -65,7 +60,7 @@ jobs:
|
||||
push-to-fork: ci-forks/LocalAI
|
||||
commit-message: ':arrow_up: Update ${{ matrix.repository }}'
|
||||
title: 'chore: :arrow_up: Update ${{ matrix.repository }} to `${{ steps.bump.outputs.commit }}`'
|
||||
branch: "update/${{ matrix.variable }}${{ matrix.branch_suffix }}"
|
||||
branch: "update/${{ matrix.variable }}"
|
||||
body: ${{ steps.bump.outputs.message }}
|
||||
signoff: true
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,7 +9,6 @@ prepare-sources
|
||||
/backend/cpp/llama-cpp/llama.cpp
|
||||
/backend/cpp/llama-*
|
||||
!backend/cpp/llama-cpp
|
||||
!backend/cpp/llama-cpp-tq
|
||||
/backends
|
||||
/backend-images
|
||||
/result.yaml
|
||||
|
||||
4
Makefile
4
Makefile
@@ -544,9 +544,8 @@ backend-images:
|
||||
mkdir -p backend-images
|
||||
|
||||
# Backend metadata: BACKEND_NAME | DOCKERFILE_TYPE | BUILD_CONTEXT | PROGRESS_FLAG | NEEDS_BACKEND_ARG
|
||||
# llama-cpp and forks - use llama-cpp Dockerfile
|
||||
# llama-cpp is special - uses llama-cpp Dockerfile and doesn't need BACKEND arg
|
||||
BACKEND_LLAMA_CPP = llama-cpp|llama-cpp|.|false|false
|
||||
BACKEND_LLAMA_CPP_TQ = llama-cpp-tq|llama-cpp|.|false|true
|
||||
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
@@ -610,7 +609,6 @@ endef
|
||||
|
||||
# Generate all docker-build targets
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_TQ)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
|
||||
@@ -58,9 +58,7 @@ ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG BACKEND=llama-cpp
|
||||
ARG LLAMA_BACKEND_DIR=${BACKEND}
|
||||
ENV LLAMA_BACKEND_DIR=${LLAMA_BACKEND_DIR}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
@@ -257,27 +255,32 @@ if [[ -n "${CUDA_DOCKER_ARCH:-}" ]]; then
|
||||
CUDA_ARCH_ESC="${CUDA_DOCKER_ARCH//;/\\;}"
|
||||
export CMAKE_ARGS="${CMAKE_ARGS:-} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH_ESC}"
|
||||
echo "CMAKE_ARGS(env) = ${CMAKE_ARGS}"
|
||||
rm -rf /LocalAI/backend/cpp/${LLAMA_BACKEND_DIR}-*-build
|
||||
rm -rf /LocalAI/backend/cpp/llama-cpp-*-build
|
||||
fi
|
||||
|
||||
cd /LocalAI/backend/cpp/${LLAMA_BACKEND_DIR}
|
||||
|
||||
if [ "${TARGETARCH}" = "arm64" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then
|
||||
make ARCH=aarch64 build-variants
|
||||
cd /LocalAI/backend/cpp/llama-cpp
|
||||
make llama-cpp-fallback
|
||||
make llama-cpp-grpc
|
||||
make llama-cpp-rpc-server
|
||||
else
|
||||
make build-variants
|
||||
cd /LocalAI/backend/cpp/llama-cpp
|
||||
make llama-cpp-avx
|
||||
make llama-cpp-avx2
|
||||
make llama-cpp-avx512
|
||||
make llama-cpp-fallback
|
||||
make llama-cpp-grpc
|
||||
make llama-cpp-rpc-server
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# Copy libraries using a script to handle architecture differences
|
||||
RUN make -BC /LocalAI/backend/cpp/${LLAMA_BACKEND_DIR} package
|
||||
RUN make -BC /LocalAI/backend/cpp/llama-cpp package
|
||||
|
||||
|
||||
FROM scratch
|
||||
|
||||
ARG BACKEND=llama-cpp
|
||||
ARG LLAMA_BACKEND_DIR=${BACKEND}
|
||||
|
||||
# Copy all available binaries (the build process only creates the appropriate ones for the target architecture)
|
||||
COPY --from=builder /LocalAI/backend/cpp/${LLAMA_BACKEND_DIR}/package/. ./
|
||||
COPY --from=builder /LocalAI/backend/cpp/llama-cpp/package/. ./
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
LLAMA_VERSION?=master
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
BACKEND_NAME?=llama-cpp-tq
|
||||
SHARED_DIR?=$(CURDIR)/../llama-cpp
|
||||
|
||||
include ../llama-cpp/Makefile
|
||||
@@ -59,11 +59,6 @@ add_library(hw_grpc_proto
|
||||
|
||||
add_executable(${TARGET} grpc-server.cpp json.hpp httplib.h)
|
||||
|
||||
# Enable autoparser support if the header exists (not present in all llama.cpp forks)
|
||||
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/chat-auto-parser.h")
|
||||
target_compile_definitions(${TARGET} PRIVATE HAS_AUTOPARSER)
|
||||
endif()
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../llava)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
|
||||
LLAMA_VERSION?=a1cfb645307edc61a89e41557f290f441043d3c2
|
||||
LLAMA_VERSION?=761797ffdf2ce3f118e82c663b1ad7d935fbd656
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
BACKEND_NAME?=llama-cpp
|
||||
SHARED_DIR?=$(CURDIR)
|
||||
GRPC_SERVER_DIR?=tools/grpc-server
|
||||
SERVER_SOURCE_DIR?=tools/server
|
||||
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
@@ -71,17 +67,6 @@ ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
-DCMAKE_CXX_FLAGS="-fsycl"
|
||||
endif
|
||||
|
||||
# Variants to build for each architecture (can be overridden by forks)
|
||||
X86_64_VARIANTS ?= llama-cpp-avx llama-cpp-avx2 llama-cpp-avx512 llama-cpp-fallback llama-cpp-grpc llama-cpp-rpc-server
|
||||
ARM64_VARIANTS ?= llama-cpp-fallback llama-cpp-grpc llama-cpp-rpc-server
|
||||
|
||||
build-variants:
|
||||
ifeq ($(ARCH),aarch64)
|
||||
@for v in $(ARM64_VARIANTS); do $(MAKE) $$v || exit 1; done
|
||||
else
|
||||
@for v in $(X86_64_VARIANTS); do $(MAKE) $$v || exit 1; done
|
||||
endif
|
||||
|
||||
INSTALLED_PACKAGES=$(CURDIR)/../grpc/installed_packages
|
||||
INSTALLED_LIB_CMAKE=$(INSTALLED_PACKAGES)/lib/cmake
|
||||
ADDED_CMAKE_ARGS=-Dabsl_DIR=${INSTALLED_LIB_CMAKE}/absl \
|
||||
@@ -105,42 +90,42 @@ else
|
||||
endif
|
||||
|
||||
llama-cpp-avx2: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME) $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx2-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx2-build purge
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:avx2${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="$(BACKEND_NAME)-avx2-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx2-build/grpc-server llama-cpp-avx2
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx2-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build/grpc-server llama-cpp-avx2
|
||||
|
||||
llama-cpp-avx512: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME) $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx512-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx512-build purge
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:avx512${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="$(BACKEND_NAME)-avx512-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx512-build/grpc-server llama-cpp-avx512
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx512-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build/grpc-server llama-cpp-avx512
|
||||
|
||||
llama-cpp-avx: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME) $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx-build purge
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:avx${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="$(BACKEND_NAME)-avx-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-avx-build/grpc-server llama-cpp-avx
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-avx-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build/grpc-server llama-cpp-avx
|
||||
|
||||
llama-cpp-fallback: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME) $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-fallback-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-fallback-build purge
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:fallback${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="$(BACKEND_NAME)-fallback-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-fallback-build/grpc-server llama-cpp-fallback
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-fallback-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build/grpc-server llama-cpp-fallback
|
||||
|
||||
llama-cpp-grpc: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME) $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-grpc-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-grpc-build purge
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:grpc${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target rpc-server" $(MAKE) VARIANT="$(BACKEND_NAME)-grpc-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-grpc-build/grpc-server llama-cpp-grpc
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target rpc-server" $(MAKE) VARIANT="llama-cpp-grpc-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/grpc-server llama-cpp-grpc
|
||||
|
||||
llama-cpp-rpc-server: llama-cpp-grpc
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../$(BACKEND_NAME)-grpc-build/llama.cpp/build/bin/rpc-server llama-cpp-rpc-server
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/llama.cpp/build/bin/rpc-server llama-cpp-rpc-server
|
||||
|
||||
llama.cpp:
|
||||
mkdir -p llama.cpp
|
||||
@@ -148,30 +133,30 @@ llama.cpp:
|
||||
git init && \
|
||||
git remote add origin $(LLAMA_REPO) && \
|
||||
git fetch origin && \
|
||||
(git checkout -b build $(LLAMA_VERSION) || git checkout -b build origin/$(LLAMA_VERSION)) && \
|
||||
git checkout -b build $(LLAMA_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
llama.cpp/$(GRPC_SERVER_DIR): llama.cpp
|
||||
mkdir -p llama.cpp/$(GRPC_SERVER_DIR)
|
||||
SHARED_DIR=$(SHARED_DIR) SERVER_SOURCE_DIR=$(SERVER_SOURCE_DIR) GRPC_SERVER_DIR=$(GRPC_SERVER_DIR) bash $(SHARED_DIR)/prepare.sh
|
||||
llama.cpp/tools/grpc-server: llama.cpp
|
||||
mkdir -p llama.cpp/tools/grpc-server
|
||||
bash prepare.sh
|
||||
|
||||
rebuild:
|
||||
SHARED_DIR=$(SHARED_DIR) SERVER_SOURCE_DIR=$(SERVER_SOURCE_DIR) GRPC_SERVER_DIR=$(GRPC_SERVER_DIR) bash $(SHARED_DIR)/prepare.sh
|
||||
bash prepare.sh
|
||||
rm -rf grpc-server
|
||||
$(MAKE) grpc-server
|
||||
|
||||
package:
|
||||
bash $(SHARED_DIR)/package.sh
|
||||
bash package.sh
|
||||
|
||||
purge:
|
||||
rm -rf llama.cpp/build
|
||||
rm -rf llama.cpp/$(GRPC_SERVER_DIR)
|
||||
rm -rf llama.cpp/tools/grpc-server
|
||||
rm -rf grpc-server
|
||||
|
||||
clean: purge
|
||||
rm -rf llama.cpp
|
||||
|
||||
grpc-server: llama.cpp llama.cpp/$(GRPC_SERVER_DIR)
|
||||
grpc-server: llama.cpp llama.cpp/tools/grpc-server
|
||||
@echo "Building grpc-server with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
|
||||
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||
+bash -c "source $(ONEAPI_VARS); \
|
||||
|
||||
@@ -17,9 +17,7 @@
|
||||
#include "backend.pb.h"
|
||||
#include "backend.grpc.pb.h"
|
||||
#include "common.h"
|
||||
#ifdef HAS_AUTOPARSER
|
||||
#include "chat-auto-parser.h"
|
||||
#endif
|
||||
#include <getopt.h>
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
@@ -42,45 +40,41 @@ using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
||||
// gRPC bearer token auth for distributed mode.
|
||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
||||
public:
|
||||
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
||||
|
||||
bool IsBlocking() const override { return false; }
|
||||
// Cached auth token — empty means auth is disabled.
|
||||
static std::string g_grpc_auth_token;
|
||||
|
||||
grpc::Status Process(const InputMetadata& auth_metadata,
|
||||
grpc::AuthContext* /*context*/,
|
||||
OutputMetadata* /*consumed_auth_metadata*/,
|
||||
OutputMetadata* /*response_metadata*/) override {
|
||||
auto it = auth_metadata.find("authorization");
|
||||
if (it != auth_metadata.end()) {
|
||||
std::string expected = "Bearer " + token_;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
// Constant-time comparison
|
||||
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string token_;
|
||||
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
// Returns OK when auth is disabled or the token matches.
|
||||
static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||
if (g_grpc_auth_token.empty()) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
auto metadata = context->client_metadata();
|
||||
auto it = metadata.find("authorization");
|
||||
if (it != metadata.end()) {
|
||||
std::string expected = "Bearer " + g_grpc_auth_token;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
if (expected.size() == got.size() &&
|
||||
ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
// END LocalAI
|
||||
|
||||
@@ -290,6 +284,12 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
// Speculative decoding per-request overrides
|
||||
// NDraft maps to speculative.n_max (maximum draft tokens per speculation step)
|
||||
if (predict->ndraft() > 0) {
|
||||
data["speculative.n_max"] = predict->ndraft();
|
||||
}
|
||||
|
||||
// Add the correlationid to json data
|
||||
data["correlation_id"] = predict->correlationid();
|
||||
|
||||
@@ -408,6 +408,16 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->mmproj().empty()) {
|
||||
params.mmproj.path = request->mmproj();
|
||||
}
|
||||
|
||||
// Draft model for speculative decoding
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.mparams_dft.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
params.model_alias.insert(request->modelfile());
|
||||
if (!request->cachetypekey().empty()) {
|
||||
@@ -615,6 +625,48 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (8)
|
||||
}
|
||||
}
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
auto type = common_speculative_type_from_name(optval_str);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_min") || !strcmp(optname, "draft_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_min = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_min") || !strcmp(optname, "draft_p_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_min = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_split")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_split = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_n") || !strcmp(optname, "ngram_size_n")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_n = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_m") || !strcmp(optname, "ngram_size_m")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_m = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_min_hits") || !strcmp(optname, "ngram_min_hits")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_min_hits = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_gpu_layers")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_gpu_layers = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_ctx_size")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_ctx = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -759,13 +811,17 @@ private:
|
||||
public:
|
||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement Health RPC
|
||||
reply->set_message("OK");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(ctx_server, request, params);
|
||||
@@ -964,6 +1020,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -1251,6 +1309,7 @@ public:
|
||||
|
||||
body_json["messages"] = messages_json;
|
||||
body_json["stream"] = true; // PredictStream is always streaming
|
||||
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||
|
||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||
@@ -1558,8 +1617,11 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
// Without this, the PEG parser never produces diffs and the Go side
|
||||
// cannot detect tool calls or separate reasoning from content.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -1584,19 +1646,47 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
std::string completion_text;
|
||||
|
||||
if (res_json.contains("choices")) {
|
||||
// OAI chat format — extract content from choices[0].delta
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & delta = choices[0].value("delta", json::object());
|
||||
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||
completion_text = delta.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = res_json.value("content", "");
|
||||
}
|
||||
|
||||
reply.set_message(completion_text);
|
||||
|
||||
// Token counts: native format has top-level fields,
|
||||
// OAI format has them in "usage" (final chunk only)
|
||||
if (res_json.contains("usage")) {
|
||||
const auto & usage = res_json.at("usage");
|
||||
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||
} else {
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
}
|
||||
|
||||
// Timings: present as top-level "timings" in both formats
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
@@ -1605,6 +1695,12 @@ public:
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Attach chat deltas from the autoparser to a Reply.
|
||||
// When diffs are available, populate ChatDeltas on the reply.
|
||||
// The raw message is always preserved so the Go side can use it
|
||||
// for reasoning extraction and tool call parsing as a fallback
|
||||
// (important in distributed mode where ChatDeltas may not be
|
||||
// the primary parsing path).
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
@@ -1667,6 +1763,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2287,8 +2385,9 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -2319,25 +2418,48 @@ public:
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||
std::string completion_text;
|
||||
int32_t tokens_predicted = 0;
|
||||
int32_t tokens_evaluated = 0;
|
||||
|
||||
if (result_json.contains("choices")) {
|
||||
// OAI chat format
|
||||
const auto & choices = result_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
completion_text = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
if (result_json.contains("usage")) {
|
||||
const auto & usage = result_json.at("usage");
|
||||
tokens_predicted = usage.value("completion_tokens", 0);
|
||||
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = result_json.value("content", "");
|
||||
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
}
|
||||
reply->set_message(completion_text);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Timings: present in both formats as a top-level "timings" object
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
reply->set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
@@ -2353,7 +2475,20 @@ public:
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
// Handle both native and OAI chat formats
|
||||
std::string result_content;
|
||||
if (res_json.contains("choices")) {
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
result_content = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result_content = res_json.value("content", "");
|
||||
}
|
||||
arr.push_back(result_content);
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
@@ -2385,6 +2520,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2565,7 +2702,9 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2667,7 +2806,6 @@ public:
|
||||
|
||||
response->set_rendered_template(rendered_template);
|
||||
|
||||
#ifdef HAS_AUTOPARSER
|
||||
// Run differential template analysis to detect tool format markers
|
||||
if (params_base.use_jinja) {
|
||||
try {
|
||||
@@ -2773,7 +2911,6 @@ public:
|
||||
SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
@@ -2807,19 +2944,14 @@ int main(int argc, char** argv) {
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
ServerBuilder builder;
|
||||
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
std::shared_ptr<grpc::ServerCredentials> creds;
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
creds->SetAuthMetadataProcessor(
|
||||
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
} else {
|
||||
creds = grpc::InsecureServerCredentials();
|
||||
}
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
|
||||
builder.AddListeningPort(server_address, creds);
|
||||
// Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
g_grpc_auth_token = auth_token;
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
|
||||
@@ -5,21 +5,14 @@
|
||||
|
||||
set -e
|
||||
|
||||
# Use working directory (not script location) so forks that share this script work correctly
|
||||
CURDIR=$(pwd)
|
||||
SCRIPT_DIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${SCRIPT_DIR}/../../.."
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avrf $CURDIR/llama-cpp-* $CURDIR/package/
|
||||
# Copy run.sh — prefer local copy, fall back to shared dir (script location)
|
||||
if [ -f "$CURDIR/run.sh" ]; then
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
else
|
||||
cp -rfv $SCRIPT_DIR/run.sh $CURDIR/package/
|
||||
fi
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
|
||||
@@ -1,43 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
SHARED_DIR="${SHARED_DIR:-.}"
|
||||
SERVER_SOURCE_DIR="${SERVER_SOURCE_DIR:-tools/server}"
|
||||
GRPC_SERVER_DIR="${GRPC_SERVER_DIR:-tools/grpc-server}"
|
||||
## Patches
|
||||
|
||||
## Apply patches from the `patches` directory
|
||||
if [ -d "patches" ]; then
|
||||
for patch in $(ls patches); do
|
||||
echo "Applying patch $patch"
|
||||
patch -d llama.cpp/ -p1 < patches/$patch
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
set -e
|
||||
|
||||
# Copy server source files into grpc-server build directory
|
||||
for file in $(ls llama.cpp/${SERVER_SOURCE_DIR}/); do
|
||||
cp -rfv llama.cpp/${SERVER_SOURCE_DIR}/$file llama.cpp/${GRPC_SERVER_DIR}/
|
||||
for file in $(ls llama.cpp/tools/server/); do
|
||||
cp -rfv llama.cpp/tools/server/$file llama.cpp/tools/grpc-server/
|
||||
done
|
||||
|
||||
# Copy build files — prefer local overrides, fall back to SHARED_DIR
|
||||
for f in CMakeLists.txt grpc-server.cpp; do
|
||||
if [ -f "$f" ]; then
|
||||
cp -r "$f" llama.cpp/${GRPC_SERVER_DIR}/
|
||||
else
|
||||
cp -r "$SHARED_DIR/$f" llama.cpp/${GRPC_SERVER_DIR}/
|
||||
fi
|
||||
done
|
||||
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/${GRPC_SERVER_DIR}/
|
||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/${GRPC_SERVER_DIR}/
|
||||
|
||||
# Add grpc-server subdirectory to the parent CMakeLists.txt
|
||||
PARENT_CMAKELISTS="llama.cpp/$(dirname ${GRPC_SERVER_DIR})/CMakeLists.txt"
|
||||
cp -r CMakeLists.txt llama.cpp/tools/grpc-server/
|
||||
cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
||||
|
||||
set +e
|
||||
if grep -q "grpc-server" "$PARENT_CMAKELISTS"; then
|
||||
if grep -q "grpc-server" llama.cpp/tools/CMakeLists.txt; then
|
||||
echo "grpc-server already added"
|
||||
else
|
||||
echo "add_subdirectory(grpc-server)" >> "$PARENT_CMAKELISTS"
|
||||
echo "add_subdirectory(grpc-server)" >> llama.cpp/tools/CMakeLists.txt
|
||||
fi
|
||||
set -e
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=87ecb95cbc65dc8e58e3d88f4f4a59a0939796f5
|
||||
STABLEDIFFUSION_GGML_VERSION?=7397ddaa86f4e8837d5261724678cde0f36d4d89
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -29,34 +29,6 @@
|
||||
nvidia-cuda-12: "cuda12-llama-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp"
|
||||
- &llamacpp_tq
|
||||
name: "llama-cpp-tq"
|
||||
alias: "llama-cpp-tq"
|
||||
license: mit
|
||||
description: |
|
||||
TurboQuant llama.cpp fork - quantization research
|
||||
urls:
|
||||
- https://github.com/TheTom/llama-cpp-turboquant
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- CPU
|
||||
- GPU
|
||||
- Metal
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-llama-cpp-tq"
|
||||
nvidia: "cuda12-llama-cpp-tq"
|
||||
intel: "intel-sycl-f16-llama-cpp-tq"
|
||||
amd: "rocm-llama-cpp-tq"
|
||||
metal: "metal-llama-cpp-tq"
|
||||
vulkan: "vulkan-llama-cpp-tq"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-llama-cpp-tq"
|
||||
nvidia-cuda-13: "cuda13-llama-cpp-tq"
|
||||
nvidia-cuda-12: "cuda12-llama-cpp-tq"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp-tq"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp-tq"
|
||||
- &whispercpp
|
||||
name: "whisper"
|
||||
alias: "whisper"
|
||||
@@ -1280,57 +1252,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-llama-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-llama-cpp
|
||||
# llama-cpp-tq (TurboQuant) concrete backends
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "cpu-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "cuda12-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "cuda13-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "rocm-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "intel-sycl-f16-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "intel-sycl-f32-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "vulkan-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "metal-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "nvidia-l4t-arm64-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-llama-cpp-tq
|
||||
- !!merge <<: *llamacpp_tq
|
||||
name: "cuda13-nvidia-l4t-arm64-llama-cpp-tq"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp-tq"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp-tq
|
||||
## whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "nvidia-l4t-arm64-whisper"
|
||||
|
||||
@@ -36,6 +36,27 @@ type TokenUsage struct {
|
||||
Completion int
|
||||
TimingPromptProcessing float64
|
||||
TimingTokenGeneration float64
|
||||
ChatDeltas []*proto.ChatDelta // per-chunk deltas from C++ autoparser (only set during streaming)
|
||||
}
|
||||
|
||||
// HasChatDeltaContent returns true if any chat delta carries content or reasoning text.
|
||||
// Used to decide whether to prefer C++ autoparser deltas over Go-side tag extraction.
|
||||
func (t TokenUsage) HasChatDeltaContent() bool {
|
||||
for _, d := range t.ChatDeltas {
|
||||
if d.Content != "" || d.ReasoningContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChatDeltaReasoningAndContent extracts accumulated reasoning and content from chat deltas.
|
||||
func (t TokenUsage) ChatDeltaReasoningAndContent() (reasoning, content string) {
|
||||
for _, d := range t.ChatDeltas {
|
||||
content += d.Content
|
||||
reasoning += d.ReasoningContent
|
||||
}
|
||||
return reasoning, content
|
||||
}
|
||||
|
||||
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
||||
@@ -171,6 +192,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
||||
}
|
||||
|
||||
// Attach per-chunk chat deltas to tokenUsage so the callback can use them
|
||||
tokenUsage.ChatDeltas = reply.ChatDeltas
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
@@ -200,6 +224,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if len(msg) == 0 {
|
||||
tokenCallback("", tokenUsage)
|
||||
}
|
||||
|
||||
// Clear per-chunk deltas so they don't leak to the next chunk
|
||||
tokenUsage.ChatDeltas = nil
|
||||
})
|
||||
if len(allChatDeltas) > 0 {
|
||||
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
. "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -107,3 +108,111 @@ var _ = Describe("LLM tests", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TokenUsage ChatDelta helpers", func() {
|
||||
Describe("HasChatDeltaContent", func() {
|
||||
It("should return false when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when ChatDeltas is empty", func() {
|
||||
usage := TokenUsage{ChatDeltas: []*pb.ChatDelta{}}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when all deltas have empty content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "", ReasoningContent: ""},
|
||||
{Content: ""},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return true when a delta has content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello"},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has reasoning content", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should return true when a delta has both content and reasoning", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "hello", ReasoningContent: "thinking..."},
|
||||
},
|
||||
}
|
||||
Expect(usage.HasChatDeltaContent()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ChatDeltaReasoningAndContent", func() {
|
||||
It("should return empty strings when ChatDeltas is nil", func() {
|
||||
usage := TokenUsage{}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate content from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "Hello"},
|
||||
{Content: " world"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(content).To(Equal("Hello world"))
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should concatenate reasoning from multiple deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "step 1"},
|
||||
{ReasoningContent: " step 2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("step 1 step 2"))
|
||||
Expect(content).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should separate reasoning and content from mixed deltas", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking"},
|
||||
{Content: "answer"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("thinking"))
|
||||
Expect(content).To(Equal("answer"))
|
||||
})
|
||||
|
||||
It("should handle deltas with both fields set", func() {
|
||||
usage := TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "a", ReasoningContent: "r1"},
|
||||
{Content: "b", ReasoningContent: "r2"},
|
||||
},
|
||||
}
|
||||
reasoning, content := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(reasoning).To(Equal("r1r2"))
|
||||
Expect(content).To(Equal("ab"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -512,11 +512,9 @@ func (s *backendSupervisor) stopBackend(backend string) {
|
||||
|
||||
// Network I/O outside the lock
|
||||
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||
}
|
||||
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
|
||||
if err := client.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
|
||||
@@ -692,13 +690,13 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
|
||||
// backend.delete — stop backend + delete files (request-reply)
|
||||
s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) {
|
||||
xlog.Info("Received NATS backend.delete event")
|
||||
var req messaging.BackendDeleteRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)}
|
||||
replyJSON(reply, resp)
|
||||
return
|
||||
}
|
||||
xlog.Info("Received NATS backend.delete event", "backend", req.Backend)
|
||||
|
||||
// Stop if running this backend
|
||||
if s.isRunning(req.Backend) {
|
||||
@@ -774,10 +772,8 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
if targetAddr != "" {
|
||||
// Best-effort gRPC Free()
|
||||
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
|
||||
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
|
||||
if err := freeFunc.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||
}
|
||||
if err := client.Free(context.Background()); err != nil {
|
||||
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"qwen2-vl": {"min_p":0.1,"repeat_penalty":1,"temperature":1.5,"top_k":-1,"top_p":0.95},
|
||||
"qwen2": {"min_p":0,"repeat_penalty":1,"temperature":0.7,"top_k":20,"top_p":0.8},
|
||||
"qwq": {"min_p":0,"repeat_penalty":1,"temperature":0.6,"top_k":40,"top_p":0.95},
|
||||
"gemma-4": {"min_p":0,"presence_penalty":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-3n": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"gemma-3": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
"medgemma": {"min_p":0,"repeat_penalty":1,"temperature":1,"top_k":64,"top_p":0.95},
|
||||
@@ -53,5 +54,5 @@
|
||||
"grok": {"min_p":0.01,"repeat_penalty":1,"temperature":1,"top_k":-1,"top_p":0.95},
|
||||
"mimo": {"min_p":0.01,"repeat_penalty":1,"temperature":0.7,"top_k":-1,"top_p":0.95}
|
||||
},
|
||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||
"patterns": ["qwen3.5","qwen3-coder","qwen3-next","qwen3-vl","qwen3","qwen2.5-coder","qwen2.5-vl","qwen2.5-omni","qwen2.5-math","qwen2.5","qwen2-vl","qwen2","qwq","gemma-4","gemma-3n","gemma-3","medgemma","gemma-2","llama-4","llama-3.3","llama-3.2","llama-3.1","llama-3","phi-4","phi-3","mistral-nemo","mistral-small","mistral-large","magistral","ministral","devstral","pixtral","deepseek-r1","deepseek-v3","deepseek-ocr","glm-5","glm-4","nemotron","minimax-m2.5","minimax","gpt-oss","granite-4","kimi-k2","kimi","lfm2","smollm","olmo","falcon","ernie","seed","grok","mimo"]
|
||||
}
|
||||
|
||||
132
core/config/meta/build.go
Normal file
132
core/config/meta/build.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
cachedMetadata *ConfigMetadata
|
||||
cacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
// BuildConfigMetadata reflects on the given struct type (ModelConfig),
|
||||
// merges the enrichment registry, and returns the full ConfigMetadata.
|
||||
// The result is cached in memory after the first call.
|
||||
func BuildConfigMetadata(modelConfigType reflect.Type) *ConfigMetadata {
|
||||
cacheMu.RLock()
|
||||
if cachedMetadata != nil {
|
||||
cacheMu.RUnlock()
|
||||
return cachedMetadata
|
||||
}
|
||||
cacheMu.RUnlock()
|
||||
|
||||
cacheMu.Lock()
|
||||
defer cacheMu.Unlock()
|
||||
|
||||
if cachedMetadata != nil {
|
||||
return cachedMetadata
|
||||
}
|
||||
|
||||
cachedMetadata = buildConfigMetadataUncached(modelConfigType, DefaultRegistry())
|
||||
return cachedMetadata
|
||||
}
|
||||
|
||||
// buildConfigMetadataUncached does the actual work without caching.
|
||||
func buildConfigMetadataUncached(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||
fields := WalkModelConfig(modelConfigType)
|
||||
|
||||
for i := range fields {
|
||||
override, ok := registry[fields[i].Path]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
applyOverride(&fields[i], override)
|
||||
}
|
||||
|
||||
allSections := DefaultSections()
|
||||
|
||||
sectionOrder := make(map[string]int, len(allSections))
|
||||
for _, s := range allSections {
|
||||
sectionOrder[s.ID] = s.Order
|
||||
}
|
||||
|
||||
sort.SliceStable(fields, func(i, j int) bool {
|
||||
si := sectionOrder[fields[i].Section]
|
||||
sj := sectionOrder[fields[j].Section]
|
||||
if si != sj {
|
||||
return si < sj
|
||||
}
|
||||
return fields[i].Order < fields[j].Order
|
||||
})
|
||||
|
||||
usedSections := make(map[string]bool)
|
||||
for _, f := range fields {
|
||||
usedSections[f.Section] = true
|
||||
}
|
||||
|
||||
var sections []Section
|
||||
for _, s := range allSections {
|
||||
if usedSections[s.ID] {
|
||||
sections = append(sections, s)
|
||||
}
|
||||
}
|
||||
|
||||
return &ConfigMetadata{
|
||||
Sections: sections,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
// applyOverride merges non-zero override values into the field.
|
||||
func applyOverride(f *FieldMeta, o FieldMetaOverride) {
|
||||
if o.Section != "" {
|
||||
f.Section = o.Section
|
||||
}
|
||||
if o.Label != "" {
|
||||
f.Label = o.Label
|
||||
}
|
||||
if o.Description != "" {
|
||||
f.Description = o.Description
|
||||
}
|
||||
if o.Component != "" {
|
||||
f.Component = o.Component
|
||||
}
|
||||
if o.Placeholder != "" {
|
||||
f.Placeholder = o.Placeholder
|
||||
}
|
||||
if o.Default != nil {
|
||||
f.Default = o.Default
|
||||
}
|
||||
if o.Min != nil {
|
||||
f.Min = o.Min
|
||||
}
|
||||
if o.Max != nil {
|
||||
f.Max = o.Max
|
||||
}
|
||||
if o.Step != nil {
|
||||
f.Step = o.Step
|
||||
}
|
||||
if o.Options != nil {
|
||||
f.Options = o.Options
|
||||
}
|
||||
if o.AutocompleteProvider != "" {
|
||||
f.AutocompleteProvider = o.AutocompleteProvider
|
||||
}
|
||||
if o.VRAMImpact {
|
||||
f.VRAMImpact = true
|
||||
}
|
||||
if o.Advanced {
|
||||
f.Advanced = true
|
||||
}
|
||||
if o.Order != 0 {
|
||||
f.Order = o.Order
|
||||
}
|
||||
}
|
||||
|
||||
// BuildForTest builds metadata without caching, for use in tests.
|
||||
func BuildForTest(modelConfigType reflect.Type, registry map[string]FieldMetaOverride) *ConfigMetadata {
|
||||
return buildConfigMetadataUncached(modelConfigType, registry)
|
||||
}
|
||||
|
||||
211
core/config/meta/build_test.go
Normal file
211
core/config/meta/build_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
func TestBuildConfigMetadata(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
if len(md.Sections) == 0 {
|
||||
t.Fatal("expected sections, got 0")
|
||||
}
|
||||
if len(md.Fields) == 0 {
|
||||
t.Fatal("expected fields, got 0")
|
||||
}
|
||||
|
||||
// Verify sections are ordered
|
||||
for i := 1; i < len(md.Sections); i++ {
|
||||
if md.Sections[i].Order < md.Sections[i-1].Order {
|
||||
t.Errorf("sections not ordered: %s (order=%d) before %s (order=%d)",
|
||||
md.Sections[i-1].ID, md.Sections[i-1].Order,
|
||||
md.Sections[i].ID, md.Sections[i].Order)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryOverrides(t *testing.T) {
|
||||
registry := map[string]meta.FieldMetaOverride{
|
||||
"name": {
|
||||
Label: "My Custom Label",
|
||||
Description: "Custom description",
|
||||
Component: "textarea",
|
||||
Order: 999,
|
||||
},
|
||||
}
|
||||
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), registry)
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
f, ok := byPath["name"]
|
||||
if !ok {
|
||||
t.Fatal("field 'name' not found")
|
||||
}
|
||||
if f.Label != "My Custom Label" {
|
||||
t.Errorf("expected label 'My Custom Label', got %q", f.Label)
|
||||
}
|
||||
if f.Description != "Custom description" {
|
||||
t.Errorf("expected description 'Custom description', got %q", f.Description)
|
||||
}
|
||||
if f.Component != "textarea" {
|
||||
t.Errorf("expected component 'textarea', got %q", f.Component)
|
||||
}
|
||||
if f.Order != 999 {
|
||||
t.Errorf("expected order 999, got %d", f.Order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregisteredFieldsGetDefaults(t *testing.T) {
|
||||
// Use empty registry - all fields should still get auto-generated metadata
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), map[string]meta.FieldMetaOverride{})
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// context_size should still exist with auto-generated label
|
||||
f, ok := byPath["context_size"]
|
||||
if !ok {
|
||||
t.Fatal("field 'context_size' not found")
|
||||
}
|
||||
if f.Label == "" {
|
||||
t.Error("expected auto-generated label, got empty")
|
||||
}
|
||||
if f.UIType != "int" {
|
||||
t.Errorf("expected UIType 'int', got %q", f.UIType)
|
||||
}
|
||||
if f.Component == "" {
|
||||
t.Error("expected auto-generated component, got empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistryOverridesApply(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Verify enriched fields got their overrides
|
||||
tests := []struct {
|
||||
path string
|
||||
label string
|
||||
description string
|
||||
vramImpact bool
|
||||
}{
|
||||
{"context_size", "Context Size", "Maximum context window in tokens", true},
|
||||
{"gpu_layers", "GPU Layers", "Number of layers to offload to GPU (-1 = all)", true},
|
||||
{"backend", "Backend", "The inference backend to use (e.g. llama-cpp, vllm, diffusers)", false},
|
||||
{"parameters.temperature", "Temperature", "Sampling temperature (higher = more creative, lower = more deterministic)", false},
|
||||
{"template.chat", "Chat Template", "Go template for chat completion requests", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Label != tt.label {
|
||||
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||
}
|
||||
if f.Description != tt.description {
|
||||
t.Errorf("field %q: expected description %q, got %q", tt.path, tt.description, f.Description)
|
||||
}
|
||||
if f.VRAMImpact != tt.vramImpact {
|
||||
t.Errorf("field %q: expected vramImpact=%v, got %v", tt.path, tt.vramImpact, f.VRAMImpact)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticOptionsFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Fields with static options should have Options populated and no AutocompleteProvider
|
||||
staticFields := []string{"quantization", "cache_type_k", "cache_type_v", "diffusers.pipeline_type", "diffusers.scheduler_type"}
|
||||
for _, path := range staticFields {
|
||||
f, ok := byPath[path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", path)
|
||||
continue
|
||||
}
|
||||
if len(f.Options) == 0 {
|
||||
t.Errorf("field %q: expected Options to be populated", path)
|
||||
}
|
||||
if f.AutocompleteProvider != "" {
|
||||
t.Errorf("field %q: expected no AutocompleteProvider, got %q", path, f.AutocompleteProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicProviderFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Fields with dynamic providers should have AutocompleteProvider and no Options
|
||||
dynamicFields := map[string]string{
|
||||
"backend": meta.ProviderBackends,
|
||||
"pipeline.llm": meta.ProviderModelsChat,
|
||||
"pipeline.tts": meta.ProviderModelsTTS,
|
||||
"pipeline.transcription": meta.ProviderModelsTranscript,
|
||||
"pipeline.vad": meta.ProviderModelsVAD,
|
||||
}
|
||||
for path, expectedProvider := range dynamicFields {
|
||||
f, ok := byPath[path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", path)
|
||||
continue
|
||||
}
|
||||
if f.AutocompleteProvider != expectedProvider {
|
||||
t.Errorf("field %q: expected AutocompleteProvider %q, got %q", path, expectedProvider, f.AutocompleteProvider)
|
||||
}
|
||||
if len(f.Options) != 0 {
|
||||
t.Errorf("field %q: expected no Options, got %d", path, len(f.Options))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVRAMImpactFields(t *testing.T) {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
|
||||
var vramFields []string
|
||||
for _, f := range md.Fields {
|
||||
if f.VRAMImpact {
|
||||
vramFields = append(vramFields, f.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vramFields) == 0 {
|
||||
t.Error("expected some VRAM impact fields, got 0")
|
||||
}
|
||||
|
||||
// context_size and gpu_layers should be marked
|
||||
expected := map[string]bool{"context_size": true, "gpu_layers": true}
|
||||
for _, path := range vramFields {
|
||||
if expected[path] {
|
||||
delete(expected, path)
|
||||
}
|
||||
}
|
||||
for path := range expected {
|
||||
t.Errorf("expected VRAM impact field %q not found", path)
|
||||
}
|
||||
}
|
||||
63
core/config/meta/constants.go
Normal file
63
core/config/meta/constants.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package meta
|
||||
|
||||
// Dynamic autocomplete provider constants (runtime lookup required).
|
||||
const (
|
||||
ProviderBackends = "backends"
|
||||
ProviderModels = "models"
|
||||
ProviderModelsChat = "models:chat"
|
||||
ProviderModelsTTS = "models:tts"
|
||||
ProviderModelsTranscript = "models:transcript"
|
||||
ProviderModelsVAD = "models:vad"
|
||||
)
|
||||
|
||||
// Static option lists embedded directly in field metadata.
|
||||
|
||||
var QuantizationOptions = []FieldOption{
|
||||
{Value: "q4_0", Label: "Q4_0"},
|
||||
{Value: "q4_1", Label: "Q4_1"},
|
||||
{Value: "q5_0", Label: "Q5_0"},
|
||||
{Value: "q5_1", Label: "Q5_1"},
|
||||
{Value: "q8_0", Label: "Q8_0"},
|
||||
{Value: "q2_K", Label: "Q2_K"},
|
||||
{Value: "q3_K_S", Label: "Q3_K_S"},
|
||||
{Value: "q3_K_M", Label: "Q3_K_M"},
|
||||
{Value: "q3_K_L", Label: "Q3_K_L"},
|
||||
{Value: "q4_K_S", Label: "Q4_K_S"},
|
||||
{Value: "q4_K_M", Label: "Q4_K_M"},
|
||||
{Value: "q5_K_S", Label: "Q5_K_S"},
|
||||
{Value: "q5_K_M", Label: "Q5_K_M"},
|
||||
{Value: "q6_K", Label: "Q6_K"},
|
||||
}
|
||||
|
||||
var CacheTypeOptions = []FieldOption{
|
||||
{Value: "f16", Label: "F16"},
|
||||
{Value: "f32", Label: "F32"},
|
||||
{Value: "q8_0", Label: "Q8_0"},
|
||||
{Value: "q4_0", Label: "Q4_0"},
|
||||
{Value: "q4_1", Label: "Q4_1"},
|
||||
{Value: "q5_0", Label: "Q5_0"},
|
||||
{Value: "q5_1", Label: "Q5_1"},
|
||||
}
|
||||
|
||||
var DiffusersPipelineOptions = []FieldOption{
|
||||
{Value: "StableDiffusionPipeline", Label: "StableDiffusionPipeline"},
|
||||
{Value: "StableDiffusionImg2ImgPipeline", Label: "StableDiffusionImg2ImgPipeline"},
|
||||
{Value: "StableDiffusionXLPipeline", Label: "StableDiffusionXLPipeline"},
|
||||
{Value: "StableDiffusionXLImg2ImgPipeline", Label: "StableDiffusionXLImg2ImgPipeline"},
|
||||
{Value: "StableDiffusionDepth2ImgPipeline", Label: "StableDiffusionDepth2ImgPipeline"},
|
||||
{Value: "DiffusionPipeline", Label: "DiffusionPipeline"},
|
||||
{Value: "StableVideoDiffusionPipeline", Label: "StableVideoDiffusionPipeline"},
|
||||
}
|
||||
|
||||
var DiffusersSchedulerOptions = []FieldOption{
|
||||
{Value: "ddim", Label: "DDIM"},
|
||||
{Value: "ddpm", Label: "DDPM"},
|
||||
{Value: "pndm", Label: "PNDM"},
|
||||
{Value: "lms", Label: "LMS"},
|
||||
{Value: "euler", Label: "Euler"},
|
||||
{Value: "euler_a", Label: "Euler A"},
|
||||
{Value: "dpm_multistep", Label: "DPM Multistep"},
|
||||
{Value: "dpm_singlestep", Label: "DPM Singlestep"},
|
||||
{Value: "heun", Label: "Heun"},
|
||||
{Value: "unipc", Label: "UniPC"},
|
||||
}
|
||||
241
core/config/meta/reflect.go
Normal file
241
core/config/meta/reflect.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// WalkModelConfig uses reflection to discover all exported, YAML-tagged fields
|
||||
// in the given struct type (expected to be config.ModelConfig) and returns a
|
||||
// slice of FieldMeta with sensible defaults derived from the type information.
|
||||
func WalkModelConfig(t reflect.Type) []FieldMeta {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
var fields []FieldMeta
|
||||
walkStruct(t, "", &fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
// walkStruct recursively walks a struct type, collecting FieldMeta entries.
|
||||
// prefix is the dot-path prefix for nested structs (e.g. "function.grammar.").
|
||||
func walkStruct(t reflect.Type, prefix string, out *[]FieldMeta) {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
for sf := range t.Fields() {
|
||||
if !sf.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
yamlTag := sf.Tag.Get("yaml")
|
||||
if yamlTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
yamlKey, opts := parseTag(yamlTag)
|
||||
|
||||
// Handle inline embedding (e.g. LLMConfig `yaml:",inline"`)
|
||||
if opts.contains("inline") {
|
||||
ft := sf.Type
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
if ft.Kind() == reflect.Struct {
|
||||
walkStruct(ft, prefix, out)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If no yaml key and it's an embedded struct without inline, skip unknown pattern
|
||||
if yamlKey == "" {
|
||||
ft := sf.Type
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
// Anonymous struct without yaml tag - treat as inline
|
||||
if sf.Anonymous && ft.Kind() == reflect.Struct {
|
||||
walkStruct(ft, prefix, out)
|
||||
continue
|
||||
}
|
||||
// Named field without yaml tag - skip
|
||||
continue
|
||||
}
|
||||
|
||||
ft := sf.Type
|
||||
isPtr := ft.Kind() == reflect.Pointer
|
||||
if isPtr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
// Named nested struct (not a special type) -> recurse with prefix
|
||||
if ft.Kind() == reflect.Struct && !isSpecialType(ft) {
|
||||
nestedPrefix := prefix + yamlKey + "."
|
||||
walkStruct(ft, nestedPrefix, out)
|
||||
continue
|
||||
}
|
||||
|
||||
// Leaf field
|
||||
path := prefix + yamlKey
|
||||
goType := sf.Type.String()
|
||||
uiType, component := inferUIType(sf.Type)
|
||||
section := inferSection(prefix)
|
||||
label := labelFromKey(yamlKey)
|
||||
|
||||
*out = append(*out, FieldMeta{
|
||||
Path: path,
|
||||
YAMLKey: yamlKey,
|
||||
GoType: goType,
|
||||
UIType: uiType,
|
||||
Pointer: isPtr,
|
||||
Section: section,
|
||||
Label: label,
|
||||
Component: component,
|
||||
Order: len(*out),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// isSpecialType returns true for struct types that should be treated as leaf
|
||||
// values rather than recursed into (e.g. custom JSON marshalers).
|
||||
func isSpecialType(t reflect.Type) bool {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
name := t.Name()
|
||||
// LogprobsValue, URI types are leaf values despite being structs
|
||||
switch name {
|
||||
case "LogprobsValue", "URI":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// inferUIType maps a Go reflect.Type to a UI type string and default component.
|
||||
func inferUIType(t reflect.Type) (uiType, component string) {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool", "toggle"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return "int", "number"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "int", "number"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "float", "number"
|
||||
case reflect.String:
|
||||
return "string", "input"
|
||||
case reflect.Slice:
|
||||
elem := t.Elem()
|
||||
if elem.Kind() == reflect.String {
|
||||
return "[]string", "string-list"
|
||||
}
|
||||
if elem.Kind() == reflect.Pointer {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if elem.Kind() == reflect.Struct {
|
||||
return "[]object", "json-editor"
|
||||
}
|
||||
return "[]any", "json-editor"
|
||||
case reflect.Map:
|
||||
return "map", "map-editor"
|
||||
case reflect.Struct:
|
||||
// Special types treated as leaves
|
||||
if isSpecialType(t) {
|
||||
return "bool", "toggle" // LogprobsValue
|
||||
}
|
||||
return "object", "json-editor"
|
||||
default:
|
||||
return "any", "input"
|
||||
}
|
||||
}
|
||||
|
||||
// inferSection determines the config section from the dot-path prefix.
|
||||
func inferSection(prefix string) string {
|
||||
if prefix == "" {
|
||||
return "general"
|
||||
}
|
||||
// Remove trailing dot
|
||||
p := strings.TrimSuffix(prefix, ".")
|
||||
|
||||
// Use the top-level prefix to determine section
|
||||
parts := strings.SplitN(p, ".", 2)
|
||||
top := parts[0]
|
||||
|
||||
switch top {
|
||||
case "parameters":
|
||||
return "parameters"
|
||||
case "template":
|
||||
return "templates"
|
||||
case "function":
|
||||
return "functions"
|
||||
case "reasoning":
|
||||
return "reasoning"
|
||||
case "diffusers":
|
||||
return "diffusers"
|
||||
case "tts":
|
||||
return "tts"
|
||||
case "pipeline":
|
||||
return "pipeline"
|
||||
case "grpc":
|
||||
return "grpc"
|
||||
case "agent":
|
||||
return "agent"
|
||||
case "mcp":
|
||||
return "mcp"
|
||||
case "feature_flags":
|
||||
return "other"
|
||||
case "limit_mm_per_prompt":
|
||||
return "llm"
|
||||
default:
|
||||
return "other"
|
||||
}
|
||||
}
|
||||
|
||||
// labelFromKey converts a yaml key like "context_size" to "Context Size".
|
||||
func labelFromKey(key string) string {
|
||||
parts := strings.Split(key, "_")
|
||||
for i, p := range parts {
|
||||
if len(p) > 0 {
|
||||
runes := []rune(p)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
parts[i] = string(runes)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// tagOptions is a set of comma-separated yaml tag options.
|
||||
type tagOptions string
|
||||
|
||||
func (o tagOptions) contains(optName string) bool {
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var name string
|
||||
if name, s, _ = strings.Cut(s, ","); name == optName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTag splits a yaml struct tag into the key name and options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if tag == "" {
|
||||
return "", ""
|
||||
}
|
||||
before, after, found := strings.Cut(tag, ",")
|
||||
if found {
|
||||
return before, tagOptions(after)
|
||||
}
|
||||
return tag, ""
|
||||
}
|
||||
|
||||
208
core/config/meta/reflect_test.go
Normal file
208
core/config/meta/reflect_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
)
|
||||
|
||||
func TestWalkModelConfig(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
if len(fields) == 0 {
|
||||
t.Fatal("expected fields from ModelConfig, got 0")
|
||||
}
|
||||
|
||||
// Build a lookup by path
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// Verify some top-level fields exist
|
||||
for _, path := range []string{"name", "backend", "cuda", "step"} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify inline LLMConfig fields appear at top level (no prefix)
|
||||
for _, path := range []string{"context_size", "gpu_layers", "threads", "mmap"} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected inline LLMConfig field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify nested struct fields have correct prefix
|
||||
for _, path := range []string{
|
||||
"template.chat",
|
||||
"template.completion",
|
||||
"template.use_tokenizer_template",
|
||||
"function.grammar.parallel_calls",
|
||||
"function.grammar.mixed_mode",
|
||||
"diffusers.pipeline_type",
|
||||
"diffusers.cuda",
|
||||
"pipeline.llm",
|
||||
"pipeline.tts",
|
||||
"reasoning.disable",
|
||||
"agent.max_iterations",
|
||||
"grpc.attempts",
|
||||
} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected nested field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify PredictionOptions fields have parameters. prefix
|
||||
for _, path := range []string{
|
||||
"parameters.temperature",
|
||||
"parameters.top_p",
|
||||
"parameters.top_k",
|
||||
"parameters.max_tokens",
|
||||
"parameters.seed",
|
||||
} {
|
||||
if _, ok := byPath[path]; !ok {
|
||||
t.Errorf("expected parameters field %q not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify TTSConfig fields have tts. prefix
|
||||
if _, ok := byPath["tts.voice"]; !ok {
|
||||
t.Error("expected tts.voice field not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipsYAMLDashFields(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
// modelConfigFile has yaml:"-" tag, should be skipped
|
||||
for _, f := range fields {
|
||||
if f.Path == "modelConfigFile" || f.Path == "modelTemplate" {
|
||||
t.Errorf("field %q should have been skipped (yaml:\"-\")", f.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapping(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
uiType string
|
||||
pointer bool
|
||||
}{
|
||||
{"name", "string", false},
|
||||
{"cuda", "bool", false},
|
||||
{"context_size", "int", true},
|
||||
{"gpu_layers", "int", true},
|
||||
{"threads", "int", true},
|
||||
{"f16", "bool", true},
|
||||
{"mmap", "bool", true},
|
||||
{"stopwords", "[]string", false},
|
||||
{"roles", "map", false},
|
||||
{"parameters.temperature", "float", true},
|
||||
{"parameters.top_k", "int", true},
|
||||
{"function.grammar.parallel_calls", "bool", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.UIType != tt.uiType {
|
||||
t.Errorf("field %q: expected UIType %q, got %q", tt.path, tt.uiType, f.UIType)
|
||||
}
|
||||
if f.Pointer != tt.pointer {
|
||||
t.Errorf("field %q: expected Pointer=%v, got %v", tt.path, tt.pointer, f.Pointer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSectionAssignment(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
section string
|
||||
}{
|
||||
{"name", "general"},
|
||||
{"backend", "general"},
|
||||
{"context_size", "general"}, // inline LLMConfig -> no prefix -> general
|
||||
{"parameters.temperature", "parameters"},
|
||||
{"template.chat", "templates"},
|
||||
{"function.grammar.parallel_calls", "functions"},
|
||||
{"diffusers.cuda", "diffusers"},
|
||||
{"pipeline.llm", "pipeline"},
|
||||
{"reasoning.disable", "reasoning"},
|
||||
{"agent.max_iterations", "agent"},
|
||||
{"grpc.attempts", "grpc"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Section != tt.section {
|
||||
t.Errorf("field %q: expected section %q, got %q", tt.path, tt.section, f.Section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLabelGeneration(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
byPath := make(map[string]meta.FieldMeta, len(fields))
|
||||
for _, f := range fields {
|
||||
byPath[f.Path] = f
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
label string
|
||||
}{
|
||||
{"context_size", "Context Size"},
|
||||
{"gpu_layers", "Gpu Layers"},
|
||||
{"name", "Name"},
|
||||
{"cuda", "Cuda"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
f, ok := byPath[tt.path]
|
||||
if !ok {
|
||||
t.Errorf("field %q not found", tt.path)
|
||||
continue
|
||||
}
|
||||
if f.Label != tt.label {
|
||||
t.Errorf("field %q: expected label %q, got %q", tt.path, tt.label, f.Label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldCount(t *testing.T) {
|
||||
fields := meta.WalkModelConfig(reflect.TypeOf(config.ModelConfig{}))
|
||||
// We expect a large number of fields (100+) given the config complexity
|
||||
if len(fields) < 80 {
|
||||
t.Errorf("expected at least 80 fields, got %d", len(fields))
|
||||
}
|
||||
t.Logf("Total fields discovered: %d", len(fields))
|
||||
}
|
||||
314
core/config/meta/registry.go
Normal file
314
core/config/meta/registry.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package meta
|
||||
|
||||
// DefaultRegistry returns enrichment overrides for the ~30 most commonly used
|
||||
// config fields. Fields not listed here still appear with auto-generated
|
||||
// labels and type-inferred components.
|
||||
func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
f64 := func(v float64) *float64 { return &v }
|
||||
|
||||
return map[string]FieldMetaOverride{
|
||||
// --- General ---
|
||||
"name": {
|
||||
Section: "general",
|
||||
Label: "Model Name",
|
||||
Description: "Unique identifier for this model configuration",
|
||||
Component: "input",
|
||||
Order: 0,
|
||||
},
|
||||
"backend": {
|
||||
Section: "general",
|
||||
Label: "Backend",
|
||||
Description: "The inference backend to use (e.g. llama-cpp, vllm, diffusers)",
|
||||
Component: "select",
|
||||
AutocompleteProvider: ProviderBackends,
|
||||
Order: 1,
|
||||
},
|
||||
"description": {
|
||||
Section: "general",
|
||||
Label: "Description",
|
||||
Description: "Human-readable description of what this model does",
|
||||
Component: "textarea",
|
||||
Order: 2,
|
||||
},
|
||||
"usage": {
|
||||
Section: "general",
|
||||
Label: "Usage",
|
||||
Description: "Usage instructions or notes",
|
||||
Component: "textarea",
|
||||
Advanced: true,
|
||||
Order: 3,
|
||||
},
|
||||
"cuda": {
|
||||
Section: "general",
|
||||
Label: "CUDA",
|
||||
Description: "Explicitly enable CUDA acceleration",
|
||||
Order: 5,
|
||||
},
|
||||
"known_usecases": {
|
||||
Section: "general",
|
||||
Label: "Known Use Cases",
|
||||
Description: "Capabilities this model supports (e.g. FLAG_CHAT, FLAG_COMPLETION)",
|
||||
Component: "string-list",
|
||||
Order: 6,
|
||||
},
|
||||
|
||||
// --- LLM ---
|
||||
"context_size": {
|
||||
Section: "llm",
|
||||
Label: "Context Size",
|
||||
Description: "Maximum context window in tokens",
|
||||
Component: "number",
|
||||
VRAMImpact: true,
|
||||
Order: 10,
|
||||
},
|
||||
"gpu_layers": {
|
||||
Section: "llm",
|
||||
Label: "GPU Layers",
|
||||
Description: "Number of layers to offload to GPU (-1 = all)",
|
||||
Component: "number",
|
||||
Min: f64(-1),
|
||||
VRAMImpact: true,
|
||||
Order: 11,
|
||||
},
|
||||
"threads": {
|
||||
Section: "llm",
|
||||
Label: "Threads",
|
||||
Description: "Number of CPU threads for inference",
|
||||
Component: "number",
|
||||
Min: f64(1),
|
||||
Order: 12,
|
||||
},
|
||||
"f16": {
|
||||
Section: "llm",
|
||||
Label: "F16",
|
||||
Description: "Use 16-bit floating point for key/value cache",
|
||||
Order: 13,
|
||||
},
|
||||
"mmap": {
|
||||
Section: "llm",
|
||||
Label: "Memory Map",
|
||||
Description: "Use memory-mapped files for model loading",
|
||||
Order: 14,
|
||||
},
|
||||
"mmlock": {
|
||||
Section: "llm",
|
||||
Label: "Memory Lock",
|
||||
Description: "Lock model memory to prevent swapping",
|
||||
Advanced: true,
|
||||
Order: 15,
|
||||
},
|
||||
"low_vram": {
|
||||
Section: "llm",
|
||||
Label: "Low VRAM",
|
||||
Description: "Optimize for systems with limited GPU memory",
|
||||
VRAMImpact: true,
|
||||
Order: 16,
|
||||
},
|
||||
"embeddings": {
|
||||
Section: "llm",
|
||||
Label: "Embeddings",
|
||||
Description: "Enable embedding generation mode",
|
||||
Order: 17,
|
||||
},
|
||||
"quantization": {
|
||||
Section: "llm",
|
||||
Label: "Quantization",
|
||||
Description: "Quantization method (e.g. q4_0, q5_1, q8_0)",
|
||||
Component: "select",
|
||||
Options: QuantizationOptions,
|
||||
Advanced: true,
|
||||
Order: 20,
|
||||
},
|
||||
"flash_attention": {
|
||||
Section: "llm",
|
||||
Label: "Flash Attention",
|
||||
Description: "Enable flash attention for faster inference",
|
||||
Component: "input",
|
||||
Advanced: true,
|
||||
Order: 21,
|
||||
},
|
||||
"cache_type_k": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (K)",
|
||||
Description: "Quantization type for key cache (e.g. f16, q8_0, q4_0)",
|
||||
Component: "select",
|
||||
Options: CacheTypeOptions,
|
||||
VRAMImpact: true,
|
||||
Advanced: true,
|
||||
Order: 22,
|
||||
},
|
||||
"cache_type_v": {
|
||||
Section: "llm",
|
||||
Label: "KV Cache Type (V)",
|
||||
Description: "Quantization type for value cache",
|
||||
Component: "select",
|
||||
Options: CacheTypeOptions,
|
||||
VRAMImpact: true,
|
||||
Advanced: true,
|
||||
Order: 23,
|
||||
},
|
||||
|
||||
// --- Parameters ---
|
||||
"parameters.temperature": {
|
||||
Section: "parameters",
|
||||
Label: "Temperature",
|
||||
Description: "Sampling temperature (higher = more creative, lower = more deterministic)",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(2),
|
||||
Step: f64(0.05),
|
||||
Order: 30,
|
||||
},
|
||||
"parameters.top_p": {
|
||||
Section: "parameters",
|
||||
Label: "Top P",
|
||||
Description: "Nucleus sampling threshold",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.01),
|
||||
Order: 31,
|
||||
},
|
||||
"parameters.top_k": {
|
||||
Section: "parameters",
|
||||
Label: "Top K",
|
||||
Description: "Top-K sampling: consider only the K most likely tokens",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 32,
|
||||
},
|
||||
"parameters.max_tokens": {
|
||||
Section: "parameters",
|
||||
Label: "Max Tokens",
|
||||
Description: "Maximum number of tokens to generate (0 = unlimited)",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Order: 33,
|
||||
},
|
||||
"parameters.repeat_penalty": {
|
||||
Section: "parameters",
|
||||
Label: "Repeat Penalty",
|
||||
Description: "Penalize repeated tokens (1.0 = no penalty)",
|
||||
Component: "number",
|
||||
Min: f64(0),
|
||||
Advanced: true,
|
||||
Order: 34,
|
||||
},
|
||||
"parameters.seed": {
|
||||
Section: "parameters",
|
||||
Label: "Seed",
|
||||
Description: "Random seed (-1 = random)",
|
||||
Component: "number",
|
||||
Advanced: true,
|
||||
Order: 35,
|
||||
},
|
||||
|
||||
// --- Templates ---
|
||||
"template.chat": {
|
||||
Section: "templates",
|
||||
Label: "Chat Template",
|
||||
Description: "Go template for chat completion requests",
|
||||
Component: "code-editor",
|
||||
Order: 40,
|
||||
},
|
||||
"template.chat_message": {
|
||||
Section: "templates",
|
||||
Label: "Chat Message Template",
|
||||
Description: "Go template for individual chat messages",
|
||||
Component: "code-editor",
|
||||
Order: 41,
|
||||
},
|
||||
"template.completion": {
|
||||
Section: "templates",
|
||||
Label: "Completion Template",
|
||||
Description: "Go template for completion requests",
|
||||
Component: "code-editor",
|
||||
Order: 42,
|
||||
},
|
||||
"template.use_tokenizer_template": {
|
||||
Section: "templates",
|
||||
Label: "Use Tokenizer Template",
|
||||
Description: "Use the chat template from the model's tokenizer config",
|
||||
Order: 43,
|
||||
},
|
||||
|
||||
// --- Pipeline ---
|
||||
"pipeline.llm": {
|
||||
Section: "pipeline",
|
||||
Label: "LLM Model",
|
||||
Description: "Model to use for LLM inference in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsChat,
|
||||
Order: 60,
|
||||
},
|
||||
"pipeline.tts": {
|
||||
Section: "pipeline",
|
||||
Label: "TTS Model",
|
||||
Description: "Model to use for text-to-speech in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsTTS,
|
||||
Order: 61,
|
||||
},
|
||||
"pipeline.transcription": {
|
||||
Section: "pipeline",
|
||||
Label: "Transcription Model",
|
||||
Description: "Model to use for speech-to-text in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsTranscript,
|
||||
Order: 62,
|
||||
},
|
||||
"pipeline.vad": {
|
||||
Section: "pipeline",
|
||||
Label: "VAD Model",
|
||||
Description: "Model to use for voice activity detection in the pipeline",
|
||||
Component: "model-select",
|
||||
AutocompleteProvider: ProviderModelsVAD,
|
||||
Order: 63,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
Section: "functions",
|
||||
Label: "Parallel Calls",
|
||||
Description: "Allow the LLM to return multiple function calls in one response",
|
||||
Order: 70,
|
||||
},
|
||||
"function.grammar.mixed_mode": {
|
||||
Section: "functions",
|
||||
Label: "Mixed Mode",
|
||||
Description: "Allow the LLM to return both text and function calls",
|
||||
Order: 71,
|
||||
},
|
||||
"function.grammar.disable": {
|
||||
Section: "functions",
|
||||
Label: "Disable Grammar",
|
||||
Description: "Disable grammar-constrained generation for function calls",
|
||||
Advanced: true,
|
||||
Order: 72,
|
||||
},
|
||||
|
||||
// --- Diffusers ---
|
||||
"diffusers.pipeline_type": {
|
||||
Section: "diffusers",
|
||||
Label: "Pipeline Type",
|
||||
Description: "Diffusers pipeline type (e.g. StableDiffusionPipeline)",
|
||||
Component: "select",
|
||||
Options: DiffusersPipelineOptions,
|
||||
Order: 80,
|
||||
},
|
||||
"diffusers.scheduler_type": {
|
||||
Section: "diffusers",
|
||||
Label: "Scheduler Type",
|
||||
Description: "Noise scheduler type",
|
||||
Component: "select",
|
||||
Options: DiffusersSchedulerOptions,
|
||||
Order: 81,
|
||||
},
|
||||
"diffusers.cuda": {
|
||||
Section: "diffusers",
|
||||
Label: "CUDA",
|
||||
Description: "Enable CUDA for diffusers",
|
||||
Order: 82,
|
||||
},
|
||||
}
|
||||
}
|
||||
83
core/config/meta/types.go
Normal file
83
core/config/meta/types.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package meta
|
||||
|
||||
// FieldMeta describes a single configuration field for UI rendering and agent discovery.
|
||||
type FieldMeta struct {
|
||||
Path string `json:"path"` // dot-path: "context_size", "function.grammar.parallel_calls"
|
||||
YAMLKey string `json:"yaml_key"` // leaf yaml key
|
||||
GoType string `json:"go_type"` // "*int", "string", "[]string"
|
||||
UIType string `json:"ui_type"` // "string", "int", "float", "bool", "[]string", "map", "object"
|
||||
Pointer bool `json:"pointer,omitempty"` // true = nil means "not set"
|
||||
Section string `json:"section"` // "general", "llm", "templates", etc.
|
||||
Label string `json:"label"` // human-readable label
|
||||
Description string `json:"description,omitempty"` // help text
|
||||
Component string `json:"component"` // "input", "number", "toggle", "select", "slider", etc.
|
||||
Placeholder string `json:"placeholder,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
Min *float64 `json:"min,omitempty"`
|
||||
Max *float64 `json:"max,omitempty"`
|
||||
Step *float64 `json:"step,omitempty"`
|
||||
Options []FieldOption `json:"options,omitempty"`
|
||||
|
||||
AutocompleteProvider string `json:"autocomplete_provider,omitempty"` // "backends", "models:chat", etc.
|
||||
VRAMImpact bool `json:"vram_impact,omitempty"`
|
||||
Advanced bool `json:"advanced,omitempty"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
// FieldOption represents a choice in a select/enum field.
|
||||
type FieldOption struct {
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
// Section groups related fields in the UI.
|
||||
type Section struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
// ConfigMetadata is the top-level response for the metadata API.
|
||||
type ConfigMetadata struct {
|
||||
Sections []Section `json:"sections"`
|
||||
Fields []FieldMeta `json:"fields"`
|
||||
}
|
||||
|
||||
// FieldMetaOverride holds registry overrides that are merged on top of
|
||||
// the reflection-discovered defaults. Only non-zero fields override.
|
||||
type FieldMetaOverride struct {
|
||||
Section string
|
||||
Label string
|
||||
Description string
|
||||
Component string
|
||||
Placeholder string
|
||||
Default any
|
||||
Min *float64
|
||||
Max *float64
|
||||
Step *float64
|
||||
Options []FieldOption
|
||||
AutocompleteProvider string
|
||||
VRAMImpact bool
|
||||
Advanced bool
|
||||
Order int
|
||||
}
|
||||
|
||||
// DefaultSections defines the well-known config sections in display order.
|
||||
func DefaultSections() []Section {
|
||||
return []Section{
|
||||
{ID: "general", Label: "General", Icon: "settings", Order: 0},
|
||||
{ID: "llm", Label: "LLM", Icon: "cpu", Order: 10},
|
||||
{ID: "parameters", Label: "Parameters", Icon: "sliders", Order: 20},
|
||||
{ID: "templates", Label: "Templates", Icon: "file-text", Order: 30},
|
||||
{ID: "functions", Label: "Functions / Tools", Icon: "tool", Order: 40},
|
||||
{ID: "reasoning", Label: "Reasoning", Icon: "brain", Order: 45},
|
||||
{ID: "diffusers", Label: "Diffusers", Icon: "image", Order: 50},
|
||||
{ID: "tts", Label: "TTS", Icon: "volume-2", Order: 55},
|
||||
{ID: "pipeline", Label: "Pipeline", Icon: "git-merge", Order: 60},
|
||||
{ID: "grpc", Label: "gRPC", Icon: "server", Order: 65},
|
||||
{ID: "agent", Label: "Agent", Icon: "bot", Order: 70},
|
||||
{ID: "mcp", Label: "MCP", Icon: "plug", Order: 75},
|
||||
{ID: "other", Label: "Other", Icon: "more-horizontal", Order: 100},
|
||||
}
|
||||
}
|
||||
@@ -300,14 +300,29 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
|
||||
backend, ok := backends.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||
// Not found by direct key — try matching by gallery name (metadata.Name)
|
||||
// The UI may send gallery-style names like "localai@llama-cpp" which
|
||||
// don't match the directory-based keys used in the backends map.
|
||||
for _, b := range backends {
|
||||
if b.Metadata != nil && b.Metadata.Name == name && !b.IsMeta {
|
||||
backend = b
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q: %w", name, ErrBackendNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
if backend.IsSystem {
|
||||
return fmt.Errorf("system backend %q cannot be deleted", name)
|
||||
}
|
||||
|
||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
// Use the backend's actual Name (directory key) for path resolution,
|
||||
// not the caller-supplied name which may be a gallery-style name.
|
||||
dirName := backend.Name
|
||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, dirName)
|
||||
|
||||
// check if the backend dir exists
|
||||
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
|
||||
@@ -325,7 +340,7 @@ func DeleteBackendFromSystem(systemState *system.SystemState, name string) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if metadata != nil && metadata.Alias == name {
|
||||
if metadata != nil && (metadata.Alias == name || metadata.Alias == dirName) {
|
||||
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
|
||||
foundBackend = true
|
||||
break
|
||||
|
||||
@@ -52,9 +52,42 @@ var quietPaths = []string{"/api/operations", "/api/resources", "/healthz", "/rea
|
||||
// @license.name MIT
|
||||
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
|
||||
// @BasePath /
|
||||
// @schemes http https
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
// @name Authorization
|
||||
// @tag.name inference
|
||||
// @tag.description Chat completions, text completions, edits, and responses (OpenAI-compatible)
|
||||
// @tag.name embeddings
|
||||
// @tag.description Vector embeddings (OpenAI-compatible)
|
||||
// @tag.name audio
|
||||
// @tag.description Text-to-speech, transcription, voice activity detection, sound generation
|
||||
// @tag.name images
|
||||
// @tag.description Image generation and inpainting
|
||||
// @tag.name video
|
||||
// @tag.description Video generation from prompts
|
||||
// @tag.name detection
|
||||
// @tag.description Object detection in images
|
||||
// @tag.name tokenize
|
||||
// @tag.description Tokenization and token metrics
|
||||
// @tag.name models
|
||||
// @tag.description Model gallery browsing, installation, deletion, and listing
|
||||
// @tag.name backends
|
||||
// @tag.description Backend gallery browsing, installation, deletion, and listing
|
||||
// @tag.name config
|
||||
// @tag.description Model configuration metadata, autocomplete, PATCH updates, VRAM estimation
|
||||
// @tag.name monitoring
|
||||
// @tag.description Prometheus metrics, backend status, system information
|
||||
// @tag.name mcp
|
||||
// @tag.description Model Context Protocol — tool-augmented chat with MCP servers
|
||||
// @tag.name agent-jobs
|
||||
// @tag.description Agent task and job management
|
||||
// @tag.name p2p
|
||||
// @tag.description Peer-to-peer networking nodes and tokens
|
||||
// @tag.name rerank
|
||||
// @tag.description Document reranking
|
||||
// @tag.name instructions
|
||||
// @tag.description API instruction discovery — browse instruction areas and get endpoint guides
|
||||
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
@@ -360,7 +393,7 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
|
||||
// Serve React SPA from / with SPA fallback via 404 handler
|
||||
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
// MessagesEndpoint is the Anthropic Messages API endpoint
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// @Summary Generate a message response for the given messages and model.
|
||||
// @Tags inference
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
@@ -357,7 +358,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// Send initial content_block_start event
|
||||
contentBlockStart := schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
|
||||
}
|
||||
sendAnthropicSSE(c, contentBlockStart)
|
||||
@@ -376,7 +377,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
@@ -386,7 +387,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
tc := toolCalls[i]
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fmt.Sprintf("toolu_%s_%d", id, i),
|
||||
@@ -395,7 +396,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: tc.Arguments,
|
||||
@@ -403,7 +404,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
}
|
||||
@@ -416,7 +417,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !inToolCall {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
@@ -516,7 +517,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// Close the text content block
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
inToolCall = true
|
||||
@@ -528,7 +529,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
ContentBlock: &schema.AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
@@ -537,7 +538,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: fc.Arguments,
|
||||
@@ -545,7 +546,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
})
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: currentBlockIndex,
|
||||
Index: intPtr(currentBlockIndex),
|
||||
})
|
||||
currentBlockIndex++
|
||||
toolCallsEmitted++
|
||||
@@ -557,7 +558,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !inToolCall {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
Index: intPtr(0),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -598,6 +599,8 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
return tools
|
||||
}
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
||||
// @Summary Generates audio from the input text.
|
||||
// @Tags audio
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
// @Summary Generates audio from the input text.
|
||||
// @Tags audio
|
||||
// @Param voice-id path string true "Account ID"
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/)
|
||||
// @Summary Reranks a list of phrases by relevance to a given text query.
|
||||
// @Tags rerank
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
|
||||
@@ -30,6 +30,15 @@ func getJobService(app *application.Application, c echo.Context) *agentpool.Agen
|
||||
return jobSvc
|
||||
}
|
||||
|
||||
// CreateTaskEndpoint creates a new agent task definition.
|
||||
// @Summary Create a new agent task
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.Task true "Task definition"
|
||||
// @Success 201 {object} map[string]string "id"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks [post]
|
||||
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var task schema.Task
|
||||
@@ -46,6 +55,17 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskEndpoint updates an existing agent task.
|
||||
// @Summary Update an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Param request body schema.Task true "Updated task definition"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [put]
|
||||
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -65,6 +85,14 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteTaskEndpoint deletes an agent task.
|
||||
// @Summary Delete an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [delete]
|
||||
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -79,6 +107,13 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListTasksEndpoint lists all agent tasks for the current user.
|
||||
// @Summary List agent tasks
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||
// @Success 200 {object} []schema.Task "tasks"
|
||||
// @Router /api/agent/tasks [get]
|
||||
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
jobSvc := getJobService(app, c)
|
||||
@@ -121,6 +156,14 @@ func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetTaskEndpoint returns a single agent task by ID.
|
||||
// @Summary Get an agent task
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} schema.Task "task"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{id} [get]
|
||||
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -133,6 +176,15 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteJobEndpoint creates and runs a new job for a task.
|
||||
// @Summary Execute an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.JobExecutionRequest true "Job execution request"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/execute [post]
|
||||
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.JobExecutionRequest
|
||||
@@ -168,6 +220,14 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetJobEndpoint returns a single job by ID.
|
||||
// @Summary Get an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} schema.Job "job"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id} [get]
|
||||
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -180,6 +240,16 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListJobsEndpoint lists jobs, optionally filtered by task or status.
|
||||
// @Summary List agent jobs
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param task_id query string false "Filter by task ID"
|
||||
// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
|
||||
// @Param limit query integer false "Max number of jobs to return"
|
||||
// @Param all_users query string false "Set to 'true' for admin cross-user listing"
|
||||
// @Success 200 {object} []schema.Job "jobs"
|
||||
// @Router /api/agent/jobs [get]
|
||||
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var taskID *string
|
||||
@@ -241,6 +311,15 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// CancelJobEndpoint cancels a running job.
|
||||
// @Summary Cancel an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id}/cancel [post]
|
||||
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -255,6 +334,14 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteJobEndpoint deletes a job by ID.
|
||||
// @Summary Delete an agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "message"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/jobs/{id} [delete]
|
||||
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -269,6 +356,17 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTaskByNameEndpoint looks up a task by name and executes it.
|
||||
// @Summary Execute an agent task by name
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "Task name"
|
||||
// @Param parameters body object false "Optional template parameters"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "job created"
|
||||
// @Failure 400 {object} map[string]string "error"
|
||||
// @Failure 404 {object} map[string]string "error"
|
||||
// @Router /api/agent/tasks/{name}/execute [post]
|
||||
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
name := c.Param("name")
|
||||
|
||||
489
core/http/endpoints/localai/api_instructions.go
Normal file
489
core/http/endpoints/localai/api_instructions.go
Normal file
@@ -0,0 +1,489 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/swagger"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const swaggerDefsPrefix = "#/definitions/"
|
||||
|
||||
// instructionDef is a lightweight instruction definition that maps to swagger tags.
|
||||
type instructionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
Intro string `json:"-"` // brief context not in swagger
|
||||
}
|
||||
|
||||
var instructionDefs = []instructionDef{
|
||||
{
|
||||
Name: "chat-inference",
|
||||
Description: "OpenAI-compatible chat completions, text completions, and embeddings",
|
||||
Tags: []string{"inference", "embeddings"},
|
||||
Intro: "Set \"stream\": true for SSE streaming. Supports tool/function calling when the model config has function templates configured.",
|
||||
},
|
||||
{
|
||||
Name: "audio",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, and sound generation",
|
||||
Tags: []string{"audio"},
|
||||
},
|
||||
{
|
||||
Name: "images",
|
||||
Description: "Image generation and inpainting",
|
||||
Tags: []string{"images"},
|
||||
},
|
||||
{
|
||||
Name: "model-management",
|
||||
Description: "Browse the gallery, install, delete, and manage models and backends",
|
||||
Tags: []string{"models", "backends"},
|
||||
},
|
||||
{
|
||||
Name: "config-management",
|
||||
Description: "Discover, read, and modify model configuration fields with VRAM estimation",
|
||||
Tags: []string{"config"},
|
||||
Intro: "Fields with static options include an \"options\" array in metadata. Fields with dynamic values have an \"autocomplete_provider\" for runtime lookup.",
|
||||
},
|
||||
{
|
||||
Name: "monitoring",
|
||||
Description: "System metrics, backend status, API and backend traces, backend process logs, and system information",
|
||||
Tags: []string{"monitoring"},
|
||||
Intro: "Includes real-time backend log streaming via WebSocket at /ws/backend-logs/:modelId.",
|
||||
},
|
||||
{
|
||||
Name: "mcp",
|
||||
Description: "Model Context Protocol — tool-augmented chat with MCP servers",
|
||||
Tags: []string{"mcp"},
|
||||
Intro: "The model's config must define MCP servers. The endpoint handles tool execution automatically.",
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Agent task and job management for CI/automation workflows",
|
||||
Tags: []string{"agent-jobs"},
|
||||
},
|
||||
{
|
||||
Name: "video",
|
||||
Description: "Video generation from text prompts",
|
||||
Tags: []string{"video"},
|
||||
},
|
||||
}
|
||||
|
||||
// swaggerState holds parsed swagger spec data, initialised once.
|
||||
type swaggerState struct {
|
||||
once sync.Once
|
||||
spec map[string]any // full parsed swagger JSON
|
||||
ready bool
|
||||
}
|
||||
|
||||
var swState swaggerState
|
||||
|
||||
func (s *swaggerState) init() {
|
||||
s.once.Do(func() {
|
||||
var spec map[string]any
|
||||
if err := json.Unmarshal(swagger.SwaggerJSON, &spec); err != nil {
|
||||
xlog.Error("failed to parse embedded swagger spec", "err", err)
|
||||
return
|
||||
}
|
||||
s.spec = spec
|
||||
s.ready = true
|
||||
})
|
||||
}
|
||||
|
||||
// filterSwaggerByTags returns a swagger fragment containing only paths whose
|
||||
// operations carry at least one of the given tags, plus the definitions they
|
||||
// reference.
|
||||
func filterSwaggerByTags(spec map[string]any, tags []string) map[string]any {
|
||||
tagSet := make(map[string]bool, len(tags))
|
||||
for _, t := range tags {
|
||||
tagSet[t] = true
|
||||
}
|
||||
|
||||
paths, _ := spec["paths"].(map[string]any)
|
||||
allDefs, _ := spec["definitions"].(map[string]any)
|
||||
|
||||
filteredPaths := make(map[string]any)
|
||||
for path, methods := range paths {
|
||||
methodMap, ok := methods.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
filteredMethods := make(map[string]any)
|
||||
for method, opRaw := range methodMap {
|
||||
op, ok := opRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
opTags, _ := op["tags"].([]any)
|
||||
for _, t := range opTags {
|
||||
if ts, ok := t.(string); ok && tagSet[ts] {
|
||||
filteredMethods[method] = op
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(filteredMethods) > 0 {
|
||||
filteredPaths[path] = filteredMethods
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all $ref definitions used by the filtered paths.
|
||||
neededDefs := make(map[string]bool)
|
||||
collectRefs(filteredPaths, neededDefs)
|
||||
|
||||
// Resolve nested refs from definitions themselves.
|
||||
changed := true
|
||||
for changed {
|
||||
changed = false
|
||||
for name := range neededDefs {
|
||||
if def, ok := allDefs[name]; ok {
|
||||
before := len(neededDefs)
|
||||
collectRefs(def, neededDefs)
|
||||
if len(neededDefs) > before {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filteredDefs := make(map[string]any)
|
||||
for name := range neededDefs {
|
||||
if def, ok := allDefs[name]; ok {
|
||||
filteredDefs[name] = def
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"paths": filteredPaths,
|
||||
}
|
||||
if len(filteredDefs) > 0 {
|
||||
result["definitions"] = filteredDefs
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// collectRefs walks a JSON structure and collects all $ref definition names.
|
||||
func collectRefs(v any, refs map[string]bool) {
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
if ref, ok := val["$ref"].(string); ok {
|
||||
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||
refs[ref[len(swaggerDefsPrefix):]] = true
|
||||
}
|
||||
}
|
||||
for _, child := range val {
|
||||
collectRefs(child, refs)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range val {
|
||||
collectRefs(child, refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swaggerToMarkdown renders a filtered swagger fragment into concise markdown.
|
||||
func swaggerToMarkdown(skillName, intro string, fragment map[string]any) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# ")
|
||||
b.WriteString(skillName)
|
||||
b.WriteString("\n")
|
||||
if intro != "" {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(intro)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
paths, _ := fragment["paths"].(map[string]any)
|
||||
defs, _ := fragment["definitions"].(map[string]any)
|
||||
|
||||
// Sort paths for stable output.
|
||||
sortedPaths := make([]string, 0, len(paths))
|
||||
for p := range paths {
|
||||
sortedPaths = append(sortedPaths, p)
|
||||
}
|
||||
sort.Strings(sortedPaths)
|
||||
|
||||
for _, path := range sortedPaths {
|
||||
methods, ok := paths[path].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sortedMethods := sortMethods(methods)
|
||||
for _, method := range sortedMethods {
|
||||
op, ok := methods[method].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
summary, _ := op["summary"].(string)
|
||||
b.WriteString(fmt.Sprintf("\n## %s %s\n", strings.ToUpper(method), path))
|
||||
if summary != "" {
|
||||
b.WriteString(summary)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Parameters
|
||||
params, _ := op["parameters"].([]any)
|
||||
bodyParams, nonBodyParams := splitParams(params)
|
||||
|
||||
if len(nonBodyParams) > 0 {
|
||||
b.WriteString("\n**Parameters:**\n")
|
||||
b.WriteString("| Name | In | Type | Required | Description |\n")
|
||||
b.WriteString("|------|----|------|----------|-------------|\n")
|
||||
for _, p := range nonBodyParams {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := pm["name"].(string)
|
||||
in, _ := pm["in"].(string)
|
||||
typ, _ := pm["type"].(string)
|
||||
req, _ := pm["required"].(bool)
|
||||
desc, _ := pm["description"].(string)
|
||||
b.WriteString(fmt.Sprintf("| %s | %s | %s | %v | %s |\n", name, in, typ, req, desc))
|
||||
}
|
||||
}
|
||||
|
||||
if len(bodyParams) > 0 {
|
||||
for _, p := range bodyParams {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
schema, _ := pm["schema"].(map[string]any)
|
||||
refName := resolveRefName(schema)
|
||||
if refName != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Request body** (`%s`):\n", refName))
|
||||
renderSchemaFields(&b, refName, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Responses
|
||||
responses, _ := op["responses"].(map[string]any)
|
||||
if len(responses) > 0 {
|
||||
sortedCodes := make([]string, 0, len(responses))
|
||||
for code := range responses {
|
||||
sortedCodes = append(sortedCodes, code)
|
||||
}
|
||||
sort.Strings(sortedCodes)
|
||||
for _, code := range sortedCodes {
|
||||
resp, ok := responses[code].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
desc, _ := resp["description"].(string)
|
||||
respSchema, _ := resp["schema"].(map[string]any)
|
||||
refName := resolveRefName(respSchema)
|
||||
if refName != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Response %s** (`%s`): %s\n", code, refName, desc))
|
||||
renderSchemaFields(&b, refName, defs)
|
||||
} else if desc != "" {
|
||||
b.WriteString(fmt.Sprintf("\n**Response %s**: %s\n", code, desc))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// sortMethods returns HTTP methods in a conventional order.
|
||||
func sortMethods(methods map[string]any) []string {
|
||||
order := map[string]int{"get": 0, "post": 1, "put": 2, "patch": 3, "delete": 4}
|
||||
keys := make([]string, 0, len(methods))
|
||||
for k := range methods {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
oi, oki := order[keys[i]]
|
||||
oj, okj := order[keys[j]]
|
||||
if !oki {
|
||||
oi = 99
|
||||
}
|
||||
if !okj {
|
||||
oj = 99
|
||||
}
|
||||
return oi < oj
|
||||
})
|
||||
return keys
|
||||
}
|
||||
|
||||
// splitParams separates body parameters from non-body parameters.
|
||||
func splitParams(params []any) (body, nonBody []any) {
|
||||
for _, p := range params {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if in, _ := pm["in"].(string); in == "body" {
|
||||
body = append(body, p)
|
||||
} else {
|
||||
nonBody = append(nonBody, p)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// resolveRefName extracts the definition name from a $ref or returns "".
|
||||
func resolveRefName(schema map[string]any) string {
|
||||
if schema == nil {
|
||||
return ""
|
||||
}
|
||||
if ref, ok := schema["$ref"].(string); ok {
|
||||
if strings.HasPrefix(ref, swaggerDefsPrefix) {
|
||||
return ref[len(swaggerDefsPrefix):]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// renderSchemaFields writes a markdown field table for a definition.
|
||||
func renderSchemaFields(b *strings.Builder, defName string, defs map[string]any) {
|
||||
if defs == nil {
|
||||
return
|
||||
}
|
||||
def, ok := defs[defName].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
props, ok := def["properties"].(map[string]any)
|
||||
if !ok || len(props) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Sort fields
|
||||
fields := make([]string, 0, len(props))
|
||||
for f := range props {
|
||||
fields = append(fields, f)
|
||||
}
|
||||
sort.Strings(fields)
|
||||
|
||||
b.WriteString("| Field | Type | Description |\n")
|
||||
b.WriteString("|-------|------|-------------|\n")
|
||||
for _, field := range fields {
|
||||
prop, ok := props[field].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typ := schemaTypeString(prop)
|
||||
desc, _ := prop["description"].(string)
|
||||
b.WriteString(fmt.Sprintf("| %s | %s | %s |\n", field, typ, desc))
|
||||
}
|
||||
}
|
||||
|
||||
// schemaTypeString returns a human-readable type string for a schema property.
|
||||
func schemaTypeString(prop map[string]any) string {
|
||||
if ref := resolveRefName(prop); ref != "" {
|
||||
return ref
|
||||
}
|
||||
typ, _ := prop["type"].(string)
|
||||
if typ == "array" {
|
||||
items, _ := prop["items"].(map[string]any)
|
||||
if items != nil {
|
||||
if ref := resolveRefName(items); ref != "" {
|
||||
return "[]" + ref
|
||||
}
|
||||
it, _ := items["type"].(string)
|
||||
if it != "" {
|
||||
return "[]" + it
|
||||
}
|
||||
}
|
||||
return "[]any"
|
||||
}
|
||||
if typ != "" {
|
||||
return typ
|
||||
}
|
||||
return "object"
|
||||
}
|
||||
|
||||
// APIInstructionResponse is the JSON response for a single instruction (?format=json).
|
||||
type APIInstructionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
SwaggerFragment map[string]any `json:"swagger_fragment,omitempty"`
|
||||
}
|
||||
|
||||
// ListAPIInstructionsEndpoint returns all instructions (compact list without guides).
|
||||
// @Summary List available API instruction areas
|
||||
// @Description Returns a compact list of instruction areas with descriptions and URLs for detailed guides
|
||||
// @Tags instructions
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]any "instructions list with hint"
|
||||
// @Router /api/instructions [get]
|
||||
func ListAPIInstructionsEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
type compactInstruction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
instructions := make([]compactInstruction, len(instructionDefs))
|
||||
for i, s := range instructionDefs {
|
||||
instructions[i] = compactInstruction{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Tags: s.Tags,
|
||||
URL: "/api/instructions/" + s.Name,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"instructions": instructions,
|
||||
"hint": "Fetch GET {url} for a markdown API guide. Add ?format=json for a raw OpenAPI fragment.",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetAPIInstructionEndpoint returns a single instruction by name.
|
||||
// @Summary Get an instruction's API guide or OpenAPI fragment
|
||||
// @Description Returns a markdown guide (default) or filtered OpenAPI fragment (format=json) for a named instruction
|
||||
// @Tags instructions
|
||||
// @Produce json
|
||||
// @Produce text/markdown
|
||||
// @Param name path string true "Instruction name (e.g. chat-inference, config-management)"
|
||||
// @Param format query string false "Response format: json for OpenAPI fragment, omit for markdown"
|
||||
// @Success 200 {object} APIInstructionResponse "instruction documentation"
|
||||
// @Failure 404 {object} map[string]string "instruction not found"
|
||||
// @Router /api/instructions/{name} [get]
|
||||
func GetAPIInstructionEndpoint() echo.HandlerFunc {
|
||||
byName := make(map[string]*instructionDef, len(instructionDefs))
|
||||
for i := range instructionDefs {
|
||||
byName[instructionDefs[i].Name] = &instructionDefs[i]
|
||||
}
|
||||
|
||||
return func(c echo.Context) error {
|
||||
name := c.Param("name")
|
||||
inst, ok := byName[name]
|
||||
if !ok {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "instruction not found: " + name})
|
||||
}
|
||||
|
||||
swState.init()
|
||||
if !swState.ready {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "swagger spec not available"})
|
||||
}
|
||||
|
||||
fragment := filterSwaggerByTags(swState.spec, inst.Tags)
|
||||
|
||||
format := c.QueryParam("format")
|
||||
if format == "json" {
|
||||
return c.JSON(http.StatusOK, APIInstructionResponse{
|
||||
Name: inst.Name,
|
||||
Description: inst.Description,
|
||||
Tags: inst.Tags,
|
||||
SwaggerFragment: fragment,
|
||||
})
|
||||
}
|
||||
|
||||
guide := swaggerToMarkdown(inst.Name, inst.Intro, fragment)
|
||||
return c.Blob(http.StatusOK, "text/markdown; charset=utf-8", []byte(guide))
|
||||
}
|
||||
}
|
||||
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
222
core/http/endpoints/localai/api_instructions_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("API Instructions Endpoints", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
app.GET("/api/instructions", ListAPIInstructionsEndpoint())
|
||||
app.GET("/api/instructions/:name", GetAPIInstructionEndpoint())
|
||||
})
|
||||
|
||||
Context("GET /api/instructions", func() {
|
||||
It("should return all instruction definitions", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(resp).To(HaveKey("hint"))
|
||||
Expect(resp).To(HaveKey("instructions"))
|
||||
|
||||
instructions, ok := resp["instructions"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(instructions).To(HaveLen(9))
|
||||
|
||||
// Verify each instruction has required fields and correct URL format
|
||||
for _, s := range instructions {
|
||||
inst, ok := s.(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(inst["name"]).NotTo(BeEmpty())
|
||||
Expect(inst["description"]).NotTo(BeEmpty())
|
||||
Expect(inst["tags"]).NotTo(BeNil())
|
||||
Expect(inst["url"]).To(HavePrefix("/api/instructions/"))
|
||||
Expect(inst["url"]).To(Equal("/api/instructions/" + inst["name"].(string)))
|
||||
}
|
||||
})
|
||||
|
||||
It("should include known instruction names", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
instructions := resp["instructions"].([]any)
|
||||
names := make([]string, len(instructions))
|
||||
for i, s := range instructions {
|
||||
names[i] = s.(map[string]any)["name"].(string)
|
||||
}
|
||||
|
||||
Expect(names).To(ContainElements(
|
||||
"chat-inference",
|
||||
"config-management",
|
||||
"model-management",
|
||||
"monitoring",
|
||||
"agents",
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
Context("GET /api/instructions/:name", func() {
|
||||
It("should return 404 for unknown instruction", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["error"]).To(ContainSubstring("instruction not found"))
|
||||
})
|
||||
|
||||
It("should return markdown by default", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Header().Get("Content-Type")).To(ContainSubstring("text/markdown"))
|
||||
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
md := string(body)
|
||||
|
||||
Expect(md).To(HavePrefix("# chat-inference"))
|
||||
// Should contain at least one endpoint heading
|
||||
Expect(md).To(MatchRegexp(`## (GET|POST|PUT|PATCH|DELETE) /`))
|
||||
})
|
||||
|
||||
It("should include intro text for instructions that have one", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
// chat-inference has an intro about streaming
|
||||
Expect(string(body)).To(ContainSubstring("stream"))
|
||||
})
|
||||
|
||||
It("should return JSON fragment when format=json", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["name"]).To(Equal("chat-inference"))
|
||||
Expect(resp["tags"]).To(ContainElements("inference", "embeddings"))
|
||||
|
||||
fragment, ok := resp["swagger_fragment"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(fragment).To(HaveKey("paths"))
|
||||
|
||||
paths, ok := fragment["paths"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(paths).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("should include referenced definitions in JSON fragment", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
fragment := resp["swagger_fragment"].(map[string]any)
|
||||
Expect(fragment).To(HaveKey("definitions"))
|
||||
|
||||
defs, ok := fragment["definitions"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(defs).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("should only include paths matching the instruction tags in JSON fragment", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/config-management?format=json", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
fragment := resp["swagger_fragment"].(map[string]any)
|
||||
paths := fragment["paths"].(map[string]any)
|
||||
Expect(paths).NotTo(BeEmpty())
|
||||
|
||||
// Every operation in every path should have the "config" tag
|
||||
for _, methods := range paths {
|
||||
methodMap := methods.(map[string]any)
|
||||
for _, opRaw := range methodMap {
|
||||
op := opRaw.(map[string]any)
|
||||
tags, _ := op["tags"].([]any)
|
||||
tagStrs := make([]string, len(tags))
|
||||
for i, t := range tags {
|
||||
tagStrs[i] = t.(string)
|
||||
}
|
||||
Expect(tagStrs).To(ContainElement("config"))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("should produce stable output across calls", func() {
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec1, req1)
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/instructions/chat-inference", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec2, req2)
|
||||
|
||||
body1, _ := io.ReadAll(rec1.Body)
|
||||
body2, _ := io.ReadAll(rec2.Body)
|
||||
Expect(string(body1)).To(Equal(string(body2)))
|
||||
})
|
||||
|
||||
It("should return markdown for every defined instruction", func() {
|
||||
// First get the list
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/instructions", nil)
|
||||
listRec := httptest.NewRecorder()
|
||||
app.ServeHTTP(listRec, listReq)
|
||||
|
||||
var listResp map[string]any
|
||||
Expect(json.Unmarshal(listRec.Body.Bytes(), &listResp)).To(Succeed())
|
||||
|
||||
instructions := listResp["instructions"].([]any)
|
||||
for _, s := range instructions {
|
||||
name := s.(map[string]any)["name"].(string)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/instructions/"+name, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"instruction %q should return 200", name)
|
||||
body, _ := io.ReadAll(rec.Body)
|
||||
Expect(strings.TrimSpace(string(body))).NotTo(BeEmpty(),
|
||||
"instruction %q should return non-empty markdown", name)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -37,6 +37,7 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Tags backends
|
||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -51,6 +52,7 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Tags backends
|
||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -61,6 +63,7 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||
// @Summary Install backends to LocalAI.
|
||||
// @Tags backends
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
@@ -88,6 +91,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
|
||||
// DeleteBackendEndpoint lets delete backends from a LocalAI instance
|
||||
// @Summary delete backends from LocalAI.
|
||||
// @Tags backends
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
@@ -112,6 +116,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ListBackendsEndpoint list the available backends configured in LocalAI
|
||||
// @Summary List all Backends
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||
@@ -126,6 +131,7 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint() echo.HandlerFunc {
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
@@ -142,6 +148,7 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
|
||||
|
||||
// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
|
||||
// @Summary List all available Backends
|
||||
// @Tags backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
|
||||
179
core/http/endpoints/localai/backend_logs.go
Normal file
179
core/http/endpoints/localai/backend_logs.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var backendLogsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // no origin header = same-origin or non-browser
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == r.Host
|
||||
},
|
||||
}
|
||||
|
||||
// backendLogsConn wraps a websocket connection with a mutex for safe concurrent writes
|
||||
type backendLogsConn struct {
|
||||
*websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writeJSON(v any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal error: %w", err)
|
||||
}
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writePing() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
return c.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
// ListBackendLogsEndpoint returns model IDs that have log buffers
|
||||
// @Summary List models with backend logs
|
||||
// @Description Returns a sorted list of model IDs that have captured backend process output
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Success 200 {array} string "Model IDs with logs"
|
||||
// @Router /api/backend-logs [get]
|
||||
func ListBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, ml.BackendLogs().ListModels())
|
||||
}
|
||||
}
|
||||
|
||||
// GetBackendLogsEndpoint returns log lines for a specific model
|
||||
// @Summary Get backend logs for a model
|
||||
// @Description Returns all captured log lines (stdout/stderr) for the specified model's backend process
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Success 200 {array} model.BackendLogLine "Log lines"
|
||||
// @Router /api/backend-logs/{modelId} [get]
|
||||
func GetBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
return c.JSON(200, ml.BackendLogs().GetLines(modelID))
|
||||
}
|
||||
}
|
||||
|
||||
// ClearBackendLogsEndpoint clears log lines for a specific model
|
||||
// @Summary Clear backend logs for a model
|
||||
// @Description Removes all captured log lines for the specified model's backend process
|
||||
// @Tags monitoring
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Success 204 "Logs cleared"
|
||||
// @Router /api/backend-logs/{modelId}/clear [post]
|
||||
func ClearBackendLogsEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ml.BackendLogs().Clear(c.Param("modelId"))
|
||||
return c.NoContent(204)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendLogsWebSocketEndpoint streams backend logs in real-time over WebSocket
|
||||
// @Summary Stream backend logs via WebSocket
|
||||
// @Description Opens a WebSocket connection for real-time backend log streaming. Sends an initial batch of existing lines (type "initial"), then streams new lines as they appear (type "line"). Supports ping/pong keepalive.
|
||||
// @Tags monitoring
|
||||
// @Param modelId path string true "Model ID"
|
||||
// @Router /ws/backend-logs/{modelId} [get]
|
||||
func BackendLogsWebSocketEndpoint(ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
|
||||
ws, err := backendLogsUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
ws.SetReadLimit(4096)
|
||||
|
||||
// Set up ping/pong for keepalive
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
conn := &backendLogsConn{Conn: ws}
|
||||
|
||||
// Send existing lines as initial batch
|
||||
existingLines := ml.BackendLogs().GetLines(modelID)
|
||||
initialMsg := map[string]any{
|
||||
"type": "initial",
|
||||
"lines": existingLines,
|
||||
}
|
||||
if err := conn.writeJSON(initialMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs initial write failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe to new lines
|
||||
lineCh, unsubscribe := ml.BackendLogs().Subscribe(modelID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Handle close from client side
|
||||
closeCh := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
close(closeCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping ticker for keepalive
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
// Forward new lines to WebSocket
|
||||
for {
|
||||
select {
|
||||
case line, ok := <-lineCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
lineMsg := map[string]any{
|
||||
"type": "line",
|
||||
"line": line,
|
||||
}
|
||||
if err := conn.writeJSON(lineMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs write error", "error", err)
|
||||
return nil
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := conn.writePing(); err != nil {
|
||||
return nil
|
||||
}
|
||||
case <-closeCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
196
core/http/endpoints/localai/backend_logs_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Backend Logs Endpoints", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
modelLoader *model.ModelLoader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "backend-logs-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
modelsPath := filepath.Join(tempDir, "models")
|
||||
Expect(os.MkdirAll(modelsPath, 0750)).To(Succeed())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(modelsPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
|
||||
app = echo.New()
|
||||
app.GET("/api/backend-logs", ListBackendLogsEndpoint(modelLoader))
|
||||
app.GET("/api/backend-logs/:modelId", GetBackendLogsEndpoint(modelLoader))
|
||||
app.POST("/api/backend-logs/:modelId/clear", ClearBackendLogsEndpoint(modelLoader))
|
||||
app.GET("/ws/backend-logs/:modelId", BackendLogsWebSocketEndpoint(modelLoader))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Context("REST endpoints", func() {
|
||||
It("should return empty list of models with logs", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var models []string
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||
Expect(models).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should list models that have logs", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var models []string
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &models)).To(Succeed())
|
||||
Expect(models).To(ContainElement("my-model"))
|
||||
})
|
||||
|
||||
It("should return log lines for a model", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "line one")
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stderr", "line two")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var lines []model.BackendLogLine
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||
Expect(lines).To(HaveLen(2))
|
||||
Expect(lines[0].Text).To(Equal("line one"))
|
||||
Expect(lines[0].Stream).To(Equal("stdout"))
|
||||
Expect(lines[1].Text).To(Equal("line two"))
|
||||
Expect(lines[1].Stream).To(Equal("stderr"))
|
||||
})
|
||||
|
||||
It("should return empty log lines for unknown model", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-logs/unknown-model", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("should clear logs for a model", func() {
|
||||
modelLoader.BackendLogs().AppendLine("my-model", "stdout", "hello")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/backend-logs/my-model/clear", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||
|
||||
// Verify logs are cleared
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/backend-logs/my-model", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
var lines []model.BackendLogLine
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &lines)).To(Succeed())
|
||||
Expect(lines).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WebSocket endpoint", func() {
|
||||
It("should send initial lines and stream new lines", func() {
|
||||
// Seed some existing lines before connecting
|
||||
modelLoader.BackendLogs().AppendLine("ws-model", "stdout", "existing line")
|
||||
|
||||
// Start a real HTTP server for WebSocket
|
||||
srv := httptest.NewServer(app)
|
||||
defer srv.Close()
|
||||
|
||||
// Dial the WebSocket
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/ws-model"
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
// Read the initial message
|
||||
var initialMsg map[string]any
|
||||
err = conn.ReadJSON(&initialMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||
|
||||
initialLines, ok := initialMsg["lines"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(initialLines).To(HaveLen(1))
|
||||
|
||||
firstLine := initialLines[0].(map[string]any)
|
||||
Expect(firstLine["text"]).To(Equal("existing line"))
|
||||
|
||||
// Now append a new line and verify it streams through
|
||||
modelLoader.BackendLogs().AppendLine("ws-model", "stderr", "streamed line")
|
||||
|
||||
var lineMsg map[string]any
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
err = conn.ReadJSON(&lineMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(lineMsg["type"]).To(Equal("line"))
|
||||
|
||||
lineData, ok := lineMsg["line"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(lineData["text"]).To(Equal("streamed line"))
|
||||
Expect(lineData["stream"]).To(Equal("stderr"))
|
||||
})
|
||||
|
||||
It("should handle connection close gracefully", func() {
|
||||
srv := httptest.NewServer(app)
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/backend-logs/close-model"
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 2 * time.Second}
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Read initial message
|
||||
var initialMsg map[string]any
|
||||
err = conn.ReadJSON(&initialMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(initialMsg["type"]).To(Equal("initial"))
|
||||
|
||||
// Close the connection from client side
|
||||
conn.Close()
|
||||
|
||||
// Give the server goroutine time to detect the close
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// No panic or hang — the test passing is the assertion
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Tags monitoring
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
@@ -29,7 +30,8 @@ func BackendMonitorEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFu
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Summary Backend shutdown endpoint
|
||||
// @Tags monitoring
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *monitoring.BackendMonitorService) echo.HandlerFunc {
|
||||
|
||||
242
core/http/endpoints/localai/config_meta.go
Normal file
242
core/http/endpoints/localai/config_meta.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigMetadataEndpoint returns field metadata for config fields.
|
||||
// Without ?section, returns just the section index (lightweight).
|
||||
// With ?section=<id>, returns fields for that section only.
|
||||
// With ?section=all, returns all fields grouped by section.
|
||||
// @Summary List model configuration field metadata
|
||||
// @Description Returns config field metadata. Use ?section=<id> to filter by section, or omit for a section index.
|
||||
// @Tags config
|
||||
// @Produce json
|
||||
// @Param section query string false "Section ID to filter (e.g. 'general', 'llm', 'parameters') or 'all' for everything"
|
||||
// @Success 200 {object} map[string]any "Section index or filtered field metadata"
|
||||
// @Router /api/models/config-metadata [get]
|
||||
func ConfigMetadataEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
sectionParam := c.QueryParam("section")
|
||||
|
||||
// No section param: return lightweight section index.
|
||||
if sectionParam == "" {
|
||||
sections := meta.DefaultSections()
|
||||
type sectionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
index := make([]sectionInfo, len(sections))
|
||||
for i, s := range sections {
|
||||
index[i] = sectionInfo{
|
||||
ID: s.ID,
|
||||
Label: s.Label,
|
||||
URL: "/api/models/config-metadata?section=" + s.ID,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"hint": "Fetch a section URL to see its fields. Use ?section=all for everything.",
|
||||
"sections": index,
|
||||
})
|
||||
}
|
||||
|
||||
md := meta.BuildConfigMetadata(reflect.TypeOf(config.ModelConfig{}))
|
||||
|
||||
// section=all: return everything.
|
||||
if sectionParam == "all" {
|
||||
return c.JSON(http.StatusOK, md)
|
||||
}
|
||||
|
||||
// Filter to requested section.
|
||||
var filtered []meta.FieldMeta
|
||||
for _, f := range md.Fields {
|
||||
if f.Section == sectionParam {
|
||||
filtered = append(filtered, f)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown section: " + sectionParam})
|
||||
}
|
||||
return c.JSON(http.StatusOK, filtered)
|
||||
}
|
||||
}
|
||||
|
||||
// AutocompleteEndpoint handles dynamic autocomplete lookups for config fields.
|
||||
// Static option lists (quantizations, cache types, diffusers pipelines/schedulers)
|
||||
// are embedded directly in the field metadata Options; only truly dynamic values
|
||||
// that require runtime lookup are served here.
|
||||
// @Summary Get dynamic autocomplete values for a config field
|
||||
// @Description Returns runtime-resolved values for dynamic providers (backends, models)
|
||||
// @Tags config
|
||||
// @Produce json
|
||||
// @Param provider path string true "Provider name (backends, models, models:chat, models:tts, models:transcript, models:vad)"
|
||||
// @Success 200 {object} map[string]any "values array"
|
||||
// @Router /api/models/config-metadata/autocomplete/{provider} [get]
|
||||
func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
provider := c.Param("provider")
|
||||
var values []string
|
||||
|
||||
switch {
|
||||
case provider == meta.ProviderBackends:
|
||||
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
|
||||
if err == nil {
|
||||
for name := range installedBackends {
|
||||
values = append(values, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(values)
|
||||
|
||||
case provider == meta.ProviderModels:
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
for _, cfg := range modelConfigs {
|
||||
values = append(values, cfg.Name)
|
||||
}
|
||||
modelsWithoutConfig, _ := galleryop.ListModels(cl, ml, config.NoFilterFn, galleryop.LOOSE_ONLY)
|
||||
values = append(values, modelsWithoutConfig...)
|
||||
sort.Strings(values)
|
||||
|
||||
case strings.HasPrefix(provider, "models:"):
|
||||
capability := strings.TrimPrefix(provider, "models:")
|
||||
var filterFn config.ModelConfigFilterFn
|
||||
switch capability {
|
||||
case "chat":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT)
|
||||
case "tts":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS)
|
||||
case "vad":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD)
|
||||
case "transcript":
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||
default:
|
||||
filterFn = config.NoFilterFn
|
||||
}
|
||||
filteredConfigs := cl.GetModelConfigsByFilter(filterFn)
|
||||
for _, cfg := range filteredConfigs {
|
||||
values = append(values, cfg.Name)
|
||||
}
|
||||
sort.Strings(values)
|
||||
|
||||
default:
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "unknown provider: " + provider})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{"values": values})
|
||||
}
|
||||
}
|
||||
|
||||
// PatchConfigEndpoint handles PATCH requests to partially update a model config
|
||||
// using nested JSON merge.
|
||||
// @Summary Partially update a model configuration
|
||||
// @Description Deep-merges the JSON patch body into the existing model config
|
||||
// @Tags config
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} map[string]any "success message"
|
||||
// @Router /api/models/config-json/{name} [patch]
|
||||
func PatchConfigEndpoint(cl *config.ModelConfigLoader, _ *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "model name is required"})
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "model configuration not found"})
|
||||
}
|
||||
|
||||
patchBody, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil || len(patchBody) == 0 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "request body is empty or unreadable"})
|
||||
}
|
||||
|
||||
var patchMap map[string]any
|
||||
if err := json.Unmarshal(patchBody, &patchMap); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid JSON: " + err.Error()})
|
||||
}
|
||||
|
||||
existingJSON, err := json.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal existing config"})
|
||||
}
|
||||
|
||||
var existingMap map[string]any
|
||||
if err := json.Unmarshal(existingJSON, &existingMap); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to parse existing config"})
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&existingMap, patchMap, mergo.WithOverride); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to merge configs: " + err.Error()})
|
||||
}
|
||||
|
||||
mergedJSON, err := json.Marshal(existingMap)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal merged config"})
|
||||
}
|
||||
|
||||
var updatedConfig config.ModelConfig
|
||||
if err := json.Unmarshal(mergedJSON, &updatedConfig); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "merged config is invalid: " + err.Error()})
|
||||
}
|
||||
|
||||
if valid, err := updatedConfig.Validate(); !valid {
|
||||
errMsg := "validation failed"
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": errMsg})
|
||||
}
|
||||
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return c.JSON(http.StatusForbidden, map[string]any{"error": "config path not trusted: " + err.Error()})
|
||||
}
|
||||
|
||||
yamlData, err := yaml.Marshal(updatedConfig)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to marshal YAML"})
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to write config file"})
|
||||
}
|
||||
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": "failed to reload configs: " + err.Error()})
|
||||
}
|
||||
|
||||
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
xlog.Warn("Failed to preload after PATCH", "error", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Model '%s' updated successfully", modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
243
core/http/endpoints/localai/config_meta_test.go
Normal file
243
core/http/endpoints/localai/config_meta_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Config Metadata Endpoints", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
configLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "config-meta-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
appConfig = config.NewApplicationConfig(
|
||||
config.WithSystemState(systemState),
|
||||
)
|
||||
configLoader = config.NewModelConfigLoader(tempDir)
|
||||
modelLoader = model.NewModelLoader(systemState)
|
||||
|
||||
app = echo.New()
|
||||
app.GET("/api/models/config-metadata", ConfigMetadataEndpoint())
|
||||
app.GET("/api/models/config-metadata/autocomplete/:provider", AutocompleteEndpoint(configLoader, modelLoader, appConfig))
|
||||
app.PATCH("/api/models/config-json/:name", PatchConfigEndpoint(configLoader, modelLoader, appConfig))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Context("GET /api/models/config-metadata", func() {
|
||||
It("should return section index when no section param", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("hint"))
|
||||
Expect(resp).To(HaveKey("sections"))
|
||||
|
||||
sections, ok := resp["sections"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(sections).NotTo(BeEmpty())
|
||||
|
||||
// Verify known section IDs are present
|
||||
ids := make([]string, len(sections))
|
||||
for i, s := range sections {
|
||||
sec := s.(map[string]any)
|
||||
Expect(sec).To(HaveKey("id"))
|
||||
Expect(sec).To(HaveKey("label"))
|
||||
Expect(sec).To(HaveKey("url"))
|
||||
ids[i] = sec["id"].(string)
|
||||
}
|
||||
Expect(ids).To(ContainElements("general", "parameters"))
|
||||
})
|
||||
|
||||
It("should return all fields when section=all", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=all", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("fields"))
|
||||
|
||||
fields, ok := resp["fields"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(fields)).To(BeNumerically(">=", 80))
|
||||
})
|
||||
|
||||
It("should filter by section", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=general", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var fields []map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &fields)).To(Succeed())
|
||||
Expect(fields).NotTo(BeEmpty())
|
||||
|
||||
for _, f := range fields {
|
||||
Expect(f["section"]).To(Equal("general"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should return 404 for unknown section", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata?section=nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Context("GET /api/models/config-metadata/autocomplete/:provider", func() {
|
||||
It("should return values for backends provider", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/backends", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp).To(HaveKey("values"))
|
||||
})
|
||||
|
||||
It("should return model names for models provider", func() {
|
||||
// Seed a model config
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/models", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
|
||||
values, ok := resp["values"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(values).To(ContainElement("test-model"))
|
||||
})
|
||||
|
||||
It("should return 404 for unknown provider", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models/config-metadata/autocomplete/unknown", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PATCH /api/models/config-json/:name", func() {
|
||||
It("should return 404 for nonexistent model", func() {
|
||||
body := bytes.NewBufferString(`{"backend": "bar"}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/nonexistent", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
|
||||
It("should return 400 for empty body", func() {
|
||||
// Seed a model config
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("should return 400 for invalid JSON", func() {
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`not json`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("should merge a field update and persist to disk", func() {
|
||||
seedConfig := `name: test-model
|
||||
backend: llama-cpp
|
||||
`
|
||||
configPath := filepath.Join(tempDir, "test-model.yaml")
|
||||
Expect(os.WriteFile(configPath, []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`{"backend": "vllm"}`)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/models/config-json/test-model", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["success"]).To(BeTrue())
|
||||
|
||||
// Verify the reloaded config has the updated value
|
||||
updatedConfig, exists := configLoader.GetModelConfig("test-model")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(updatedConfig.Backend).To(Equal("vllm"))
|
||||
|
||||
// Verify the file on disk was updated
|
||||
data, err := os.ReadFile(configPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(data)).To(ContainSubstring("vllm"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
|
||||
// @Summary Detects objects in the input image.
|
||||
// @Tags detection
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
|
||||
@@ -40,6 +40,7 @@ func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGaller
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Tags models
|
||||
// @Success 200 {object} galleryop.OpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -54,6 +55,7 @@ func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Tags models
|
||||
// @Success 200 {object} map[string]galleryop.OpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
@@ -64,6 +66,7 @@ func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Tags models
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
@@ -93,6 +96,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Tags models
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
@@ -118,7 +122,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Tags models
|
||||
// @Success 200 {object} []gallery.Metadata "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -149,6 +154,7 @@ func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Tags models
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
|
||||
//
|
||||
// @Summary Get TokenMetrics for Active Slot.
|
||||
// @Tags tokenize
|
||||
// @Accept json
|
||||
// @Produce audio/x-wav
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
|
||||
@@ -53,6 +53,7 @@ type MCPErrorEvent struct {
|
||||
// which handles MCP tool injection and server-side execution.
|
||||
// Both streaming and non-streaming modes use standard OpenAI response format.
|
||||
// @Summary MCP chat completions with automatic tool execution
|
||||
// @Tags mcp
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/mcp/chat/completions [post]
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Tags monitoring
|
||||
// @Produce text/plain
|
||||
// @Success 200 {string} string "Prometheus metrics"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||
return echo.WrapHandler(promhttp.Handler())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// ShowP2PNodes returns the P2P Nodes
|
||||
// @Summary Returns available P2P nodes
|
||||
// @Tags p2p
|
||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||
// @Router /api/p2p [get]
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
@@ -24,6 +25,7 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
// ShowP2PToken returns the P2P token
|
||||
// @Summary Show the P2P token
|
||||
// @Tags p2p
|
||||
// @Success 200 {string} string "Response"
|
||||
// @Router /api/p2p/token [get]
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// SystemInformations returns the system informations
|
||||
// @Summary Show the LocalAI instance information
|
||||
// @Tags monitoring
|
||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||
// @Router /system [get]
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
// TokenizeEndpoint exposes a REST API to tokenize the content
|
||||
// @Summary Tokenize the input.
|
||||
// @Tags tokenize
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
|
||||
59
core/http/endpoints/localai/traces.go
Normal file
59
core/http/endpoints/localai/traces.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// GetAPITracesEndpoint returns all API request/response traces
|
||||
// @Summary List API request/response traces
|
||||
// @Description Returns captured API exchange traces (request/response pairs) in reverse chronological order
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]any "Traced API exchanges"
|
||||
// @Router /api/traces [get]
|
||||
func GetAPITracesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, middleware.GetTraces())
|
||||
}
|
||||
}
|
||||
|
||||
// ClearAPITracesEndpoint clears all API traces
|
||||
// @Summary Clear API traces
|
||||
// @Description Removes all captured API request/response traces from the buffer
|
||||
// @Tags monitoring
|
||||
// @Success 204 "Traces cleared"
|
||||
// @Router /api/traces/clear [post]
|
||||
func ClearAPITracesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
middleware.ClearTraces()
|
||||
return c.NoContent(204)
|
||||
}
|
||||
}
|
||||
|
||||
// GetBackendTracesEndpoint returns all backend operation traces
|
||||
// @Summary List backend operation traces
|
||||
// @Description Returns captured backend traces (LLM calls, embeddings, TTS, etc.) in reverse chronological order
|
||||
// @Tags monitoring
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]any "Backend operation traces"
|
||||
// @Router /api/backend-traces [get]
|
||||
func GetBackendTracesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, trace.GetBackendTraces())
|
||||
}
|
||||
}
|
||||
|
||||
// ClearBackendTracesEndpoint clears all backend traces
|
||||
// @Summary Clear backend traces
|
||||
// @Description Removes all captured backend operation traces from the buffer
|
||||
// @Tags monitoring
|
||||
// @Success 204 "Traces cleared"
|
||||
// @Router /api/backend-traces/clear [post]
|
||||
func ClearBackendTracesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
trace.ClearBackendTraces()
|
||||
return c.NoContent(204)
|
||||
}
|
||||
}
|
||||
55
core/http/endpoints/localai/traces_test.go
Normal file
55
core/http/endpoints/localai/traces_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Traces Endpoints", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
app.GET("/api/traces", GetAPITracesEndpoint())
|
||||
app.POST("/api/traces/clear", ClearAPITracesEndpoint())
|
||||
app.GET("/api/backend-traces", GetBackendTracesEndpoint())
|
||||
app.POST("/api/backend-traces/clear", ClearBackendTracesEndpoint())
|
||||
})
|
||||
|
||||
It("should return API traces", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/traces", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("should clear API traces", func() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/traces/clear", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||
})
|
||||
|
||||
It("should return backend traces", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backend-traces", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("should clear backend traces", func() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/backend-traces/clear", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNoContent))
|
||||
})
|
||||
})
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
//
|
||||
// @Summary Generates audio from the input text.
|
||||
// @Tags audio
|
||||
// @Accept json
|
||||
// @Produce audio/x-wav
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
// VADEndpoint is Voice-Activation-Detection endpoint
|
||||
// @Summary Detect voice fragments in an audio stream
|
||||
// @Tags audio
|
||||
// @Accept json
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
|
||||
@@ -62,6 +62,7 @@ func downloadFile(url string) (string, error) {
|
||||
*/
|
||||
// VideoEndpoint
|
||||
// @Summary Creates a video given a prompt.
|
||||
// @Tags video
|
||||
// @Param request body schema.VideoRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
|
||||
145
core/http/endpoints/localai/vram.go
Normal file
145
core/http/endpoints/localai/vram.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
)
|
||||
|
||||
type vramEstimateRequest struct {
|
||||
Model string `json:"model"` // model name (must be installed)
|
||||
ContextSize uint32 `json:"context_size,omitempty"` // context length to estimate for (default 8192)
|
||||
GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all)
|
||||
KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16)
|
||||
}
|
||||
|
||||
type vramEstimateResponse struct {
|
||||
vram.EstimateResult
|
||||
ContextNote string `json:"context_note,omitempty"` // note when context_size was defaulted
|
||||
ModelMaxContext uint64 `json:"model_max_context,omitempty"` // model's trained maximum context length
|
||||
}
|
||||
|
||||
// resolveModelURI converts a relative model path to a file:// URI so the
|
||||
// size resolver can stat it on disk. URIs that already have a scheme are
|
||||
// returned unchanged.
|
||||
func resolveModelURI(uri, modelsPath string) string {
|
||||
if strings.Contains(uri, "://") {
|
||||
return uri
|
||||
}
|
||||
return "file://" + filepath.Join(modelsPath, uri)
|
||||
}
|
||||
|
||||
// addWeightFile appends a resolved weight file to files and tracks the first GGUF.
|
||||
func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *string, seen map[string]bool) {
|
||||
if !vram.IsWeightFile(uri) {
|
||||
return
|
||||
}
|
||||
resolved := resolveModelURI(uri, modelsPath)
|
||||
if seen[resolved] {
|
||||
return
|
||||
}
|
||||
seen[resolved] = true
|
||||
*files = append(*files, vram.FileInput{URI: resolved, Size: 0})
|
||||
if *firstGGUF == "" && vram.IsGGUF(uri) {
|
||||
*firstGGUF = resolved
|
||||
}
|
||||
}
|
||||
|
||||
// VRAMEstimateEndpoint returns a handler that estimates VRAM usage for an
|
||||
// installed model configuration. For uninstalled models (gallery URLs), use
|
||||
// the gallery-level estimates in /api/models instead.
|
||||
// @Summary Estimate VRAM usage for a model
|
||||
// @Description Estimates VRAM based on model weight files, context size, and GPU layers
|
||||
// @Tags config
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body vramEstimateRequest true "VRAM estimation parameters"
|
||||
// @Success 200 {object} vramEstimateResponse "VRAM estimate"
|
||||
// @Router /api/models/vram-estimate [post]
|
||||
func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req vramEstimateRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{"error": "model name is required"})
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(req.Model)
|
||||
if !exists {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"error": "model configuration not found"})
|
||||
}
|
||||
|
||||
modelsPath := appConfig.SystemState.Model.ModelsPath
|
||||
|
||||
var files []vram.FileInput
|
||||
var firstGGUF string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, f := range modelConfig.DownloadFiles {
|
||||
addWeightFile(string(f.URI), modelsPath, &files, &firstGGUF, seen)
|
||||
}
|
||||
if modelConfig.Model != "" {
|
||||
addWeightFile(modelConfig.Model, modelsPath, &files, &firstGGUF, seen)
|
||||
}
|
||||
if modelConfig.MMProj != "" {
|
||||
addWeightFile(modelConfig.MMProj, modelsPath, &files, &firstGGUF, seen)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"message": "no weight files found for estimation",
|
||||
})
|
||||
}
|
||||
|
||||
contextDefaulted := false
|
||||
opts := vram.EstimateOptions{
|
||||
ContextLength: req.ContextSize,
|
||||
GPULayers: req.GPULayers,
|
||||
KVQuantBits: req.KVQuantBits,
|
||||
}
|
||||
if opts.ContextLength == 0 {
|
||||
if modelConfig.ContextSize != nil {
|
||||
opts.ContextLength = uint32(*modelConfig.ContextSize)
|
||||
} else {
|
||||
opts.ContextLength = 8192
|
||||
contextDefaulted = true
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := vram.Estimate(ctx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader())
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
resp := vramEstimateResponse{EstimateResult: result}
|
||||
|
||||
// When context was defaulted to 8192, read the GGUF metadata to report
|
||||
// the model's trained maximum context length so callers know the estimate
|
||||
// may be conservative.
|
||||
if contextDefaulted && firstGGUF != "" {
|
||||
ggufMeta, err := vram.DefaultCachedGGUFReader().ReadMetadata(ctx, firstGGUF)
|
||||
if err == nil && ggufMeta != nil && ggufMeta.MaximumContextLength > 0 {
|
||||
resp.ModelMaxContext = ggufMeta.MaximumContextLength
|
||||
resp.ContextNote = fmt.Sprintf(
|
||||
"Estimate used default context_size=8192. The model's trained maximum context is %d; VRAM usage will be higher at larger context sizes.",
|
||||
ggufMeta.MaximumContextLength,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
133
core/http/endpoints/localai/vram_test.go
Normal file
133
core/http/endpoints/localai/vram_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VRAM Estimate Endpoint", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tempDir string
|
||||
configLoader *config.ModelConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "vram-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
appConfig = config.NewApplicationConfig(
|
||||
config.WithSystemState(systemState),
|
||||
)
|
||||
configLoader = config.NewModelConfigLoader(tempDir)
|
||||
|
||||
app = echo.New()
|
||||
app.POST("/api/models/vram-estimate", VRAMEstimateEndpoint(configLoader, appConfig))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
It("should return 400 for invalid request body", func() {
|
||||
body := bytes.NewBufferString(`not json`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/vram-estimate", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("should return 400 when model name is missing", func() {
|
||||
body := bytes.NewBufferString(`{"context_size": 4096}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/vram-estimate", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["error"]).To(ContainSubstring("model name is required"))
|
||||
})
|
||||
|
||||
It("should return 404 when model config does not exist", func() {
|
||||
body := bytes.NewBufferString(`{"model": "nonexistent"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/vram-estimate", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
|
||||
It("should return no-weight-files message when model has no weight files", func() {
|
||||
seedConfig := "name: test-model\nbackend: llama-cpp\n"
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`{"model": "test-model"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/vram-estimate", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["message"]).To(ContainSubstring("no weight files"))
|
||||
})
|
||||
|
||||
It("should return an estimate for a model with a weight file on disk", func() {
|
||||
// Create a dummy GGUF file (not valid GGUF, but the size resolver
|
||||
// will stat it and Estimate falls back to size-only estimation).
|
||||
dummyData := make([]byte, 1024*1024) // 1 MiB
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "model.gguf"), dummyData, 0644)).To(Succeed())
|
||||
|
||||
seedConfig := "name: test-model\nbackend: llama-cpp\nparameters:\n model: model.gguf\n"
|
||||
Expect(os.WriteFile(filepath.Join(tempDir, "test-model.yaml"), []byte(seedConfig), 0644)).To(Succeed())
|
||||
Expect(configLoader.LoadModelConfigsFromPath(tempDir)).To(Succeed())
|
||||
|
||||
body := bytes.NewBufferString(`{"model": "test-model", "context_size": 4096}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models/vram-estimate", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
// The response should have non-zero size and vram estimates.
|
||||
// JSON numbers unmarshal as float64.
|
||||
sizeBytes, ok := resp["sizeBytes"].(float64)
|
||||
Expect(ok).To(BeTrue(), "sizeBytes should be a number, got: %v (response: %s)", resp["sizeBytes"], rec.Body.String())
|
||||
Expect(sizeBytes).To(BeNumerically(">", 0))
|
||||
vramBytes, ok := resp["vramBytes"].(float64)
|
||||
Expect(ok).To(BeTrue(), "vramBytes should be a number")
|
||||
Expect(vramBytes).To(BeNumerically(">", 0))
|
||||
Expect(resp["sizeDisplay"]).NotTo(BeEmpty())
|
||||
Expect(resp["vramDisplay"]).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -55,6 +55,7 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [
|
||||
|
||||
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
|
||||
// @Summary Generate a chat completions for a given prompt and model.
|
||||
// @Tags inference
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
@@ -81,7 +82,23 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
// Always keep the Go-side extractor in sync with raw tokens so it
|
||||
// can serve as fallback for backends without an autoparser (e.g. vLLM).
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
// When C++ autoparser chat deltas are available, prefer them — they
|
||||
// handle model-specific formats (Gemma 4, etc.) without Go-side tags.
|
||||
// Otherwise fall back to Go-side extraction.
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -130,10 +147,35 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
|
||||
if usage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := usage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
@@ -280,15 +322,22 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry)
|
||||
hasToolCalls := lastEmittedCount > 0
|
||||
if cleaned == "" && !hasToolCalls {
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags — when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -963,6 +1012,29 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
|
||||
// For non-tool requests: prefer C++ autoparser chat deltas over
|
||||
// Go-side tag extraction (which can mangle output when thinkingStartToken
|
||||
// differs from the model's actual reasoning tags, e.g. Gemma 4).
|
||||
if !shouldUseFn && len(chatDeltas) > 0 {
|
||||
deltaContent := functions.ContentFromChatDeltas(chatDeltas)
|
||||
deltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
if deltaContent != "" || deltaReasoning != "" {
|
||||
xlog.Debug("[ChatDeltas] non-SSE no-tools: overriding result with C++ autoparser deltas",
|
||||
"content_len", len(deltaContent), "reasoning_len", len(deltaReasoning))
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &deltaContent}
|
||||
if deltaReasoning != "" {
|
||||
message.Reasoning = &deltaReasoning
|
||||
}
|
||||
newChoice := schema.Choice{FinishReason: &stopReason, Index: 0, Message: message}
|
||||
// Preserve logprobs from the original result
|
||||
if len(result) > 0 && result[0].Logprobs != nil {
|
||||
newChoice.Logprobs = result[0].Logprobs
|
||||
}
|
||||
result = []schema.Choice{newChoice}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
// @Summary Generate completions for a given prompt and model.
|
||||
// @Tags inference
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
// EditEndpoint is the OpenAI edit API endpoint
|
||||
// @Summary OpenAI edit endpoint
|
||||
// @Tags inference
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
|
||||
@@ -42,6 +42,7 @@ func embeddingItem(embeddings []float32, index int, encodingFormat string) schem
|
||||
|
||||
// EmbeddingsEndpoint is the OpenAI Embeddings API endpoint https://platform.openai.com/docs/api-reference/embeddings
|
||||
// @Summary Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
|
||||
// @Tags embeddings
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
|
||||
@@ -68,6 +68,7 @@ func downloadFile(url string) (string, error) {
|
||||
*/
|
||||
// ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create
|
||||
// @Summary Creates an image given a prompt.
|
||||
// @Tags images
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
|
||||
@@ -113,11 +113,23 @@ func ComputeChoices(
|
||||
}
|
||||
prediction = p
|
||||
|
||||
// Built-in: retry on truly empty response (no tokens at all)
|
||||
// Built-in: retry on truly empty response (no tokens at all).
|
||||
// However, when the C++ autoparser is active, it clears the raw
|
||||
// message and delivers content via ChatDeltas instead. Do NOT
|
||||
// retry if ChatDeltas contain tool calls or content.
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
hasChatDeltaData := false
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasChatDeltaData {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage.Prompt = prediction.Usage.Prompt
|
||||
@@ -130,8 +142,21 @@ func ComputeChoices(
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.)
|
||||
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.).
|
||||
// When the C++ autoparser is active, it clears the raw response
|
||||
// and delivers data via ChatDeltas. If the response is empty but
|
||||
// ChatDeltas contain actionable data, skip the caller retry —
|
||||
// the autoparser already parsed the response successfully.
|
||||
skipCallerRetry := false
|
||||
if strings.TrimSpace(prediction.Response) == "" && len(prediction.ChatDeltas) > 0 {
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
skipCallerRetry = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldRetryFn != nil && !skipCallerRetry && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller has already reset its state inside shouldRetry
|
||||
result = result[:0]
|
||||
allChatDeltas = nil
|
||||
|
||||
@@ -398,5 +398,124 @@ var _ = Describe("ComputeChoices", func() {
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(streamedTokens).To(Equal([]string{"Hello", " world"}))
|
||||
})
|
||||
|
||||
It("should pass chat deltas through TokenUsage during streaming", func() {
|
||||
var receivedDeltas [][]*pb.ChatDelta
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
if tokenCallback != nil {
|
||||
// Simulate C++ autoparser sending reasoning in chat deltas
|
||||
tokenCallback("<|channel>thought\nthinking\n<channel|>", backend.TokenUsage{
|
||||
Prompt: 5,
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking"},
|
||||
},
|
||||
})
|
||||
tokenCallback("Hello!", backend.TokenUsage{
|
||||
Prompt: 5, Completion: 3,
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: "Hello!"},
|
||||
},
|
||||
})
|
||||
}
|
||||
return backend.LLMResponse{
|
||||
Response: "<|channel>thought\nthinking\n<channel|>Hello!",
|
||||
Usage: backend.TokenUsage{Prompt: 5, Completion: 3},
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "thinking"},
|
||||
{Content: "Hello!"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
choices, _, deltas, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
func(s string, usage backend.TokenUsage) bool {
|
||||
// Capture chat deltas received per-chunk
|
||||
if len(usage.ChatDeltas) > 0 {
|
||||
receivedDeltas = append(receivedDeltas, usage.ChatDeltas)
|
||||
}
|
||||
return true
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
|
||||
// Verify per-chunk deltas were received during streaming
|
||||
Expect(receivedDeltas).To(HaveLen(2))
|
||||
Expect(receivedDeltas[0][0].ReasoningContent).To(Equal("thinking"))
|
||||
Expect(receivedDeltas[1][0].Content).To(Equal("Hello!"))
|
||||
|
||||
// Verify final accumulated deltas are also returned
|
||||
Expect(deltas).To(HaveLen(2))
|
||||
Expect(deltas[0].ReasoningContent).To(Equal("thinking"))
|
||||
Expect(deltas[1].Content).To(Equal("Hello!"))
|
||||
})
|
||||
|
||||
It("should prefer chat deltas over raw text when HasChatDeltaContent is true", func() {
|
||||
// Verify that the callback can distinguish between
|
||||
// chunks with and without chat deltas
|
||||
var withDeltas, withoutDeltas int
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
if tokenCallback != nil {
|
||||
// Chunk with chat deltas (C++ autoparser active)
|
||||
tokenCallback("raw-text", backend.TokenUsage{
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "parsed-content"}},
|
||||
})
|
||||
// Chunk without chat deltas (fallback)
|
||||
tokenCallback("fallback-text", backend.TokenUsage{})
|
||||
}
|
||||
return backend.LLMResponse{Response: "raw-textfallback-text"}, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
_, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
func(s string, usage backend.TokenUsage) bool {
|
||||
if usage.HasChatDeltaContent() {
|
||||
withDeltas++
|
||||
r, c := usage.ChatDeltaReasoningAndContent()
|
||||
Expect(c).To(Equal("parsed-content"))
|
||||
Expect(r).To(BeEmpty())
|
||||
} else {
|
||||
withoutDeltas++
|
||||
}
|
||||
return true
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(withDeltas).To(Equal(1))
|
||||
Expect(withoutDeltas).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Tags models
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
// TranscriptEndpoint is the OpenAI Whisper API endpoint https://platform.openai.com/docs/api-reference/audio/create
|
||||
// @Summary Transcribes audio into the input language.
|
||||
// @Tags audio
|
||||
// @accept multipart/form-data
|
||||
// @Param model formData string true "model"
|
||||
// @Param file formData file true "file"
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
// ResponsesEndpoint is the Open Responses API endpoint
|
||||
// https://www.openresponses.org/specification
|
||||
// @Summary Create a response using the Open Responses API
|
||||
// @Tags inference
|
||||
// @Param request body schema.OpenResponsesRequest true "Request body"
|
||||
// @Success 200 {object} schema.ORResponseResource "Response"
|
||||
// @Router /v1/responses [post]
|
||||
@@ -1819,7 +1820,17 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// If no tool calls detected yet, handle reasoning and text
|
||||
if !inToolCallMode {
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
var reasoningDelta, contentDelta string
|
||||
goReasoning, goContent := extractor.ProcessToken(token)
|
||||
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Handle reasoning item
|
||||
if extractor.Reasoning() != "" {
|
||||
@@ -2338,7 +2349,18 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Stream text deltas with reasoning extraction
|
||||
tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedText += token
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
goReasoning, goContent := extractor.ProcessToken(token)
|
||||
|
||||
if tokenUsage.HasChatDeltaContent() {
|
||||
rawReasoning, cd := tokenUsage.ChatDeltaReasoningAndContent()
|
||||
contentDelta = cd
|
||||
reasoningDelta = extractor.ProcessChatDeltaReasoning(rawReasoning)
|
||||
} else {
|
||||
reasoningDelta = goReasoning
|
||||
contentDelta = goContent
|
||||
}
|
||||
|
||||
// Handle reasoning item
|
||||
if extractor.Reasoning() != "" {
|
||||
@@ -2931,6 +2953,7 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T
|
||||
// GetResponseEndpoint returns a handler for GET /responses/:id
|
||||
// This endpoint is used for polling background responses or resuming streaming
|
||||
// @Summary Get a response by ID
|
||||
// @Tags inference
|
||||
// @Description Retrieve a response by ID. Can be used for polling background responses or resuming streaming responses.
|
||||
// @Param id path string true "Response ID"
|
||||
// @Param stream query string false "Set to 'true' to resume streaming"
|
||||
@@ -3072,6 +3095,7 @@ func handleStreamResume(c echo.Context, store *ResponseStore, responseID string,
|
||||
// CancelResponseEndpoint returns a handler for POST /responses/:id/cancel
|
||||
// This endpoint cancels a background response if it's still in progress
|
||||
// @Summary Cancel a response
|
||||
// @Tags inference
|
||||
// @Description Cancel a background response if it's still in progress
|
||||
// @Param id path string true "Response ID"
|
||||
// @Success 200 {object} schema.ORResponseResource "Response"
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Navigate } from 'react-router-dom'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
|
||||
export default function RequireAuth({ children }) {
|
||||
const { authEnabled, user, loading } = useAuth()
|
||||
const { authEnabled, staticApiKeyRequired, user, loading } = useAuth()
|
||||
if (loading) return null
|
||||
if (authEnabled && !user) return <Navigate to="/login" replace />
|
||||
if ((authEnabled || staticApiKeyRequired) && !user) return <Navigate to="/login" replace />
|
||||
return children
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ export function AuthProvider({ children }) {
|
||||
const [state, setState] = useState({
|
||||
loading: true,
|
||||
authEnabled: false,
|
||||
staticApiKeyRequired: false,
|
||||
user: null,
|
||||
permissions: {},
|
||||
})
|
||||
@@ -20,12 +21,13 @@ export function AuthProvider({ children }) {
|
||||
setState({
|
||||
loading: false,
|
||||
authEnabled: data.authEnabled || false,
|
||||
staticApiKeyRequired: data.staticApiKeyRequired || false,
|
||||
user,
|
||||
permissions,
|
||||
})
|
||||
})
|
||||
.catch(() => {
|
||||
setState({ loading: false, authEnabled: false, user: null, permissions: {} })
|
||||
setState({ loading: false, authEnabled: false, staticApiKeyRequired: false, user: null, permissions: {} })
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,17 +47,20 @@ export function AuthProvider({ children }) {
|
||||
|
||||
const refresh = () => fetchStatus()
|
||||
|
||||
const noAuthRequired = !state.authEnabled && !state.staticApiKeyRequired
|
||||
|
||||
const hasFeature = (name) => {
|
||||
if (state.user?.role === 'admin' || !state.authEnabled) return true
|
||||
if (state.user?.role === 'admin' || noAuthRequired) return true
|
||||
return !!state.permissions[name]
|
||||
}
|
||||
|
||||
const value = {
|
||||
loading: state.loading,
|
||||
authEnabled: state.authEnabled,
|
||||
staticApiKeyRequired: state.staticApiKeyRequired,
|
||||
user: state.user,
|
||||
permissions: state.permissions,
|
||||
isAdmin: state.user?.role === 'admin' || !state.authEnabled,
|
||||
isAdmin: state.user?.role === 'admin' || noAuthRequired,
|
||||
hasFeature,
|
||||
logout,
|
||||
refresh,
|
||||
|
||||
12
core/http/react-ui/src/hooks/useChat.js
vendored
12
core/http/react-ui/src/hooks/useChat.js
vendored
@@ -2,9 +2,9 @@ import { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import { API_CONFIG } from '../utils/config'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
|
||||
const thinkingTagRegex = /<thinking>([\s\S]*?)<\/thinking>|<think>([\s\S]*?)<\/think>/g
|
||||
const openThinkTagRegex = /<thinking>|<think>/
|
||||
const closeThinkTagRegex = /<\/thinking>|<\/think>/
|
||||
const thinkingTagRegex = /<thinking>([\s\S]*?)<\/thinking>|<think>([\s\S]*?)<\/think>|<\|channel>thought([\s\S]*?)<channel\|>/g
|
||||
const openThinkTagRegex = /<thinking>|<think>|<\|channel>thought/
|
||||
const closeThinkTagRegex = /<\/thinking>|<\/think>|<channel\|>/
|
||||
|
||||
async function extractHttpError(response) {
|
||||
let errorMsg = `HTTP ${response.status}`
|
||||
@@ -23,7 +23,7 @@ function extractThinking(text) {
|
||||
thinkingTagRegex.lastIndex = 0
|
||||
while ((match = thinkingTagRegex.exec(text)) !== null) {
|
||||
regularContent += text.slice(lastIdx, match.index)
|
||||
thinkingContent += match[1] || match[2] || ''
|
||||
thinkingContent += match[1] || match[2] || match[3] || ''
|
||||
lastIdx = match.index + match[0].length
|
||||
}
|
||||
regularContent += text.slice(lastIdx)
|
||||
@@ -578,9 +578,9 @@ export function useChat(initialModel = '') {
|
||||
}
|
||||
|
||||
if (insideThinkTag) {
|
||||
const lastOpen = Math.max(rawContent.lastIndexOf('<thinking>'), rawContent.lastIndexOf('<think>'))
|
||||
const lastOpen = Math.max(rawContent.lastIndexOf('<thinking>'), rawContent.lastIndexOf('<think>'), rawContent.lastIndexOf('<|channel>thought'))
|
||||
if (lastOpen >= 0) {
|
||||
const partial = rawContent.slice(lastOpen).replace(/<thinking>|<think>/, '')
|
||||
const partial = rawContent.slice(lastOpen).replace(/<thinking>|<think>|<\|channel>thought/, '')
|
||||
setStreamingReasoning(partial)
|
||||
const beforeThink = rawContent.slice(0, lastOpen)
|
||||
const { regularContent: contentBeforeThink } = extractThinking(beforeThink)
|
||||
|
||||
@@ -8,7 +8,7 @@ export default function Login() {
|
||||
const navigate = useNavigate()
|
||||
const { code: urlInviteCode } = useParams()
|
||||
const [searchParams] = useSearchParams()
|
||||
const { authEnabled, user, loading: authLoading, refresh } = useAuth()
|
||||
const { authEnabled, staticApiKeyRequired, user, loading: authLoading, refresh } = useAuth()
|
||||
const [providers, setProviders] = useState([])
|
||||
const [hasUsers, setHasUsers] = useState(true)
|
||||
const [registrationMode, setRegistrationMode] = useState('open')
|
||||
@@ -66,7 +66,7 @@ export default function Login() {
|
||||
|
||||
// Redirect if auth is disabled or user is already logged in
|
||||
useEffect(() => {
|
||||
if (!authLoading && (!authEnabled || user)) {
|
||||
if (!authLoading && ((!authEnabled && !staticApiKeyRequired) || user)) {
|
||||
navigate('/app', { replace: true })
|
||||
}
|
||||
}, [authLoading, authEnabled, user, navigate])
|
||||
@@ -176,6 +176,40 @@ export default function Login() {
|
||||
|
||||
if (authLoading || statusLoading) return null
|
||||
|
||||
// Legacy API key-only mode: show a simplified login with just the token input
|
||||
if (staticApiKeyRequired && !authEnabled) {
|
||||
return (
|
||||
<div className="login-page">
|
||||
<div className="card login-card">
|
||||
<div className="login-header">
|
||||
<img src={apiUrl('/static/logo.png')} alt="LocalAI" className="login-logo" />
|
||||
<p className="login-subtitle">Enter your API key to continue</p>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="login-alert login-alert-error">{error}</div>
|
||||
)}
|
||||
|
||||
<form onSubmit={handleTokenLogin}>
|
||||
<div className="form-group">
|
||||
<input
|
||||
className="input"
|
||||
type="password"
|
||||
value={token}
|
||||
onChange={(e) => { setToken(e.target.value); setError('') }}
|
||||
placeholder="Enter API key..."
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
<button type="submit" className="btn btn-primary login-btn-full" disabled={submitting}>
|
||||
{submitting ? 'Signing in...' : 'Sign In'}
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const hasGitHub = providers.includes('github')
|
||||
const hasOIDC = providers.includes('oidc')
|
||||
const hasLocal = providers.includes('local')
|
||||
|
||||
@@ -157,10 +157,11 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"authEnabled": authEnabled,
|
||||
"providers": providers,
|
||||
"hasUsers": hasUsers,
|
||||
"registrationMode": registrationMode,
|
||||
"authEnabled": authEnabled,
|
||||
"staticApiKeyRequired": !authEnabled && len(appConfig.ApiKeys) > 0,
|
||||
"providers": providers,
|
||||
"hasUsers": hasUsers,
|
||||
"registrationMode": registrationMode,
|
||||
}
|
||||
|
||||
// Include current user if authenticated
|
||||
|
||||
@@ -45,9 +45,10 @@ func newTestAuthApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"authEnabled": authEnabled,
|
||||
"providers": providers,
|
||||
"hasUsers": hasUsers,
|
||||
"authEnabled": authEnabled,
|
||||
"staticApiKeyRequired": !authEnabled && len(appConfig.ApiKeys) > 0,
|
||||
"providers": providers,
|
||||
"hasUsers": hasUsers,
|
||||
}
|
||||
|
||||
user := auth.GetUser(c)
|
||||
@@ -407,6 +408,29 @@ var _ = Describe("Auth Routes", Label("auth"), func() {
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
Expect(resp["hasUsers"]).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns staticApiKeyRequired=true when no DB but API keys configured", func() {
|
||||
cfg := config.NewApplicationConfig()
|
||||
config.WithApiKeys([]string{"test-key-123"})(cfg)
|
||||
app := newTestAuthApp(nil, cfg)
|
||||
rec := doAuthRequest(app, "GET", "/api/auth/status", nil)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
Expect(resp["authEnabled"]).To(BeFalse())
|
||||
Expect(resp["staticApiKeyRequired"]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns staticApiKeyRequired=false when no DB and no API keys", func() {
|
||||
app := newTestAuthApp(nil, config.NewApplicationConfig())
|
||||
rec := doAuthRequest(app, "GET", "/api/auth/status", nil)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
Expect(resp["staticApiKeyRequired"]).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("POST /api/auth/logout", func() {
|
||||
|
||||
@@ -29,7 +29,9 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
mcpJobsMw echo.MiddlewareFunc,
|
||||
mcpMw echo.MiddlewareFunc) {
|
||||
|
||||
router.GET("/swagger/*", echoswagger.WrapHandler) // default
|
||||
router.GET("/swagger/*", echoswagger.EchoWrapHandler(func(c *echoswagger.Config) {
|
||||
c.URLs = []string{"doc.json"}
|
||||
}))
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
@@ -124,6 +126,19 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService), adminMiddleware)
|
||||
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService), adminMiddleware)
|
||||
|
||||
// Traces and backend logs (monitoring)
|
||||
router.GET("/api/traces", localai.GetAPITracesEndpoint(), adminMiddleware)
|
||||
router.POST("/api/traces/clear", localai.ClearAPITracesEndpoint(), adminMiddleware)
|
||||
router.GET("/api/backend-traces", localai.GetBackendTracesEndpoint(), adminMiddleware)
|
||||
router.POST("/api/backend-traces/clear", localai.ClearBackendTracesEndpoint(), adminMiddleware)
|
||||
// Backend logs — standalone only (distributed mode uses node-proxied routes)
|
||||
if !appConfig.Distributed.Enabled {
|
||||
router.GET("/api/backend-logs", localai.ListBackendLogsEndpoint(ml), adminMiddleware)
|
||||
router.GET("/api/backend-logs/:modelId", localai.GetBackendLogsEndpoint(ml), adminMiddleware)
|
||||
router.POST("/api/backend-logs/:modelId/clear", localai.ClearBackendLogsEndpoint(ml), adminMiddleware)
|
||||
router.GET("/ws/backend-logs/:modelId", localai.BackendLogsWebSocketEndpoint(ml), adminMiddleware)
|
||||
}
|
||||
|
||||
// p2p
|
||||
router.GET("/api/p2p", localai.ShowP2PNodes(appConfig), adminMiddleware)
|
||||
router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig), adminMiddleware)
|
||||
@@ -134,6 +149,127 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
// Agent discovery endpoint
|
||||
router.GET("/.well-known/localai.json", func(c echo.Context) error {
|
||||
monitoringRoutes := map[string]string{
|
||||
"metrics": "/metrics",
|
||||
"backend_monitor": "/backend/monitor",
|
||||
"backend_shutdown": "/backend/shutdown",
|
||||
"system": "/system",
|
||||
"version": "/version",
|
||||
"traces": "/api/traces",
|
||||
"traces_clear": "/api/traces/clear",
|
||||
"backend_traces": "/api/backend-traces",
|
||||
"backend_traces_clear": "/api/backend-traces/clear",
|
||||
}
|
||||
if !appConfig.Distributed.Enabled {
|
||||
monitoringRoutes["backend_logs"] = "/api/backend-logs"
|
||||
monitoringRoutes["backend_logs_model"] = "/api/backend-logs/:modelId"
|
||||
monitoringRoutes["backend_logs_clear"] = "/api/backend-logs/:modelId/clear"
|
||||
monitoringRoutes["backend_logs_ws"] = "/ws/backend-logs/:modelId"
|
||||
} else {
|
||||
monitoringRoutes["node_backend_logs"] = "/api/nodes/:id/backend-logs"
|
||||
monitoringRoutes["node_backend_logs_model"] = "/api/nodes/:id/backend-logs/:modelId"
|
||||
monitoringRoutes["node_backend_logs_ws"] = "/ws/nodes/:id/backend-logs/:modelId"
|
||||
}
|
||||
return c.JSON(200, map[string]any{
|
||||
"version": internal.PrintableVersion(),
|
||||
// Flat endpoint list for backwards compatibility
|
||||
"endpoints": map[string]any{
|
||||
"models": "/v1/models",
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"config_metadata": "/api/models/config-metadata",
|
||||
"config_json": "/api/models/config-json/:name",
|
||||
"config_patch": "/api/models/config-json/:name",
|
||||
"autocomplete": "/api/models/config-metadata/autocomplete/:provider",
|
||||
"vram_estimate": "/api/models/vram-estimate",
|
||||
"tts": "/tts",
|
||||
"transcription": "/v1/audio/transcriptions",
|
||||
"image_generation": "/v1/images/generations",
|
||||
"swagger": "/swagger/index.html",
|
||||
"instructions": "/api/instructions",
|
||||
},
|
||||
// Categorized endpoint groups for structured discovery
|
||||
"endpoint_groups": map[string]any{
|
||||
"openai_compatible": map[string]string{
|
||||
"models": "/v1/models",
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"transcription": "/v1/audio/transcriptions",
|
||||
"image_generation": "/v1/images/generations",
|
||||
},
|
||||
"config_management": map[string]string{
|
||||
"config_metadata": "/api/models/config-metadata",
|
||||
"config_json": "/api/models/config-json/:name",
|
||||
"config_patch": "/api/models/config-json/:name",
|
||||
"autocomplete": "/api/models/config-metadata/autocomplete/:provider",
|
||||
"vram_estimate": "/api/models/vram-estimate",
|
||||
},
|
||||
"model_management": map[string]string{
|
||||
"list_gallery": "/models/available",
|
||||
"install": "/models/apply",
|
||||
"delete": "/models/delete/:name",
|
||||
"edit": "/models/edit/:name",
|
||||
"import": "/models/import",
|
||||
"reload": "/models/reload",
|
||||
},
|
||||
"ai_functions": map[string]string{
|
||||
"tts": "/tts",
|
||||
"vad": "/vad",
|
||||
"video": "/video",
|
||||
"detection": "/v1/detection",
|
||||
"tokenize": "/v1/tokenize",
|
||||
},
|
||||
"monitoring": monitoringRoutes,
|
||||
"mcp": map[string]string{
|
||||
"chat_completions": "/v1/mcp/chat/completions",
|
||||
"servers": "/v1/mcp/servers/:model",
|
||||
"prompts": "/v1/mcp/prompts/:model",
|
||||
"resources": "/v1/mcp/resources/:model",
|
||||
},
|
||||
"p2p": map[string]string{
|
||||
"nodes": "/api/p2p",
|
||||
"token": "/api/p2p/token",
|
||||
},
|
||||
"agents": map[string]string{
|
||||
"tasks": "/api/agent/tasks",
|
||||
"jobs": "/api/agent/jobs",
|
||||
"execute": "/api/agent/jobs/execute",
|
||||
},
|
||||
"settings": map[string]string{
|
||||
"get": "/api/settings",
|
||||
"update": "/api/settings",
|
||||
},
|
||||
"stores": map[string]string{
|
||||
"set": "/stores/set",
|
||||
"get": "/stores/get",
|
||||
"find": "/stores/find",
|
||||
"delete": "/stores/delete",
|
||||
},
|
||||
"docs": map[string]string{
|
||||
"swagger": "/swagger/index.html",
|
||||
"instructions": "/api/instructions",
|
||||
},
|
||||
},
|
||||
"capabilities": map[string]bool{
|
||||
"config_metadata": true,
|
||||
"config_patch": true,
|
||||
"vram_estimate": true,
|
||||
"mcp": !appConfig.DisableMCP,
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
"p2p": appConfig.P2PToken != "",
|
||||
"tracing": true,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// API instructions for agent discovery (no auth — agents should discover these without credentials)
|
||||
router.GET("/api/instructions", localai.ListAPIInstructionsEndpoint())
|
||||
router.GET("/api/instructions/:name", localai.GetAPIInstructionEndpoint())
|
||||
|
||||
router.GET("/api/features", func(c echo.Context) error {
|
||||
return c.JSON(200, map[string]bool{
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
|
||||
@@ -2,41 +2,15 @@ package routes
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var backendLogsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // no origin header = same-origin or non-browser
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == r.Host
|
||||
},
|
||||
}
|
||||
|
||||
func RegisterUIRoutes(app *echo.Echo,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *galleryop.GalleryService,
|
||||
adminMiddleware echo.MiddlewareFunc) {
|
||||
@@ -78,142 +52,4 @@ func RegisterUIRoutes(app *echo.Echo,
|
||||
|
||||
return c.JSON(200, models)
|
||||
})
|
||||
|
||||
app.GET("/api/traces", func(c echo.Context) error {
|
||||
return c.JSON(200, middleware.GetTraces())
|
||||
}, adminMiddleware)
|
||||
|
||||
app.POST("/api/traces/clear", func(c echo.Context) error {
|
||||
middleware.ClearTraces()
|
||||
return c.NoContent(204)
|
||||
}, adminMiddleware)
|
||||
|
||||
app.GET("/api/backend-traces", func(c echo.Context) error {
|
||||
return c.JSON(200, trace.GetBackendTraces())
|
||||
}, adminMiddleware)
|
||||
|
||||
app.POST("/api/backend-traces/clear", func(c echo.Context) error {
|
||||
trace.ClearBackendTraces()
|
||||
return c.NoContent(204)
|
||||
}, adminMiddleware)
|
||||
|
||||
// Backend logs endpoints — only in standalone mode.
|
||||
// In distributed mode, backend processes run on workers and logs are
|
||||
// streamed via /api/nodes/:id/backend-logs and /ws/nodes/:id/backend-logs/:modelId.
|
||||
if !appConfig.Distributed.Enabled {
|
||||
app.GET("/api/backend-logs", func(c echo.Context) error {
|
||||
return c.JSON(200, ml.BackendLogs().ListModels())
|
||||
}, adminMiddleware)
|
||||
|
||||
app.GET("/api/backend-logs/:modelId", func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
return c.JSON(200, ml.BackendLogs().GetLines(modelID))
|
||||
}, adminMiddleware)
|
||||
|
||||
app.POST("/api/backend-logs/:modelId/clear", func(c echo.Context) error {
|
||||
ml.BackendLogs().Clear(c.Param("modelId"))
|
||||
return c.NoContent(204)
|
||||
}, adminMiddleware)
|
||||
|
||||
// Backend logs WebSocket endpoint for real-time streaming
|
||||
app.GET("/ws/backend-logs/:modelId", func(c echo.Context) error {
|
||||
modelID := c.Param("modelId")
|
||||
|
||||
ws, err := backendLogsUpgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
ws.SetReadLimit(4096)
|
||||
|
||||
// Set up ping/pong for keepalive
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(90 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
conn := &backendLogsConn{Conn: ws}
|
||||
|
||||
// Send existing lines as initial batch
|
||||
existingLines := ml.BackendLogs().GetLines(modelID)
|
||||
initialMsg := map[string]any{
|
||||
"type": "initial",
|
||||
"lines": existingLines,
|
||||
}
|
||||
if err := conn.writeJSON(initialMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs initial write failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe to new lines
|
||||
lineCh, unsubscribe := ml.BackendLogs().Subscribe(modelID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Handle close from client side
|
||||
closeCh := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
close(closeCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping ticker for keepalive
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
// Forward new lines to WebSocket
|
||||
for {
|
||||
select {
|
||||
case line, ok := <-lineCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
lineMsg := map[string]any{
|
||||
"type": "line",
|
||||
"line": line,
|
||||
}
|
||||
if err := conn.writeJSON(lineMsg); err != nil {
|
||||
xlog.Debug("WebSocket backend-logs write error", "error", err)
|
||||
return nil
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := conn.writePing(); err != nil {
|
||||
return nil
|
||||
}
|
||||
case <-closeCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}, adminMiddleware)
|
||||
}
|
||||
}
|
||||
|
||||
// backendLogsConn wraps a websocket connection with a mutex for safe concurrent writes
|
||||
type backendLogsConn struct {
|
||||
*websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writeJSON(v any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal error: %w", err)
|
||||
}
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (c *backendLogsConn) writePing() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
return c.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
@@ -690,6 +690,18 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
return c.JSON(http.StatusOK, modelConfig)
|
||||
}, adminMiddleware)
|
||||
|
||||
// Config metadata API - returns field metadata for all ~170 config fields
|
||||
app.GET("/api/models/config-metadata", localai.ConfigMetadataEndpoint(), adminMiddleware)
|
||||
|
||||
// Autocomplete providers for config fields (dynamic values only)
|
||||
app.GET("/api/models/config-metadata/autocomplete/:provider", localai.AutocompleteEndpoint(cl, ml, appConfig), adminMiddleware)
|
||||
|
||||
// PATCH config endpoint - partial update using nested JSON merge
|
||||
app.PATCH("/api/models/config-json/:name", localai.PatchConfigEndpoint(cl, ml, appConfig), adminMiddleware)
|
||||
|
||||
// VRAM estimation endpoint
|
||||
app.POST("/api/models/vram-estimate", localai.VRAMEstimateEndpoint(cl, appConfig), adminMiddleware)
|
||||
|
||||
// Get installed model YAML config for the React model editor
|
||||
app.GET("/api/models/edit/:name", func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
@@ -1313,3 +1325,4 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
})
|
||||
}, adminMiddleware)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,21 +10,18 @@ type Task struct {
|
||||
Name string `json:"name"` // User-friendly name
|
||||
Description string `json:"description"` // Optional description
|
||||
Model string `json:"model"` // Model name (must have MCP config)
|
||||
Prompt string `json:"prompt"` // Template prompt (supports {{.param}} syntax)
|
||||
Prompt string `json:"prompt"` // Template prompt (supports Go template .param syntax)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Enabled bool `json:"enabled"` // Can be disabled without deletion
|
||||
Cron string `json:"cron,omitempty"` // Optional cron expression
|
||||
CronParameters map[string]string `json:"cron_parameters,omitempty"` // Parameters to use when executing cron jobs
|
||||
|
||||
// Webhook configuration (for notifications)
|
||||
// Support multiple webhook endpoints
|
||||
// Webhook configuration (for notifications).
|
||||
// Supports multiple webhook endpoints.
|
||||
// Webhooks can handle both success and failure cases using template variables:
|
||||
// - {{.Job}} - Job object with all fields
|
||||
// - {{.Task}} - Task object
|
||||
// - {{.Result}} - Job result (if successful)
|
||||
// - {{.Error}} - Error message (if failed, empty string if successful)
|
||||
// - {{.Status}} - Job status string
|
||||
// .Job (Job object), .Task (Task object), .Result (if successful),
|
||||
// .Error (if failed), .Status (job status string).
|
||||
Webhooks []WebhookConfig `json:"webhooks,omitempty"` // Webhook configs for job completion notifications
|
||||
|
||||
// Multimedia sources (for cron jobs)
|
||||
@@ -39,13 +36,8 @@ type WebhookConfig struct {
|
||||
Method string `json:"method"` // HTTP method (POST, PUT, PATCH) - default: POST
|
||||
Headers map[string]string `json:"headers,omitempty"` // Custom headers (e.g., Authorization)
|
||||
PayloadTemplate string `json:"payload_template,omitempty"` // Optional template for payload
|
||||
// If PayloadTemplate is empty, uses default JSON structure
|
||||
// Available template variables:
|
||||
// - {{.Job}} - Job object with all fields
|
||||
// - {{.Task}} - Task object
|
||||
// - {{.Result}} - Job result (if successful)
|
||||
// - {{.Error}} - Error message (if failed, empty string if successful)
|
||||
// - {{.Status}} - Job status string
|
||||
// If PayloadTemplate is empty, uses default JSON structure.
|
||||
// Available template variables: .Job, .Task, .Result, .Error, .Status.
|
||||
}
|
||||
|
||||
// MultimediaSourceConfig represents configuration for fetching multimedia content
|
||||
@@ -126,9 +118,9 @@ type JobExecutionRequest struct {
|
||||
|
||||
// JobExecutionResponse represents the response after creating a job
|
||||
type JobExecutionResponse struct {
|
||||
JobID string `json:"job_id"`
|
||||
Status string `json:"status"`
|
||||
URL string `json:"url"` // URL to check job status
|
||||
JobID string `json:"job_id"` // unique job identifier
|
||||
Status string `json:"status"` // initial status (pending)
|
||||
URL string `json:"url"` // URL to poll for job status
|
||||
}
|
||||
|
||||
// TasksFile represents the structure of agent_tasks.json
|
||||
|
||||
@@ -78,7 +78,7 @@ type AnthropicMessage struct {
|
||||
// AnthropicContentBlock represents a content block in an Anthropic message
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -116,7 +116,7 @@ type AnthropicUsage struct {
|
||||
// AnthropicStreamEvent represents a streaming event from the Anthropic API
|
||||
type AnthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *AnthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *AnthropicStreamMessage `json:"message,omitempty"`
|
||||
|
||||
@@ -33,31 +33,31 @@ type GalleryResponse struct {
|
||||
|
||||
type VideoRequest struct {
|
||||
BasicModelRequest
|
||||
Prompt string `json:"prompt" yaml:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
||||
StartImage string `json:"start_image" yaml:"start_image"`
|
||||
EndImage string `json:"end_image" yaml:"end_image"`
|
||||
Width int32 `json:"width" yaml:"width"`
|
||||
Height int32 `json:"height" yaml:"height"`
|
||||
NumFrames int32 `json:"num_frames" yaml:"num_frames"`
|
||||
FPS int32 `json:"fps" yaml:"fps"`
|
||||
Seconds string `json:"seconds,omitempty" yaml:"seconds,omitempty"`
|
||||
Size string `json:"size,omitempty" yaml:"size,omitempty"`
|
||||
InputReference string `json:"input_reference,omitempty" yaml:"input_reference,omitempty"`
|
||||
Seed int32 `json:"seed" yaml:"seed"`
|
||||
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"`
|
||||
Step int32 `json:"step" yaml:"step"`
|
||||
ResponseFormat string `json:"response_format" yaml:"response_format"`
|
||||
Prompt string `json:"prompt" yaml:"prompt"` // text description of the video to generate
|
||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` // things to avoid in the output
|
||||
StartImage string `json:"start_image" yaml:"start_image"` // URL or base64 of the first frame
|
||||
EndImage string `json:"end_image" yaml:"end_image"` // URL or base64 of the last frame
|
||||
Width int32 `json:"width" yaml:"width"` // output width in pixels
|
||||
Height int32 `json:"height" yaml:"height"` // output height in pixels
|
||||
NumFrames int32 `json:"num_frames" yaml:"num_frames"` // total number of frames to generate
|
||||
FPS int32 `json:"fps" yaml:"fps"` // frames per second
|
||||
Seconds string `json:"seconds,omitempty" yaml:"seconds,omitempty"` // duration in seconds (alternative to num_frames)
|
||||
Size string `json:"size,omitempty" yaml:"size,omitempty"` // WxH shorthand (e.g. "512x512")
|
||||
InputReference string `json:"input_reference,omitempty" yaml:"input_reference,omitempty"` // reference image or video URL
|
||||
Seed int32 `json:"seed" yaml:"seed"` // random seed for reproducibility
|
||||
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"` // classifier-free guidance scale
|
||||
Step int32 `json:"step" yaml:"step"` // number of diffusion steps
|
||||
ResponseFormat string `json:"response_format" yaml:"response_format"` // output format (url or b64_json)
|
||||
}
|
||||
|
||||
// @Description TTS request body
|
||||
type TTSRequest struct {
|
||||
BasicModelRequest
|
||||
Input string `json:"input" yaml:"input"` // text input
|
||||
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
|
||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||
Input string `json:"input" yaml:"input"` // text input
|
||||
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
|
||||
Backend string `json:"backend" yaml:"backend"` // backend engine override
|
||||
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
|
||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
|
||||
SampleRate int `json:"sample_rate,omitempty" yaml:"sample_rate,omitempty"` // (optional) desired output sample rate
|
||||
}
|
||||
@@ -65,7 +65,7 @@ type TTSRequest struct {
|
||||
// @Description VAD request body
|
||||
type VADRequest struct {
|
||||
BasicModelRequest
|
||||
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
|
||||
Audio []float32 `json:"audio" yaml:"audio"` // raw audio samples as float32 PCM
|
||||
}
|
||||
|
||||
type VADSegment struct {
|
||||
@@ -146,13 +146,13 @@ type SysInfoModel struct {
|
||||
}
|
||||
|
||||
type SystemInformationResponse struct {
|
||||
Backends []string `json:"backends"`
|
||||
Models []SysInfoModel `json:"loaded_models"`
|
||||
Backends []string `json:"backends"` // available backend engines
|
||||
Models []SysInfoModel `json:"loaded_models"` // currently loaded models
|
||||
}
|
||||
|
||||
type DetectionRequest struct {
|
||||
BasicModelRequest
|
||||
Image string `json:"image"`
|
||||
Image string `json:"image"` // URL or base64-encoded image to analyze
|
||||
}
|
||||
|
||||
type DetectionResponse struct {
|
||||
|
||||
@@ -2,9 +2,9 @@ package schema
|
||||
|
||||
type TokenizeRequest struct {
|
||||
BasicModelRequest
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content"` // text to tokenize
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int32 `json:"tokens"`
|
||||
Tokens []int32 `json:"tokens"` // token IDs
|
||||
}
|
||||
|
||||
@@ -231,6 +231,9 @@ func (c *fakeBackendClient) QuantizationProgress(_ context.Context, _ *pb.Quanti
|
||||
func (c *fakeBackendClient) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeBackendClient) Free(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- fakeBackendClientFactory ---
|
||||
|
||||
|
||||
@@ -175,6 +175,10 @@ func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.Quantization
|
||||
return &pb.Result{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeGRPCBackend) Free(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
var _ = Describe("InFlightTrackingClient", func() {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// DistributedModelManager wraps a local ModelManager and adds NATS fan-out
|
||||
@@ -84,6 +85,13 @@ func (d *DistributedBackendManager) DeleteBackend(name string) error {
|
||||
continue
|
||||
}
|
||||
if _, delErr := d.adapter.DeleteBackend(node.ID, name); delErr != nil {
|
||||
if errors.Is(delErr, nats.ErrNoResponders) {
|
||||
// Node's NATS subscription is gone — likely restarted with a new ID.
|
||||
// Mark it unhealthy so future fan-outs skip it.
|
||||
xlog.Warn("No NATS responders for node, marking unhealthy", "node", node.Name, "nodeID", node.ID)
|
||||
d.registry.MarkUnhealthy(context.Background(), node.ID)
|
||||
continue
|
||||
}
|
||||
xlog.Warn("Failed to propagate backend deletion to worker", "node", node.Name, "backend", name, "error", delErr)
|
||||
errs = append(errs, fmt.Errorf("node %s: %w", node.Name, delErr))
|
||||
}
|
||||
@@ -105,6 +113,11 @@ func (d *DistributedBackendManager) ListBackends() (gallery.SystemBackends, erro
|
||||
}
|
||||
reply, err := d.adapter.ListBackends(node.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, nats.ErrNoResponders) {
|
||||
xlog.Warn("No NATS responders for node, marking unhealthy", "node", node.Name, "nodeID", node.ID)
|
||||
d.registry.MarkUnhealthy(context.Background(), node.ID)
|
||||
continue
|
||||
}
|
||||
xlog.Warn("Failed to list backends on worker", "node", node.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
@@ -145,6 +158,11 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
|
||||
}
|
||||
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON))
|
||||
if err != nil {
|
||||
if errors.Is(err, nats.ErrNoResponders) {
|
||||
xlog.Warn("No NATS responders for node, marking unhealthy", "node", node.Name, "nodeID", node.ID)
|
||||
d.registry.MarkUnhealthy(context.Background(), node.ID)
|
||||
continue
|
||||
}
|
||||
xlog.Warn("Failed to install backend on worker", "node", node.Name, "backend", backendName, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -189,8 +189,8 @@ These settings apply to most LLM backends (llama.cpp, vLLM, etc.):
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `no_mulmatq` | bool | Disable matrix multiplication queuing |
|
||||
| `draft_model` | string | Draft model for speculative decoding |
|
||||
| `n_draft` | int32 | Number of draft tokens |
|
||||
| `draft_model` | string | Draft model GGUF file for speculative decoding (see [Speculative Decoding](#speculative-decoding)) |
|
||||
| `n_draft` | int32 | Maximum number of draft tokens per speculative step (default: 16) |
|
||||
| `quantization` | string | Quantization format |
|
||||
| `load_format` | string | Model load format |
|
||||
| `numa` | bool | Enable NUMA (Non-Uniform Memory Access) |
|
||||
@@ -211,6 +211,76 @@ YARN (Yet Another RoPE extensioN) settings for context extension:
|
||||
| `yarn_beta_fast` | float32 | YARN beta fast parameter |
|
||||
| `yarn_beta_slow` | float32 | YARN beta slow parameter |
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
Speculative decoding speeds up text generation by predicting multiple tokens ahead and verifying them in a single forward pass. The output is identical to normal decoding — only faster. This feature is only available with the `llama-cpp` backend.
|
||||
|
||||
There are two approaches:
|
||||
|
||||
#### Draft Model Speculative Decoding
|
||||
|
||||
Uses a smaller, faster model from the same model family to draft candidate tokens, which the main model then verifies. Requires a separate GGUF file for the draft model.
|
||||
|
||||
```yaml
|
||||
name: my-model
|
||||
backend: llama-cpp
|
||||
parameters:
|
||||
model: large-model.gguf
|
||||
draft_model: small-draft-model.gguf
|
||||
n_draft: 8
|
||||
options:
|
||||
- spec_p_min:0.8
|
||||
- draft_gpu_layers:99
|
||||
```
|
||||
|
||||
#### N-gram Self-Speculative Decoding
|
||||
|
||||
Uses patterns from the token history to predict future tokens — no extra model required. Works well for repetitive or structured output (code, JSON, lists).
|
||||
|
||||
```yaml
|
||||
name: my-model
|
||||
backend: llama-cpp
|
||||
parameters:
|
||||
model: my-model.gguf
|
||||
options:
|
||||
- spec_type:ngram_simple
|
||||
- spec_n_max:16
|
||||
```
|
||||
|
||||
#### Speculative Decoding Options
|
||||
|
||||
These are set via the `options:` array in the model configuration (format: `key:value`):
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `spec_type` | string | `none` | Speculative decoding type (see table below) |
|
||||
| `spec_n_max` / `draft_max` | int | 16 | Maximum number of tokens to draft per step |
|
||||
| `spec_n_min` / `draft_min` | int | 0 | Minimum draft tokens required to use speculation |
|
||||
| `spec_p_min` / `draft_p_min` | float | 0.75 | Minimum probability threshold for greedy acceptance |
|
||||
| `spec_p_split` | float | 0.1 | Split probability for tree-based branching |
|
||||
| `spec_ngram_size_n` / `ngram_size_n` | int | 12 | N-gram lookup size |
|
||||
| `spec_ngram_size_m` / `ngram_size_m` | int | 48 | M-gram proposal size |
|
||||
| `spec_ngram_min_hits` / `ngram_min_hits` | int | 1 | Minimum hits for accepting n-gram proposals |
|
||||
| `draft_gpu_layers` | int | -1 | GPU layers for the draft model (-1 = use default) |
|
||||
| `draft_ctx_size` | int | 0 | Context size for the draft model (0 = auto) |
|
||||
|
||||
#### Speculative Type Values
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft` | Draft model-based speculation (auto-set when `draft_model` is configured) |
|
||||
| `eagle3` | EAGLE3 draft model architecture |
|
||||
| `ngram_simple` | Simple self-speculative using token history |
|
||||
| `ngram_map_k` | N-gram with key-only map |
|
||||
| `ngram_map_k4v` | N-gram with keys and 4 m-gram values |
|
||||
| `ngram_mod` | Modified n-gram speculation |
|
||||
| `ngram_cache` | 3-level n-gram cache |
|
||||
|
||||
{{% notice note %}}
|
||||
Speculative decoding is automatically disabled when multimodal models (with `mmproj`) are active. The `n_draft` parameter can also be overridden per-request.
|
||||
{{% /notice %}}
|
||||
|
||||
### Prompt Caching
|
||||
|
||||
| Field | Type | Description |
|
||||
|
||||
223
docs/content/features/api-discovery.md
Normal file
223
docs/content/features/api-discovery.md
Normal file
@@ -0,0 +1,223 @@
|
||||
+++
|
||||
title = "API Discovery & Instructions"
|
||||
weight = 27
|
||||
toc = true
|
||||
description = "Programmatic API discovery for agents, tools, and automation"
|
||||
tags = ["API", "Agents", "Instructions", "Configuration", "Advanced"]
|
||||
categories = ["Features"]
|
||||
+++
|
||||
|
||||
LocalAI exposes a set of discovery endpoints that let external agents, coding assistants, and automation tools programmatically learn what the instance can do and how to control it — without reading documentation ahead of time.
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
# 1. Discover what's available
|
||||
curl http://localhost:8080/.well-known/localai.json
|
||||
|
||||
# 2. Browse instruction areas
|
||||
curl http://localhost:8080/api/instructions
|
||||
|
||||
# 3. Get an API guide for a specific instruction
|
||||
curl http://localhost:8080/api/instructions/config-management
|
||||
```
|
||||
|
||||
## Well-Known Discovery Endpoint
|
||||
|
||||
`GET /.well-known/localai.json`
|
||||
|
||||
Returns the instance version, all available endpoint URLs (flat and categorized), and runtime capabilities.
|
||||
|
||||
**Example response (abbreviated):**
|
||||
|
||||
```json
|
||||
{
|
||||
"version": "v2.28.0",
|
||||
"endpoints": {
|
||||
"chat_completions": "/v1/chat/completions",
|
||||
"models": "/v1/models",
|
||||
"config_metadata": "/api/models/config-metadata",
|
||||
"instructions": "/api/instructions",
|
||||
"swagger": "/swagger/index.html"
|
||||
},
|
||||
"endpoint_groups": {
|
||||
"openai_compatible": { "chat_completions": "/v1/chat/completions", "..." : "..." },
|
||||
"config_management": { "config_metadata": "/api/models/config-metadata", "..." : "..." },
|
||||
"model_management": { "..." : "..." },
|
||||
"monitoring": { "..." : "..." }
|
||||
},
|
||||
"capabilities": {
|
||||
"config_metadata": true,
|
||||
"config_patch": true,
|
||||
"vram_estimate": true,
|
||||
"mcp": true,
|
||||
"agents": false,
|
||||
"p2p": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `capabilities` object reflects the current runtime configuration — for example, `mcp` is only `true` if MCP is enabled, and `agents` is `true` only if the agent pool is running.
|
||||
|
||||
## Instructions API
|
||||
|
||||
Instructions are curated groups of related API endpoints. Each instruction maps to one or more Swagger tags and provides a focused, LLM-readable guide.
|
||||
|
||||
### List all instructions
|
||||
|
||||
`GET /api/instructions`
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/instructions
|
||||
```
|
||||
|
||||
Returns a compact list of instruction areas:
|
||||
|
||||
```json
|
||||
{
|
||||
"instructions": [
|
||||
{
|
||||
"name": "chat-inference",
|
||||
"description": "OpenAI-compatible chat completions, text completions, and embeddings",
|
||||
"tags": ["inference", "embeddings"],
|
||||
"url": "/api/instructions/chat-inference"
|
||||
},
|
||||
{
|
||||
"name": "config-management",
|
||||
"description": "Discover, read, and modify model configuration fields with VRAM estimation",
|
||||
"tags": ["config"],
|
||||
"url": "/api/instructions/config-management"
|
||||
}
|
||||
],
|
||||
"hint": "Fetch GET {url} for a markdown API guide. Add ?format=json for a raw OpenAPI fragment."
|
||||
}
|
||||
```
|
||||
|
||||
**Available instructions:**
|
||||
|
||||
| Instruction | Description |
|
||||
|-------------|-------------|
|
||||
| `chat-inference` | Chat completions, text completions, embeddings (OpenAI-compatible) |
|
||||
| `audio` | Text-to-speech, transcription, voice activity detection, sound generation |
|
||||
| `images` | Image generation and inpainting |
|
||||
| `model-management` | Browse gallery, install, delete, manage models and backends |
|
||||
| `config-management` | Discover, read, and modify model config fields with VRAM estimation |
|
||||
| `monitoring` | System metrics, backend status, system information |
|
||||
| `mcp` | Model Context Protocol — tool-augmented chat with MCP servers |
|
||||
| `agents` | Agent task and job management |
|
||||
| `video` | Video generation from text prompts |
|
||||
|
||||
### Get an instruction guide
|
||||
|
||||
`GET /api/instructions/:name`
|
||||
|
||||
By default, returns a **markdown guide** suitable for LLMs and humans:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/instructions/config-management
|
||||
```
|
||||
|
||||
Add `?format=json` to get a raw **OpenAPI fragment** (filtered Swagger spec with only the relevant paths and definitions):
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/instructions/config-management?format=json
|
||||
```
|
||||
|
||||
## Configuration Management APIs
|
||||
|
||||
These endpoints let agents discover model configuration fields, read current settings, modify them, and estimate VRAM usage.
|
||||
|
||||
### Config metadata
|
||||
|
||||
`GET /api/models/config-metadata`
|
||||
|
||||
Returns structured metadata for all model configuration fields, organized by section. Each field includes its YAML path, Go type, UI type, label, description, default value, validation constraints, and available options.
|
||||
|
||||
```bash
|
||||
# All fields
|
||||
curl http://localhost:8080/api/models/config-metadata
|
||||
|
||||
# Filter by section
|
||||
curl http://localhost:8080/api/models/config-metadata?section=parameters
|
||||
```
|
||||
|
||||
### Autocomplete values
|
||||
|
||||
`GET /api/models/config-metadata/autocomplete/:provider`
|
||||
|
||||
Returns runtime values for dynamic fields. Providers include `backends`, `models`, `models:chat`, `models:tts`, `models:transcript`, `models:vad`.
|
||||
|
||||
```bash
|
||||
# List available backends
|
||||
curl http://localhost:8080/api/models/config-metadata/autocomplete/backends
|
||||
|
||||
# List chat-capable models
|
||||
curl http://localhost:8080/api/models/config-metadata/autocomplete/models:chat
|
||||
```
|
||||
|
||||
### Read model config
|
||||
|
||||
`GET /api/models/config-json/:name`
|
||||
|
||||
Returns the full model configuration as JSON:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/models/config-json/my-model
|
||||
```
|
||||
|
||||
### Update model config
|
||||
|
||||
`PATCH /api/models/config-json/:name`
|
||||
|
||||
Deep-merges a JSON patch into the existing model configuration. Only include the fields you want to change:
|
||||
|
||||
```bash
|
||||
curl -X PATCH http://localhost:8080/api/models/config-json/my-model \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"context_size": 16384, "gpu_layers": 40}'
|
||||
```
|
||||
|
||||
The endpoint validates the merged config and writes it to disk as YAML.
|
||||
|
||||
{{% notice context="warning" %}}
|
||||
Config management endpoints require **admin authentication** when API keys are configured. The discovery and instructions endpoints are unauthenticated.
|
||||
{{% /notice %}}
|
||||
|
||||
### VRAM estimation
|
||||
|
||||
`POST /api/models/vram-estimate`
|
||||
|
||||
Estimates VRAM usage for an installed model based on its weight files, context size, and GPU layer offloading:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/models/vram-estimate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "my-model", "context_size": 8192}'
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"sizeBytes": 4368438272,
|
||||
"sizeDisplay": "4.4 GB",
|
||||
"vramBytes": 6123456789,
|
||||
"vramDisplay": "6.1 GB",
|
||||
"context_note": "Estimate used default context_size=8192. The model's trained maximum context is 131072; VRAM usage will be higher at larger context sizes.",
|
||||
"model_max_context": 131072
|
||||
}
|
||||
```
|
||||
|
||||
Optional parameters: `gpu_layers` (number of layers to offload, 0 = all), `kv_quant_bits` (KV cache quantization, 0 = fp16).
|
||||
|
||||
## Integration guide
|
||||
|
||||
A recommended workflow for agent/tool builders:
|
||||
|
||||
1. **Discover**: Fetch `/.well-known/localai.json` to learn available endpoints and capabilities
|
||||
2. **Browse instructions**: Fetch `/api/instructions` for an overview of instruction areas
|
||||
3. **Deep dive**: Fetch `/api/instructions/:name` for a markdown API guide on a specific area
|
||||
4. **Explore config**: Use `/api/models/config-metadata` to understand configuration fields
|
||||
5. **Interact**: Use the standard OpenAI-compatible endpoints for inference, and the config management endpoints for runtime tuning
|
||||
|
||||
## Swagger UI
|
||||
|
||||
The full interactive API documentation is available at `/swagger/index.html`. All annotated endpoints can be explored and tested directly from the browser.
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "v4.0.0"
|
||||
"version": "v4.1.1"
|
||||
}
|
||||
|
||||
@@ -1,4 +1,32 @@
|
||||
---
|
||||
- name: "gemma-4-26b-a4b-it-apex"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
- https://huggingface.co/mudler/gemma-4-26B-A4B-it-APEX-GGUF
|
||||
description: |
|
||||
AI model: gemma-4-26b-a4b-it-apex
|
||||
overrides:
|
||||
backend: llama-cpp
|
||||
function:
|
||||
automatic_tool_parsing_fallback: true
|
||||
grammar:
|
||||
disable: true
|
||||
known_usecases:
|
||||
- chat
|
||||
mmproj: llama-cpp/mmproj/gemma-4-26B-A4B-it-APEX-GGUF/mmproj-F16.gguf
|
||||
options:
|
||||
- use_jinja:true
|
||||
parameters:
|
||||
model: llama-cpp/models/gemma-4-26B-A4B-it-APEX-GGUF/gemma-4-26B-A4B-APEX-Quality.gguf
|
||||
template:
|
||||
use_tokenizer_template: true
|
||||
files:
|
||||
- filename: llama-cpp/mmproj/gemma-4-26B-A4B-it-APEX-GGUF/mmproj-F16.gguf
|
||||
sha256: cfc8dc4e41ab1d0c4846ed63ba4a62186846b04eb25fb38e1f2555ce2d00cb26
|
||||
uri: https://huggingface.co/mudler/gemma-4-26B-A4B-it-APEX-GGUF/resolve/main/mmproj-F16.gguf
|
||||
- filename: llama-cpp/models/gemma-4-26B-A4B-it-APEX-GGUF/gemma-4-26B-A4B-APEX-Quality.gguf
|
||||
uri: https://huggingface.co/mudler/gemma-4-26B-A4B-it-APEX-GGUF/resolve/main/gemma-4-26B-A4B-APEX-Quality.gguf
|
||||
sha256: a6591d7b41978e6f465acd9d03e96286f70912402c695158fb267ccbfbb740ed
|
||||
- name: "qwen3.5-35b-a3b-apex"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
@@ -1260,6 +1288,59 @@
|
||||
- filename: llama-cpp/mmproj/Qwen3-VL-Reranker-8B.mmproj-f16.gguf
|
||||
sha256: 15cd9bd4882dae771344f0ac204fce07de91b47c1438ada3861dfc817403c31e
|
||||
uri: https://huggingface.co/mradermacher/Qwen3-VL-Reranker-8B-GGUF/resolve/main/Qwen3-VL-Reranker-8B.mmproj-f16.gguf
|
||||
- name: "qwen3-vl-reranker-2b-i1"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
- https://huggingface.co/mradermacher/Qwen3-VL-Reranker-2B-i1-GGUF
|
||||
description: |
|
||||
**Model Name:** Qwen3-VL-Reranker-2B-i1
|
||||
**Base Model:** Qwen/Qwen3-VL-Reranker-2B
|
||||
|
||||
**Description:**
|
||||
A high-performance multimodal reranking model for state-of-the-art cross-modal search. It supports 30+ languages and handles text, images, screenshots, videos, and mixed modalities. With 8B parameters and a 32K context length, it refines retrieval results by combining embedding vectors with precise relevance scores. Optimized for efficiency, it supports quantized versions (e.g., Q8_0, Q4_K_M) and is ideal for applications requiring accurate multimodal content matching.
|
||||
|
||||
**Key Features:**
|
||||
- **Multimodal**: Text, images, videos, and mixed content.
|
||||
- **Language Support**: 30+ languages.
|
||||
- **Quantization**: Available in Q8_0 (best quality), Q4_K_M (fast, recommended), and lower-precision options.
|
||||
- **Performance**: Outperforms base models in retrieval tasks (e.g., JinaVDR, ViDoRe v3).
|
||||
- **Use Case**: Enhances search pipelines by refining embeddings with precise relevance scores.
|
||||
|
||||
**Downloads:**
|
||||
- [GGUF Files](https://huggingface.co/mradermacher/Qwen3-VL-Reranker-2B-i1-GGUF) (e.g., `Qwen3-VL-Reranker-2B.i1-Q4_K_M.gguf`).
|
||||
|
||||
**Usage:**
|
||||
- Requires `transformers`, `qwen-vl-utils`, and `torch`.
|
||||
- Example: `from scripts.qwen3_vl_reranker import Qwen3VLReranker; model = Qwen3VLReranker(...)`
|
||||
|
||||
**Citation:**
|
||||
@article{qwen3vlembedding, ...}
|
||||
|
||||
This description emphasizes its capabilities, efficiency, and versatility for multimodal search tasks.
|
||||
overrides:
|
||||
reranking: true
|
||||
parameters:
|
||||
model: llama-cpp/models/Qwen3-VL-Reranker-2B.i1-Q4_K_M.gguf
|
||||
name: Qwen3-VL-Reranker-2B-i1-GGUF
|
||||
backend: llama-cpp
|
||||
template:
|
||||
use_tokenizer_template: true
|
||||
known_usecases:
|
||||
- chat
|
||||
function:
|
||||
grammar:
|
||||
disable: true
|
||||
mmproj: llama-cpp/mmproj/Qwen3-VL-Reranker-2B.mmproj-f16.gguf
|
||||
description: Imported from https://huggingface.co/mradermacher/Qwen3-VL-Reranker-2B-GGUF/
|
||||
options:
|
||||
- use_jinja:true
|
||||
files:
|
||||
- filename: llama-cpp/models/Qwen3-VL-Reranker-2B.i1-Q4_K_M.gguf
|
||||
sha256: f19dfbceeef9f6ee1f7d0ff536d66e9b1b90424a4b8aa1d1777db43d20afdbc5
|
||||
uri: https://huggingface.co/mradermacher/Qwen3-VL-Reranker-2B-i1-GGUF/resolve/main/Qwen3-VL-Reranker-2B.i1-Q4_K_M.gguf
|
||||
- filename: llama-cpp/mmproj/Qwen3-VL-Reranker-8B.mmproj-f16.gguf
|
||||
sha256: d38b7ae347fc3e51726bfb9cba1b04885f1f005a4087d8070933e46509db5a6e
|
||||
uri: https://huggingface.co/mradermacher/Qwen3-VL-Reranker-2B-GGUF/resolve/main/Qwen3-VL-Reranker-2B.mmproj-f16.gguf
|
||||
- name: "liquidai.lfm2-2.6b-transcript"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
@@ -3067,6 +3148,35 @@
|
||||
- filename: Qwen_Qwen3-30B-A3B-Q4_K_M.gguf
|
||||
sha256: a015794bfb1d69cb03dbb86b185fb2b9b339f757df5f8f9dd9ebdab8f6ed5d32
|
||||
uri: huggingface://bartowski/Qwen_Qwen3-30B-A3B-GGUF/Qwen_Qwen3-30B-A3B-Q4_K_M.gguf
|
||||
- !!merge <<: *qwen3
|
||||
name: "qwen3-reranker-0.6b"
|
||||
tags:
|
||||
- qwen3
|
||||
- reranker
|
||||
- gguf
|
||||
- gpu
|
||||
- cpu
|
||||
urls:
|
||||
- https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
||||
description: |
|
||||
The Qwen3 Embedding model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This series inherits the exceptional multilingual capabilities, long-text understanding, and reasoning skills of its foundational model. The Qwen3 Embedding series represents significant advancements in multiple text embedding and ranking tasks, including text retrieval, code retrieval, text classification, text clustering, and bitext mining.
|
||||
**Exceptional Versatility**: The embedding model has achieved state-of-the-art performance across a wide range of downstream application evaluations. The 8B size embedding model ranks No.1 in the MTEB multilingual leaderboard (as of June 5, 2025, score 70.58), while the reranking model excels in various text retrieval scenarios.
|
||||
**Comprehensive Flexibility**: The Qwen3 Embedding series offers a full spectrum of sizes (from 0.6B to 8B) for both embedding and reranking models, catering to diverse use cases that prioritize efficiency and effectiveness. Developers can seamlessly combine these two modules. Additionally, the embedding model allows for flexible vector definitions across all dimensions, and both embedding and reranking models support user-defined instructions to enhance performance for specific tasks, languages, or scenarios.
|
||||
**Multilingual Capability**: The Qwen3 Embedding series offer support for over 100 languages, thanks to the multilingual capabilites of Qwen3 models. This includes various programming languages, and provides robust multilingual, cross-lingual, and code retrieval capabilities.
|
||||
**Qwen3-Reranker-0.6B** has the following features:
|
||||
- Model Type: Text Reranking
|
||||
- Supported Languages: 100+ Languages
|
||||
- Number of Paramaters: 0.6B
|
||||
- Context Length: 32k
|
||||
- Quantization: q4_K_M, q5_0, q5_K_M, q6_K, q8_0, f16
|
||||
overrides:
|
||||
reranking: true
|
||||
parameters:
|
||||
model: Qwen3-Reranker-0.6B.Q8_0.gguf
|
||||
files:
|
||||
- filename: Qwen3-Reranker-0.6B.Q8_0.gguf
|
||||
uri: huggingface://mradermacher/Qwen3-Reranker-0.6B-GGUF/Qwen3-Reranker-0.6B.Q8_0.gguf
|
||||
sha256: c525a7449243f690a7062e6377d6cf5adbb289354bd4316312367cd20e187ab7
|
||||
- !!merge <<: *qwen3
|
||||
name: "qwen3-235b-a22b-instruct-2507"
|
||||
icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
name: "qwen3"
|
||||
|
||||
config_file: |
|
||||
mmap: true
|
||||
parameters:
|
||||
context_size: 8192
|
||||
f16: true
|
||||
mmap: true
|
||||
backend: "llama-cpp"
|
||||
template:
|
||||
chat_message: |
|
||||
@@ -36,8 +39,6 @@ config_file: |
|
||||
<|im_start|>assistant
|
||||
completion: |
|
||||
{{.Input}}
|
||||
context_size: 8192
|
||||
f16: true
|
||||
stopwords:
|
||||
- '<|im_end|>'
|
||||
- '<dummy32000>'
|
||||
|
||||
@@ -85,4 +85,7 @@ type Backend interface {
|
||||
StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error)
|
||||
QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error
|
||||
StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
|
||||
// Free releases GPU/model resources (e.g. VRAM) without stopping the process.
|
||||
Free(ctx context.Context) error
|
||||
}
|
||||
|
||||
@@ -163,6 +163,11 @@ func (e *embedBackend) StopQuantization(ctx context.Context, in *pb.Quantization
|
||||
return e.s.StopQuantization(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) Free(ctx context.Context) error {
|
||||
_, err := e.s.Free(ctx, &pb.HealthMessage{})
|
||||
return err
|
||||
}
|
||||
|
||||
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
||||
|
||||
type embedBackendFineTuneProgressStream struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -52,11 +53,9 @@ func (ml *ModelLoader) deleteProcess(s string) error {
|
||||
}
|
||||
|
||||
// Free GPU resources before stopping the process to ensure VRAM is released
|
||||
if freeFunc, ok := model.GRPC(false, ml.wd).(interface{ Free() error }); ok {
|
||||
xlog.Debug("Calling Free() to release GPU resources", "model", s)
|
||||
if err := freeFunc.Free(); err != nil {
|
||||
xlog.Warn("Error freeing GPU resources", "error", err, "model", s)
|
||||
}
|
||||
xlog.Debug("Calling Free() to release GPU resources", "model", s)
|
||||
if err := model.GRPC(false, ml.wd).Free(context.Background()); err != nil {
|
||||
xlog.Warn("Error freeing GPU resources", "error", err, "model", s)
|
||||
}
|
||||
|
||||
process := model.Process()
|
||||
|
||||
@@ -21,6 +21,12 @@ type ReasoningExtractor struct {
|
||||
lastReasoning string
|
||||
lastCleaned string
|
||||
suppressReasoning bool
|
||||
|
||||
// ChatDelta reasoning accumulator — used by ProcessChatDeltaReasoning
|
||||
// to strip reasoning tags (e.g. <|channel>thought, <channel|>) that
|
||||
// the C++ autoparser includes in reasoning_content deltas.
|
||||
cdReasoningAccum string
|
||||
cdLastStrippedReasoning string
|
||||
}
|
||||
|
||||
// NewReasoningExtractor creates a new extractor for the given thinking token and config.
|
||||
@@ -64,6 +70,61 @@ func (e *ReasoningExtractor) ProcessToken(token string) (reasoningDelta, content
|
||||
return reasoningDelta, contentDelta
|
||||
}
|
||||
|
||||
// ProcessChatDeltaReasoning accumulates raw reasoning text from C++ autoparser
|
||||
// ChatDeltas, strips any embedded reasoning tags (e.g. <|channel>thought …
|
||||
// <channel|> for Gemma 4), and returns only the new stripped delta.
|
||||
// This prevents tag tokens from leaking into the reasoning field of SSE chunks.
|
||||
//
|
||||
// When the C++ autoparser already strips tags (e.g. <think> models), the text
|
||||
// passes through unchanged — ExtractReasoning finds no tags so we use the raw text.
|
||||
func (e *ReasoningExtractor) ProcessChatDeltaReasoning(rawDelta string) string {
|
||||
if rawDelta == "" {
|
||||
return ""
|
||||
}
|
||||
e.cdReasoningAccum += rawDelta
|
||||
|
||||
// Try to strip reasoning tags from accumulated ChatDelta reasoning.
|
||||
stripped, cleaned := ExtractReasoning(e.cdReasoningAccum, &e.config)
|
||||
|
||||
if stripped == "" {
|
||||
// ExtractReasoning found no reasoning content. This happens when:
|
||||
// a) A complete start tag was found but has no content after it yet
|
||||
// (cleaned == "" because everything is inside the unclosed tag)
|
||||
// → keep buffering
|
||||
// b) We're accumulating a partial multi-token start tag
|
||||
// (e.g. "<|channel>" before "thought" arrives)
|
||||
// → keep buffering
|
||||
// c) No tags at all — C++ already stripped them
|
||||
// → pass through the raw text as-is
|
||||
if cleaned == "" && strings.TrimSpace(e.cdReasoningAccum) != "" {
|
||||
// Case (a): tag found, unclosed, no content yet
|
||||
stripped = ""
|
||||
} else if e.thinkingStartToken != "" &&
|
||||
len(strings.TrimSpace(e.cdReasoningAccum)) < len(e.thinkingStartToken) &&
|
||||
strings.HasPrefix(e.thinkingStartToken, strings.TrimSpace(e.cdReasoningAccum)) {
|
||||
// Case (b): partial start tag prefix
|
||||
stripped = ""
|
||||
} else {
|
||||
// Case (c): no tags found — text is already clean from C++
|
||||
stripped = e.cdReasoningAccum
|
||||
}
|
||||
}
|
||||
|
||||
// Compute delta from stripped reasoning
|
||||
var delta string
|
||||
if len(stripped) > len(e.cdLastStrippedReasoning) && strings.HasPrefix(stripped, e.cdLastStrippedReasoning) {
|
||||
delta = stripped[len(e.cdLastStrippedReasoning):]
|
||||
} else if stripped != e.cdLastStrippedReasoning && stripped != "" {
|
||||
delta = stripped
|
||||
}
|
||||
e.cdLastStrippedReasoning = stripped
|
||||
|
||||
if e.suppressReasoning {
|
||||
return ""
|
||||
}
|
||||
return delta
|
||||
}
|
||||
|
||||
// Reasoning returns the total accumulated reasoning after streaming.
|
||||
func (e *ReasoningExtractor) Reasoning() string {
|
||||
return e.lastReasoning
|
||||
@@ -84,6 +145,8 @@ func (e *ReasoningExtractor) Reset() {
|
||||
e.accumulated = ""
|
||||
e.lastReasoning = ""
|
||||
e.lastCleaned = ""
|
||||
e.cdReasoningAccum = ""
|
||||
e.cdLastStrippedReasoning = ""
|
||||
}
|
||||
|
||||
// ResetAndSuppressReasoning clears state and suppresses future reasoning deltas.
|
||||
@@ -95,6 +158,8 @@ func (e *ReasoningExtractor) ResetAndSuppressReasoning() {
|
||||
e.accumulated = ""
|
||||
e.lastReasoning = ""
|
||||
e.lastCleaned = ""
|
||||
e.cdReasoningAccum = ""
|
||||
e.cdLastStrippedReasoning = ""
|
||||
e.suppressReasoning = true
|
||||
}
|
||||
|
||||
|
||||
@@ -195,4 +195,91 @@ var _ = Describe("ReasoningExtractor", func() {
|
||||
Expect(ext.CleanedContent()).To(Equal("visible content"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ProcessChatDeltaReasoning with Gemma 4 tags", func() {
|
||||
It("should strip <|channel>thought and <channel|> tags from streaming deltas", func() {
|
||||
ext := NewReasoningExtractor("<|channel>thought", Config{})
|
||||
|
||||
// Simulate C++ autoparser sending tag tokens as reasoning
|
||||
d1 := ext.ProcessChatDeltaReasoning("<|channel>")
|
||||
Expect(d1).To(BeEmpty(), "start tag prefix should be buffered, not emitted")
|
||||
|
||||
d2 := ext.ProcessChatDeltaReasoning("thought")
|
||||
Expect(d2).To(BeEmpty(), "start tag suffix should be buffered, not emitted")
|
||||
|
||||
d3 := ext.ProcessChatDeltaReasoning("\n")
|
||||
Expect(d3).To(BeEmpty(), "newline after start tag should not emit yet")
|
||||
|
||||
d4 := ext.ProcessChatDeltaReasoning("The")
|
||||
Expect(d4).To(Equal("The"))
|
||||
|
||||
d5 := ext.ProcessChatDeltaReasoning(" user")
|
||||
Expect(d5).To(Equal(" user"))
|
||||
|
||||
d6 := ext.ProcessChatDeltaReasoning(" asks")
|
||||
Expect(d6).To(Equal(" asks"))
|
||||
|
||||
// Trailing newline gets TrimSpaced by ExtractReasoning,
|
||||
// so it appears delayed with the next non-whitespace token
|
||||
d7 := ext.ProcessChatDeltaReasoning("\n")
|
||||
Expect(d7).To(BeEmpty(), "trailing newline is buffered by TrimSpace")
|
||||
|
||||
d8 := ext.ProcessChatDeltaReasoning("2+2=4")
|
||||
Expect(d8).To(Equal("\n2+2=4"), "delayed newline emitted with next content")
|
||||
|
||||
d9 := ext.ProcessChatDeltaReasoning("<channel|>")
|
||||
Expect(d9).To(BeEmpty(), "close tag should be consumed, not emitted")
|
||||
})
|
||||
|
||||
It("should handle empty deltas", func() {
|
||||
ext := NewReasoningExtractor("<|channel>thought", Config{})
|
||||
d := ext.ProcessChatDeltaReasoning("")
|
||||
Expect(d).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should pass through reasoning without tags unchanged", func() {
|
||||
ext := NewReasoningExtractor("<think>", Config{})
|
||||
|
||||
// When C++ autoparser already strips tags (e.g. <think> models),
|
||||
// reasoning arrives clean — just pass it through.
|
||||
d1 := ext.ProcessChatDeltaReasoning("I need to")
|
||||
Expect(d1).To(Equal("I need to"))
|
||||
|
||||
d2 := ext.ProcessChatDeltaReasoning(" think carefully")
|
||||
Expect(d2).To(Equal(" think carefully"))
|
||||
})
|
||||
|
||||
It("should strip <think> tags if C++ autoparser includes them", func() {
|
||||
ext := NewReasoningExtractor("<think>", Config{})
|
||||
|
||||
d1 := ext.ProcessChatDeltaReasoning("<think>")
|
||||
Expect(d1).To(BeEmpty())
|
||||
|
||||
d2 := ext.ProcessChatDeltaReasoning("reasoning")
|
||||
Expect(d2).To(Equal("reasoning"))
|
||||
|
||||
d3 := ext.ProcessChatDeltaReasoning("</think>")
|
||||
Expect(d3).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should respect suppressReasoning", func() {
|
||||
ext := NewReasoningExtractor("<|channel>thought", Config{})
|
||||
ext.ResetAndSuppressReasoning()
|
||||
|
||||
d := ext.ProcessChatDeltaReasoning("some reasoning")
|
||||
Expect(d).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should reset ChatDelta state on Reset", func() {
|
||||
ext := NewReasoningExtractor("<|channel>thought", Config{})
|
||||
|
||||
ext.ProcessChatDeltaReasoning("<|channel>thought")
|
||||
ext.ProcessChatDeltaReasoning("\nfirst reasoning")
|
||||
ext.Reset()
|
||||
|
||||
// After reset, should start fresh
|
||||
d := ext.ProcessChatDeltaReasoning("clean reasoning")
|
||||
Expect(d).To(Equal("clean reasoning"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
// - <|inner_prefix|> (Apertus models)
|
||||
// - <seed:think> (Seed models)
|
||||
// - <think> (DeepSeek, Granite, ExaOne models)
|
||||
// - <|channel>thought (Gemma 4 models)
|
||||
// - <|think|> (Solar Open models)
|
||||
// - <thinking> (General thinking tag)
|
||||
// - [THINK] (Magistral models)
|
||||
@@ -24,6 +25,7 @@ func DetectThinkingStartToken(prompt string, config *Config) string {
|
||||
// Based on llama.cpp's chat-parser.cpp implementations
|
||||
defaultTokens := []string{
|
||||
"<|START_THINKING|>", // Command-R models
|
||||
"<|channel>thought", // Gemma 4 models (before <|think|> — Gemma 4 templates contain both)
|
||||
"<|inner_prefix|>", // Apertus models
|
||||
"<seed:think>", // Seed models
|
||||
"<think>", // DeepSeek, Granite, ExaOne models
|
||||
@@ -100,11 +102,18 @@ func PrependThinkingTokenIfNeeded(content string, startToken string) string {
|
||||
return r == ' ' || r == '\t' || r == '\n' || r == '\r'
|
||||
})
|
||||
|
||||
// If content already starts with the token, don't prepend
|
||||
// If content already contains the token, don't prepend
|
||||
if strings.Contains(trimmed, startToken) {
|
||||
return content
|
||||
}
|
||||
|
||||
// If content is a non-empty prefix of the start token (e.g. "<|channel>"
|
||||
// accumulating toward "<|channel>thought"), don't prepend — we're still
|
||||
// receiving the tag token-by-token during streaming.
|
||||
if trimmed != "" && strings.HasPrefix(startToken, trimmed) {
|
||||
return content
|
||||
}
|
||||
|
||||
// Find where leading whitespace ends
|
||||
whitespaceEnd := 0
|
||||
for whitespaceEnd < len(content) {
|
||||
@@ -146,6 +155,7 @@ func ExtractReasoning(content string, config *Config) (reasoning string, cleaned
|
||||
{"<seed:think>", "</seed:think>"}, // Seed models
|
||||
{"<think>", "</think>"}, // DeepSeek, Granite, ExaOne models
|
||||
{"<|think|>", "<|end|><|begin|>assistant<|content|>"}, // Solar Open models (complex end)
|
||||
{"<|channel>thought", "<channel|>"}, // Gemma 4 models
|
||||
{"<thinking>", "</thinking>"}, // General thinking tag
|
||||
{"[THINK]", "[/THINK]"}, // Magistral models
|
||||
}
|
||||
|
||||
@@ -317,6 +317,29 @@ var _ = Describe("ExtractReasoning", func() {
|
||||
Expect(cleaned).To(Equal("Before "))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when content has <|channel>thought tags (Gemma 4)", func() {
|
||||
It("should extract reasoning from channel thought block", func() {
|
||||
content := "<|channel>thought\nThis is my reasoning\n<channel|>Hello! How can I help?"
|
||||
reasoning, cleaned := ExtractReasoning(content, nil)
|
||||
Expect(reasoning).To(Equal("This is my reasoning"))
|
||||
Expect(cleaned).To(Equal("Hello! How can I help?"))
|
||||
})
|
||||
|
||||
It("should handle unclosed channel thought block", func() {
|
||||
content := "<|channel>thought\nIncomplete reasoning"
|
||||
reasoning, cleaned := ExtractReasoning(content, nil)
|
||||
Expect(reasoning).To(Equal("Incomplete reasoning"))
|
||||
Expect(cleaned).To(Equal(""))
|
||||
})
|
||||
|
||||
It("should handle content before and after channel thought block", func() {
|
||||
content := "Before <|channel>thought\nGemma 4 reasoning\n<channel|> After"
|
||||
reasoning, cleaned := ExtractReasoning(content, nil)
|
||||
Expect(reasoning).To(Equal("Gemma 4 reasoning"))
|
||||
Expect(cleaned).To(Equal("Before After"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("DetectThinkingStartToken", func() {
|
||||
@@ -339,6 +362,12 @@ var _ = Describe("DetectThinkingStartToken", func() {
|
||||
Expect(token).To(Equal("<thinking>"))
|
||||
})
|
||||
|
||||
It("should detect <|channel>thought at the end (Gemma 4)", func() {
|
||||
prompt := "Prompt text <|channel>thought"
|
||||
token := DetectThinkingStartToken(prompt, nil)
|
||||
Expect(token).To(Equal("<|channel>thought"))
|
||||
})
|
||||
|
||||
It("should detect <|inner_prefix|> at the end", func() {
|
||||
prompt := "Prompt <|inner_prefix|>"
|
||||
token := DetectThinkingStartToken(prompt, nil)
|
||||
@@ -817,6 +846,14 @@ var _ = Describe("ExtractReasoningWithConfig", func() {
|
||||
Expect(cleaned).To(Equal("Text More"))
|
||||
})
|
||||
|
||||
It("should strip reasoning from Gemma 4 channel tags when StripReasoningOnly is true", func() {
|
||||
content := "<|channel>thought\nGemma 4 reasoning\n<channel|>Response text"
|
||||
config := Config{StripReasoningOnly: boolPtr(true)}
|
||||
reasoning, cleaned := ExtractReasoningWithConfig(content, "<|channel>thought", config)
|
||||
Expect(reasoning).To(BeEmpty())
|
||||
Expect(cleaned).To(Equal("Response text"))
|
||||
})
|
||||
|
||||
It("should strip reasoning with multiline content when StripReasoningOnly is true", func() {
|
||||
content := "Start <thinking>Line 1\nLine 2\nLine 3</thinking> End"
|
||||
config := Config{StripReasoningOnly: boolPtr(true)}
|
||||
|
||||
@@ -14,12 +14,12 @@ var weightExts = map[string]bool{
|
||||
".gguf": true, ".safetensors": true, ".bin": true, ".pt": true,
|
||||
}
|
||||
|
||||
func isWeightFile(nameOrURI string) bool {
|
||||
func IsWeightFile(nameOrURI string) bool {
|
||||
ext := strings.ToLower(path.Ext(path.Base(nameOrURI)))
|
||||
return weightExts[ext]
|
||||
}
|
||||
|
||||
func isGGUF(nameOrURI string) bool {
|
||||
func IsGGUF(nameOrURI string) bool {
|
||||
return strings.ToLower(path.Ext(path.Base(nameOrURI))) == ".gguf"
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size
|
||||
var firstGGUFURI string
|
||||
for i := range files {
|
||||
f := &files[i]
|
||||
if !isWeightFile(f.URI) {
|
||||
if !IsWeightFile(f.URI) {
|
||||
continue
|
||||
}
|
||||
sz := f.Size
|
||||
@@ -48,7 +48,7 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size
|
||||
}
|
||||
}
|
||||
sizeBytes += uint64(sz)
|
||||
if isGGUF(f.URI) {
|
||||
if IsGGUF(f.URI) {
|
||||
ggufSize += uint64(sz)
|
||||
if firstGGUFURI == "" {
|
||||
firstGGUFURI = f.URI
|
||||
|
||||
@@ -34,10 +34,11 @@ func (defaultGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMet
|
||||
func ggufFileToMeta(f *gguf.GGUFFile) *GGUFMeta {
|
||||
arch := f.Architecture()
|
||||
meta := &GGUFMeta{
|
||||
BlockCount: uint32(arch.BlockCount),
|
||||
EmbeddingLength: uint32(arch.EmbeddingLength),
|
||||
HeadCount: uint32(arch.AttentionHeadCount),
|
||||
HeadCountKV: uint32(arch.AttentionHeadCountKV),
|
||||
BlockCount: uint32(arch.BlockCount),
|
||||
EmbeddingLength: uint32(arch.EmbeddingLength),
|
||||
HeadCount: uint32(arch.AttentionHeadCount),
|
||||
HeadCountKV: uint32(arch.AttentionHeadCountKV),
|
||||
MaximumContextLength: arch.MaximumContextLength,
|
||||
}
|
||||
if meta.HeadCountKV == 0 {
|
||||
meta.HeadCountKV = meta.HeadCount
|
||||
|
||||
@@ -15,10 +15,11 @@ type SizeResolver interface {
|
||||
|
||||
// GGUFMeta holds parsed GGUF metadata used for VRAM estimation.
|
||||
type GGUFMeta struct {
|
||||
BlockCount uint32
|
||||
EmbeddingLength uint32
|
||||
HeadCount uint32
|
||||
HeadCountKV uint32
|
||||
BlockCount uint32
|
||||
EmbeddingLength uint32
|
||||
HeadCount uint32
|
||||
HeadCountKV uint32
|
||||
MaximumContextLength uint64
|
||||
}
|
||||
|
||||
// GGUFMetadataReader reads GGUF metadata from a URI (e.g. via HTTP Range).
|
||||
@@ -35,8 +36,8 @@ type EstimateOptions struct {
|
||||
|
||||
// EstimateResult holds estimated download size and VRAM with display strings.
|
||||
type EstimateResult struct {
|
||||
SizeBytes uint64
|
||||
SizeDisplay string
|
||||
VRAMBytes uint64
|
||||
VRAMDisplay string
|
||||
SizeBytes uint64 `json:"sizeBytes"` // total model weight size in bytes
|
||||
SizeDisplay string `json:"sizeDisplay"` // human-readable size (e.g. "4.2 GB")
|
||||
VRAMBytes uint64 `json:"vramBytes"` // estimated VRAM usage in bytes
|
||||
VRAMDisplay string `json:"vramDisplay"` // human-readable VRAM (e.g. "6.1 GB")
|
||||
}
|
||||
|
||||
1410
swagger/docs.go
1410
swagger/docs.go
File diff suppressed because it is too large
Load Diff
6
swagger/embed.go
Normal file
6
swagger/embed.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package swagger
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed swagger.json
|
||||
var SwaggerJSON []byte
|
||||
1412
swagger/swagger.json
1412
swagger/swagger.json
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -101,6 +101,25 @@ var _ = BeforeSuite(func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed())
|
||||
|
||||
// Create model config for autoparser tests (NoGrammar so tool calls
|
||||
// are driven entirely by the backend's ChatDeltas, not grammar enforcement)
|
||||
autoparserConfig := map[string]any{
|
||||
"name": "mock-model-autoparser",
|
||||
"backend": "mock-backend",
|
||||
"parameters": map[string]any{
|
||||
"model": "mock-model.bin",
|
||||
},
|
||||
"function": map[string]any{
|
||||
"grammar": map[string]any{
|
||||
"disable": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
autoparserPath := filepath.Join(modelsPath, "mock-model-autoparser.yaml")
|
||||
autoparserYAML, err := yaml.Marshal(autoparserConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(autoparserPath, autoparserYAML, 0644)).To(Succeed())
|
||||
|
||||
// Start mock MCP server and create MCP-enabled model config
|
||||
mcpServerURL, mcpServerShutdown = startMockMCPServer()
|
||||
mcpConfig := mcpModelConfig(mcpServerURL)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user