Compare commits

..

2 Commits

Author SHA1 Message Date
Roy Han
c79fd5c168 Reincluding Numbers 2024-05-29 12:22:36 -07:00
Roy Han
73fb9ea36e Draft for Multi-Language Modelfile Creation 2024-05-29 11:51:57 -07:00
24 changed files with 399 additions and 903 deletions

View File

@@ -34,13 +34,13 @@ jobs:
git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
}
{
echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE=$(changed llm/)
echo GENERATE_CUDA=$(changed llm/)
echo GENERATE_ROCM=$(changed llm/)
} >>$GITHUB_OUTPUT
generate:
@@ -287,8 +287,6 @@ jobs:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: 'static'
OLLAMA_SKIP_CPU_GENERATE: '1'
OLLAMA_SKIP_METAL_GENERATE: '1'
steps:
- uses: actions/checkout@v4
with:

View File

@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host ""
write-host "Run your first model:"
write-host ""
write-host "`tollama run llama3"
write-host "`tollama run llama2"
write-host ""

View File

@@ -755,11 +755,7 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
}
// backtrack the length of the last word and clear to the end of the line
a := runewidth.StringWidth(state.wordBuffer)
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
fmt.Printf("%s%c", state.wordBuffer, ch)
chWidth := runewidth.RuneWidth(ch)
@@ -1255,9 +1251,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
})
default:
appendEnvDocs(cmd, envs)

View File

@@ -76,7 +76,6 @@ Make sure you've set up the container runtime first as described in [docker.md](
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting

View File

@@ -51,16 +51,16 @@ func AsMap() map[string]EnvVar {
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_ORIGINS", LLMLibrary, ""},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, ""},
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, ""},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
}
}

View File

@@ -2,41 +2,32 @@ package format
import (
"fmt"
"math"
)
const (
Thousand = 1000
Million = Thousand * 1000
Billion = Million * 1000
Trillion = Billion * 1000
)
func HumanNumber(b uint64) string {
switch {
case b >= Trillion:
number := float64(b) / Trillion
return fmt.Sprintf("%sT", DecimalPlace(number))
case b >= Billion:
number := float64(b) / Billion
return fmt.Sprintf("%sB", DecimalPlace(number))
if number == math.Floor(number) {
return fmt.Sprintf("%.0fB", number) // no decimals if whole number
}
return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
case b >= Million:
number := float64(b) / Million
return fmt.Sprintf("%sM", DecimalPlace(number))
if number == math.Floor(number) {
return fmt.Sprintf("%.0fM", number) // no decimals if whole number
}
return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
case b >= Thousand:
number := float64(b) / Thousand
return fmt.Sprintf("%sK", DecimalPlace(number))
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
default:
return fmt.Sprintf("%d", b)
}
}
func DecimalPlace(number float64) string {
switch {
case number >= 100:
return fmt.Sprintf("%.0f", number)
case number >= 10:
return fmt.Sprintf("%.1f", number)
default:
return fmt.Sprintf("%.2f", number)
}
}

View File

