mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 00:26:34 -04:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a48157a80 | ||
|
|
41c838b2df | ||
|
|
21e793ad2a | ||
|
|
7c190bb4b9 | ||
|
|
d77a9137d8 | ||
|
|
661a0c3b9d | ||
|
|
00b8989886 | ||
|
|
43e0d397ca | ||
|
|
a1a7a219ed | ||
|
|
3937ec6527 | ||
|
|
1355b55794 | ||
|
|
5a2626d465 | ||
|
|
a39591f144 | ||
|
|
8c785dbe4a | ||
|
|
4abf5befbb | ||
|
|
195b910260 | ||
|
|
ba21bf667c | ||
|
|
7bd1693ad0 | ||
|
|
b5ac3a7373 | ||
|
|
53de474ef5 | ||
|
|
c33d36b870 | ||
|
|
57fa178a64 | ||
|
|
745473cbe6 | ||
|
|
594c9fd92e | ||
|
|
8af963bdd9 | ||
|
|
6e1dbae256 | ||
|
|
53bdb18d10 |
@@ -61,6 +61,12 @@ Always check `llama.cpp` for new model configuration options that should be supp
|
|||||||
- `reasoning_format` - Reasoning format options
|
- `reasoning_format` - Reasoning format options
|
||||||
- Any new flags or parameters
|
- Any new flags or parameters
|
||||||
|
|
||||||
|
### Speculative Decoding Types
|
||||||
|
|
||||||
|
The `spec_type` option in `grpc-server.cpp` delegates to upstream's `common_speculative_types_from_names()`, so new speculative types added to the `common_speculative_type_from_name` map in `common/speculative.cpp` are picked up automatically with no code changes - only docs need an entry in `docs/content/advanced/model-configuration.md`. Current values: `none`, `draft-simple`, `draft-eagle3`, `draft-mtp`, `ngram-simple`, `ngram-map-k`, `ngram-map-k4v`, `ngram-mod`, `ngram-cache`.
|
||||||
|
|
||||||
|
`draft-mtp` (Multi-Token Prediction, [ggml-org/llama.cpp#22673](https://github.com/ggml-org/llama.cpp/pull/22673)) does not need a separate draft GGUF: when `spec_type` includes `draft-mtp` and `draftmodel` is empty, the upstream server creates an MTP context off the target model itself. LocalAI's gRPC layer needs no changes for this — it works through the existing `params.speculative.types` plumbing and the derived `cparams.n_rs_seq = params.speculative.need_n_rs_seq()` in `common_context_params_to_llama`.
|
||||||
|
|
||||||
### Implementation Guidelines
|
### Implementation Guidelines
|
||||||
|
|
||||||
1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
|
1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# ds4 backend Makefile.
|
# ds4 backend Makefile.
|
||||||
#
|
#
|
||||||
# Upstream pin lives below as DS4_VERSION?=0cba357ca1bc0e7510421cc26888e420ea942123
|
# Upstream pin lives below as DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||||
# (.github/bump_deps.sh) can find and update it - matches the
|
# (.github/bump_deps.sh) can find and update it - matches the
|
||||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||||
|
|
||||||
DS4_VERSION?=0cba357ca1bc0e7510421cc26888e420ea942123
|
DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||||
DS4_REPO?=https://github.com/antirez/ds4
|
DS4_REPO?=https://github.com/antirez/ds4
|
||||||
|
|
||||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
IK_LLAMA_VERSION?=949bb8f1d660fc1264c137a6f3dbd619375f6134
|
IK_LLAMA_VERSION?=3e573cfea6e0a332eff822ffbdb1dd3b112e9051
|
||||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=a9883db8ee021cf16783016a60996d41820b5195
|
LLAMA_VERSION?=0253fb21f595246f54c192fe8332f34173be251b
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
#include <grpcpp/health_check_service_interface.h>
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
#include <grpcpp/security/server_credentials.h>
|
#include <grpcpp/security/server_credentials.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
#include <algorithm>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -450,6 +451,8 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// vector; the turboquant fork still uses the legacy scalar. The
|
// vector; the turboquant fork still uses the legacy scalar. The
|
||||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
// LOCALAI_LEGACY_LLAMA_CPP_SPEC macro is injected by
|
||||||
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
// backend/cpp/turboquant/patch-grpc-server.sh for fork builds only.
|
||||||
|
// Upstream renamed COMMON_SPECULATIVE_TYPE_DRAFT -> ..._DRAFT_SIMPLE
|
||||||
|
// in ggml-org/llama.cpp#22964; the fork still uses the old name.
|
||||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||||
@@ -458,7 +461,7 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
const bool no_spec_type = params.speculative.types.empty() ||
|
const bool no_spec_type = params.speculative.types.empty() ||
|
||||||
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
(params.speculative.types.size() == 1 && params.speculative.types[0] == COMMON_SPECULATIVE_TYPE_NONE);
|
||||||
if (no_spec_type) {
|
if (no_spec_type) {
|
||||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT };
|
params.speculative.types = { COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE };
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -685,6 +688,136 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// If conversion fails, keep default value (8)
|
// If conversion fails, keep default value (8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||||
|
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||||
|
// option lets users decouple the two (useful for embeddings/rerank).
|
||||||
|
} else if (!strcmp(optname, "n_ubatch") || !strcmp(optname, "ubatch")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.n_ubatch = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- main-model batch threads (upstream -tb / --threads-batch) ---
|
||||||
|
} else if (!strcmp(optname, "threads_batch") || !strcmp(optname, "n_threads_batch")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try {
|
||||||
|
int n = std::stoi(optval_str);
|
||||||
|
if (n <= 0) n = (int)std::thread::hardware_concurrency();
|
||||||
|
params.cpuparams_batch.n_threads = n;
|
||||||
|
} catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- pooling type for embeddings (upstream --pooling) ---
|
||||||
|
} else if (!strcmp(optname, "pooling_type") || !strcmp(optname, "pooling")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
if (optval_str == "none") params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
|
else if (optval_str == "mean") params.pooling_type = LLAMA_POOLING_TYPE_MEAN;
|
||||||
|
else if (optval_str == "cls") params.pooling_type = LLAMA_POOLING_TYPE_CLS;
|
||||||
|
else if (optval_str == "last") params.pooling_type = LLAMA_POOLING_TYPE_LAST;
|
||||||
|
else if (optval_str == "rank") params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
|
// unknown values silently leave UNSPECIFIED (auto-detect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- llama log verbosity threshold (upstream -lv / --verbosity) ---
|
||||||
|
} else if (!strcmp(optname, "verbosity")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.verbosity = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- O_DIRECT model loading (upstream --direct-io) ---
|
||||||
|
} else if (!strcmp(optname, "direct_io") || !strcmp(optname, "use_direct_io")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.use_direct_io = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.use_direct_io = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- embedding normalization (upstream --embd-normalize) ---
|
||||||
|
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm
|
||||||
|
} else if (!strcmp(optname, "embd_normalize") || !strcmp(optname, "embedding_normalize")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.embd_normalize = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- reasoning parser (upstream --reasoning-format) ---
|
||||||
|
// Picks the parser for <think> blocks emitted by reasoning models.
|
||||||
|
// none / auto / deepseek / deepseek-legacy
|
||||||
|
} else if (!strcmp(optname, "reasoning_format")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
if (optval_str == "none") params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
else if (optval_str == "auto") params.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||||
|
else if (optval_str == "deepseek") params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
else if (optval_str == "deepseek-legacy" || optval_str == "deepseek_legacy")
|
||||||
|
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
||||||
|
// unknown values silently keep the upstream default (DEEPSEEK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- reasoning budget (upstream --reasoning-budget) ---
|
||||||
|
// -1 unlimited, 0 disabled, >0 token budget for thinking blocks.
|
||||||
|
// Distinct from per-request `enable_thinking` (chat_template_kwargs).
|
||||||
|
} else if (!strcmp(optname, "enable_reasoning") || !strcmp(optname, "reasoning_budget")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.enable_reasoning = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- prefill assistant turn (upstream --no-prefill-assistant) ---
|
||||||
|
} else if (!strcmp(optname, "prefill_assistant")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.prefill_assistant = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.prefill_assistant = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- mmproj GPU offload (upstream --no-mmproj-offload, inverted) ---
|
||||||
|
} else if (!strcmp(optname, "mmproj_use_gpu") || !strcmp(optname, "mmproj_offload")) {
|
||||||
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||||
|
params.mmproj_use_gpu = true;
|
||||||
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||||
|
params.mmproj_use_gpu = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- per-image vision token budget (upstream --image-min/max-tokens) ---
|
||||||
|
} else if (!strcmp(optname, "image_min_tokens")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.image_min_tokens = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
} else if (!strcmp(optname, "image_max_tokens")) {
|
||||||
|
if (optval != NULL) {
|
||||||
|
try { params.image_max_tokens = std::stoi(optval_str); } catch (...) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- main-model tensor buffer overrides (upstream --override-tensor) ---
|
||||||
|
// Format: <tensor regex>=<buffer type>,<tensor regex>=<buffer type>,...
|
||||||
|
// Mirrors the existing `draft_override_tensor` parser below.
|
||||||
|
} else if (!strcmp(optname, "override_tensor") || !strcmp(optname, "tensor_buft_overrides")) {
|
||||||
|
ggml_backend_load_all();
|
||||||
|
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||||
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
|
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||||
|
if (buft) {
|
||||||
|
buft_list[ggml_backend_buft_name(buft)] = buft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static std::list<std::string> override_names;
|
||||||
|
std::string cur;
|
||||||
|
auto flush = [&](const std::string & spec) {
|
||||||
|
auto pos = spec.find('=');
|
||||||
|
if (pos == std::string::npos) return;
|
||||||
|
const std::string name = spec.substr(0, pos);
|
||||||
|
const std::string type = spec.substr(pos + 1);
|
||||||
|
auto it = buft_list.find(type);
|
||||||
|
if (it == buft_list.end()) return; // unknown buffer type: ignore
|
||||||
|
override_names.push_back(name);
|
||||||
|
params.tensor_buft_overrides.push_back(
|
||||||
|
{override_names.back().c_str(), it->second});
|
||||||
|
};
|
||||||
|
for (char c : optval_str) {
|
||||||
|
if (c == ',') { if (!cur.empty()) { flush(cur); cur.clear(); } }
|
||||||
|
else { cur.push_back(c); }
|
||||||
|
}
|
||||||
|
if (!cur.empty()) flush(cur);
|
||||||
|
|
||||||
// Speculative decoding options
|
// Speculative decoding options
|
||||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||||
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
#ifdef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||||
@@ -701,16 +834,27 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
|||||||
// Upstream switched to a vector of types (comma-separated for multi-type
|
// Upstream switched to a vector of types (comma-separated for multi-type
|
||||||
// chaining via common_speculative_types_from_names). We keep accepting a
|
// chaining via common_speculative_types_from_names). We keep accepting a
|
||||||
// single value here, but also tolerate comma-separated lists.
|
// single value here, but also tolerate comma-separated lists.
|
||||||
|
//
|
||||||
|
// ggml-org/llama.cpp#22964 also renamed the registered names from
|
||||||
|
// underscore- to dash-separated form, and replaced the bare
|
||||||
|
// `draft`/`eagle3` aliases with `draft-simple`/`draft-eagle3`. We
|
||||||
|
// normalize each token here so existing model configs keep working.
|
||||||
|
auto normalize_spec_name = [](std::string s) -> std::string {
|
||||||
|
std::replace(s.begin(), s.end(), '_', '-');
|
||||||
|
if (s == "draft") return "draft-simple";
|
||||||
|
if (s == "eagle3") return "draft-eagle3";
|
||||||
|
return s;
|
||||||
|
};
|
||||||
std::vector<std::string> names;
|
std::vector<std::string> names;
|
||||||
std::string item;
|
std::string item;
|
||||||
for (char c : optval_str) {
|
for (char c : optval_str) {
|
||||||
if (c == ',') {
|
if (c == ',') {
|
||||||
if (!item.empty()) { names.push_back(item); item.clear(); }
|
if (!item.empty()) { names.push_back(normalize_spec_name(item)); item.clear(); }
|
||||||
} else {
|
} else {
|
||||||
item.push_back(c);
|
item.push_back(c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!item.empty()) names.push_back(item);
|
if (!item.empty()) names.push_back(normalize_spec_name(item));
|
||||||
auto parsed = common_speculative_types_from_names(names);
|
auto parsed = common_speculative_types_from_names(names);
|
||||||
if (!parsed.empty()) {
|
if (!parsed.empty()) {
|
||||||
params.speculative.types = parsed;
|
params.speculative.types = parsed;
|
||||||
@@ -2794,7 +2938,9 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
// Honor the load-time embd_normalize set via options:embd_normalize.
|
||||||
|
// -1 none, 0 max-abs, 1 taxicab, 2 L2 (default), >2 p-norm.
|
||||||
|
int embd_normalize = params_base.embd_normalize;
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
auto rd = ctx_server.get_response_reader();
|
auto rd = ctx_server.get_response_reader();
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=90e87bc846f17059771efb8aaa31e9ef0cab6f78
|
STABLEDIFFUSION_GGML_VERSION?=bd17f53b7386fb5f60e8587b75e73c4b2fed3426
|
||||||
|
|
||||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=3e9b7d0fef3528ee2208da3cdb873a2c53d2ae2f
|
WHISPER_CPP_VERSION?=968eebe77225d25e57a3f981da7c696310f0e881
|
||||||
SO_TARGET?=libgowhisper.so
|
SO_TARGET?=libgowhisper.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
@@ -3,5 +3,5 @@
|
|||||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||||
# so uv consults this index alongside PyPI.
|
# so uv consults this index alongside PyPI.
|
||||||
--extra-index-url https://wheels.vllm.ai/0.20.2/cu130
|
--extra-index-url https://wheels.vllm.ai/0.21.0/cu130
|
||||||
vllm==0.20.2
|
vllm==0.21.0
|
||||||
|
|||||||
@@ -54,6 +54,13 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
|||||||
cfg.modelTemplate = chatTemplate.ValueString()
|
cfg.modelTemplate = chatTemplate.ValueString()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auto-enable Multi-Token Prediction (ggml-org/llama.cpp#22673) when the
|
||||||
|
// GGUF carries an embedded MTP head. Skipped silently for non-MTP models
|
||||||
|
// and when the user already configured a spec_type.
|
||||||
|
if n, ok := HasEmbeddedMTPHead(f); ok {
|
||||||
|
ApplyMTPDefaults(cfg, n)
|
||||||
|
}
|
||||||
|
|
||||||
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
||||||
|
|
||||||
// template estimations
|
// template estimations
|
||||||
|
|||||||
84
core/config/mtp.go
Normal file
84
core/config/mtp.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
gguf "github.com/gpustack/gguf-parser-go"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mtpSpecOptions lists the speculative-decoding option keys auto-applied when
|
||||||
|
// an MTP head is detected on a llama-cpp GGUF. Defaults track the upstream
|
||||||
|
// MTP PR (ggml-org/llama.cpp#22673):
|
||||||
|
//
|
||||||
|
// - spec_type:draft-mtp activates Multi-Token Prediction
|
||||||
|
// - spec_n_max:6 draft window
|
||||||
|
// - spec_p_min:0.75 pinned because upstream marked the 0.75 default
|
||||||
|
// with a "change to 0.0f" TODO; locking it here keeps acceptance
|
||||||
|
// thresholds stable across future bumps
|
||||||
|
var mtpSpecOptions = []string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTPSpecOptions returns a copy of the option keys auto-applied when an MTP
|
||||||
|
// head is detected. Exported for testing and for the importer.
|
||||||
|
func MTPSpecOptions() []string {
|
||||||
|
out := make([]string, len(mtpSpecOptions))
|
||||||
|
copy(out, mtpSpecOptions)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasEmbeddedMTPHead reports whether the parsed GGUF declares a Multi-Token
|
||||||
|
// Prediction head. Detection reads `<arch>.nextn_predict_layers`, which is
|
||||||
|
// what `gguf_writer.add_nextn_predict_layers(n)` emits in upstream's
|
||||||
|
// `conversion/qwen.py` MTP mixin. A positive layer count means the head is
|
||||||
|
// present in the same GGUF as the trunk.
|
||||||
|
func HasEmbeddedMTPHead(f *gguf.GGUFFile) (uint32, bool) {
|
||||||
|
if f == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
arch := f.Architecture().Architecture
|
||||||
|
if arch == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
v, ok := f.Header.MetadataKV.Get(arch + ".nextn_predict_layers")
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
n := gguf.ValueNumeric[uint32](v)
|
||||||
|
return n, n > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasSpecTypeOption returns true when the slice already contains a
|
||||||
|
// user-configured `spec_type:` / `speculative_type:` entry. Used to avoid
|
||||||
|
// clobbering an explicit choice with the MTP auto-defaults.
|
||||||
|
func hasSpecTypeOption(opts []string) bool {
|
||||||
|
for _, o := range opts {
|
||||||
|
if strings.HasPrefix(o, "spec_type:") || strings.HasPrefix(o, "speculative_type:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyMTPDefaults appends the auto-MTP option keys to cfg.Options when none
|
||||||
|
// is already configured. It is a no-op when the user already picked a
|
||||||
|
// `spec_type` (either via YAML or via the importer's preferences flow).
|
||||||
|
//
|
||||||
|
// `layers` is the value read from `<arch>.nextn_predict_layers` and is only
|
||||||
|
// used for the diagnostic log line.
|
||||||
|
func ApplyMTPDefaults(cfg *ModelConfig, layers uint32) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if hasSpecTypeOption(cfg.Options) {
|
||||||
|
xlog.Debug("[mtp] embedded MTP head detected but spec_type already configured; leaving user choice intact",
|
||||||
|
"name", cfg.Name, "nextn_layers", layers)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Options = append(cfg.Options, mtpSpecOptions...)
|
||||||
|
xlog.Info("[mtp] embedded MTP head detected; enabling draft-mtp speculative decoding",
|
||||||
|
"name", cfg.Name, "nextn_layers", layers, "spec_n_max", 6, "spec_p_min", 0.75)
|
||||||
|
}
|
||||||
86
core/config/mtp_test.go
Normal file
86
core/config/mtp_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("MTP auto-defaults", func() {
|
||||||
|
Context("MTPSpecOptions", func() {
|
||||||
|
It("returns the upstream-recommended speculative tuple", func() {
|
||||||
|
Expect(MTPSpecOptions()).To(Equal([]string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns a defensive copy so callers cannot mutate the package default", func() {
|
||||||
|
opts := MTPSpecOptions()
|
||||||
|
opts[0] = "spec_type:none"
|
||||||
|
Expect(MTPSpecOptions()[0]).To(Equal("spec_type:draft-mtp"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("ApplyMTPDefaults", func() {
|
||||||
|
It("appends MTP options when nothing is configured", func() {
|
||||||
|
cfg := &ModelConfig{Name: "qwen-mtp"}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("preserves unrelated options already on the config", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"use_jinja:true", "cache_reuse:256"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"use_jinja:true",
|
||||||
|
"cache_reuse:256",
|
||||||
|
"spec_type:draft-mtp",
|
||||||
|
"spec_n_max:6",
|
||||||
|
"spec_p_min:0.75",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("is a no-op when the user already configured spec_type", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"spec_type:ngram-simple", "use_jinja:true"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{
|
||||||
|
"spec_type:ngram-simple",
|
||||||
|
"use_jinja:true",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("also respects the legacy speculative_type alias", func() {
|
||||||
|
cfg := &ModelConfig{
|
||||||
|
Name: "qwen-mtp",
|
||||||
|
Options: []string{"speculative_type:ngram-mod"},
|
||||||
|
}
|
||||||
|
ApplyMTPDefaults(cfg, 1)
|
||||||
|
Expect(cfg.Options).To(Equal([]string{"speculative_type:ngram-mod"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("tolerates a nil config", func() {
|
||||||
|
Expect(func() { ApplyMTPDefaults(nil, 1) }).ToNot(Panic())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("HasEmbeddedMTPHead", func() {
|
||||||
|
It("returns false on a nil GGUF file", func() {
|
||||||
|
n, ok := HasEmbeddedMTPHead(nil)
|
||||||
|
Expect(ok).To(BeFalse())
|
||||||
|
Expect(n).To(BeZero())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
package importers
|
package importers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gguf "github.com/gpustack/gguf-parser-go"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
@@ -261,6 +264,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
|||||||
// Apply per-model-family inference parameter defaults
|
// Apply per-model-family inference parameter defaults
|
||||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||||
|
|
||||||
|
// Auto-detect Multi-Token Prediction heads (ggml-org/llama.cpp#22673) and
|
||||||
|
// enable speculative decoding. Mirrors the load-time hook so freshly
|
||||||
|
// imported configs already carry spec_type:draft-mtp before the model is
|
||||||
|
// ever loaded - users see it in the YAML preview rather than discovering
|
||||||
|
// it after the first start.
|
||||||
|
maybeApplyMTPDefaults(&modelConfig, details, &cfg)
|
||||||
|
|
||||||
data, err := yaml.Marshal(modelConfig)
|
data, err := yaml.Marshal(modelConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return gallery.ModelConfig{}, err
|
return gallery.ModelConfig{}, err
|
||||||
@@ -291,6 +301,85 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
|||||||
return &groups[len(groups)-1]
|
return &groups[len(groups)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybeApplyMTPDefaults parses the picked GGUF header (range-fetched over
|
||||||
|
// HTTP for HF/URL imports) and, if the file declares a Multi-Token Prediction
|
||||||
|
// head, appends the auto-MTP option keys to modelConfig.Options. Failures
|
||||||
|
// during the probe are non-fatal: the importer keeps the config without MTP
|
||||||
|
// so an unrelated network blip or weird header doesn't break the import.
|
||||||
|
//
|
||||||
|
// OCI/Ollama URIs are skipped because the artifact isn't directly fetchable
|
||||||
|
// as a GGUF byte stream - the load-time hook (core/config/gguf.go) covers
|
||||||
|
// those once the model is materialised on disk.
|
||||||
|
func maybeApplyMTPDefaults(modelConfig *config.ModelConfig, details Details, cfg *gallery.ModelConfig) {
|
||||||
|
probeURL := pickMTPProbeURL(details, cfg)
|
||||||
|
if probeURL == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
xlog.Debug("[mtp-importer] panic while probing GGUF header", "uri", probeURL, "recover", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
f, err := gguf.ParseGGUFFileRemote(ctx, probeURL)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Debug("[mtp-importer] failed to read remote GGUF header for MTP detection", "uri", probeURL, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n, ok := config.HasEmbeddedMTPHead(f)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.ApplyMTPDefaults(modelConfig, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickMTPProbeURL returns an HTTP(S) URL pointing at the main (non-mmproj)
|
||||||
|
// GGUF shard that should be inspected for an MTP head, or "" when no
|
||||||
|
// suitable URL is available. Custom URI schemes (`huggingface://`,
|
||||||
|
// `ollama://`, etc.) are run through `downloader.URI.ResolveURL` so the
|
||||||
|
// resulting URL is something `gguf.ParseGGUFFileRemote` can actually open.
|
||||||
|
// OCI/Ollama URIs are skipped because the artifact is not directly
|
||||||
|
// streamable as a GGUF byte range.
|
||||||
|
func pickMTPProbeURL(details Details, cfg *gallery.ModelConfig) string {
|
||||||
|
uri := downloader.URI(details.URI)
|
||||||
|
|
||||||
|
if uri.LooksLikeOCI() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(strings.ToLower(details.URI), ".gguf") {
|
||||||
|
return resolveHTTPProbe(details.URI)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range cfg.Files {
|
||||||
|
lower := strings.ToLower(f.Filename)
|
||||||
|
if strings.Contains(lower, "mmproj") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(lower, ".gguf") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resolveHTTPProbe(f.URI)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveHTTPProbe resolves an importer-side URI to the HTTP(S) URL that
|
||||||
|
// `gguf.ParseGGUFFileRemote` can range-fetch. Returns "" if the URI can't
|
||||||
|
// be reduced to an HTTP(S) endpoint (e.g. local path, unsupported scheme).
|
||||||
|
func resolveHTTPProbe(uri string) string {
|
||||||
|
resolved := downloader.URI(uri).ResolveURL()
|
||||||
|
if downloader.URI(resolved).LooksLikeHTTPURL() {
|
||||||
|
return resolved
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
||||||
// skipping any entry whose target filename is already present so repeated
|
// skipping any entry whose target filename is already present so repeated
|
||||||
// calls (e.g. the rare case of mmproj + model picking the same group)
|
// calls (e.g. the rare case of mmproj + model picking the same group)
|
||||||
|
|||||||
@@ -22,12 +22,19 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var videoDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||||
|
|
||||||
func downloadFile(url string) (string, error) {
|
func downloadFile(url string) (string, error) {
|
||||||
|
if err := utils.ValidateExternalURL(url); err != nil {
|
||||||
|
return "", fmt.Errorf("URL validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Get the data
|
// Get the data
|
||||||
resp, err := http.Get(url)
|
resp, err := videoDownloadClient.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -131,13 +131,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
delta.Reasoning = &reasoningDelta
|
delta.Reasoning = &reasoningDelta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Usage rides as a struct field for the consumer to track the
|
||||||
|
// running cumulative — it is stripped before JSON marshal so the
|
||||||
|
// wire chunk stays spec-compliant (no `usage` on intermediate
|
||||||
|
// chunks). The dedicated trailer chunk (when include_usage=true)
|
||||||
|
// carries the final totals.
|
||||||
|
usageForChunk := usage
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Usage: usage,
|
Usage: &usageForChunk,
|
||||||
}
|
}
|
||||||
|
|
||||||
responses <- resp
|
responses <- resp
|
||||||
@@ -164,7 +170,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
hasChatDeltaToolCalls := false
|
hasChatDeltaToolCalls := false
|
||||||
hasChatDeltaContent := false
|
hasChatDeltaContent := false
|
||||||
|
|
||||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
_, _, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
result += s
|
result += s
|
||||||
|
|
||||||
// Track whether ChatDeltas from the C++ autoparser contain
|
// Track whether ChatDeltas from the C++ autoparser contain
|
||||||
@@ -387,16 +393,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case noActionToRun:
|
case noActionToRun:
|
||||||
usage := schema.OpenAIUsage{
|
// Token-cumulative usage is communicated to the streaming
|
||||||
PromptTokens: tokenUsage.Prompt,
|
// consumer via the per-token callback's chunk struct (stripped
|
||||||
CompletionTokens: tokenUsage.Completion,
|
// before wire marshal). The final usage trailer — when the
|
||||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
// caller opted in with stream_options.include_usage — is built
|
||||||
}
|
// by the outer streaming loop, not here.
|
||||||
if extraUsage {
|
|
||||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
|
||||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
|
||||||
}
|
|
||||||
|
|
||||||
var result string
|
var result string
|
||||||
if !sentInitialRole {
|
if !sentInitialRole {
|
||||||
var hqErr error
|
var hqErr error
|
||||||
@@ -409,7 +410,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
for _, chunk := range buildNoActionFinalChunks(
|
for _, chunk := range buildNoActionFinalChunks(
|
||||||
id, req.Model, created,
|
id, req.Model, created,
|
||||||
sentInitialRole, sentReasoning,
|
sentInitialRole, sentReasoning,
|
||||||
result, reasoning, usage,
|
result, reasoning,
|
||||||
) {
|
) {
|
||||||
responses <- chunk
|
responses <- chunk
|
||||||
}
|
}
|
||||||
@@ -724,7 +725,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
xlog.Debug("No choices in the response, skipping")
|
xlog.Debug("No choices in the response, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
// Capture the running cumulative usage from this chunk
|
||||||
|
// (when present) so the include_usage trailer can carry
|
||||||
|
// the final totals. Usage is stripped before marshal
|
||||||
|
// below so the wire chunk stays spec-compliant.
|
||||||
|
if ev.Usage != nil {
|
||||||
|
usage = ev.Usage
|
||||||
|
}
|
||||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
toolsCalled = true
|
toolsCalled = true
|
||||||
// Collect and merge tool call deltas for MCP execution
|
// Collect and merge tool call deltas for MCP execution
|
||||||
@@ -740,6 +747,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
collectedContent += *sp
|
collectedContent += *sp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// OpenAI streaming spec: intermediate chunks must NOT
|
||||||
|
// carry a `usage` field. Strip the tracking copy
|
||||||
|
// before marshalling — usage is delivered via the
|
||||||
|
// dedicated trailer chunk when include_usage=true.
|
||||||
|
ev.Usage = nil
|
||||||
respData, err := json.Marshal(ev)
|
respData, err := json.Marshal(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Debug("Failed to marshal response", "error", err)
|
xlog.Debug("Failed to marshal response", "error", err)
|
||||||
@@ -888,6 +900,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
finishReason = FinishReasonFunctionCall
|
finishReason = FinishReasonFunctionCall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Final delta chunk: empty delta with finish_reason set. Per
|
||||||
|
// OpenAI streaming spec this chunk does NOT carry usage —
|
||||||
|
// the optional trailer (below) does, gated on include_usage.
|
||||||
resp := &schema.OpenAIResponse{
|
resp := &schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -899,11 +914,18 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Delta: &schema.Message{},
|
Delta: &schema.Message{},
|
||||||
}},
|
}},
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Usage: *usage,
|
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||||
|
|
||||||
|
// Trailing usage chunk per OpenAI spec: emit only when the
|
||||||
|
// caller opted in via stream_options.include_usage. Shape:
|
||||||
|
// {"choices":[],"usage":{...},"object":"chat.completion.chunk",...}
|
||||||
|
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage && usage != nil {
|
||||||
|
trailer := streamUsageTrailerJSON(id, input.Model, created, *usage)
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
c.Response().Flush()
|
c.Response().Flush()
|
||||||
xlog.Debug("Stream ended")
|
xlog.Debug("Stream ended")
|
||||||
@@ -1263,7 +1285,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
xlog.Debug("Response", "response", string(respData))
|
xlog.Debug("Response", "response", string(respData))
|
||||||
|
|||||||
@@ -1,12 +1,45 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// streamUsageTrailerJSON returns the bytes of the OpenAI-spec trailing usage
|
||||||
|
// chunk emitted in streaming completions when the request opts in via
|
||||||
|
// `stream_options.include_usage: true`. The shape is:
|
||||||
|
//
|
||||||
|
// {"id":"...","object":"chat.completion.chunk","created":N,
|
||||||
|
// "model":"...","choices":[],"usage":{...}}
|
||||||
|
//
|
||||||
|
// `choices` is intentionally an empty array (not absent, not null) — that is
|
||||||
|
// what the OpenAI spec mandates, and what consumers like the official OpenAI
|
||||||
|
// SDK and Continue's openai-adapter look for to recognise this as the usage
|
||||||
|
// chunk rather than a content chunk. schema.OpenAIResponse has `omitempty`
|
||||||
|
// on Choices, so we cannot reuse it for the trailer.
|
||||||
|
func streamUsageTrailerJSON(id, model string, created int, usage schema.OpenAIUsage) []byte {
|
||||||
|
trailer := struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Choices []schema.Choice `json:"choices"`
|
||||||
|
Usage schema.OpenAIUsage `json:"usage"`
|
||||||
|
}{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: model,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Choices: []schema.Choice{},
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(trailer)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// hasRealCall reports whether functionResults contains at least one
|
// hasRealCall reports whether functionResults contains at least one
|
||||||
// entry whose Name is something other than the noAction sentinel.
|
// entry whose Name is something other than the noAction sentinel.
|
||||||
// Used by processTools to decide between the "answer the question"
|
// Used by processTools to decide between the "answer the question"
|
||||||
@@ -25,10 +58,10 @@ func hasRealCall(functionResults []functions.FuncCallResults, noAction string) b
|
|||||||
// pseudo-function or emitted no tool calls at all).
|
// pseudo-function or emitted no tool calls at all).
|
||||||
//
|
//
|
||||||
// When content was already streamed (contentAlreadyStreamed=true) the
|
// When content was already streamed (contentAlreadyStreamed=true) the
|
||||||
// helper emits a single trailing usage chunk, optionally carrying
|
// helper emits a trailing reasoning chunk if any non-streamed reasoning
|
||||||
// reasoning that was produced but not streamed incrementally. When
|
// remains, else nothing. When content was not streamed it emits a role
|
||||||
// content was not streamed it emits a role chunk followed by a
|
// chunk followed by a content (+reasoning) chunk — the "send everything
|
||||||
// content+reasoning+usage chunk — the "send everything at once" fallback.
|
// at once" fallback.
|
||||||
//
|
//
|
||||||
// Reasoning re-emission is guarded by reasoningAlreadyStreamed, not by
|
// Reasoning re-emission is guarded by reasoningAlreadyStreamed, not by
|
||||||
// probing the extractor's Go-side state: the C++ autoparser delivers
|
// probing the extractor's Go-side state: the C++ autoparser delivers
|
||||||
@@ -36,6 +69,10 @@ func hasRealCall(functionResults []functions.FuncCallResults, noAction string) b
|
|||||||
// separate accumulator that extractor.Reasoning() does not expose.
|
// separate accumulator that extractor.Reasoning() does not expose.
|
||||||
// Without this guard the callback would stream reasoning incrementally
|
// Without this guard the callback would stream reasoning incrementally
|
||||||
// and the final chunk would duplicate it.
|
// and the final chunk would duplicate it.
|
||||||
|
//
|
||||||
|
// The returned chunks intentionally do NOT carry a `usage` field. The
|
||||||
|
// usage trailer is emitted separately by the streaming handler when
|
||||||
|
// `stream_options.include_usage` is true, per OpenAI spec.
|
||||||
func buildNoActionFinalChunks(
|
func buildNoActionFinalChunks(
|
||||||
id, model string,
|
id, model string,
|
||||||
created int,
|
created int,
|
||||||
@@ -43,26 +80,26 @@ func buildNoActionFinalChunks(
|
|||||||
reasoningAlreadyStreamed bool,
|
reasoningAlreadyStreamed bool,
|
||||||
content string,
|
content string,
|
||||||
reasoning string,
|
reasoning string,
|
||||||
usage schema.OpenAIUsage,
|
|
||||||
) []schema.OpenAIResponse {
|
) []schema.OpenAIResponse {
|
||||||
var out []schema.OpenAIResponse
|
var out []schema.OpenAIResponse
|
||||||
|
|
||||||
if contentAlreadyStreamed {
|
if contentAlreadyStreamed {
|
||||||
delta := &schema.Message{}
|
if reasoning == "" || reasoningAlreadyStreamed {
|
||||||
if reasoning != "" && !reasoningAlreadyStreamed {
|
return nil
|
||||||
r := reasoning
|
|
||||||
delta.Reasoning = &r
|
|
||||||
}
|
}
|
||||||
|
r := reasoning
|
||||||
out = append(out, schema.OpenAIResponse{
|
out = append(out, schema.OpenAIResponse{
|
||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
Choices: []schema.Choice{{
|
||||||
Object: "chat.completion.chunk",
|
Delta: &schema.Message{Reasoning: &r},
|
||||||
Usage: usage,
|
Index: 0,
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// Content was not streamed — send role, then content (+reasoning) + usage.
|
// Content was not streamed — send role, then content (+reasoning).
|
||||||
out = append(out, schema.OpenAIResponse{
|
out = append(out, schema.OpenAIResponse{
|
||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{
|
Choices: []schema.Choice{{
|
||||||
@@ -82,7 +119,6 @@ func buildNoActionFinalChunks(
|
|||||||
ID: id, Created: created, Model: model,
|
ID: id, Created: created, Model: model,
|
||||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Usage: usage,
|
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -609,54 +609,52 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
testModel = "test-model"
|
testModel = "test-model"
|
||||||
testCreated = 1700000000
|
testCreated = 1700000000
|
||||||
)
|
)
|
||||||
usage := schema.OpenAIUsage{PromptTokens: 5, CompletionTokens: 7, TotalTokens: 12}
|
|
||||||
|
|
||||||
Describe("Content streamed — trailing usage chunk", func() {
|
Describe("Content streamed — trailing reasoning only", func() {
|
||||||
It("emits just one chunk with usage, no content, no reasoning when reasoning was streamed", func() {
|
It("emits nothing when content and reasoning were already streamed", func() {
|
||||||
|
// Before the streaming-usage-spec fix this branch emitted a
|
||||||
|
// content-less chunk solely to carry `usage`. Per the OpenAI
|
||||||
|
// spec usage no longer rides on delta chunks; the dedicated
|
||||||
|
// trailer (when include_usage=true) carries it instead — so
|
||||||
|
// with nothing to deliver the helper returns no chunks.
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, true,
|
true, true,
|
||||||
"", "already-streamed-reasoning", usage,
|
"", "already-streamed-reasoning",
|
||||||
)
|
)
|
||||||
|
Expect(chunks).To(BeEmpty())
|
||||||
Expect(chunks).To(HaveLen(1))
|
|
||||||
Expect(chunks[0].Usage.TotalTokens).To(Equal(12))
|
|
||||||
Expect(contentOf(chunks[0])).To(BeEmpty())
|
|
||||||
Expect(reasoningOf(chunks[0])).To(BeEmpty(),
|
|
||||||
"reasoning must not be re-emitted once it was streamed via the callback")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("emits a trailing reasoning delivery when reasoning came only at end", func() {
|
It("emits a trailing reasoning delivery when reasoning came only at end", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, false,
|
true, false,
|
||||||
"", "autoparser final reasoning", usage,
|
"", "autoparser final reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(1))
|
Expect(chunks).To(HaveLen(1))
|
||||||
Expect(reasoningOf(chunks[0])).To(Equal("autoparser final reasoning"))
|
Expect(reasoningOf(chunks[0])).To(Equal("autoparser final reasoning"))
|
||||||
Expect(contentOf(chunks[0])).To(BeEmpty())
|
Expect(contentOf(chunks[0])).To(BeEmpty())
|
||||||
Expect(chunks[0].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[0].Usage).To(BeNil(),
|
||||||
|
"intermediate chunks must not carry usage per OpenAI spec")
|
||||||
})
|
})
|
||||||
|
|
||||||
It("omits reasoning when it's empty regardless of streamed flag", func() {
|
It("returns no chunks when reasoning is empty and content was streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
true, false,
|
true, false,
|
||||||
"", "", usage,
|
"", "",
|
||||||
)
|
)
|
||||||
|
Expect(chunks).To(BeEmpty())
|
||||||
Expect(chunks).To(HaveLen(1))
|
|
||||||
Expect(reasoningOf(chunks[0])).To(BeEmpty())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Content not streamed — role, then content+usage", func() {
|
Describe("Content not streamed — role, then content", func() {
|
||||||
It("emits role chunk then content chunk without reasoning when reasoning was streamed", func() {
|
It("emits role chunk then content chunk without reasoning when reasoning was streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, true,
|
false, true,
|
||||||
"the answer", "already-streamed-reasoning", usage,
|
"the answer", "already-streamed-reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -666,14 +664,14 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
||||||
Expect(reasoningOf(chunks[1])).To(BeEmpty(),
|
Expect(reasoningOf(chunks[1])).To(BeEmpty(),
|
||||||
"reasoning must not be re-emitted if it was streamed earlier")
|
"reasoning must not be re-emitted if it was streamed earlier")
|
||||||
Expect(chunks[1].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[1].Usage).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("emits role, then content+reasoning when reasoning was not streamed", func() {
|
It("emits role, then content+reasoning when reasoning was not streamed", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"the answer", "autoparser final reasoning", usage,
|
"the answer", "autoparser final reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -681,14 +679,14 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
|
|
||||||
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
Expect(contentOf(chunks[1])).To(Equal("the answer"))
|
||||||
Expect(reasoningOf(chunks[1])).To(Equal("autoparser final reasoning"))
|
Expect(reasoningOf(chunks[1])).To(Equal("autoparser final reasoning"))
|
||||||
Expect(chunks[1].Usage.TotalTokens).To(Equal(12))
|
Expect(chunks[1].Usage).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("still emits content even when reasoning is empty", func() {
|
It("still emits content even when reasoning is empty", func() {
|
||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"just an answer", "", usage,
|
"just an answer", "",
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(chunks).To(HaveLen(2))
|
Expect(chunks).To(HaveLen(2))
|
||||||
@@ -702,7 +700,7 @@ var _ = Describe("buildNoActionFinalChunks", func() {
|
|||||||
chunks := buildNoActionFinalChunks(
|
chunks := buildNoActionFinalChunks(
|
||||||
testID, testModel, testCreated,
|
testID, testModel, testCreated,
|
||||||
false, false,
|
false, false,
|
||||||
"hi", "reasoning", usage,
|
"hi", "reasoning",
|
||||||
)
|
)
|
||||||
for i, ch := range chunks {
|
for i, ch := range chunks {
|
||||||
Expect(ch.ID).To(Equal(testID), "chunk[%d] ID", i)
|
Expect(ch.ID).To(Equal(testID), "chunk[%d] ID", i)
|
||||||
|
|||||||
179
core/http/endpoints/openai/chat_stream_usage_test.go
Normal file
179
core/http/endpoints/openai/chat_stream_usage_test.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These tests pin LocalAI's streaming chunks to the OpenAI spec for the
|
||||||
|
// `usage` field. The regression that motivated them (issue #8546) was that
|
||||||
|
// LocalAI emitted `"usage":{...zeros...}` on every chunk, which made the
|
||||||
|
// official OpenAI Node SDK consumers (Continue, Kilo Code, Roo Code, Zed,
|
||||||
|
// IntelliJ Continue) drop every content chunk via the filter at
|
||||||
|
// continuedev/continue packages/openai-adapters/src/apis/OpenAI.ts:275-288.
|
||||||
|
//
|
||||||
|
// Per OpenAI's chat-completion streaming contract:
|
||||||
|
// - intermediate chunks MUST NOT carry a `usage` field
|
||||||
|
// - usage is only delivered when the request opts in via
|
||||||
|
// `stream_options.include_usage: true`, on a final extra chunk whose
|
||||||
|
// `choices` is an empty array.
|
||||||
|
|
||||||
|
var _ = Describe("streaming usage spec compliance", func() {
|
||||||
|
Describe("OpenAIResponse JSON shape", func() {
|
||||||
|
It("does not emit a 'usage' key when Usage is unset", func() {
|
||||||
|
// A typical intermediate token chunk: no Usage populated.
|
||||||
|
content := "hello"
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: "req-1",
|
||||||
|
Created: 1,
|
||||||
|
Model: "m",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Index: 0,
|
||||||
|
Delta: &schema.Message{Content: &content},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed())
|
||||||
|
_, present := raw["usage"]
|
||||||
|
Expect(present).To(BeFalse(),
|
||||||
|
"intermediate chunk must not include a 'usage' key; got: %s", string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("emits the usage object when Usage is explicitly set", func() {
|
||||||
|
usage := &schema.OpenAIUsage{PromptTokens: 11, CompletionTokens: 22, TotalTokens: 33}
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: "req-1",
|
||||||
|
Created: 1,
|
||||||
|
Model: "m",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed())
|
||||||
|
u, ok := raw["usage"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue(), "expected 'usage' object, got: %s", string(data))
|
||||||
|
Expect(u["prompt_tokens"]).To(BeNumerically("==", 11))
|
||||||
|
Expect(u["completion_tokens"]).To(BeNumerically("==", 22))
|
||||||
|
Expect(u["total_tokens"]).To(BeNumerically("==", 33))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("buildNoActionFinalChunks", func() {
|
||||||
|
It("returns chunks with no Usage embedded", func() {
|
||||||
|
// Whatever the caller is doing, helpers must not bake usage
|
||||||
|
// into intermediate or final delta chunks. The usage trailer
|
||||||
|
// (when requested via include_usage) is emitted separately.
|
||||||
|
chunks := buildNoActionFinalChunks(
|
||||||
|
"req-1", "m", 1,
|
||||||
|
false, false,
|
||||||
|
"hi", "",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns chunks with no Usage when only trailing reasoning needs delivery", func() {
|
||||||
|
chunks := buildNoActionFinalChunks(
|
||||||
|
"req-1", "m", 1,
|
||||||
|
true, false,
|
||||||
|
"", "autoparser late reasoning",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("buildDeferredToolCallChunks", func() {
|
||||||
|
It("returns chunks with no Usage embedded", func() {
|
||||||
|
calls := []functions.FuncCallResults{{
|
||||||
|
Name: "do_thing", Arguments: `{"x":1}`,
|
||||||
|
}}
|
||||||
|
chunks := buildDeferredToolCallChunks(
|
||||||
|
"req-1", "m", 1, calls, 0,
|
||||||
|
false, "", false, "",
|
||||||
|
)
|
||||||
|
Expect(chunks).ToNot(BeEmpty())
|
||||||
|
for i, ch := range chunks {
|
||||||
|
Expect(ch.Usage).To(BeNil(),
|
||||||
|
"chunk[%d] must not carry Usage; got %+v", i, ch.Usage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("streamUsageTrailerJSON", func() {
|
||||||
|
It("produces JSON matching the OpenAI spec for the trailer chunk", func() {
|
||||||
|
// Trailing usage chunk shape (OpenAI streaming spec):
|
||||||
|
// {"id":"...","object":"chat.completion.chunk","created":...,
|
||||||
|
// "model":"...","choices":[],"usage":{...}}
|
||||||
|
usage := schema.OpenAIUsage{
|
||||||
|
PromptTokens: 18, CompletionTokens: 14, TotalTokens: 32,
|
||||||
|
}
|
||||||
|
data := streamUsageTrailerJSON("req-1", "m", 1, usage)
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
Expect(json.Unmarshal(data, &raw)).To(Succeed(),
|
||||||
|
"trailer must be valid JSON, got: %s", string(data))
|
||||||
|
|
||||||
|
Expect(raw["id"]).To(Equal("req-1"))
|
||||||
|
Expect(raw["model"]).To(Equal("m"))
|
||||||
|
Expect(raw["object"]).To(Equal("chat.completion.chunk"))
|
||||||
|
Expect(raw["created"]).To(BeNumerically("==", 1))
|
||||||
|
|
||||||
|
// `choices` MUST be present as an empty array (not absent, not null).
|
||||||
|
rawChoices, present := raw["choices"]
|
||||||
|
Expect(present).To(BeTrue(), "choices key must be present, got: %s", string(data))
|
||||||
|
choicesArr, ok := rawChoices.([]any)
|
||||||
|
Expect(ok).To(BeTrue(), "choices must serialize as an array, got: %s", string(data))
|
||||||
|
Expect(choicesArr).To(BeEmpty(), "choices must be empty in usage trailer, got: %s", string(data))
|
||||||
|
|
||||||
|
// `usage` MUST be present and non-null with the populated counts.
|
||||||
|
u, ok := raw["usage"].(map[string]any)
|
||||||
|
Expect(ok).To(BeTrue(), "usage object must be present, got: %s", string(data))
|
||||||
|
Expect(u["prompt_tokens"]).To(BeNumerically("==", 18))
|
||||||
|
Expect(u["completion_tokens"]).To(BeNumerically("==", 14))
|
||||||
|
Expect(u["total_tokens"]).To(BeNumerically("==", 32))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("OpenAIRequest.StreamOptions", func() {
|
||||||
|
It("parses stream_options.include_usage=true", func() {
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "m",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"messages": []
|
||||||
|
}`)
|
||||||
|
var req schema.OpenAIRequest
|
||||||
|
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||||
|
Expect(req.StreamOptions).ToNot(BeNil())
|
||||||
|
Expect(req.StreamOptions.IncludeUsage).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("defaults IncludeUsage to false when stream_options is absent", func() {
|
||||||
|
body := []byte(`{"model":"m","stream":true,"messages":[]}`)
|
||||||
|
var req schema.OpenAIRequest
|
||||||
|
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||||
|
// Either a nil StreamOptions or one with IncludeUsage=false is acceptable.
|
||||||
|
if req.StreamOptions != nil {
|
||||||
|
Expect(req.StreamOptions.IncludeUsage).To(BeFalse())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -39,6 +39,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||||
}
|
}
|
||||||
|
// Usage rides on the struct for the consumer to track the
|
||||||
|
// running cumulative; the consumer strips it before marshalling
|
||||||
|
// so intermediate chunks stay OpenAI-spec compliant.
|
||||||
|
usageForChunk := usage
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -51,7 +55,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
Usage: usage,
|
Usage: &usageForChunk,
|
||||||
}
|
}
|
||||||
xlog.Debug("Sending goroutine", "text", s)
|
xlog.Debug("Sending goroutine", "text", s)
|
||||||
|
|
||||||
@@ -127,6 +131,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var latestUsage *schema.OpenAIUsage
|
||||||
|
|
||||||
LOOP:
|
LOOP:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -135,6 +141,14 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
xlog.Debug("No choices in the response, skipping")
|
xlog.Debug("No choices in the response, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Capture running cumulative usage for the optional trailer
|
||||||
|
// emitted after the final stop chunk when include_usage=true.
|
||||||
|
if ev.Usage != nil {
|
||||||
|
latestUsage = ev.Usage
|
||||||
|
}
|
||||||
|
// OpenAI streaming spec: intermediate chunks must NOT
|
||||||
|
// carry a `usage` field. Strip the tracking copy now.
|
||||||
|
ev.Usage = nil
|
||||||
respData, err := json.Marshal(ev)
|
respData, err := json.Marshal(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Debug("Failed to marshal response", "error", err)
|
xlog.Debug("Failed to marshal response", "error", err)
|
||||||
@@ -194,8 +208,15 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
}
|
}
|
||||||
respData, _ := json.Marshal(resp)
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||||
|
|
||||||
|
// Trailing usage chunk per OpenAI spec: emit only when the caller
|
||||||
|
// opted in via stream_options.include_usage.
|
||||||
|
if input.StreamOptions != nil && input.StreamOptions.IncludeUsage && latestUsage != nil {
|
||||||
|
trailer := streamUsageTrailerJSON(id, input.Model, created, *latestUsage)
|
||||||
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", trailer)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||||
c.Response().Flush()
|
c.Response().Flush()
|
||||||
return nil
|
return nil
|
||||||
@@ -247,7 +268,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "text_completion",
|
Object: "text_completion",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
|||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: result,
|
Choices: result,
|
||||||
Object: "edit",
|
Object: "edit",
|
||||||
Usage: usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
|||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
Data: result,
|
Data: result,
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: &schema.OpenAIUsage{
|
||||||
PromptTokens: 0,
|
PromptTokens: 0,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: 0,
|
TotalTokens: 0,
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
Data: []schema.Item{{
|
Data: []schema.Item{{
|
||||||
URL: imgPath,
|
URL: imgPath,
|
||||||
}},
|
}},
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: &schema.OpenAIUsage{
|
||||||
PromptTokens: 0,
|
PromptTokens: 0,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: 0,
|
TotalTokens: 0,
|
||||||
|
|||||||
@@ -54,6 +54,30 @@ const (
|
|||||||
"Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized."
|
"Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// resolveOutputModalities returns the effective output modalities for a
|
||||||
|
// response: response-level overrides session-level, and the OpenAI Realtime
|
||||||
|
// spec default is ["audio"] when neither is set.
|
||||||
|
func resolveOutputModalities(session, response []types.Modality) []types.Modality {
|
||||||
|
if len(response) > 0 {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
if len(session) > 0 {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return []types.Modality{types.ModalityAudio}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modalitiesContainAudio reports whether the resolved modalities include audio
|
||||||
|
// output.
|
||||||
|
func modalitiesContainAudio(m []types.Modality) bool {
|
||||||
|
for _, x := range m {
|
||||||
|
if x == types.ModalityAudio {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
||||||
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
||||||
|
|
||||||
@@ -82,6 +106,10 @@ type Session struct {
|
|||||||
InputSampleRate int
|
InputSampleRate int
|
||||||
OutputSampleRate int
|
OutputSampleRate int
|
||||||
MaxOutputTokens types.IntOrInf
|
MaxOutputTokens types.IntOrInf
|
||||||
|
// OutputModalities mirrors the OpenAI Realtime spec field of the same
|
||||||
|
// name. Empty means "use the spec default" (audio). ["text"] suppresses
|
||||||
|
// TTS so the client receives only response.output_text.* events.
|
||||||
|
OutputModalities []types.Modality
|
||||||
// MaxHistoryItems caps the number of MessageItems passed to the LLM each
|
// MaxHistoryItems caps the number of MessageItems passed to the LLM each
|
||||||
// turn (0 = unlimited). Small models — especially the LFM2.5-Audio 1.5B
|
// turn (0 = unlimited). Small models — especially the LFM2.5-Audio 1.5B
|
||||||
// served via the liquid-audio backend — degrade quickly past a handful
|
// served via the liquid-audio backend — degrade quickly past a handful
|
||||||
@@ -162,13 +190,14 @@ func (s *Session) ToServer() types.SessionUnion {
|
|||||||
} else {
|
} else {
|
||||||
return types.SessionUnion{
|
return types.SessionUnion{
|
||||||
Realtime: &types.RealtimeSession{
|
Realtime: &types.RealtimeSession{
|
||||||
ID: s.ID,
|
ID: s.ID,
|
||||||
Object: "realtime.session",
|
Object: "realtime.session",
|
||||||
Model: s.Model,
|
Model: s.Model,
|
||||||
Instructions: s.Instructions,
|
Instructions: s.Instructions,
|
||||||
Tools: s.Tools,
|
Tools: s.Tools,
|
||||||
ToolChoice: s.ToolChoice,
|
ToolChoice: s.ToolChoice,
|
||||||
MaxOutputTokens: s.MaxOutputTokens,
|
MaxOutputTokens: s.MaxOutputTokens,
|
||||||
|
OutputModalities: s.OutputModalities,
|
||||||
Audio: &types.RealtimeSessionAudio{
|
Audio: &types.RealtimeSessionAudio{
|
||||||
Input: &types.SessionAudioInput{
|
Input: &types.SessionAudioInput{
|
||||||
TurnDetection: s.TurnDetection,
|
TurnDetection: s.TurnDetection,
|
||||||
@@ -1015,6 +1044,10 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
|||||||
session.MaxOutputTokens = rt.MaxOutputTokens
|
session.MaxOutputTokens = rt.MaxOutputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(rt.OutputModalities) > 0 {
|
||||||
|
session.OutputModalities = rt.OutputModalities
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1654,106 +1687,130 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for cancellation before TTS
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("Response cancelled before TTS (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("TTS cancelled (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
xlog.Error("TTS failed", "error", err)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !res.Success {
|
|
||||||
xlog.Error("TTS failed", "message", res.Message)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer os.Remove(audioFilePath)
|
|
||||||
|
|
||||||
audioBytes, err := os.ReadFile(audioFilePath)
|
|
||||||
if err != nil {
|
|
||||||
xlog.Error("failed to read TTS file", "error", err)
|
|
||||||
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
|
||||||
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
|
||||||
if ttsSampleRate == 0 {
|
|
||||||
ttsSampleRate = localSampleRate
|
|
||||||
}
|
|
||||||
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
|
||||||
|
|
||||||
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
|
||||||
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
|
||||||
// lossy intermediate resample through 16kHz.
|
|
||||||
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
|
||||||
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
xlog.Debug("Audio playback cancelled (barge-in)")
|
|
||||||
sendCancelledResponse()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
xlog.Error("failed to send audio via transport", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, isWebRTC := t.(*WebRTCTransport)
|
|
||||||
|
|
||||||
// For WebSocket clients, resample to the session's output rate and
|
|
||||||
// deliver audio as base64 in JSON events. WebRTC clients already
|
|
||||||
// received audio over the RTP track, so skip the base64 payload.
|
|
||||||
var audioString string
|
var audioString string
|
||||||
if !isWebRTC {
|
_, isWebRTC := t.(*WebRTCTransport)
|
||||||
wsPCM := pcmData
|
var respMods []types.Modality
|
||||||
if ttsSampleRate != session.OutputSampleRate {
|
if overrides != nil {
|
||||||
samples := sound.BytesToInt16sLE(pcmData)
|
respMods = overrides.OutputModalities
|
||||||
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
|
||||||
wsPCM = sound.Int16toBytesLE(resampled)
|
|
||||||
}
|
|
||||||
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
|
||||||
}
|
}
|
||||||
|
modalities := resolveOutputModalities(session.OutputModalities, respMods)
|
||||||
|
if modalitiesContainAudio(modalities) {
|
||||||
|
// Check for cancellation before TTS
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
xlog.Debug("Response cancelled before TTS (barge-in)")
|
||||||
|
sendCancelledResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
||||||
ServerEventBase: types.ServerEventBase{},
|
if err != nil {
|
||||||
ResponseID: responseID,
|
if ctx.Err() != nil {
|
||||||
ItemID: item.Assistant.ID,
|
xlog.Debug("TTS cancelled (barge-in)")
|
||||||
OutputIndex: 0,
|
sendCancelledResponse()
|
||||||
ContentIndex: 0,
|
return
|
||||||
Delta: finalSpeech,
|
}
|
||||||
})
|
xlog.Error("TTS failed", "error", err)
|
||||||
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
|
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||||
ServerEventBase: types.ServerEventBase{},
|
return
|
||||||
ResponseID: responseID,
|
}
|
||||||
ItemID: item.Assistant.ID,
|
if !res.Success {
|
||||||
OutputIndex: 0,
|
xlog.Error("TTS failed", "message", res.Message)
|
||||||
ContentIndex: 0,
|
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
||||||
Transcript: finalSpeech,
|
return
|
||||||
})
|
}
|
||||||
|
defer func() { _ = os.Remove(audioFilePath) }()
|
||||||
|
|
||||||
if !isWebRTC {
|
audioBytes, err := os.ReadFile(audioFilePath)
|
||||||
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
if err != nil {
|
||||||
|
xlog.Error("failed to read TTS file", "error", err)
|
||||||
|
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
||||||
|
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
||||||
|
if ttsSampleRate == 0 {
|
||||||
|
ttsSampleRate = localSampleRate
|
||||||
|
}
|
||||||
|
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
||||||
|
|
||||||
|
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
||||||
|
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
||||||
|
// lossy intermediate resample through 16kHz.
|
||||||
|
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
||||||
|
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
xlog.Debug("Audio playback cancelled (barge-in)")
|
||||||
|
sendCancelledResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
xlog.Error("failed to send audio via transport", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For WebSocket clients, resample to the session's output rate and
|
||||||
|
// deliver audio as base64 in JSON events. WebRTC clients already
|
||||||
|
// received audio over the RTP track, so skip the base64 payload.
|
||||||
|
if !isWebRTC {
|
||||||
|
wsPCM := pcmData
|
||||||
|
if ttsSampleRate != session.OutputSampleRate {
|
||||||
|
samples := sound.BytesToInt16sLE(pcmData)
|
||||||
|
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
||||||
|
wsPCM = sound.Int16toBytesLE(resampled)
|
||||||
|
}
|
||||||
|
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||||
ServerEventBase: types.ServerEventBase{},
|
ServerEventBase: types.ServerEventBase{},
|
||||||
ResponseID: responseID,
|
ResponseID: responseID,
|
||||||
ItemID: item.Assistant.ID,
|
ItemID: item.Assistant.ID,
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
ContentIndex: 0,
|
ContentIndex: 0,
|
||||||
Delta: audioString,
|
Delta: finalSpeech,
|
||||||
})
|
})
|
||||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
|
||||||
ServerEventBase: types.ServerEventBase{},
|
ServerEventBase: types.ServerEventBase{},
|
||||||
ResponseID: responseID,
|
ResponseID: responseID,
|
||||||
ItemID: item.Assistant.ID,
|
ItemID: item.Assistant.ID,
|
||||||
OutputIndex: 0,
|
OutputIndex: 0,
|
||||||
ContentIndex: 0,
|
ContentIndex: 0,
|
||||||
|
Transcript: finalSpeech,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !isWebRTC {
|
||||||
|
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Delta: audioString,
|
||||||
|
})
|
||||||
|
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Text-only mode: skip TTS, emit only the text events.
|
||||||
|
sendEvent(t, types.ResponseOutputTextDeltaEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Delta: finalSpeech,
|
||||||
|
})
|
||||||
|
sendEvent(t, types.ResponseOutputTextDoneEvent{
|
||||||
|
ServerEventBase: types.ServerEventBase{},
|
||||||
|
ResponseID: responseID,
|
||||||
|
ItemID: item.Assistant.ID,
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Text: finalSpeech,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
39
core/http/endpoints/openai/realtime_modality_test.go
Normal file
39
core/http/endpoints/openai/realtime_modality_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("resolveOutputModalities", func() {
|
||||||
|
It("defaults to audio when neither session nor response specify", func() {
|
||||||
|
got := resolveOutputModalities(nil, nil)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityAudio))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses session modalities when response omits them", func() {
|
||||||
|
sess := []types.Modality{types.ModalityText}
|
||||||
|
got := resolveOutputModalities(sess, nil)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityText))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("response modalities override session", func() {
|
||||||
|
sess := []types.Modality{types.ModalityAudio}
|
||||||
|
resp := []types.Modality{types.ModalityText}
|
||||||
|
got := resolveOutputModalities(sess, resp)
|
||||||
|
Expect(got).To(ConsistOf(types.ModalityText))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns false from modalitiesContainAudio for text-only", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText})).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns true from modalitiesContainAudio for audio (default)", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityAudio})).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns true when both audio and text are present", func() {
|
||||||
|
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText, types.ModalityAudio})).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
5
core/http/react-ui/src/hooks/useChat.js
vendored
5
core/http/react-ui/src/hooks/useChat.js
vendored
@@ -255,7 +255,10 @@ export function useChat(initialModel = '') {
|
|||||||
)
|
)
|
||||||
messages.push(...historyForApi, { role: 'user', content: messageContent })
|
messages.push(...historyForApi, { role: 'user', content: messageContent })
|
||||||
|
|
||||||
const requestBody = { model, messages, stream: true }
|
// include_usage tells LocalAI to emit a trailing chunk with token totals;
|
||||||
|
// without it the spec-compliant server drops `usage` from the stream and
|
||||||
|
// the token-count badge would never populate.
|
||||||
|
const requestBody = { model, messages, stream: true, stream_options: { include_usage: true } }
|
||||||
if (temperature !== null && temperature !== undefined) requestBody.temperature = temperature
|
if (temperature !== null && temperature !== undefined) requestBody.temperature = temperature
|
||||||
if (topP !== null && topP !== undefined) requestBody.top_p = topP
|
if (topP !== null && topP !== undefined) requestBody.top_p = topP
|
||||||
if (topK !== null && topK !== undefined) requestBody.top_k = topK
|
if (topK !== null && topK !== undefined) requestBody.top_k = topK
|
||||||
|
|||||||
@@ -1212,6 +1212,9 @@ async function promptGPT(systemPrompt, input) {
|
|||||||
|
|
||||||
// Add stream parameter for both regular chat and MCP (MCP now supports SSE streaming)
|
// Add stream parameter for both regular chat and MCP (MCP now supports SSE streaming)
|
||||||
requestBody.stream = true;
|
requestBody.stream = true;
|
||||||
|
// include_usage tells LocalAI to emit a trailing chunk with token totals;
|
||||||
|
// the spec-compliant server otherwise drops `usage` from the stream.
|
||||||
|
requestBody.stream_options = { include_usage: true };
|
||||||
|
|
||||||
// Add generation parameters if they are set (null means use default)
|
// Add generation parameters if they are set (null means use default)
|
||||||
if (activeChat.temperature !== null && activeChat.temperature !== undefined) {
|
if (activeChat.temperature !== null && activeChat.temperature !== undefined) {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,6 +20,79 @@ type OllamaOptions struct {
|
|||||||
NumCtx int `json:"num_ctx,omitempty"`
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON accepts integer parameters encoded as either JSON ints
|
||||||
|
// (`8192`) or JSON floats (`8192.0`). Some clients - notably Home Assistant's
|
||||||
|
// Ollama integration - serialize ints as floats, which stdlib json refuses
|
||||||
|
// to decode into int fields. See https://github.com/mudler/LocalAI/issues/9837.
|
||||||
|
func (o *OllamaOptions) UnmarshalJSON(data []byte) error {
|
||||||
|
type aux struct {
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
TopK *json.Number `json:"top_k,omitempty"`
|
||||||
|
NumPredict *json.Number `json:"num_predict,omitempty"`
|
||||||
|
RepeatPenalty float64 `json:"repeat_penalty,omitempty"`
|
||||||
|
RepeatLastN *json.Number `json:"repeat_last_n,omitempty"`
|
||||||
|
Seed *json.Number `json:"seed,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
NumCtx *json.Number `json:"num_ctx,omitempty"`
|
||||||
|
}
|
||||||
|
var a aux
|
||||||
|
if err := json.Unmarshal(data, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
o.Temperature = a.Temperature
|
||||||
|
o.TopP = a.TopP
|
||||||
|
o.RepeatPenalty = a.RepeatPenalty
|
||||||
|
o.Stop = a.Stop
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if o.TopK, err = jsonNumberToIntPtr(a.TopK); err != nil {
|
||||||
|
return fmt.Errorf("options.top_k: %w", err)
|
||||||
|
}
|
||||||
|
if o.NumPredict, err = jsonNumberToIntPtr(a.NumPredict); err != nil {
|
||||||
|
return fmt.Errorf("options.num_predict: %w", err)
|
||||||
|
}
|
||||||
|
if o.Seed, err = jsonNumberToIntPtr(a.Seed); err != nil {
|
||||||
|
return fmt.Errorf("options.seed: %w", err)
|
||||||
|
}
|
||||||
|
if o.RepeatLastN, err = jsonNumberToInt(a.RepeatLastN); err != nil {
|
||||||
|
return fmt.Errorf("options.repeat_last_n: %w", err)
|
||||||
|
}
|
||||||
|
if o.NumCtx, err = jsonNumberToInt(a.NumCtx); err != nil {
|
||||||
|
return fmt.Errorf("options.num_ctx: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonNumberToInt parses a json.Number literal as an int, tolerating both
|
||||||
|
// integer (`8192`) and float (`8192.0`) encodings. A nil pointer or empty
|
||||||
|
// string yields 0, matching the zero-value semantics of the int fields.
|
||||||
|
func jsonNumberToInt(n *json.Number) (int, error) {
|
||||||
|
if n == nil || *n == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if i, err := n.Int64(); err == nil {
|
||||||
|
return int(i), nil
|
||||||
|
}
|
||||||
|
f, err := n.Float64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int(f), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonNumberToIntPtr(n *json.Number) (*int, error) {
|
||||||
|
if n == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
i, err := jsonNumberToInt(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &i, nil
|
||||||
|
}
|
||||||
|
|
||||||
// OllamaMessage represents a message in Ollama chat format
|
// OllamaMessage represents a message in Ollama chat format
|
||||||
type OllamaMessage struct {
|
type OllamaMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
|
|||||||
@@ -84,3 +84,94 @@ var _ = Describe("OllamaEmbedRequest", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Several Ollama clients (notably Home Assistant's Python client) encode
|
||||||
|
// integer parameters as JSON floats (`8192.0`). Stdlib json refuses to
|
||||||
|
// unmarshal those into `int` fields, so OllamaOptions has a custom
|
||||||
|
// UnmarshalJSON that accepts both forms. See
|
||||||
|
// https://github.com/mudler/LocalAI/issues/9837.
|
||||||
|
var _ = Describe("OllamaOptions JSON unmarshaling", func() {
|
||||||
|
It("accepts integer literals for int fields", func() {
|
||||||
|
body := []byte(`{"num_ctx": 8192, "num_predict": 256, "top_k": 40, "seed": 7, "repeat_last_n": 64}`)
|
||||||
|
|
||||||
|
var opts OllamaOptions
|
||||||
|
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(opts.NumCtx).To(Equal(8192))
|
||||||
|
Expect(opts.NumPredict).NotTo(BeNil())
|
||||||
|
Expect(*opts.NumPredict).To(Equal(256))
|
||||||
|
Expect(opts.TopK).NotTo(BeNil())
|
||||||
|
Expect(*opts.TopK).To(Equal(40))
|
||||||
|
Expect(opts.Seed).NotTo(BeNil())
|
||||||
|
Expect(*opts.Seed).To(Equal(7))
|
||||||
|
Expect(opts.RepeatLastN).To(Equal(64))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts float literals for int fields (Home Assistant Ollama client)", func() {
|
||||||
|
body := []byte(`{"num_ctx": 8192.0, "num_predict": 256.0, "top_k": 40.0, "seed": 7.0, "repeat_last_n": 64.0}`)
|
||||||
|
|
||||||
|
var opts OllamaOptions
|
||||||
|
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(opts.NumCtx).To(Equal(8192))
|
||||||
|
Expect(opts.NumPredict).NotTo(BeNil())
|
||||||
|
Expect(*opts.NumPredict).To(Equal(256))
|
||||||
|
Expect(opts.TopK).NotTo(BeNil())
|
||||||
|
Expect(*opts.TopK).To(Equal(40))
|
||||||
|
Expect(opts.Seed).NotTo(BeNil())
|
||||||
|
Expect(*opts.Seed).To(Equal(7))
|
||||||
|
Expect(opts.RepeatLastN).To(Equal(64))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("preserves float fields and stop list", func() {
|
||||||
|
body := []byte(`{"temperature": 0.7, "top_p": 0.9, "repeat_penalty": 1.1, "stop": ["<|end|>", "</s>"]}`)
|
||||||
|
|
||||||
|
var opts OllamaOptions
|
||||||
|
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(opts.Temperature).NotTo(BeNil())
|
||||||
|
Expect(*opts.Temperature).To(Equal(0.7))
|
||||||
|
Expect(opts.TopP).NotTo(BeNil())
|
||||||
|
Expect(*opts.TopP).To(Equal(0.9))
|
||||||
|
Expect(opts.RepeatPenalty).To(Equal(1.1))
|
||||||
|
Expect(opts.Stop).To(Equal([]string{"<|end|>", "</s>"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("leaves optional int fields nil when absent", func() {
|
||||||
|
body := []byte(`{}`)
|
||||||
|
|
||||||
|
var opts OllamaOptions
|
||||||
|
Expect(json.Unmarshal(body, &opts)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(opts.NumPredict).To(BeNil())
|
||||||
|
Expect(opts.TopK).To(BeNil())
|
||||||
|
Expect(opts.Seed).To(BeNil())
|
||||||
|
Expect(opts.NumCtx).To(Equal(0))
|
||||||
|
Expect(opts.RepeatLastN).To(Equal(0))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts nested options on a chat request with float num_ctx", func() {
|
||||||
|
// Mirrors the payload Home Assistant sends; reproduces issue #9837.
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "qwen2",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"options": {"num_ctx": 8192.0, "top_k": 40.0}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
var req OllamaChatRequest
|
||||||
|
Expect(json.Unmarshal(body, &req)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(req.Options).NotTo(BeNil())
|
||||||
|
Expect(req.Options.NumCtx).To(Equal(8192))
|
||||||
|
Expect(req.Options.TopK).NotTo(BeNil())
|
||||||
|
Expect(*req.Options.TopK).To(Equal(40))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects non-numeric values with a clear error", func() {
|
||||||
|
body := []byte(`{"num_ctx": "not-a-number"}`)
|
||||||
|
|
||||||
|
var opts OllamaOptions
|
||||||
|
err := json.Unmarshal(body, &opts)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -82,7 +82,21 @@ type OpenAIResponse struct {
|
|||||||
Choices []Choice `json:"choices,omitempty"`
|
Choices []Choice `json:"choices,omitempty"`
|
||||||
Data []Item `json:"data,omitempty"`
|
Data []Item `json:"data,omitempty"`
|
||||||
|
|
||||||
Usage OpenAIUsage `json:"usage"`
|
// Usage is intentionally a pointer with omitempty: per the OpenAI
|
||||||
|
// chat-completion streaming spec, intermediate chunks must not carry
|
||||||
|
// a `usage` field. Marshalling a value-typed usage would emit
|
||||||
|
// `"usage":{"prompt_tokens":0,...}` on every chunk and break
|
||||||
|
// OpenAI-SDK consumers that filter on a truthy `result.usage`
|
||||||
|
// (continuedev/continue, Kilo Code, Roo Code, etc.).
|
||||||
|
Usage *OpenAIUsage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamOptions mirrors OpenAI's `stream_options` request field. The only
|
||||||
|
// member currently honored is IncludeUsage; when true, the streaming
|
||||||
|
// chat-completion response emits a trailing chunk with `choices:[]` and a
|
||||||
|
// populated `usage` object.
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty" yaml:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
@@ -198,6 +212,9 @@ type OpenAIRequest struct {
|
|||||||
|
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
|
||||||
|
// StreamOptions opts into OpenAI streaming extensions, e.g. include_usage.
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options,omitempty" yaml:"stream_options,omitempty"`
|
||||||
|
|
||||||
// Image (not supported by OpenAI)
|
// Image (not supported by OpenAI)
|
||||||
Quality string `json:"quality"`
|
Quality string `json:"quality"`
|
||||||
Step int `json:"step"`
|
Step int `json:"step"`
|
||||||
|
|||||||
@@ -16,6 +16,14 @@ const (
|
|||||||
|
|
||||||
func ListModels(bcl *config.ModelConfigLoader, ml *model.ModelLoader, filter config.ModelConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
|
func ListModels(bcl *config.ModelConfigLoader, ml *model.ModelLoader, filter config.ModelConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
|
||||||
|
|
||||||
|
// Callers (e.g. the Ollama /api/tags handler) pass nil to mean "no
|
||||||
|
// filtering". Without this guard the loose-file loop below dereferences
|
||||||
|
// filter and panics, which Echo surfaces to clients as a dropped
|
||||||
|
// connection (see issue #9817).
|
||||||
|
if filter == nil {
|
||||||
|
filter = config.NoFilterFn
|
||||||
|
}
|
||||||
|
|
||||||
skipMap := map[string]struct{}{}
|
skipMap := map[string]struct{}{}
|
||||||
|
|
||||||
dataModels := []string{}
|
dataModels := []string{}
|
||||||
|
|||||||
64
core/services/galleryop/list_models_test.go
Normal file
64
core/services/galleryop/list_models_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package galleryop_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Regression test for issue #9817: the Ollama /api/tags handler calls
|
||||||
|
// ListModels with a nil filter, which used to panic as soon as a loose file
|
||||||
|
// existed under ModelsPath. The panic surfaced to Ollama clients (e.g. Home
|
||||||
|
// Assistant) as "Server disconnected without sending a response".
|
||||||
|
var _ = Describe("ListModels", func() {
|
||||||
|
var (
|
||||||
|
tempDir string
|
||||||
|
bcl *config.ModelConfigLoader
|
||||||
|
ml *model.ModelLoader
|
||||||
|
systemState *system.SystemState
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tempDir, err = os.MkdirTemp("", "list-models-test-*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
systemState, err = system.GetSystemState(system.WithModelPath(tempDir))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
ml = model.NewModelLoader(systemState)
|
||||||
|
bcl = config.NewModelConfigLoader(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tempDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not panic with a nil filter when loose files exist", func() {
|
||||||
|
// ListFilesInModelPath skips well-known weight-file extensions
|
||||||
|
// (.gguf, .bin, ...) so use an extension-less file to ensure the
|
||||||
|
// filter path is exercised.
|
||||||
|
Expect(os.WriteFile(filepath.Join(tempDir, "loose-model"), []byte("x"), 0o644)).To(Succeed())
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
var err error
|
||||||
|
Expect(func() {
|
||||||
|
names, err = galleryop.ListModels(bcl, ml, nil, galleryop.SKIP_IF_CONFIGURED)
|
||||||
|
}).ToNot(Panic())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(names).To(ContainElement("loose-model"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not panic with a nil filter when ModelsPath is empty", func() {
|
||||||
|
Expect(func() {
|
||||||
|
_, err := galleryop.ListModels(bcl, ml, nil, galleryop.SKIP_IF_CONFIGURED)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}).ToNot(Panic())
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -316,23 +316,132 @@ These are set via the `options:` array in the model configuration (format: `key:
|
|||||||
|
|
||||||
#### Speculative Type Values
|
#### Speculative Type Values
|
||||||
|
|
||||||
| Type | Description |
|
The canonical names match upstream llama.cpp (dash-separated). For backward compatibility LocalAI also accepts the underscore-separated forms and the bare `draft` / `eagle3` aliases.
|
||||||
|------|-------------|
|
|
||||||
| `none` | No speculative decoding (default) |
|
|
||||||
| `draft` | Draft model-based speculation (auto-set when `draft_model` is configured) |
|
|
||||||
| `eagle3` | EAGLE3 draft model architecture |
|
|
||||||
| `ngram_simple` | Simple self-speculative using token history |
|
|
||||||
| `ngram_map_k` | N-gram with key-only map |
|
|
||||||
| `ngram_map_k4v` | N-gram with keys and 4 m-gram values |
|
|
||||||
| `ngram_mod` | Modified n-gram speculation |
|
|
||||||
| `ngram_cache` | 3-level n-gram cache |
|
|
||||||
|
|
||||||
Multiple types can be chained by passing a comma-separated list to `spec_type` (e.g. `spec_type:ngram_simple,ngram_mod`). The runtime tries them in order and accepts the first proposal that meets the acceptance criteria.
|
| Type | Aliases accepted | Description |
|
||||||
|
|------|------------------|-------------|
|
||||||
|
| `none` | | No speculative decoding (default) |
|
||||||
|
| `draft-simple` | `draft`, `draft_simple` | Draft model-based speculation (auto-set when `draft_model` is configured) |
|
||||||
|
| `draft-eagle3` | `eagle3`, `draft_eagle3` | EAGLE3 draft model architecture |
|
||||||
|
| `draft-mtp` | `draft_mtp` | Multi-Token Prediction. Reuses the target model's embedded MTP head; no separate draft GGUF required (`draft_model` can be omitted). |
|
||||||
|
| `ngram-simple` | `ngram_simple` | Simple self-speculative using token history |
|
||||||
|
| `ngram-map-k` | `ngram_map_k` | N-gram with key-only map |
|
||||||
|
| `ngram-map-k4v` | `ngram_map_k4v` | N-gram with keys and 4 m-gram values |
|
||||||
|
| `ngram-mod` | `ngram_mod` | Modified n-gram speculation |
|
||||||
|
| `ngram-cache` | `ngram_cache` | 3-level n-gram cache |
|
||||||
|
|
||||||
|
Multiple types can be chained by passing a comma-separated list to `spec_type` (e.g. `spec_type:ngram-simple,ngram-mod`). The runtime tries them in order and accepts the first proposal that meets the acceptance criteria.
|
||||||
|
|
||||||
{{% notice note %}}
|
{{% notice note %}}
|
||||||
Speculative decoding is automatically disabled when multimodal models (with `mmproj`) are active. The `n_draft` parameter can also be overridden per-request.
|
Speculative decoding is automatically disabled when multimodal models (with `mmproj`) are active. The `n_draft` parameter can also be overridden per-request.
|
||||||
{{% /notice %}}
|
{{% /notice %}}
|
||||||
|
|
||||||
|
##### Multi-Token Prediction (MTP)
|
||||||
|
|
||||||
|
`draft-mtp` enables [Multi-Token Prediction](https://github.com/ggml-org/llama.cpp/pull/22673) (ggml-org/llama.cpp#22673). MTP uses a small prediction head trained into the target model: the head runs alongside the main forward pass and proposes the next few tokens, which the target then verifies in a single batched step. Upstream reports ~1.85x-2.1x token throughput at ~72-82% draft acceptance on Qwen3.6 27B / 35B A3B.
|
||||||
|
|
||||||
|
**Auto-detection (default).** When a GGUF declares an MTP head (the upstream `<arch>.nextn_predict_layers` metadata key, set by `convert_hf_to_gguf.py` for Qwen3.5/3.6 family models and similar), LocalAI auto-enables MTP with the following defaults:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
options:
|
||||||
|
- spec_type:draft-mtp
|
||||||
|
- spec_n_max:6
|
||||||
|
- spec_p_min:0.75
|
||||||
|
```
|
||||||
|
|
||||||
|
Detection runs both at **import time** (the `/import-model` UI / `POST /models/import-uri` flow range-fetches the GGUF header and writes the options into the generated YAML before you save it) and at **load time** (every llama-cpp model start re-checks the local header and appends the options if `spec_type` isn't already set). To opt out, set an explicit `spec_type:` / `speculative_type:` in your YAML - auto-detection always preserves the user value, including `spec_type:none`.
|
||||||
|
|
||||||
|
**Two ways to load the MTP head:**
|
||||||
|
|
||||||
|
1. **Embedded in the target GGUF** (the recommended path for LocalAI, and what auto-detection assumes). When `spec_type` includes `draft-mtp` and `draft_model` is empty, the backend builds the MTP draft context directly from the target model's weights. The GGUF must have been converted with the MTP tensors included.
|
||||||
|
2. **Separate `mtp-*.gguf` sibling file.** If you point `draft_model` at the separate MTP-head GGUF that ships next to the main weights on HuggingFace, the backend will load it as a draft model. Note: upstream's `-hf` auto-discovery of `mtp-*.gguf` siblings is **not** wired into LocalAI's gRPC layer - you need to download the sibling file and configure `draft_model` explicitly.
|
||||||
|
|
||||||
|
**Manual override knobs** (overlap with the auto-detect defaults above):
|
||||||
|
|
||||||
|
| Option | Recommended | Notes |
|
||||||
|
|--------|------------|-------|
|
||||||
|
| `spec_type` | `draft-mtp` | Activates MTP. Can be chained with other types (see below). |
|
||||||
|
| `spec_n_max` / `draft_max` | `2`-`6` | Number of draft tokens per step. Upstream's PR suggests 2-3 for the tightest acceptance window; LocalAI's auto-default is 6 to favour throughput on models with high acceptance. |
|
||||||
|
| `spec_p_min` | `0.75` | Pinned because upstream marks the current default with a "change to 0.0f" TODO; locking it here keeps acceptance thresholds stable across future llama.cpp bumps. |
|
||||||
|
| `mmproj_use_gpu` | `false` (or unset `mmproj`) | MTP has a prompt-processing overhead; if the model is non-vision, drop the mmproj entirely to save VRAM. |
|
||||||
|
|
||||||
|
**Minimal config** (override-only, since auto-detection already covers this for MTP-capable GGUFs):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: qwen3-mtp
|
||||||
|
backend: llama-cpp
|
||||||
|
parameters:
|
||||||
|
model: qwen3-27b-with-mtp.gguf
|
||||||
|
options:
|
||||||
|
- spec_type:draft-mtp
|
||||||
|
- spec_n_max:3
|
||||||
|
```
|
||||||
|
|
||||||
|
**With a separate MTP head file:**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: qwen3-mtp
|
||||||
|
backend: llama-cpp
|
||||||
|
parameters:
|
||||||
|
model: qwen3-27b.gguf
|
||||||
|
draft_model: qwen3-27b-mtp-head.gguf
|
||||||
|
options:
|
||||||
|
- spec_type:draft-mtp
|
||||||
|
- spec_n_max:3
|
||||||
|
```
|
||||||
|
|
||||||
|
**Chaining MTP with n-gram fallback** (experimental, from the PR's usage notes - useful when MTP acceptance drops on highly repetitive output):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
options:
|
||||||
|
- spec_type:draft-mtp,ngram-mod
|
||||||
|
- spec_n_max:3
|
||||||
|
- spec_ngram_mod_n_match:24
|
||||||
|
```
|
||||||
|
|
||||||
|
Pre-converted GGUFs with MTP heads are published on the [ggml-org HuggingFace org](https://huggingface.co/ggml-org) (initially Qwen3.6 27B and Qwen3.6 35B A3B).
|
||||||
|
|
||||||
|
### Reasoning Models (DeepSeek-R1, Qwen3, etc.)
|
||||||
|
|
||||||
|
These load-time options control how the backend parses `<think>` reasoning blocks and how much budget the model is allowed for thinking. They are set per model via the `options:` array.
|
||||||
|
|
||||||
|
| Option | Type | Default | Description |
|
||||||
|
|--------|------|---------|-------------|
|
||||||
|
| `reasoning_format` | string | `deepseek` | Parser for reasoning/thinking blocks. One of `none`, `auto`, `deepseek`, `deepseek-legacy` (alias `deepseek_legacy`). |
|
||||||
|
| `enable_reasoning` / `reasoning_budget` | int | `-1` | Reasoning budget in tokens: `-1` unlimited, `0` disabled, `>0` token cap for the thinking section. |
|
||||||
|
| `prefill_assistant` | bool | `true` | When `false`, the trailing assistant message is not pre-filled by the chat template. |
|
||||||
|
|
||||||
|
{{% notice note %}}
|
||||||
|
This is the load-time reasoning configuration. The orthogonal per-request `enable_thinking` chat-template kwarg (set via the YAML `reasoning.disable` field) toggles thinking on/off per call without restarting the model.
|
||||||
|
{{% /notice %}}
|
||||||
|
|
||||||
|
### Multimodal Backend Options
|
||||||
|
|
||||||
|
| Option | Type | Default | Description |
|
||||||
|
|--------|------|---------|-------------|
|
||||||
|
| `mmproj_use_gpu` / `mmproj_offload` | bool | `true` | Set `false` to keep the multimodal projector on CPU (saves VRAM at cost of speed). |
|
||||||
|
| `image_min_tokens` | int | `-1` | Minimum vision tokens per image. `-1` keeps the model default. |
|
||||||
|
| `image_max_tokens` | int | `-1` | Maximum vision tokens per image. `-1` keeps the model default. |
|
||||||
|
|
||||||
|
### Embedding & Reranking Backend Options
|
||||||
|
|
||||||
|
| Option | Type | Default | Description |
|
||||||
|
|--------|------|---------|-------------|
|
||||||
|
| `pooling_type` / `pooling` | string | auto | Pooling strategy for embeddings: `none`, `mean`, `cls`, `last`, `rank`. Reranking automatically uses `rank`. |
|
||||||
|
| `embd_normalize` / `embedding_normalize` | int | `2` | Normalization: `-1` none, `0` max-abs, `1` taxicab, `2` Euclidean (L2), `>2` p-norm. |
|
||||||
|
|
||||||
|
### Other Backend Tuning Options
|
||||||
|
|
||||||
|
These llama.cpp options are passed through the `options:` array.
|
||||||
|
|
||||||
|
| Option | Type | Default | Description |
|
||||||
|
|--------|------|---------|-------------|
|
||||||
|
| `n_ubatch` / `ubatch` | int | same as `batch` | Physical batch size. Decouple from `n_batch` when an embedding/rerank workload needs a different value. |
|
||||||
|
| `threads_batch` / `n_threads_batch` | int | same as `threads` | Threads used during prompt processing. `<= 0` means `hardware_concurrency()`. |
|
||||||
|
| `direct_io` / `use_direct_io` | bool | `false` | Open the model with `O_DIRECT` (faster cold loads on NVMe; ignored if not supported). |
|
||||||
|
| `verbosity` | int | `3` | llama.cpp internal log verbosity threshold. Higher = more verbose. |
|
||||||
|
| `override_tensor` / `tensor_buft_overrides` | string | "" | Per-tensor buffer-type overrides for the main model. Format: `<tensor regex>=<buffer type>,<tensor regex>=<buffer type>,...`. Mirrors the existing `draft_override_tensor` syntax for the draft model. |
|
||||||
|
|
||||||
### Prompt Caching
|
### Prompt Caching
|
||||||
|
|
||||||
| Field | Type | Description |
|
| Field | Type | Description |
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"version": "v4.2.3"
|
"version": "v4.2.5"
|
||||||
}
|
}
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -163,7 +163,7 @@ require (
|
|||||||
github.com/gocolly/colly v1.2.0 // indirect
|
github.com/gocolly/colly v1.2.0 // indirect
|
||||||
github.com/gofiber/fiber/v2 v2.52.13 // indirect
|
github.com/gofiber/fiber/v2 v2.52.13 // indirect
|
||||||
github.com/golang/protobuf v1.5.4 // indirect
|
github.com/golang/protobuf v1.5.4 // indirect
|
||||||
github.com/gomarkdown/markdown v0.0.0-20250311123330-531bef5e742b // indirect
|
github.com/gomarkdown/markdown v0.0.0-20260411013819-759bbc3e3207 // indirect
|
||||||
github.com/google/go-github/v69 v69.2.0 // indirect
|
github.com/google/go-github/v69 v69.2.0 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
@@ -359,7 +359,7 @@ require (
|
|||||||
github.com/jaypipes/pcidb v1.1.1 // indirect
|
github.com/jaypipes/pcidb v1.1.1 // indirect
|
||||||
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
|
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
|
||||||
github.com/josharian/intern v1.0.0 // indirect
|
github.com/josharian/intern v1.0.0 // indirect
|
||||||
github.com/klauspost/compress v1.18.5 // indirect
|
github.com/klauspost/compress v1.18.5
|
||||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||||
github.com/koron/go-ssdp v0.0.6 // indirect
|
github.com/koron/go-ssdp v0.0.6 // indirect
|
||||||
github.com/libp2p/go-buffer-pool v0.1.0 // indirect
|
github.com/libp2p/go-buffer-pool v0.1.0 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -472,8 +472,8 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6
|
|||||||
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e h1:4bw4WeyTYPp0smaXiJZCNnLrvVBqirQVreixayXezGc=
|
github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e h1:4bw4WeyTYPp0smaXiJZCNnLrvVBqirQVreixayXezGc=
|
||||||
github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/gomarkdown/markdown v0.0.0-20250311123330-531bef5e742b h1:EY/KpStFl60qA17CptGXhwfZ+k1sFNJIUNR8DdbcuUk=
|
github.com/gomarkdown/markdown v0.0.0-20260411013819-759bbc3e3207 h1:p7t34F7K4OCRQblcDhNJnP46Uaarz3z2cLcvOZYxWn8=
|
||||||
github.com/gomarkdown/markdown v0.0.0-20250311123330-531bef5e742b/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA=
|
github.com/gomarkdown/markdown v0.0.0-20260411013819-759bbc3e3207/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA=
|
||||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer results.Body.Close()
|
||||||
if results.StatusCode != 200 {
|
if results.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
|
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"archive/tar"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/zip"
|
||||||
"github.com/mholt/archiver/v3"
|
"github.com/mholt/archiver/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +58,15 @@ func ExtractArchive(archive, dst string) error {
|
|||||||
v.Tar = mytar
|
v.Tar = mytar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extractRoot, err := filepath.Abs(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = archiver.Walk(archive, func(f archiver.File) error {
|
err = archiver.Walk(archive, func(f archiver.File) error {
|
||||||
|
if err := validateArchiveMemberPath(extractRoot, archiveMemberName(f)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if f.FileInfo.Mode()&os.ModeSymlink != 0 {
|
if f.FileInfo.Mode()&os.ModeSymlink != 0 {
|
||||||
return fmt.Errorf("archive contains a symlink")
|
return fmt.Errorf("archive contains a symlink")
|
||||||
}
|
}
|
||||||
@@ -67,3 +79,41 @@ func ExtractArchive(archive, dst string) error {
|
|||||||
|
|
||||||
return un.Unarchive(archive, dst)
|
return un.Unarchive(archive, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func archiveMemberName(f archiver.File) string {
|
||||||
|
switch h := f.Header.(type) {
|
||||||
|
case tar.Header:
|
||||||
|
return h.Name
|
||||||
|
case *tar.Header:
|
||||||
|
return h.Name
|
||||||
|
case zip.FileHeader:
|
||||||
|
return h.Name
|
||||||
|
case *zip.FileHeader:
|
||||||
|
return h.Name
|
||||||
|
default:
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateArchiveMemberPath(root, name string) error {
|
||||||
|
if name == "" {
|
||||||
|
return fmt.Errorf("archive contains an empty path")
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName := filepath.FromSlash(strings.ReplaceAll(name, "\\", "/"))
|
||||||
|
cleanedName := filepath.Clean(normalizedName)
|
||||||
|
if filepath.IsAbs(cleanedName) || cleanedName == ".." || strings.HasPrefix(cleanedName, ".."+string(os.PathSeparator)) {
|
||||||
|
return fmt.Errorf("archive contains an unsafe path: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetPath := filepath.Join(root, cleanedName)
|
||||||
|
relativePath, err := filepath.Rel(root, targetPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if relativePath == ".." || strings.HasPrefix(relativePath, ".."+string(os.PathSeparator)) || filepath.IsAbs(relativePath) {
|
||||||
|
return fmt.Errorf("archive contains an unsafe path: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
128
pkg/utils/untar_test.go
Normal file
128
pkg/utils/untar_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package utils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"archive/zip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
. "github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("utils/archive tests", func() {
|
||||||
|
It("extracts regular nested zip members", func() {
|
||||||
|
tmpDir := GinkgoT().TempDir()
|
||||||
|
archivePath := filepath.Join(tmpDir, "model.zip")
|
||||||
|
extractPath := filepath.Join(tmpDir, "models")
|
||||||
|
|
||||||
|
Expect(writeZipArchive(archivePath, map[string]string{
|
||||||
|
"nested/model.yaml": "name: test",
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
Expect(ExtractArchive(archivePath, extractPath)).To(Succeed())
|
||||||
|
|
||||||
|
extracted, err := os.ReadFile(filepath.Join(extractPath, "nested", "model.yaml"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(extracted)).To(Equal("name: test"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects zip members that escape the destination", func() {
|
||||||
|
tmpDir := GinkgoT().TempDir()
|
||||||
|
archivePath := filepath.Join(tmpDir, "model.zip")
|
||||||
|
extractPath := filepath.Join(tmpDir, "models")
|
||||||
|
|
||||||
|
Expect(writeZipArchive(archivePath, map[string]string{
|
||||||
|
"../escaped.txt": "escaped",
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
err := ExtractArchive(archivePath, extractPath)
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("unsafe path"))
|
||||||
|
Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects tar members that escape the destination", func() {
|
||||||
|
tmpDir := GinkgoT().TempDir()
|
||||||
|
archivePath := filepath.Join(tmpDir, "model.tar")
|
||||||
|
extractPath := filepath.Join(tmpDir, "models")
|
||||||
|
|
||||||
|
Expect(writeTarArchive(archivePath, map[string]string{
|
||||||
|
"../escaped.txt": "escaped",
|
||||||
|
})).To(Succeed())
|
||||||
|
|
||||||
|
err := ExtractArchive(archivePath, extractPath)
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("unsafe path"))
|
||||||
|
Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
func writeZipArchive(path string, files map[string]string) (err error) {
|
||||||
|
out, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := out.Close(); err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
writer := zip.NewWriter(out)
|
||||||
|
defer func() {
|
||||||
|
if closeErr := writer.Close(); err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for name, contents := range files {
|
||||||
|
fileWriter, err := writer.Create(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fileWriter.Write([]byte(contents)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTarArchive(path string, files map[string]string) (err error) {
|
||||||
|
out, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := out.Close(); err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
writer := tar.NewWriter(out)
|
||||||
|
defer func() {
|
||||||
|
if closeErr := writer.Close(); err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for name, contents := range files {
|
||||||
|
data := []byte(contents)
|
||||||
|
if err := writer.WriteHeader(&tar.Header{
|
||||||
|
Name: name,
|
||||||
|
Mode: 0o600,
|
||||||
|
Size: int64(len(data)),
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := writer.Write(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -5347,6 +5347,14 @@ const docTemplate = `{
|
|||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"stream_options": {
|
||||||
|
"description": "StreamOptions opts into OpenAI streaming extensions, e.g. include_usage.",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/definitions/schema.StreamOptions"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
@@ -5412,7 +5420,12 @@ const docTemplate = `{
|
|||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"usage": {
|
"usage": {
|
||||||
"$ref": "#/definitions/schema.OpenAIUsage"
|
"description": "Usage is intentionally a pointer with omitempty: per the OpenAI\nchat-completion streaming spec, intermediate chunks must not carry\na ` + "`" + `usage` + "`" + ` field. Marshalling a value-typed usage would emit\n` + "`" + `\"usage\":{\"prompt_tokens\":0,...}` + "`" + ` on every chunk and break\nOpenAI-SDK consumers that filter on a truthy ` + "`" + `result.usage` + "`" + `\n(continuedev/continue, Kilo Code, Roo Code, etc.).",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/definitions/schema.OpenAIUsage"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -5578,6 +5591,14 @@ const docTemplate = `{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"schema.StreamOptions": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"include_usage": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"schema.SysInfoModel": {
|
"schema.SysInfoModel": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -5344,6 +5344,14 @@
|
|||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"stream_options": {
|
||||||
|
"description": "StreamOptions opts into OpenAI streaming extensions, e.g. include_usage.",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/definitions/schema.StreamOptions"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
@@ -5409,7 +5417,12 @@
|
|||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"usage": {
|
"usage": {
|
||||||
"$ref": "#/definitions/schema.OpenAIUsage"
|
"description": "Usage is intentionally a pointer with omitempty: per the OpenAI\nchat-completion streaming spec, intermediate chunks must not carry\na `usage` field. Marshalling a value-typed usage would emit\n`\"usage\":{\"prompt_tokens\":0,...}` on every chunk and break\nOpenAI-SDK consumers that filter on a truthy `result.usage`\n(continuedev/continue, Kilo Code, Roo Code, etc.).",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/definitions/schema.OpenAIUsage"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -5575,6 +5588,14 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"schema.StreamOptions": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"include_usage": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"schema.SysInfoModel": {
|
"schema.SysInfoModel": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -1650,6 +1650,10 @@ definitions:
|
|||||||
stop: {}
|
stop: {}
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
stream_options:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/definitions/schema.StreamOptions'
|
||||||
|
description: StreamOptions opts into OpenAI streaming extensions, e.g. include_usage.
|
||||||
temperature:
|
temperature:
|
||||||
type: number
|
type: number
|
||||||
tfz:
|
tfz:
|
||||||
@@ -1698,7 +1702,15 @@ definitions:
|
|||||||
object:
|
object:
|
||||||
type: string
|
type: string
|
||||||
usage:
|
usage:
|
||||||
$ref: '#/definitions/schema.OpenAIUsage'
|
allOf:
|
||||||
|
- $ref: '#/definitions/schema.OpenAIUsage'
|
||||||
|
description: |-
|
||||||
|
Usage is intentionally a pointer with omitempty: per the OpenAI
|
||||||
|
chat-completion streaming spec, intermediate chunks must not carry
|
||||||
|
a `usage` field. Marshalling a value-typed usage would emit
|
||||||
|
`"usage":{"prompt_tokens":0,...}` on every chunk and break
|
||||||
|
OpenAI-SDK consumers that filter on a truthy `result.usage`
|
||||||
|
(continuedev/continue, Kilo Code, Roo Code, etc.).
|
||||||
type: object
|
type: object
|
||||||
schema.OpenAIUsage:
|
schema.OpenAIUsage:
|
||||||
properties:
|
properties:
|
||||||
@@ -1813,6 +1825,11 @@ definitions:
|
|||||||
$ref: '#/definitions/schema.NodeData'
|
$ref: '#/definitions/schema.NodeData'
|
||||||
type: array
|
type: array
|
||||||
type: object
|
type: object
|
||||||
|
schema.StreamOptions:
|
||||||
|
properties:
|
||||||
|
include_usage:
|
||||||
|
type: boolean
|
||||||
|
type: object
|
||||||
schema.SysInfoModel:
|
schema.SysInfoModel:
|
||||||
properties:
|
properties:
|
||||||
id:
|
id:
|
||||||
|
|||||||
Reference in New Issue
Block a user