@@ -13,15 +13,14 @@ func TestHumanNumber(t *testing.T) {
testCases := []testCase{
{0, "0"},
{1000000, "1.00M"},
{1000000, "1M"},
{125000000, "125M"},
{500500000, "500M"},
{500550000, "501M"},
{1000000000, "1.00B"},
{2800000000, "2.80B"},
{2850000000, "2.85B"},
{28550000000, "28.6B"},
{1000000000000, "1.00T"},
{500500000, "500.50M"},
{500550000, "500.55M"},
{1000000000, "1B"},
{2800000000, "2.8B"},
{2850000000, "2.9B"},
{1000000000000, "1000B"},
}
for _, tc := range testCases {

View File

@@ -16,12 +16,13 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/envconfig"
)
type handles struct {
@@ -104,6 +105,8 @@ func initGPUHandles() *handles {
var cudartMgmtPatterns []string
var nvcudaMgmtName string
var nvcudaMgmtPatterns []string
var oneapiMgmtName string
var oneapiMgmtPatterns []string
tmpDir, _ := PayloadsDir()
switch runtime.GOOS {
@@ -115,6 +118,8 @@ func initGPUHandles() *handles {
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux":
cudartMgmtName = "libcudart.so*"
if tmpDir != "" {
@@ -125,6 +130,8 @@ func initGPUHandles() *handles {
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default:
return gpuHandles
}
@@ -152,6 +159,17 @@ func initGPUHandles() *handles {
}
}
oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
if len(oneapiLibPaths) > 0 {
deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths)
if oneapi != nil {
slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
gpuHandles.oneapi = oneapi
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
return gpuHandles
}
@@ -227,6 +245,18 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo)
}
if gpuHandles.oneapi != nil {
gpuInfo := GpuInfo{
Library: "oneapi",
}
C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = strconv.Itoa(i)
resp = append(resp, gpuInfo)
}
}
// Then AMD

View File

@@ -140,6 +140,7 @@ struct server_slot {
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool truncated = false;
@@ -186,6 +187,7 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
@@ -598,6 +600,16 @@ struct llama_server_context
slot->params.n_predict = slot->n_predict;
}
// infill
if (data.count("input_prefix") != 0)
{
slot->params.input_prefix = data["input_prefix"];
}
else
{
slot->params.input_prefix = "";
}
if (data.count("input_suffix") != 0)
{
slot->params.input_suffix = data["input_suffix"];
@@ -885,6 +897,15 @@ struct llama_server_context
system_need_update = true;
}
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot)
{
@@ -1242,12 +1263,13 @@ struct llama_server_context
queue_results.send(res);
}
void request_completion(int task_id, json data, bool embedding, int multitask_id)
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{
task_server task;
task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id;
@@ -1393,8 +1415,8 @@ struct llama_server_context
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id);
// subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}
}
@@ -1412,8 +1434,26 @@ struct llama_server_context
break;
}
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
@@ -1639,7 +1679,8 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
// note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
{
slot.release();
slot.print_timings();
@@ -1656,7 +1697,33 @@ struct llama_server_context
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
if (slot.infill)
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
{
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_prompt_tokens = prompt_tokens.size();
@@ -2063,7 +2130,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n");
}
static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context& llama)
{
gpt_params default_params;
server_params default_sparams;
@@ -2478,6 +2546,27 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
}
params.n_predict = std::stoi(argv[i]);
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.system_prompt_process(json::parse(systm_content));
}
else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
}
@@ -2729,7 +2818,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference
llama_server_context llama;
server_params_parse(argc, argv, sparams, params);
server_params_parse(argc, argv, sparams, params, llama);
if (params.model_alias == "unknown")
{
@@ -3061,7 +3150,7 @@ int main(int argc, char **argv) {
json data = json::parse(req.body);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, -1);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
@@ -3183,7 +3272,7 @@ int main(int argc, char **argv) {
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);

View File

@@ -32,43 +32,42 @@ case "${GOARCH}" in
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;;
"arm64")
@@ -80,15 +79,13 @@ case "${GOARCH}" in
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;;
*)
echo "GOARCH must be set"

View File

@@ -1,32 +1,35 @@
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:33:20 -0700
Subject: [PATCH] default pretokenizer on unrecognized type
---
llama.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llama.cpp b/llama.cpp
index 40d2ec2c..74f3ee9c 100644
index 15c66077..af1aede3 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4642,16 +4642,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
+ if (
tokenizer_pre == "default") {
@@ -4504,9 +4504,6 @@ static void llm_load_vocab(
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
- tokenizer_pre == "default") {
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
@@ -4703,7 +4694,8 @@ static void llm_load_vocab(
tokenizer_pre == "smaug-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
tokenizer_pre == "llama3" ||
tokenizer_pre == "llama-v3" ||
@@ -4553,7 +4550,7 @@ static void llm_load_vocab(
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
} else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--
2.45.1

View File

@@ -189,38 +189,35 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32")
}
flashAttnEnabled := envconfig.FlashAttention
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
if opts.UseMLock {
params = append(params, "--mlock")
}
if !opts.UseMMap {
params = append(params, "--no-mmap")
}
if opts.UseMLock {
params = append(params, "--mlock")
}
if opts.UseNUMA {
params = append(params, "--numa")
}
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet

View File

@@ -771,6 +771,37 @@ func PruneDirectory(path string) error {
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})

View File

@@ -88,26 +88,3 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob)
}
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
return os.Remove(blob)
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
@@ -15,10 +14,7 @@ import (
type Manifest struct {
ManifestV2
filepath string
fi os.FileInfo
digest string
Digest string `json:"-"`
}
func (m *Manifest) Size() (size int64) {
@@ -29,28 +25,9 @@ func (m *Manifest) Size() (size int64) {
return
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return PruneDirectory(manifests)
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
func ParseNamedManifest(name model.Name) (*Manifest, error) {
if !name.IsFullyQualified() {
return nil, model.Unqualified(name)
}
manifests, err := GetManifestPath()
@@ -58,30 +35,20 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
p := filepath.Join(manifests, n.Filepath())
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
if err != nil {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
return nil, err
}
return &Manifest{
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
ManifestV2: manifest,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
@@ -110,48 +77,3 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
}
func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}

View File

@@ -1,150 +0,0 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string
wantValidCount int
wantInvalidCount int
}{
"empty": {},
"single": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
"multiple": {
ps: []string{
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
wantValidCount: 15,
},
"hidden": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
wantValidCount: 1,
wantInvalidCount: 1,
},
"subdir": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
wantInvalidCount: 2,
},
"upper tag": {
ps: []string{
filepath.Join("host", "namespace", "model", "TAG"),
},
wantValidCount: 1,
},
"upper model": {
ps: []string{
filepath.Join("host", "namespace", "MODEL", "tag"),
},
wantValidCount: 1,
},
"upper namespace": {
ps: []string{
filepath.Join("host", "NAMESPACE", "model", "tag"),
},
wantValidCount: 1,
},
"upper host": {
ps: []string{
filepath.Join("HOST", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, p := range wants.ps {
createManifest(t, d, p)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
var gotValidCount, gotInvalidCount int
for _, p := range wants.ps {
n := model.ParseNameFromFilepath(p)
if n.IsValid() {
gotValidCount++
} else {
gotInvalidCount++
}
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", p)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", p)
}
}
if gotValidCount != wants.wantValidCount {
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
}
if gotInvalidCount != wants.wantInvalidCount {
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
}
})
}
}

View File

@@ -421,14 +421,13 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
@@ -446,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
if err := PullModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -508,21 +507,6 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -539,11 +523,6 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return
}
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
@@ -596,31 +575,48 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
}
func (s *Server) DeleteModelHandler(c *gin.Context) {
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := ParseNamedManifest(n)
if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := m.Remove(); err != nil {
if err := PruneDirectory(manifestsPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
}
func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -724,42 +720,72 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
func (s *Server) ListModelsHandler(c *gin.Context) {
ms, err := Manifests()
manifests, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
models := []api.ModelResponse{}
for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
return nil
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
@@ -792,11 +818,6 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return
}
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {

View File

@@ -1,160 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
var stream bool = false
func createBinFile(t *testing.T) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
return f.Name()
}
type responseRecorder struct {
*httptest.ResponseRecorder
http.CloseNotifier
}
func NewRecorder() *responseRecorder {
return &responseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
w := NewRecorder()
c, _ := gin.CreateTestContext(w)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
t.Fatal(err)
}
c.Request = &http.Request{
Body: io.NopCloser(&b),
}
fn(c)
return w.ResponseRecorder
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
actual, err := filepath.Glob(p)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(actual, expect) {
t.Fatalf("expected slices to be equal %v", actual)
}
}
func TestCreateFromBin(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateFromModel(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}

View File

@@ -1,71 +0,0 @@
package server
import (
"fmt"
"net/http"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
)
func TestDelete(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}

View File

@@ -1,61 +0,0 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
var s Server
for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
}
w := createRequest(t, s.ListModelsHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}

View File

@@ -21,28 +21,6 @@ import (
"github.com/ollama/ollama/version"
)
func createTestFile(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
func Test_Routes(t *testing.T) {
type testCase struct {
Name string
@@ -52,6 +30,28 @@ func Test_Routes(t *testing.T) {
Expected func(t *testing.T, resp *http.Response)
}
createTestFile := func(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
@@ -237,82 +237,3 @@ func Test_Routes(t *testing.T) {
})
}
}
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"log/slog"
"path/filepath"
"strings"
"unicode"
)
// Errors
@@ -251,10 +252,6 @@ func (n Name) DisplayShortest() string {
return sb.String()
}
func IsValidNamespace(namespace string) bool {
return isValidPart(kindNamespace, namespace)
}
// IsValid reports whether all parts of the name are present and valid. The
// digest is a special case, and is checked for validity only if present.
func (n Name) IsValid() bool {
@@ -322,14 +319,14 @@ func isValidPart(kind partKind, s string) bool {
if !isValidLen(kind, s) {
return false
}
for i := range s {
for i, c := range s {
if i == 0 {
if !isAlphanumericOrUnderscore(s[i]) {
if !isAlphanumericOrUnderscore(c) {
return false
}
continue
}
switch s[i] {
switch c {
case '_', '-':
case '.':
if kind == kindNamespace {
@@ -340,7 +337,7 @@ func isValidPart(kind partKind, s string) bool {
return false
}
default:
if !isAlphanumericOrUnderscore(s[i]) {
if !isAlphanumericOrUnderscore(c) {
return false
}
}
@@ -348,8 +345,8 @@ func isValidPart(kind partKind, s string) bool {
return true
}
func isAlphanumericOrUnderscore(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
func isAlphanumericOrUnderscore(c rune) bool {
return unicode.IsLetter(c) || unicode.IsDigit(c) || c == '_'
}
func cutLast(s, sep string) (before, after string, ok bool) {

View File

@@ -385,30 +385,3 @@ func FuzzName(f *testing.F) {
})
}
func TestIsValidNamespace(t *testing.T) {
cases := []struct {
username string
expected bool
}{
{"", false},
{"a", true},
{"a:b", false},
{"a/b", false},
{"a:b/c", false},
{"a/b:c", false},
{"a/b:c", false},
{"a/b:c/d", false},
{"a/b:c/d@e", false},
{"a/b:c/d@sha256-100", false},
{"himynameisjoe", true},
{"himynameisreallyreallyreallyreallylongbutitshouldstillbevalid", true},
}
for _, tt := range cases {
t.Run(tt.username, func(t *testing.T) {
if got := IsValidNamespace(tt.username); got != tt.expected {
t.Errorf("IsValidName(%q) = %v; want %v", tt.username, got, tt.expected)
}
})
}
}