mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
43 Commits
v0.16.0
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
764c68fe78 | ||
|
|
97b2a92299 | ||
|
|
0c82c9069a | ||
|
|
a52a739dda | ||
|
|
906b046928 | ||
|
|
7469d659e7 | ||
|
|
a8c1ca226c | ||
|
|
317cfb5610 | ||
|
|
4e9d388f65 | ||
|
|
0ade9205cc | ||
|
|
06edabdde1 | ||
|
|
8b4e5a82a8 | ||
|
|
3445223311 | ||
|
|
fa6c0127e6 | ||
|
|
97323d1c68 | ||
|
|
458dd1b9d9 | ||
|
|
9d02d1d767 | ||
|
|
1a636fb47a | ||
|
|
0759fface9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 | ||
|
|
d18dcd7775 | ||
|
|
5f5ef20131 | ||
|
|
f0a07a353b | ||
|
|
948de6bbd2 | ||
|
|
598b74d42c | ||
|
|
935a48ed1a | ||
|
|
de39e24bf7 | ||
|
|
519b11eba1 | ||
|
|
379fd64fa8 | ||
|
|
59c019a6fb | ||
|
|
fad3bcccb2 | ||
|
|
bd6697ad95 |
912
README.md
912
README.md
@@ -1,20 +1,30 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<p align="center">
|
||||
<a href="https://ollama.com">
|
||||
<img src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7" alt="ollama" width="200"/>
|
||||
</a>
|
||||
</div>
|
||||
</p>
|
||||
|
||||
# Ollama
|
||||
|
||||
Get up and running with large language models.
|
||||
Start building with open models.
|
||||
|
||||
## Download
|
||||
|
||||
### macOS
|
||||
|
||||
[Download](https://ollama.com/download/Ollama.dmg)
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
```
|
||||
|
||||
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
[Download](https://ollama.com/download/OllamaSetup.exe)
|
||||
```shell
|
||||
irm https://ollama.com/install.ps1 | iex
|
||||
```
|
||||
|
||||
or [download manually](https://ollama.com/download/OllamaSetup.exe)
|
||||
|
||||
### Linux
|
||||
|
||||
@@ -36,649 +46,311 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
### Community
|
||||
|
||||
- [Discord](https://discord.gg/ollama)
|
||||
- [𝕏 (Twitter)](https://x.com/ollama)
|
||||
- [Reddit](https://reddit.com/r/ollama)
|
||||
|
||||
## Quickstart
|
||||
## Get started
|
||||
|
||||
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||
```
|
||||
ollama
|
||||
```
|
||||
|
||||
```shell
|
||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
||||
|
||||
### Coding
|
||||
|
||||
To launch a specific integration:
|
||||
|
||||
```
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||
|
||||
### AI assistant
|
||||
|
||||
Use [OpenClaw](https://docs.ollama.com/integrations/openclaw) to turn Ollama into a personal AI assistant across WhatsApp, Telegram, Slack, Discord, and more:
|
||||
|
||||
```
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
### Chat with a model
|
||||
|
||||
Run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||
|
||||
```
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
## Model library
|
||||
See [ollama.com/library](https://ollama.com/library) for the full list.
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
||||
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
||||
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
||||
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
|
||||
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
|
||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||
| Llama 3.2 Vision | 11B | 7.9GB | `ollama run llama3.2-vision` |
|
||||
| Llama 3.2 Vision | 90B | 55GB | `ollama run llama3.2-vision:90b` |
|
||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
## Customize a model
|
||||
|
||||
### Import from GGUF
|
||||
|
||||
Ollama supports importing GGUF models in the Modelfile:
|
||||
|
||||
1. Create a file named `Modelfile`, with a `FROM` instruction with the local filepath to the model you want to import.
|
||||
|
||||
```
|
||||
FROM ./vicuna-33b.Q4_0.gguf
|
||||
```
|
||||
|
||||
2. Create the model in Ollama
|
||||
|
||||
```shell
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
3. Run the model
|
||||
|
||||
```shell
|
||||
ollama run example
|
||||
```
|
||||
|
||||
### Import from Safetensors
|
||||
|
||||
See the [guide](https://docs.ollama.com/import) on importing models for more information.
|
||||
|
||||
### Customize a prompt
|
||||
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
||||
|
||||
```shell
|
||||
ollama pull llama3.2
|
||||
```
|
||||
|
||||
Create a `Modelfile`:
|
||||
|
||||
```
|
||||
FROM llama3.2
|
||||
|
||||
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
PARAMETER temperature 1
|
||||
|
||||
# set the system message
|
||||
SYSTEM """
|
||||
You are Mario from Super Mario Bros. Answer as Mario, the assistant, only.
|
||||
"""
|
||||
```
|
||||
|
||||
Next, create and run the model:
|
||||
|
||||
```
|
||||
ollama create mario -f ./Modelfile
|
||||
ollama run mario
|
||||
>>> hi
|
||||
Hello! It's your friend Mario.
|
||||
```
|
||||
|
||||
For more information on working with a Modelfile, see the [Modelfile](https://docs.ollama.com/modelfile) documentation.
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### Create a model
|
||||
|
||||
`ollama create` is used to create a model from a Modelfile.
|
||||
|
||||
```shell
|
||||
ollama create mymodel -f ./Modelfile
|
||||
```
|
||||
|
||||
### Pull a model
|
||||
|
||||
```shell
|
||||
ollama pull llama3.2
|
||||
```
|
||||
|
||||
> This command can also be used to update a local model. Only the diff will be pulled.
|
||||
|
||||
### Remove a model
|
||||
|
||||
```shell
|
||||
ollama rm llama3.2
|
||||
```
|
||||
|
||||
### Copy a model
|
||||
|
||||
```shell
|
||||
ollama cp llama3.2 my-model
|
||||
```
|
||||
|
||||
### Multiline input
|
||||
|
||||
For multiline input, you can wrap text with `"""`:
|
||||
|
||||
```
|
||||
>>> """Hello,
|
||||
... world!
|
||||
... """
|
||||
I'm a basic program that prints the famous "Hello, world!" message to the console.
|
||||
```
|
||||
|
||||
### Multimodal models
|
||||
|
||||
```
|
||||
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
||||
```
|
||||
|
||||
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||
|
||||
### Pass the prompt as an argument
|
||||
|
||||
```shell
|
||||
ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
||||
```
|
||||
|
||||
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
|
||||
### Show model information
|
||||
|
||||
```shell
|
||||
ollama show llama3.2
|
||||
```
|
||||
|
||||
### List models on your computer
|
||||
|
||||
```shell
|
||||
ollama list
|
||||
```
|
||||
|
||||
### List which models are currently loaded
|
||||
|
||||
```shell
|
||||
ollama ps
|
||||
```
|
||||
|
||||
### Stop a model which is currently running
|
||||
|
||||
```shell
|
||||
ollama stop llama3.2
|
||||
```
|
||||
|
||||
### Generate embeddings from the CLI
|
||||
|
||||
```shell
|
||||
ollama run embeddinggemma "Your text to embed"
|
||||
```
|
||||
|
||||
You can also pipe text for scripted workflows:
|
||||
|
||||
```shell
|
||||
echo "Your text to embed" | ollama run embeddinggemma
|
||||
```
|
||||
|
||||
### Start Ollama
|
||||
|
||||
`ollama serve` is used when you want to start ollama without running the desktop application.
|
||||
|
||||
## Building
|
||||
|
||||
See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/development.md)
|
||||
|
||||
### Running local builds
|
||||
|
||||
Next, start the server:
|
||||
|
||||
```shell
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
Finally, in a separate shell, run a model:
|
||||
|
||||
```shell
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
|
||||
```shell
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
See the [quickstart guide](https://docs.ollama.com/quickstart) for more details.
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
|
||||
### Generate a response
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama3.2",
|
||||
"prompt":"Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
|
||||
### Chat with a model
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{ "role": "user", "content": "why is the sky blue?" }
|
||||
]
|
||||
"model": "gemma3",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Why is the sky blue?"
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](./docs/api.md) for all endpoints.
|
||||
See the [API documentation](https://docs.ollama.com/api) for all endpoints.
|
||||
|
||||
### Python
|
||||
|
||||
```
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
```python
|
||||
from ollama import chat
|
||||
|
||||
response = chat(model='gemma3', messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
])
|
||||
print(response.message.content)
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
|
||||
```
|
||||
npm i ollama
|
||||
```
|
||||
|
||||
```javascript
|
||||
import ollama from "ollama";
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: "gemma3",
|
||||
messages: [{ role: "user", content: "Why is the sky blue?" }],
|
||||
});
|
||||
console.log(response.message.content);
|
||||
```
|
||||
|
||||
## Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
## Documentation
|
||||
|
||||
- [CLI reference](https://docs.ollama.com/cli)
|
||||
- [REST API reference](https://docs.ollama.com/api)
|
||||
- [Importing models](https://docs.ollama.com/import)
|
||||
- [Modelfile reference](https://docs.ollama.com/modelfile)
|
||||
- [Building from source](https://github.com/ollama/ollama/blob/main/docs/development.md)
|
||||
|
||||
## Community Integrations
|
||||
|
||||
### Web & Desktop
|
||||
> Want to add your project? Open a pull request.
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
- [Hollama](https://github.com/fmaclen/hollama)
|
||||
- [Lollms WebUI (Single user)](https://github.com/ParisNeo/lollms-webui)
|
||||
- [Lollms (Multi users)](https://github.com/ParisNeo/lollms)
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [AI-UI](https://github.com/bajahaw/ai-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
||||
- [Dify.AI](https://github.com/langgenius/dify)
|
||||
- [MindMac](https://mindmac.app)
|
||||
- [NextJS Web Interface for Ollama](https://github.com/jakobhoeg/nextjs-ollama-llm-ui)
|
||||
- [Msty](https://msty.app)
|
||||
- [Chatbox](https://github.com/Bin-Huang/Chatbox)
|
||||
- [WinForm Ollama Copilot](https://github.com/tgraupmann/WinForm_Ollama_Copilot)
|
||||
- [NextChat](https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web) with [Get Started Doc](https://docs.nextchat.dev/models/ollama)
|
||||
- [Alpaca WebUI](https://github.com/mmo80/alpaca-webui)
|
||||
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
|
||||
- [OpenAOE](https://github.com/InternLM/OpenAOE)
|
||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
|
||||
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
|
||||
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
|
||||
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
|
||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
||||
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
|
||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
|
||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
||||
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
|
||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
||||
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
||||
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
|
||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
|
||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
||||
- [VT](https://github.com/vinhnx/vt.ai) (A minimal multimodal AI chat app, with dynamic conversation routing. Supports local models via Ollama)
|
||||
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application available for Mac/Windows/Linux)
|
||||
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
|
||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
|
||||
### Chat Interfaces
|
||||
|
||||
### Cloud
|
||||
#### Web
|
||||
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui) - Extensible, self-hosted AI interface
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx) - Connected AI workspace
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat) - Enhanced ChatGPT clone with multi-provider support
|
||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) - Modern chat framework with plugin ecosystem ([docs](https://lobehub.com/docs/self-hosting/examples/ollama))
|
||||
- [NextChat](https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web) - Cross-platform ChatGPT UI ([docs](https://docs.nextchat.dev/models/ollama))
|
||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) - AI-powered search engine, open-source Perplexity alternative
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI) - AI suite for professionals
|
||||
- [Lollms WebUI](https://github.com/ParisNeo/lollms-webui) - Multi-model web interface
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) - Chatbot with knowledge bases
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - On-premise AI platform
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - ChatGPT-style web interface
|
||||
- [Hollama](https://github.com/fmaclen/hollama) - Minimal web interface
|
||||
- [Chatbox](https://github.com/Bin-Huang/Chatbox) - Desktop and web AI client
|
||||
- [chat](https://github.com/swuecho/chat) - Chat web app for teams
|
||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) - Chat with multiple PDFs using RAG
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) - Python desktop client
|
||||
|
||||
#### Desktop
|
||||
|
||||
- [Dify.AI](https://github.com/langgenius/dify) - LLM app development platform
|
||||
- [AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) - All-in-one AI app for Mac, Windows, and Linux
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - Cross-platform mobile and desktop client
|
||||
- [Witsy](https://github.com/nbonamy/witsy) - AI desktop app for Mac, Windows, and Linux
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) - Multi-provider desktop client
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) - Multi-platform client for desktop and mobile
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) - AI desktop assistant for Linux, Windows, and Mac
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) - GTK4 client for Linux and macOS
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) - Cross-platform including iOS, Android, and Apple Vision Pro
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted) - Native macOS and iOS client
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) - Multi-model desktop runner
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) - Evaluate and compare models
|
||||
- [macai](https://github.com/Renset/macai) - macOS client for Ollama and ChatGPT
|
||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio) - Multi-provider desktop IDE
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) - Parameter tuning and reasoning model support
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) - Privacy-focused with optional encryption
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) - Electron desktop client
|
||||
- [MindMac](https://mindmac.app) - AI chat client for Mac
|
||||
- [Msty](https://msty.app) - Multi-model desktop client
|
||||
- [BoltAI for Mac](https://boltai.com) - AI chat client for Mac
|
||||
- [IntelliBar](https://intellibar.app/) - AI-powered assistant for macOS
|
||||
- [Kerlig AI](https://www.kerlig.com/) - AI writing assistant for macOS
|
||||
- [Hillnote](https://hillnote.com) - Markdown-first AI workspace
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) - Productivity AI personalized by screen and meeting history
|
||||
|
||||
#### Mobile
|
||||
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) - One-click Ollama on Android
|
||||
|
||||
> SwiftChat, Enchanted, Maid, Ollama App, Reins, and ConfiChat listed above also support mobile platforms.
|
||||
|
||||
### Code Editors & Development
|
||||
|
||||
- [Cline](https://github.com/cline/cline) - VS Code extension for multi-file/whole-repo coding
|
||||
- [Continue](https://github.com/continuedev/continue) - Open-source AI code assistant for any IDE
|
||||
- [Void](https://github.com/voideditor/void) - Open source AI code editor, Cursor alternative
|
||||
- [Copilot for Obsidian](https://github.com/logancyang/obsidian-copilot) - AI assistant for Obsidian
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) - Copilot and Copilot chat alternative
|
||||
- [gptel Emacs client](https://github.com/karthink/gptel) - LLM client for Emacs
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) - Use Ollama as GitHub Copilot
|
||||
- [Obsidian Local GPT](https://github.com/pfrankov/obsidian-local-gpt) - Local AI for Obsidian
|
||||
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama) - LLM tool for Emacs
|
||||
- [orbiton](https://github.com/xyproto/orbiton) - Config-free text editor with Ollama tab completion
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) - Sublime Text 4 AI assistant
|
||||
- [VT Code](https://github.com/vinhnx/vtcode) - Rust-based terminal coding agent with Tree-sitter
|
||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) - AI coding assistant for Qt Creator
|
||||
- [AI Toolkit for VS Code](https://aka.ms/ai-tooklit/ollama-docs) - Microsoft-official VS Code extension
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - Natural language interface for computers
|
||||
|
||||
### Libraries & SDKs
|
||||
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm) - Unified API for 100+ LLM providers
|
||||
- [Semantic Kernel](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama) - Microsoft AI orchestration SDK
|
||||
- [LangChain4j](https://github.com/langchain4j/langchain4j) - Java LangChain ([example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java))
|
||||
- [LangChainGo](https://github.com/tmc/langchaingo/) - Go LangChain ([example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example))
|
||||
- [Spring AI](https://github.com/spring-projects/spring-ai) - Spring framework AI support ([docs](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html))
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm) - Ruby LLM library
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) - Unified LLM interface by Mozilla
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) - .NET SDK
|
||||
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) - Rust LangChain ([example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs))
|
||||
- [Agents-Flex for Java](https://github.com/agents-flex/agents-flex) - Java agent framework ([example](https://github.com/agents-flex/agents-flex/tree/main/agents-flex-llm/agents-flex-llm-ollama/src/test/java/com/agentsflex/llm/ollama))
|
||||
- [Elixir LangChain](https://github.com/brainlid/langchain) - Elixir LangChain
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs) - Rust SDK
|
||||
- [LangChain for .NET](https://github.com/tryAGI/LangChain) - .NET LangChain ([example](https://github.com/tryAGI/LangChain/blob/main/examples/LangChain.Samples.OpenAI/Program.cs))
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go) - Go vector database with Ollama embeddings ([example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama))
|
||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart) - Dart LangChain
|
||||
- [LlmTornado](https://github.com/lofcz/llmtornado) - Unified C# interface for multiple inference APIs
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j) - Java SDK
|
||||
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel) - Laravel integration
|
||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift) - Swift SDK
|
||||
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/llm/ollama/) and [LlamaIndexTS](https://ts.llamaindex.ai/modules/llms/available_llms/ollama) - Data framework for LLM apps
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - AI pipeline framework
|
||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama) - Google AI framework
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp) - C++ SDK
|
||||
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) - Julia LLM toolkit ([example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama))
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama) - R SDK
|
||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - AI gateway
|
||||
- [Testcontainers](https://testcontainers.com/modules/ollama/) - Container-based testing
|
||||
- [LLPhant](https://github.com/theodo-group/LLPhant?tab=readme-ov-file#ollama) - PHP AI framework
|
||||
|
||||
### Frameworks & Agents
|
||||
|
||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) - Autonomous AI agent platform
|
||||
- [crewAI](https://github.com/crewAIInc/crewAI) - Multi-agent orchestration framework
|
||||
- [Strands Agents](https://github.com/strands-agents/sdk-python) - Model-driven agent building by AWS
|
||||
- [Cheshire Cat](https://github.com/cheshire-cat-ai/core) - AI assistant framework
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) - Unified agent framework interface by Mozilla
|
||||
- [Stakpak](https://github.com/stakpak/agent) - Open source DevOps agent
|
||||
- [Hexabot](https://github.com/hexastack/hexabot) - Conversational AI builder
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) - Multi-agent orchestration ([docs](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama))
|
||||
|
||||
### RAG & Knowledge Bases
|
||||
|
||||
- [RAGFlow](https://github.com/infiniflow/ragflow) - RAG engine based on deep document understanding
|
||||
- [R2R](https://github.com/SciPhi-AI/R2R) - Open-source RAG engine
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) - Ready-to-use RAG chatbot
|
||||
- [Minima](https://github.com/dmayboroda/minima) - On-premises or fully local RAG
|
||||
- [Chipper](https://github.com/TilmanGriesel/chipper) - AI interface with Haystack RAG
|
||||
- [ARGO](https://github.com/xark-argo/argo) - RAG and deep research on Mac/Windows/Linux
|
||||
- [Archyve](https://github.com/nickthecook/archyve) - RAG-enabling document library
|
||||
- [Casibase](https://casibase.org) - AI knowledge base with RAG and SSO
|
||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) - Native client with RAG and multi-agent automation
|
||||
|
||||
### Bots & Messaging
|
||||
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) - Multi-platform messaging bots with agents and RAG
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) - Multi-platform chatbot with RAG and plugins
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) - TypeScript Discord bot
|
||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram) - Telegram bot
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) - Telegram bot for roleplay
|
||||
|
||||
### Terminal & CLI
|
||||
|
||||
- [aichat](https://github.com/sigoden/aichat) - All-in-one LLM CLI with Shell Assistant, RAG, and AI tools
|
||||
- [oterm](https://github.com/ggozad/oterm) - Terminal client for Ollama
|
||||
- [gollama](https://github.com/sammcj/gollama) - Go-based model manager for Ollama
|
||||
- [tlm](https://github.com/yusufcanb/tlm) - Local shell copilot
|
||||
- [tenere](https://github.com/pythops/tenere) - TUI for LLMs
|
||||
- [ParLlama](https://github.com/paulrobello/parllama) - TUI for Ollama
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) - Plugin for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/)
|
||||
- [ShellOracle](https://github.com/djcopley/ShellOracle) - Shell command suggestions
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) - Progressive web app for LLMs
|
||||
- [cmdh](https://github.com/pgibler/cmdh) - Natural language to shell commands
|
||||
- [VT](https://github.com/vinhnx/vt.ai) - Minimal multimodal AI chat app
|
||||
|
||||
### Productivity & Apps
|
||||
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) - AI collaborative workspace, self-hostable Notion alternative
|
||||
- [Screenpipe](https://github.com/mediar-ai/screenpipe) - 24/7 screen and mic recording with AI-powered search
|
||||
- [Vibe](https://github.com/thewh1teagle/vibe) - Transcribe and analyze meetings
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) - Chrome extension for AI-powered browsing
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) - Private, on-device browser AI assistant
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - Security proxy for Ollama
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) - Web-based Linux server management
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) - Text editor with Ollama integration
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) - GitHub code repository understanding
|
||||
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama) - Ollama in Raycast
|
||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) - Painting app with AI integrations
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) - AI roleplaying app
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) - Document management with Ollama workflows
|
||||
- [TagSpaces](https://www.tagspaces.org) - File management with [AI tagging](https://docs.tagspaces.org/ai/)
|
||||
|
||||
### Observability & Monitoring
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) - Debug, evaluate, and monitor LLM applications
|
||||
- [OpenLIT](https://github.com/openlit/openlit) - OpenTelemetry-native monitoring for Ollama and GPUs
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) - LLM observability with analytics and PII masking
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) - Open source LLM observability
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) - AI observability and evaluation for agents
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) - Open source LLM observability
|
||||
|
||||
### Database & Embeddings
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database ([guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md))
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) - Connect Ollama with 200+ data platforms
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) - Embeddable vector database for Go ([example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama))
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) - AI-powered SQL client
|
||||
|
||||
### Infrastructure & Deployment
|
||||
|
||||
#### Cloud
|
||||
|
||||
- [Google Cloud](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama)
|
||||
- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/)
|
||||
- [Koyeb](https://www.koyeb.com/deploy/ollama)
|
||||
- [Harbor](https://github.com/av/harbor) - Containerized LLM toolkit with Ollama as default backend
|
||||
|
||||
### Tutorial
|
||||
|
||||
- [handy-ollama](https://github.com/datawhalechina/handy-ollama) (Chinese Tutorial for Ollama by [Datawhale ](https://github.com/datawhalechina) - China's Largest Open Source AI Learning Community)
|
||||
|
||||
### Terminal
|
||||
|
||||
- [oterm](https://github.com/ggozad/oterm)
|
||||
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
|
||||
- [Emacs client](https://github.com/zweifisch/ollama)
|
||||
- [neollama](https://github.com/paradoxical-dev/neollama) UI client for interacting with models from within Neovim
|
||||
- [gen.nvim](https://github.com/David-Kunz/gen.nvim)
|
||||
- [ollama.nvim](https://github.com/nomnivore/ollama.nvim)
|
||||
- [ollero.nvim](https://github.com/marco-souza/ollero.nvim)
|
||||
- [ollama-chat.nvim](https://github.com/gerazov/ollama-chat.nvim)
|
||||
- [ogpt.nvim](https://github.com/huynle/ogpt.nvim)
|
||||
- [gptel Emacs client](https://github.com/karthink/gptel)
|
||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||
- [cmdh](https://github.com/pgibler/cmdh)
|
||||
- [ooo](https://github.com/npahlfer/ooo)
|
||||
- [shell-pilot](https://github.com/reid41/shell-pilot)(Interact with models via pure shell scripts on Linux or macOS)
|
||||
- [tenere](https://github.com/pythops/tenere)
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
||||
- [tlm](https://github.com/yusufcanb/tlm)
|
||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||
- [gollama](https://github.com/sammcj/gollama)
|
||||
- [ParLlama](https://github.com/paulrobello/parllama)
|
||||
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
||||
- [Ollama Mixture of Experts (MOE) in 50 lines of code](https://github.com/rapidarchitect/ollama_moe)
|
||||
- [vim-intelligence-bridge](https://github.com/pepo-ec/vim-intelligence-bridge) Simple interaction of "Ollama" with the Vim editor
|
||||
- [x-cmd ollama](https://x-cmd.com/mod/ollama)
|
||||
- [bb7](https://github.com/drunkwcodes/bb7)
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
- [hle-eval-ollama](https://github.com/mags0ft/hle-eval-ollama) - Runs benchmarks like "Humanity's Last Exam" (HLE) on your favorite local Ollama models and evaluates the quality of their responses
|
||||
- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter. Ollama integration for running local/cloud models with configurable endpoints.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
|
||||
### Package managers
|
||||
#### Package Managers
|
||||
|
||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||
|
||||
### Libraries
|
||||
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
||||
- [crewAI](https://github.com/crewAIInc/crewAI)
|
||||
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
||||
- [Strands Agents](https://github.com/strands-agents/sdk-python) (A model-driven approach to building AI agents in just a few lines of code)
|
||||
- [Spring AI](https://github.com/spring-projects/spring-ai) with [reference](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html) and [example](https://github.com/tzolov/ollama-tools)
|
||||
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
|
||||
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
|
||||
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
|
||||
- [LangChain for .NET](https://github.com/tryAGI/LangChain) with [example](https://github.com/tryAGI/LangChain/blob/main/examples/LangChain.Samples.OpenAI/Program.cs)
|
||||
- [LLPhant](https://github.com/theodo-group/LLPhant?tab=readme-ov-file#ollama)
|
||||
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/llm/ollama/) and [LlamaIndexTS](https://ts.llamaindex.ai/modules/llms/available_llms/ollama)
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
||||
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
||||
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
|
||||
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel)
|
||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
|
||||
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||
- [Elixir LangChain](https://github.com/brainlid/langchain)
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
|
||||
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
|
||||
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
|
||||
- [Testcontainers](https://testcontainers.com/modules/ollama/)
|
||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
||||
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
|
||||
- [LlamaScript](https://github.com/Project-Llama/llamascript)
|
||||
- [llm-axe](https://github.com/emirsahin1/llm-axe) (Python Toolkit for Building LLM Powered Apps)
|
||||
- [Gollm](https://docs.gollm.co/examples/ollama-example)
|
||||
- [Gollama for Golang](https://github.com/jonathanhecl/gollama)
|
||||
- [Ollamaclient for Golang](https://github.com/xyproto/ollamaclient)
|
||||
- [High-level function abstraction in Go](https://gitlab.com/tozd/go/fun)
|
||||
- [Ollama PHP](https://github.com/ArdaGnsrn/ollama-php)
|
||||
- [Agents-Flex for Java](https://github.com/agents-flex/agents-flex) with [example](https://github.com/agents-flex/agents-flex/tree/main/agents-flex-llm/agents-flex-llm-ollama/src/test/java/com/agentsflex/llm/ollama)
|
||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
|
||||
- [Ollama Bash Lib](https://github.com/attogram/ollama-bash-lib) - A Bash Library for Ollama. Run LLM prompts straight from your shell, and more
|
||||
|
||||
### Mobile
|
||||
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
### Extensions & Plugins
|
||||
|
||||
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama)
|
||||
- [Discollama](https://github.com/mxyng/discollama) (Discord bot inside the Ollama discord channel)
|
||||
- [Continue](https://github.com/continuedev/continue)
|
||||
- [Vibe](https://github.com/thewh1teagle/vibe) (Transcribe and analyze meetings with Ollama)
|
||||
- [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama)
|
||||
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
|
||||
- [NotesOllama](https://github.com/andersrex/notesollama) (Apple Notes Ollama plugin)
|
||||
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
|
||||
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
|
||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
|
||||
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
|
||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
- [Plasmoid Ollama Control](https://github.com/imoize/plasmoid-ollamacontrol) (KDE Plasma extension that allows you to quickly manage/control Ollama model)
|
||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
|
||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
@@ -82,22 +90,25 @@ type MessageParam struct {
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For text blocks with citations
|
||||
Citations []Citation `json:"citations,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
@@ -105,6 +116,30 @@ type ContentBlock struct {
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// Citation represents a citation in a text block
|
||||
type Citation struct {
|
||||
Type string `json:"type"` // "web_search_result_location"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedIndex string `json:"encrypted_index,omitempty"`
|
||||
CitedText string `json:"cited_text,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single web search result
|
||||
type WebSearchResult struct {
|
||||
Type string `json:"type"` // "web_search_result"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchToolResultError represents an error from web search
|
||||
type WebSearchToolResultError struct {
|
||||
Type string `json:"type"` // "web_search_tool_result_error"
|
||||
ErrorCode string `json:"error_code"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
@@ -115,10 +150,13 @@ type ImageSource struct {
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
|
||||
// Web search specific fields
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
@@ -233,6 +271,8 @@ type StreamErrorEvent struct {
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
@@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
for i, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
@@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
hasBuiltinWebSearch := false
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
hasBuiltinWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range r.Tools {
|
||||
// Anthropic built-in web_search maps to Ollama function name "web_search".
|
||||
// If a user-defined tool also uses that name in the same request, drop the
|
||||
// user-defined one to avoid ambiguous tool-call routing.
|
||||
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
|
||||
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
|
||||
continue
|
||||
}
|
||||
|
||||
tool, _, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
convertedRequest := &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
|
||||
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
@@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
@@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
@@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
@@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
@@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
@@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
func formatWebSearchToolResultContent(content any) string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return c
|
||||
case []WebSearchResult:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
if item.Type != "web_search_result" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
|
||||
}
|
||||
return resultContent.String()
|
||||
case []any:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch itemMap["type"] {
|
||||
case "web_search_result":
|
||||
title, _ := itemMap["title"].(string)
|
||||
url, _ := itemMap["url"].(string)
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
|
||||
case "web_search_tool_result_error":
|
||||
errorCode, _ := itemMap["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
}
|
||||
return resultContent.String()
|
||||
case map[string]any:
|
||||
if c["type"] == "web_search_tool_result_error" {
|
||||
errorCode, _ := c["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
case WebSearchToolResultError:
|
||||
if c.ErrorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + c.ErrorCode
|
||||
default:
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
|
||||
func convertTool(t Tool) (api.Tool, bool, error) {
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"query"},
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
|
||||
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}, false, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
@@ -899,3 +1098,113 @@ func countContentBlock(block any) int {
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// OllamaWebSearchRequest represents a request to the Ollama web search API
|
||||
type OllamaWebSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResult represents a single search result from Ollama API
|
||||
type OllamaWebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResponse represents the response from the Ollama web search API
|
||||
type OllamaWebSearchResponse struct {
|
||||
Results []OllamaWebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
var WebSearchEndpoint = "https://ollama.com/api/web_search"
|
||||
|
||||
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
|
||||
if internalcloud.Disabled() {
|
||||
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
|
||||
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
|
||||
}
|
||||
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
if maxResults > 10 {
|
||||
maxResults = 10
|
||||
}
|
||||
|
||||
reqBody := OllamaWebSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
|
||||
}
|
||||
|
||||
searchURL, err := url.Parse(WebSearchEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search request",
|
||||
"query", TraceTruncateString(query),
|
||||
"max_results", maxResults,
|
||||
"url", searchURL.String(),
|
||||
)
|
||||
|
||||
q := searchURL.Query()
|
||||
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
searchURL.RawQuery = q.Encode()
|
||||
|
||||
signature := ""
|
||||
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
|
||||
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||
signature, err = auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign web search request: %w", err)
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create web search request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var searchResp OllamaWebSearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode web search response: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
|
||||
|
||||
return &searchResp, nil
|
||||
}
|
||||
|
||||
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
|
||||
var results []WebSearchResult
|
||||
for _, r := range ollamaResults.Results {
|
||||
results = append(results, WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search that should be dropped",
|
||||
InputSchema: json.RawMessage(`{"type":"invalid"}`),
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[1].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[0].Function.Description != "User-defined web search" {
|
||||
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Web Search Tests
|
||||
|
||||
func TestConvertTool_WebSearch(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
MaxUses: 5,
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !isServerTool {
|
||||
t.Error("expected isServerTool to be true for web_search tool")
|
||||
}
|
||||
|
||||
if result.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", result.Type)
|
||||
}
|
||||
|
||||
if result.Function.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
|
||||
}
|
||||
|
||||
if result.Function.Description == "" {
|
||||
t.Error("expected non-empty description for web_search tool")
|
||||
}
|
||||
|
||||
// Check that query parameter is defined
|
||||
if result.Function.Parameters.Properties == nil {
|
||||
t.Fatal("expected properties to be defined")
|
||||
}
|
||||
|
||||
queryProp, ok := result.Function.Parameters.Properties.Get("query")
|
||||
if !ok {
|
||||
t.Error("expected 'query' property to be defined")
|
||||
}
|
||||
|
||||
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
|
||||
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertTool_RegularTool(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if isServerTool {
|
||||
t.Error("expected isServerTool to be false for regular tool")
|
||||
}
|
||||
|
||||
if result.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if len(messages[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
|
||||
}
|
||||
|
||||
tc := messages[0].ToolCalls[0]
|
||||
if tc.ID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
|
||||
}
|
||||
|
||||
if tc.Function.Name != "web_search" {
|
||||
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have a tool result message
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if messages[0].Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", messages[0].Role)
|
||||
}
|
||||
|
||||
if messages[0].ToolCallID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
|
||||
}
|
||||
|
||||
if messages[0].Content == "" {
|
||||
t.Error("expected non-empty content from web search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_empty" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if messages[0].Content != "" {
|
||||
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_error" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
|
||||
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOllamaToAnthropicResults(t *testing.T) {
|
||||
ollamaResp := &OllamaWebSearchResponse{
|
||||
Results: []OllamaWebSearchResult{
|
||||
{
|
||||
Title: "Test Title",
|
||||
URL: "https://example.com",
|
||||
Content: "Test content",
|
||||
},
|
||||
{
|
||||
Title: "Another Result",
|
||||
URL: "https://example.org",
|
||||
Content: "More content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ConvertOllamaToAnthropicResults(ollamaResp)
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Type != "web_search_result" {
|
||||
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
|
||||
}
|
||||
|
||||
if results[0].Title != "Test Title" {
|
||||
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
|
||||
}
|
||||
|
||||
if results[0].URL != "https://example.com" {
|
||||
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTypes(t *testing.T) {
|
||||
// Test that WebSearchResult serializes correctly
|
||||
result := WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: "https://example.com",
|
||||
Title: "Test",
|
||||
EncryptedContent: "abc123",
|
||||
PageAge: "2025-01-01",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled WebSearchResult
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != result.Type {
|
||||
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
|
||||
}
|
||||
|
||||
// Test WebSearchToolResultError
|
||||
errResult := WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(errResult)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaledErr WebSearchToolResultError
|
||||
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
|
||||
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitation(t *testing.T) {
|
||||
citation := Citation{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com",
|
||||
Title: "Example",
|
||||
EncryptedIndex: "enc123",
|
||||
CitedText: "Some cited text...",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(citation)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal Citation: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled Citation
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal Citation: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != "web_search_result_location" {
|
||||
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
|
||||
}
|
||||
|
||||
if unmarshaled.CitedText != "Some cited text..." {
|
||||
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
|
||||
}
|
||||
}
|
||||
|
||||
352
anthropic/trace.go
Normal file
352
anthropic/trace.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Trace truncation limits.
|
||||
const (
|
||||
TraceMaxStringRunes = 240
|
||||
TraceMaxSliceItems = 8
|
||||
TraceMaxMapEntries = 16
|
||||
TraceMaxDepth = 4
|
||||
)
|
||||
|
||||
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
|
||||
// omitted characters when truncated.
|
||||
func TraceTruncateString(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= TraceMaxStringRunes {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
|
||||
}
|
||||
|
||||
// TraceJSON round-trips v through JSON and returns a compacted representation.
|
||||
func TraceJSON(v any) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return TraceTruncateString(string(data))
|
||||
}
|
||||
return TraceCompactValue(out, 0)
|
||||
}
|
||||
|
||||
// TraceCompactValue recursively truncates strings, slices, and maps for trace
|
||||
// output. depth tracks recursion to enforce TraceMaxDepth.
|
||||
func TraceCompactValue(v any, depth int) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if depth >= TraceMaxDepth {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
return fmt.Sprintf("<array len=%d>", len(t))
|
||||
case map[string]any:
|
||||
return fmt.Sprintf("<object keys=%d>", len(t))
|
||||
default:
|
||||
return fmt.Sprintf("<%T>", v)
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
limit := min(len(t), TraceMaxSliceItems)
|
||||
out := make([]any, 0, limit+1)
|
||||
for i := range limit {
|
||||
out = append(out, TraceCompactValue(t[i], depth+1))
|
||||
}
|
||||
if len(t) > limit {
|
||||
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
keys := make([]string, 0, len(t))
|
||||
for k := range t {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
limit := min(len(keys), TraceMaxMapEntries)
|
||||
out := make(map[string]any, limit+1)
|
||||
for i := range limit {
|
||||
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
|
||||
}
|
||||
if len(keys) > limit {
|
||||
out["__truncated_keys"] = len(keys) - limit
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic request/response tracing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
|
||||
func TraceMessagesRequest(r MessagesRequest) map[string]any {
|
||||
return map[string]any{
|
||||
"model": r.Model,
|
||||
"max_tokens": r.MaxTokens,
|
||||
"messages": traceMessageParams(r.Messages),
|
||||
"system": traceAnthropicContent(r.System),
|
||||
"stream": r.Stream,
|
||||
"tools": traceTools(r.Tools),
|
||||
"tool_choice": TraceJSON(r.ToolChoice),
|
||||
"thinking": TraceJSON(r.Thinking),
|
||||
"stop_sequences": r.StopSequences,
|
||||
"temperature": ptrVal(r.Temperature),
|
||||
"top_p": ptrVal(r.TopP),
|
||||
"top_k": ptrVal(r.TopK),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
|
||||
func TraceMessagesResponse(r MessagesResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"id": r.ID,
|
||||
"model": r.Model,
|
||||
"content": TraceJSON(r.Content),
|
||||
"stop_reason": r.StopReason,
|
||||
"usage": r.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func traceMessageParams(msgs []MessageParam) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, map[string]any{
|
||||
"role": m.Role,
|
||||
"content": traceAnthropicContent(m.Content),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceAnthropicContent(content any) any {
|
||||
switch c := content.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
return TraceTruncateString(c)
|
||||
case []any:
|
||||
blocks := make([]any, 0, len(c))
|
||||
for _, block := range c {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
blocks = append(blocks, TraceCompactValue(block, 0))
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, traceAnthropicBlock(blockMap))
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return TraceJSON(c)
|
||||
}
|
||||
}
|
||||
|
||||
func traceAnthropicBlock(block map[string]any) map[string]any {
|
||||
blockType, _ := block["type"].(string)
|
||||
out := map[string]any{"type": blockType}
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
out["text"] = TraceTruncateString(text)
|
||||
} else {
|
||||
out["text"] = TraceCompactValue(block["text"], 0)
|
||||
}
|
||||
case "thinking":
|
||||
if thinking, ok := block["thinking"].(string); ok {
|
||||
out["thinking"] = TraceTruncateString(thinking)
|
||||
} else {
|
||||
out["thinking"] = TraceCompactValue(block["thinking"], 0)
|
||||
}
|
||||
case "tool_use", "server_tool_use":
|
||||
out["id"] = block["id"]
|
||||
out["name"] = block["name"]
|
||||
out["input"] = TraceCompactValue(block["input"], 0)
|
||||
case "tool_result", "web_search_tool_result":
|
||||
out["tool_use_id"] = block["tool_use_id"]
|
||||
out["content"] = TraceCompactValue(block["content"], 0)
|
||||
case "image":
|
||||
if source, ok := block["source"].(map[string]any); ok {
|
||||
out["source"] = map[string]any{
|
||||
"type": source["type"],
|
||||
"media_type": source["media_type"],
|
||||
"url": source["url"],
|
||||
"data_len": len(fmt.Sprint(source["data"])),
|
||||
}
|
||||
}
|
||||
default:
|
||||
out["block"] = TraceCompactValue(block, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceTools(tools []Tool) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceTool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceTool returns a compact trace representation of an Anthropic Tool.
|
||||
func TraceTool(t Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": TraceTruncateString(t.Description),
|
||||
"input_schema": TraceJSON(t.InputSchema),
|
||||
"max_uses": t.MaxUses,
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
|
||||
func ContentBlockTypes(content any) []string {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
types := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
types = append(types, fmt.Sprintf("%T", block))
|
||||
continue
|
||||
}
|
||||
t, _ := blockMap["type"].(string)
|
||||
types = append(types, t)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func ptrVal[T any](v *T) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ollama api.* tracing (shared between anthropic and middleware packages)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
|
||||
func TraceChatRequest(req *api.ChatRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
stream := false
|
||||
if req.Stream != nil {
|
||||
stream = *req.Stream
|
||||
}
|
||||
return map[string]any{
|
||||
"model": req.Model,
|
||||
"messages": TraceAPIMessages(req.Messages),
|
||||
"tools": TraceAPITools(req.Tools),
|
||||
"stream": stream,
|
||||
"options": req.Options,
|
||||
"think": TraceJSON(req.Think),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
|
||||
func TraceChatResponse(resp api.ChatResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"model": resp.Model,
|
||||
"done": resp.Done,
|
||||
"done_reason": resp.DoneReason,
|
||||
"message": TraceAPIMessage(resp.Message),
|
||||
"metrics": TraceJSON(resp.Metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
|
||||
func TraceAPIMessages(msgs []api.Message) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, TraceAPIMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPIMessage returns a compact trace representation of a single api.Message.
|
||||
func TraceAPIMessage(m api.Message) map[string]any {
|
||||
return map[string]any{
|
||||
"role": m.Role,
|
||||
"content": TraceTruncateString(m.Content),
|
||||
"thinking": TraceTruncateString(m.Thinking),
|
||||
"images": traceImageSizes(m.Images),
|
||||
"tool_calls": traceToolCalls(m.ToolCalls),
|
||||
"tool_name": m.ToolName,
|
||||
"tool_call_id": m.ToolCallID,
|
||||
}
|
||||
}
|
||||
|
||||
func traceImageSizes(images []api.ImageData) []int {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
sizes := make([]int, 0, len(images))
|
||||
for _, img := range images {
|
||||
sizes = append(sizes, len(img))
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// TraceAPITools returns compact trace representations for a slice of api.Tool.
|
||||
func TraceAPITools(tools api.Tools) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceAPITool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPITool returns a compact trace representation of a single api.Tool.
|
||||
func TraceAPITool(t api.Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Function.Name,
|
||||
"description": TraceTruncateString(t.Function.Description),
|
||||
"parameters": TraceJSON(t.Function.Parameters),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceToolCall returns a compact trace representation of an api.ToolCall.
|
||||
func TraceToolCall(tc api.ToolCall) map[string]any {
|
||||
return map[string]any{
|
||||
"id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": TraceJSON(tc.Function.Arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
out = append(out, TraceToolCall(tc))
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -449,6 +449,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// CloudStatusExperimental returns whether cloud features are disabled on the server.
|
||||
func (c *Client) CloudStatusExperimental(ctx context.Context) (*StatusResponse, error) {
|
||||
var status StatusResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/status", nil, &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// Signout will signout a client for a local ollama server.
|
||||
func (c *Client) Signout(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||
|
||||
10
api/types.go
10
api/types.go
@@ -834,6 +834,16 @@ type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type CloudStatus struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||
type StatusResponse struct {
|
||||
Cloud CloudStatus `json:"cloud"`
|
||||
}
|
||||
|
||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||
type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
|
||||
@@ -41,6 +41,11 @@ type InferenceCompute struct {
|
||||
VRAM string
|
||||
}
|
||||
|
||||
type InferenceInfo struct {
|
||||
Computes []InferenceCompute
|
||||
DefaultContextLength int
|
||||
}
|
||||
|
||||
func New(s *store.Store, devMode bool) *Server {
|
||||
p := resolvePath("ollama")
|
||||
return &Server{store: s, bin: p, dev: devMode}
|
||||
@@ -205,6 +210,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cloudDisabled, err := s.store.CloudDisabled()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := commandContext(ctx, s.bin, "serve")
|
||||
cmd.Stdout, cmd.Stderr = s.log, s.log
|
||||
|
||||
@@ -230,6 +240,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
if settings.ContextLength > 0 {
|
||||
env["OLLAMA_CONTEXT_LENGTH"] = strconv.Itoa(settings.ContextLength)
|
||||
}
|
||||
if cloudDisabled {
|
||||
env["OLLAMA_NO_CLOUD"] = "1"
|
||||
} else {
|
||||
env["OLLAMA_NO_CLOUD"] = "0"
|
||||
}
|
||||
cmd.Env = []string{}
|
||||
for k, v := range env {
|
||||
cmd.Env = append(cmd.Env, k+"="+v)
|
||||
@@ -262,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
|
||||
|
||||
// Attempt to retrieve inference compute information from the server
|
||||
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
||||
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
inference := []InferenceCompute{}
|
||||
marker := regexp.MustCompile(`inference compute.*library=`)
|
||||
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||
info := &InferenceInfo{}
|
||||
computeMarker := regexp.MustCompile(`inference compute.*library=`)
|
||||
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
|
||||
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
|
||||
|
||||
q := `inference compute.*%s=["]([^"]*)["]`
|
||||
nq := `inference compute.*%s=(\S+)\s`
|
||||
type regex struct {
|
||||
@@ -330,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
match := marker.FindStringSubmatch(line)
|
||||
if len(match) > 0 {
|
||||
// Check for inference compute lines
|
||||
if computeMarker.MatchString(line) {
|
||||
ic := InferenceCompute{
|
||||
Library: get("library", line),
|
||||
Variant: get("variant", line),
|
||||
@@ -342,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
}
|
||||
|
||||
slog.Info("Matched", "inference compute", ic)
|
||||
inference = append(inference, ic)
|
||||
} else {
|
||||
// Break out on first non matching line after we start matching
|
||||
if len(inference) > 0 {
|
||||
return inference, nil
|
||||
info.Computes = append(info.Computes, ic)
|
||||
continue
|
||||
}
|
||||
// Check for default context length line
|
||||
if defaultCtxMarker.MatchString(line) {
|
||||
match := defaultCtxRegex.FindStringSubmatch(line)
|
||||
if len(match) > 1 {
|
||||
numCtx, err := strconv.Atoi(match[1])
|
||||
if err == nil {
|
||||
info.DefaultContextLength = numCtx
|
||||
slog.Info("Matched default context length", "default_num_ctx", numCtx)
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
// If we've found compute info but hit a non-matching line, return what we have
|
||||
// This handles older server versions that don't log the default context line
|
||||
if len(info.Computes) > 0 {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestServerCmd(t *testing.T) {
|
||||
for _, want := range tt.want {
|
||||
found := false
|
||||
for _, env := range cmd.Env {
|
||||
if strings.Contains(env, want) {
|
||||
if strings.HasPrefix(env, want) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func TestServerCmd(t *testing.T) {
|
||||
|
||||
for _, dont := range tt.dont {
|
||||
for _, env := range cmd.Env {
|
||||
if strings.Contains(env, dont) {
|
||||
if strings.HasPrefix(env, dont) {
|
||||
t.Errorf("unexpected environment variable: %s", env)
|
||||
}
|
||||
}
|
||||
@@ -136,44 +136,119 @@ func TestServerCmd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceComputer(t *testing.T) {
|
||||
func TestServerCmdCloudSettingEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
log string
|
||||
exp []InferenceCompute
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "default cloud enabled",
|
||||
want: "OLLAMA_NO_CLOUD=0",
|
||||
},
|
||||
{
|
||||
name: "env disables cloud",
|
||||
envValue: "1",
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
{
|
||||
name: "config disables cloud",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
{
|
||||
name: "invalid env disables cloud",
|
||||
envValue: "invalid",
|
||||
want: "OLLAMA_NO_CLOUD=1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("USERPROFILE", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "server.json")
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
st := &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}
|
||||
defer st.Close()
|
||||
|
||||
s := &Server{store: st}
|
||||
cmd, err := s.cmd(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("s.cmd() error = %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, env := range cmd.Env {
|
||||
if env == tt.want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected environment variable %q in command env", tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
log string
|
||||
expComputes []InferenceCompute
|
||||
expDefaultCtxLen int
|
||||
}{
|
||||
{
|
||||
name: "metal",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 262144,
|
||||
},
|
||||
{
|
||||
name: "cpu",
|
||||
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
||||
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
||||
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "cpu",
|
||||
Driver: "0.0",
|
||||
VRAM: "31.3 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "cuda1",
|
||||
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
||||
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
Compute: "6.1",
|
||||
@@ -181,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
Name: "NVIDIA GeForce GT 1030",
|
||||
VRAM: "3.9 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 4096,
|
||||
},
|
||||
{
|
||||
name: "frank",
|
||||
@@ -188,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
||||
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{
|
||||
expComputes: []InferenceCompute{
|
||||
{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
@@ -207,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
VRAM: "16.0 GiB",
|
||||
},
|
||||
},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "missing_default_context",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 0, // No default context line, should return 0
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -219,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
ics, err := GetInferenceComputer(ctx)
|
||||
info, err := GetInferenceInfo(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf(" failed to get inference compute: %v", err)
|
||||
t.Fatalf("failed to get inference info: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(ics, tt.exp) {
|
||||
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
|
||||
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
|
||||
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
|
||||
}
|
||||
if info.DefaultContextLength != tt.expDefaultCtxLen {
|
||||
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceComputerTimeout(t *testing.T) {
|
||||
func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
tmpDir := t.TempDir()
|
||||
@@ -239,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
||||
}
|
||||
_, err = GetInferenceComputer(ctx)
|
||||
_, err = GetInferenceInfo(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout")
|
||||
}
|
||||
|
||||
128
app/store/cloud_config.go
Normal file
128
app/store/cloud_config.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const serverConfigFilename = "server.json"
|
||||
|
||||
type serverConfig struct {
|
||||
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
|
||||
}
|
||||
|
||||
// CloudDisabled returns whether cloud features should be disabled.
|
||||
// The source of truth is: OLLAMA_NO_CLOUD OR ~/.ollama/server.json:disable_ollama_cloud.
|
||||
func (s *Store) CloudDisabled() (bool, error) {
|
||||
disabled, _, err := s.CloudStatus()
|
||||
return disabled, err
|
||||
}
|
||||
|
||||
// CloudStatus returns whether cloud is disabled and the source of that decision.
|
||||
// Source is one of: "none", "env", "config", "both".
|
||||
func (s *Store) CloudStatus() (bool, string, error) {
|
||||
if err := s.ensureDB(); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
configDisabled, err := readServerConfigCloudDisabled()
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
envDisabled := envconfig.NoCloudEnv()
|
||||
return envDisabled || configDisabled, cloudStatusSource(envDisabled, configDisabled), nil
|
||||
}
|
||||
|
||||
// SetCloudEnabled writes the cloud setting to ~/.ollama/server.json.
|
||||
func (s *Store) SetCloudEnabled(enabled bool) error {
|
||||
if err := s.ensureDB(); err != nil {
|
||||
return err
|
||||
}
|
||||
return setCloudEnabled(enabled)
|
||||
}
|
||||
|
||||
func setCloudEnabled(enabled bool) error {
|
||||
configPath, err := serverConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create server config directory: %w", err)
|
||||
}
|
||||
|
||||
configMap := map[string]any{}
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// If the existing file is invalid JSON, overwrite with a fresh object.
|
||||
configMap = map[string]any{}
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("read server config: %w", err)
|
||||
}
|
||||
|
||||
configMap["disable_ollama_cloud"] = !enabled
|
||||
|
||||
data, err := json.MarshalIndent(configMap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal server config: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
return fmt.Errorf("write server config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readServerConfigCloudDisabled() (bool, error) {
|
||||
configPath, err := serverConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("read server config: %w", err)
|
||||
}
|
||||
|
||||
var cfg serverConfig
|
||||
// Invalid or unexpected JSON should not block startup; treat as default.
|
||||
if json.Unmarshal(data, &cfg) == nil {
|
||||
return cfg.DisableOllamaCloud, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func serverConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ollama", serverConfigFilename), nil
|
||||
}
|
||||
|
||||
func cloudStatusSource(envDisabled bool, configDisabled bool) string {
|
||||
switch {
|
||||
case envDisabled && configDisabled:
|
||||
return "both"
|
||||
case envDisabled:
|
||||
return "env"
|
||||
case configDisabled:
|
||||
return "config"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
130
app/store/cloud_config_test.go
Normal file
130
app/store/cloud_config_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudDisabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
wantDisabled bool
|
||||
wantSource string
|
||||
}{
|
||||
{
|
||||
name: "default enabled",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "env disables cloud",
|
||||
envValue: "1",
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "config disables cloud",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "config",
|
||||
},
|
||||
{
|
||||
name: "env and config",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": false}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "invalid config is ignored",
|
||||
configContent: `{bad`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||
defer s.Close()
|
||||
|
||||
disabled, err := s.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error = %v", err)
|
||||
}
|
||||
if disabled != tt.wantDisabled {
|
||||
t.Fatalf("CloudDisabled() = %v, want %v", disabled, tt.wantDisabled)
|
||||
}
|
||||
|
||||
statusDisabled, source, err := s.CloudStatus()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudStatus() error = %v", err)
|
||||
}
|
||||
if statusDisabled != tt.wantDisabled {
|
||||
t.Fatalf("CloudStatus() disabled = %v, want %v", statusDisabled, tt.wantDisabled)
|
||||
}
|
||||
if source != tt.wantSource {
|
||||
t.Fatalf("CloudStatus() source = %v, want %v", source, tt.wantSource)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCloudEnabled(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
|
||||
configDir := filepath.Join(tmpHome, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir config dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||
if err := os.WriteFile(configPath, []byte(`{"another_key":"value","disable_ollama_cloud":true}`), 0o644); err != nil {
|
||||
t.Fatalf("seed config: %v", err)
|
||||
}
|
||||
|
||||
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||
defer s.Close()
|
||||
|
||||
if err := s.SetCloudEnabled(true); err != nil {
|
||||
t.Fatalf("SetCloudEnabled(true) error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
if got["disable_ollama_cloud"] != false {
|
||||
t.Fatalf("disable_ollama_cloud = %v, want false", got["disable_ollama_cloud"])
|
||||
}
|
||||
if got["another_key"] != "value" {
|
||||
t.Fatalf("another_key = %v, want value", got["another_key"])
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 12
|
||||
const currentSchemaVersion = 14
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -73,7 +73,7 @@ func (db *database) init() error {
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
@@ -84,6 +84,7 @@ func (db *database) init() error {
|
||||
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
@@ -244,6 +245,18 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||
}
|
||||
version = 12
|
||||
case 12:
|
||||
// add cloud_setting_migrated column to settings table
|
||||
if err := db.migrateV12ToV13(); err != nil {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
case 13:
|
||||
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
|
||||
if err := db.migrateV13ToV14(); err != nil {
|
||||
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||
}
|
||||
version = 14
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -452,6 +465,37 @@ func (db *database) migrateV11ToV12() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV12ToV13 adds cloud_setting_migrated to settings.
|
||||
func (db *database) migrateV12ToV13() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add cloud_setting_migrated column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV13ToV14 changes the default context_length from 4096 to 0.
|
||||
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
|
||||
func (db *database) migrateV13ToV14() error {
|
||||
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update context_length default: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -1108,9 +1152,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1121,14 +1165,40 @@ func (db *database) getSettings() (Settings, error) {
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *database) isCloudSettingMigrated() (bool, error) {
|
||||
var migrated bool
|
||||
err := db.conn.QueryRow("SELECT cloud_setting_migrated FROM settings").Scan(&migrated)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get cloud setting migration status: %w", err)
|
||||
}
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
func (db *database) setCloudSettingMigrated(migrated bool) error {
|
||||
_, err := db.conn.Exec("UPDATE settings SET cloud_setting_migrated = ?", migrated)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set cloud setting migration status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *database) getAirplaneMode() (bool, error) {
|
||||
var airplaneMode bool
|
||||
err := db.conn.QueryRow("SELECT airplane_mode FROM settings").Scan(&airplaneMode)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get airplane_mode: %w", err)
|
||||
}
|
||||
return airplaneMode, nil
|
||||
}
|
||||
|
||||
func (db *database) getWindowSize() (int, int, error) {
|
||||
var width, height int
|
||||
err := db.conn.QueryRow("SELECT window_width, window_height FROM settings").Scan(&width, &height)
|
||||
|
||||
@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
db, err := newDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed v13 settings row: %v", err)
|
||||
}
|
||||
|
||||
if err := db.migrate(); err != nil {
|
||||
t.Fatalf("migration from v13 to v14 failed: %v", err)
|
||||
}
|
||||
|
||||
var contextLength int
|
||||
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
|
||||
t.Fatalf("failed to read context_length: %v", err)
|
||||
}
|
||||
|
||||
if contextLength != 0 {
|
||||
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
|
||||
}
|
||||
|
||||
version, err := db.getSchemaVersion()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get schema version: %v", err)
|
||||
}
|
||||
if version != currentSchemaVersion {
|
||||
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDeletionWithCascade(t *testing.T) {
|
||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -127,6 +127,65 @@ func TestNoConfigToMigrate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudMigrationFromAirplaneMode(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
setTestHome(t, tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
dbPath := filepath.Join(tmpHome, "db.sqlite")
|
||||
db, err := newDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.conn.Exec("UPDATE settings SET airplane_mode = 1, cloud_setting_migrated = 0"); err != nil {
|
||||
db.Close()
|
||||
t.Fatalf("failed to seed airplane migration state: %v", err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
s := Store{DBPath: dbPath}
|
||||
defer s.Close()
|
||||
|
||||
// Trigger DB initialization + one-time cloud migration.
|
||||
if _, err := s.ID(); err != nil {
|
||||
t.Fatalf("failed to initialize store: %v", err)
|
||||
}
|
||||
|
||||
disabled, err := s.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error: %v", err)
|
||||
}
|
||||
if !disabled {
|
||||
t.Fatal("expected cloud to be disabled after migrating airplane_mode=true")
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpHome, ".ollama", serverConfigFilename)
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read migrated server config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("failed to parse migrated server config: %v", err)
|
||||
}
|
||||
if cfg["disable_ollama_cloud"] != true {
|
||||
t.Fatalf("disable_ollama_cloud = %v, want true", cfg["disable_ollama_cloud"])
|
||||
}
|
||||
|
||||
var airplaneMode, migrated bool
|
||||
if err := s.db.conn.QueryRow("SELECT airplane_mode, cloud_setting_migrated FROM settings").Scan(&airplaneMode, &migrated); err != nil {
|
||||
t.Fatalf("failed to read migration flags from DB: %v", err)
|
||||
}
|
||||
if !airplaneMode {
|
||||
t.Fatal("expected legacy airplane_mode value to remain unchanged")
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected cloud_setting_migrated to be true")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
v1Schema = `
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
|
||||
@@ -149,9 +149,6 @@ type Settings struct {
|
||||
// ContextLength specifies the context length for the ollama server (using OLLAMA_CONTEXT_LENGTH)
|
||||
ContextLength int
|
||||
|
||||
// AirplaneMode when true, turns off Ollama Turbo features and only uses local models
|
||||
AirplaneMode bool
|
||||
|
||||
// TurboEnabled indicates if Ollama Turbo features are enabled
|
||||
TurboEnabled bool
|
||||
|
||||
@@ -259,6 +256,40 @@ func (s *Store) ensureDB() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Run one-time migration from legacy airplane_mode behavior.
|
||||
if err := s.migrateCloudSetting(database); err != nil {
|
||||
return fmt.Errorf("migrate cloud setting: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateCloudSetting migrates legacy airplane_mode into server.json exactly once.
|
||||
// After this, cloud state is sourced from server.json OR OLLAMA_NO_CLOUD.
|
||||
func (s *Store) migrateCloudSetting(database *database) error {
|
||||
migrated, err := database.isCloudSettingMigrated()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if migrated {
|
||||
return nil
|
||||
}
|
||||
|
||||
airplaneMode, err := database.getAirplaneMode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if airplaneMode {
|
||||
if err := setCloudEnabled(false); err != nil {
|
||||
return fmt.Errorf("migrate airplane_mode to cloud disabled: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := database.setCloudSettingMigrated(true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
11
app/store/test_home_test.go
Normal file
11
app/store/test_home_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package store
|
||||
|
||||
import "testing"
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
}
|
||||
2
app/store/testdata/schema.sql
vendored
2
app/store/testdata/schema.sql
vendored
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
|
||||
35
app/tools/cloud_policy.go
Normal file
35
app/tools/cloud_policy.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
// ensureCloudEnabledForTool checks cloud policy from the connected Ollama server.
|
||||
// If policy cannot be determined, this fails closed and blocks the operation.
|
||||
func ensureCloudEnabledForTool(ctx context.Context, operation string) error {
|
||||
// Reuse shared message formatting; policy evaluation is still done via
|
||||
// the connected server's /api/status endpoint below.
|
||||
disabledMessage := internalcloud.DisabledError(operation)
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||
}
|
||||
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||
}
|
||||
|
||||
if status.Cloud.Disabled {
|
||||
return errors.New(disabledMessage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
73
app/tools/cloud_policy_test.go
Normal file
73
app/tools/cloud_policy_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnsureCloudEnabledForTool(t *testing.T) {
|
||||
const op = "web search is unavailable"
|
||||
const disabledPrefix = "ollama cloud is disabled: web search is unavailable"
|
||||
|
||||
t.Run("enabled allows tool execution", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/status" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cloud":{"disabled":false,"source":"none"}}`))
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
if err := ensureCloudEnabledForTool(context.Background(), op); err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled blocks tool execution", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/status" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cloud":{"disabled":true,"source":"config"}}`))
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if got := err.Error(); got != disabledPrefix {
|
||||
t.Fatalf("unexpected error: %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status unavailable fails closed", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||
|
||||
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if got := err.Error(); !strings.Contains(got, disabledPrefix) {
|
||||
t.Fatalf("expected disabled prefix, got %q", got)
|
||||
}
|
||||
if got := err.Error(); !strings.Contains(got, "unable to verify server cloud policy") {
|
||||
t.Fatalf("expected verification failure detail, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,6 +77,10 @@ func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, strin
|
||||
}
|
||||
|
||||
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
|
||||
if err := ensureCloudEnabledForTool(ctx, "web fetch is unavailable"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody := FetchRequest{URL: targetURL}
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -93,6 +93,10 @@ func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, stri
|
||||
}
|
||||
|
||||
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
|
||||
if err := ensureCloudEnabledForTool(ctx, "web search is unavailable"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
|
||||
@@ -289,10 +289,12 @@ export class InferenceCompute {
|
||||
}
|
||||
export class InferenceComputeResponse {
|
||||
inferenceComputes: InferenceCompute[];
|
||||
defaultContextLength: number;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
||||
this.defaultContextLength = source["defaultContextLength"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
@@ -406,7 +408,6 @@ export class Settings {
|
||||
Tools: boolean;
|
||||
WorkingDir: string;
|
||||
ContextLength: number;
|
||||
AirplaneMode: boolean;
|
||||
TurboEnabled: boolean;
|
||||
WebSearchEnabled: boolean;
|
||||
ThinkEnabled: boolean;
|
||||
@@ -424,7 +425,6 @@ export class Settings {
|
||||
this.Tools = source["Tools"];
|
||||
this.WorkingDir = source["WorkingDir"];
|
||||
this.ContextLength = source["ContextLength"];
|
||||
this.AirplaneMode = source["AirplaneMode"];
|
||||
this.TurboEnabled = source["TurboEnabled"];
|
||||
this.WebSearchEnabled = source["WebSearchEnabled"];
|
||||
this.ThinkEnabled = source["ThinkEnabled"];
|
||||
|
||||
@@ -4,7 +4,6 @@ import {
|
||||
ChatEvent,
|
||||
DownloadEvent,
|
||||
ErrorEvent,
|
||||
InferenceCompute,
|
||||
InferenceComputeResponse,
|
||||
ModelCapabilitiesResponse,
|
||||
Model,
|
||||
@@ -27,6 +26,12 @@ declare module "@/gotypes" {
|
||||
Model.prototype.isCloud = function (): boolean {
|
||||
return this.model.endsWith("cloud");
|
||||
};
|
||||
|
||||
export type CloudStatusSource = "env" | "config" | "both" | "none";
|
||||
export interface CloudStatusResponse {
|
||||
disabled: boolean;
|
||||
source: CloudStatusSource;
|
||||
}
|
||||
// Helper function to convert Uint8Array to base64
|
||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||
@@ -285,6 +290,28 @@ export async function updateSettings(settings: Settings): Promise<{
|
||||
};
|
||||
}
|
||||
|
||||
export async function updateCloudSetting(
|
||||
enabled: boolean,
|
||||
): Promise<CloudStatusResponse> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/cloud`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ enabled }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(error || "Failed to update cloud setting");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return {
|
||||
disabled: Boolean(data.disabled),
|
||||
source: (data.source as CloudStatusSource) || "none",
|
||||
};
|
||||
}
|
||||
|
||||
export async function renameChat(chatId: string, title: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/rename`, {
|
||||
method: "PUT",
|
||||
@@ -379,7 +406,7 @@ export async function* pullModel(
|
||||
}
|
||||
}
|
||||
|
||||
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
@@ -388,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const inferenceComputeResponse = new InferenceComputeResponse(data);
|
||||
return inferenceComputeResponse.inferenceComputes || [];
|
||||
return new InferenceComputeResponse(data);
|
||||
}
|
||||
|
||||
export async function fetchHealth(): Promise<boolean> {
|
||||
@@ -414,3 +440,16 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCloudStatus(): Promise<CloudStatusResponse | null> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/cloud`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch cloud status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return {
|
||||
disabled: Boolean(data.disabled),
|
||||
source: (data.source as CloudStatusSource) || "none",
|
||||
};
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import { useUser } from "@/hooks/useUser";
|
||||
import { DisplayLogin } from "@/components/DisplayLogin";
|
||||
import { ErrorEvent, Message } from "@/gotypes";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { ThinkButton } from "./ThinkButton";
|
||||
import { ErrorMessage } from "./ErrorMessage";
|
||||
import { processFiles } from "@/utils/fileValidation";
|
||||
@@ -141,12 +142,12 @@ function ChatForm({
|
||||
const {
|
||||
settings: {
|
||||
webSearchEnabled,
|
||||
airplaneMode,
|
||||
thinkEnabled,
|
||||
thinkLevel: settingsThinkLevel,
|
||||
},
|
||||
setSettings,
|
||||
} = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
|
||||
// current supported models for web search
|
||||
const modelLower = selectedModel?.model.toLowerCase() || "";
|
||||
@@ -180,6 +181,12 @@ function ChatForm({
|
||||
setSettings,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (cloudDisabled && webSearchEnabled) {
|
||||
setSettings({ WebSearchEnabled: false });
|
||||
}
|
||||
}, [cloudDisabled, webSearchEnabled, setSettings]);
|
||||
|
||||
const removeFile = (index: number) => {
|
||||
setMessage((prev) => ({
|
||||
...prev,
|
||||
@@ -234,19 +241,19 @@ function ChatForm({
|
||||
|
||||
// Determine if login banner should be shown
|
||||
const shouldShowLoginBanner =
|
||||
!cloudDisabled &&
|
||||
!isLoadingUser &&
|
||||
!isAuthenticated &&
|
||||
((webSearchEnabled && supportsWebSearch) ||
|
||||
(selectedModel?.isCloud() && !airplaneMode));
|
||||
((webSearchEnabled && supportsWebSearch) || selectedModel?.isCloud());
|
||||
|
||||
// Determine which feature to highlight in the banner
|
||||
const getActiveFeatureForBanner = () => {
|
||||
if (cloudDisabled) return null;
|
||||
if (!isAuthenticated) {
|
||||
if (loginPromptFeature) return loginPromptFeature;
|
||||
if (webSearchEnabled && selectedModel?.isCloud() && !airplaneMode)
|
||||
return "webSearch";
|
||||
if (webSearchEnabled && selectedModel?.isCloud()) return "webSearch";
|
||||
if (webSearchEnabled) return "webSearch";
|
||||
if (selectedModel?.isCloud() && !airplaneMode) return "turbo";
|
||||
if (selectedModel?.isCloud()) return "turbo";
|
||||
}
|
||||
return null;
|
||||
};
|
||||
@@ -269,11 +276,12 @@ function ChatForm({
|
||||
useEffect(() => {
|
||||
if (
|
||||
isAuthenticated ||
|
||||
(!webSearchEnabled && !!selectedModel?.isCloud() && !airplaneMode)
|
||||
cloudDisabled ||
|
||||
(!webSearchEnabled && !!selectedModel?.isCloud())
|
||||
) {
|
||||
setLoginPromptFeature(null);
|
||||
}
|
||||
}, [isAuthenticated, webSearchEnabled, selectedModel, airplaneMode]);
|
||||
}, [isAuthenticated, webSearchEnabled, selectedModel, cloudDisabled]);
|
||||
|
||||
// When entering edit mode, populate the composition with existing data
|
||||
useEffect(() => {
|
||||
@@ -465,6 +473,10 @@ function ChatForm({
|
||||
const handleSubmit = async () => {
|
||||
if (!message.content.trim() || isStreaming || isDownloading) return;
|
||||
|
||||
if (cloudDisabled && selectedModel?.isCloud()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if cloud mode is enabled but user is not authenticated
|
||||
if (shouldShowLoginBanner) {
|
||||
return;
|
||||
@@ -478,7 +490,8 @@ function ChatForm({
|
||||
}),
|
||||
);
|
||||
|
||||
const useWebSearch = supportsWebSearch && webSearchEnabled && !airplaneMode;
|
||||
const useWebSearch =
|
||||
supportsWebSearch && webSearchEnabled && !cloudDisabled;
|
||||
const useThink = modelSupportsThinkingLevels
|
||||
? thinkLevel
|
||||
: supportsThinkToggling
|
||||
@@ -899,7 +912,7 @@ function ChatForm({
|
||||
)}
|
||||
<WebSearchButton
|
||||
ref={webSearchButtonRef}
|
||||
isVisible={supportsWebSearch && airplaneMode === false}
|
||||
isVisible={supportsWebSearch && cloudDisabled === false}
|
||||
isActive={webSearchEnabled}
|
||||
onToggle={() => {
|
||||
if (!webSearchEnabled && !isAuthenticated) {
|
||||
@@ -940,6 +953,7 @@ function ChatForm({
|
||||
!isDownloading &&
|
||||
(!message.content.trim() ||
|
||||
shouldShowLoginBanner ||
|
||||
(cloudDisabled && selectedModel?.isCloud()) ||
|
||||
message.fileErrors.length > 0)
|
||||
}
|
||||
className={`flex items-center justify-center h-9 w-9 rounded-full disabled:cursor-default cursor-pointer bg-black text-white dark:bg-white dark:text-black disabled:opacity-10 focus:outline-none focus:ring-2 focus:ring-blue-500`}
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
} from "react";
|
||||
import { Model } from "@/gotypes";
|
||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { getModelUpstreamInfo } from "@/api";
|
||||
import { ArrowDownTrayIcon } from "@heroicons/react/24/outline";
|
||||
@@ -34,7 +34,7 @@ export const ModelPicker = forwardRef<
|
||||
chatId,
|
||||
searchQuery,
|
||||
);
|
||||
const { settings } = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||
const queryClient = useQueryClient();
|
||||
@@ -219,7 +219,7 @@ export const ModelPicker = forwardRef<
|
||||
models={models}
|
||||
selectedModel={selectedModel}
|
||||
onModelSelect={handleModelSelect}
|
||||
airplaneMode={settings.airplaneMode}
|
||||
cloudDisabled={cloudDisabled}
|
||||
isOpen={isOpen}
|
||||
/>
|
||||
</div>
|
||||
@@ -233,13 +233,13 @@ export const ModelList = forwardRef(function ModelList(
|
||||
models,
|
||||
selectedModel,
|
||||
onModelSelect,
|
||||
airplaneMode,
|
||||
cloudDisabled,
|
||||
isOpen,
|
||||
}: {
|
||||
models: Model[];
|
||||
selectedModel: Model | null;
|
||||
onModelSelect: (model: Model) => void;
|
||||
airplaneMode: boolean;
|
||||
cloudDisabled: boolean;
|
||||
isOpen: boolean;
|
||||
},
|
||||
ref,
|
||||
@@ -348,7 +348,7 @@ export const ModelList = forwardRef(function ModelList(
|
||||
</svg>
|
||||
)}
|
||||
{model.digest === undefined &&
|
||||
(airplaneMode || !model.isCloud()) && (
|
||||
(cloudDisabled || !model.isCloud()) && (
|
||||
<ArrowDownTrayIcon
|
||||
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"
|
||||
strokeWidth={1.75}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
FolderIcon,
|
||||
BoltIcon,
|
||||
WrenchIcon,
|
||||
CloudIcon,
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
@@ -18,8 +19,15 @@ import {
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useUser } from "@/hooks/useUser";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { getSettings, updateSettings } from "@/api";
|
||||
import {
|
||||
getSettings,
|
||||
type CloudStatusResponse,
|
||||
updateCloudSetting,
|
||||
updateSettings,
|
||||
getInferenceCompute,
|
||||
} from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
return (
|
||||
@@ -53,6 +61,11 @@ export default function Settings() {
|
||||
const [connectionError, setConnectionError] = useState<string | null>(null);
|
||||
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
|
||||
const navigate = useNavigate();
|
||||
const {
|
||||
cloudDisabled,
|
||||
cloudStatus,
|
||||
isLoading: cloudStatusLoading,
|
||||
} = useCloudStatus();
|
||||
|
||||
const {
|
||||
data: settingsData,
|
||||
@@ -65,6 +78,13 @@ export default function Settings() {
|
||||
|
||||
const settings = settingsData?.settings || null;
|
||||
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
queryFn: getInferenceCompute,
|
||||
});
|
||||
|
||||
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
|
||||
|
||||
const updateSettingsMutation = useMutation({
|
||||
mutationFn: updateSettings,
|
||||
onSuccess: () => {
|
||||
@@ -74,6 +94,50 @@ export default function Settings() {
|
||||
},
|
||||
});
|
||||
|
||||
const updateCloudMutation = useMutation({
|
||||
mutationFn: (enabled: boolean) => updateCloudSetting(enabled),
|
||||
onMutate: async (enabled: boolean) => {
|
||||
await queryClient.cancelQueries({ queryKey: ["cloudStatus"] });
|
||||
|
||||
const previous = queryClient.getQueryData<CloudStatusResponse | null>([
|
||||
"cloudStatus",
|
||||
]);
|
||||
const envForcesDisabled =
|
||||
previous?.source === "env" || previous?.source === "both";
|
||||
|
||||
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||
["cloudStatus"],
|
||||
previous
|
||||
? {
|
||||
...previous,
|
||||
disabled: !enabled || envForcesDisabled,
|
||||
}
|
||||
: {
|
||||
disabled: !enabled,
|
||||
source: "config",
|
||||
},
|
||||
);
|
||||
|
||||
return { previous };
|
||||
},
|
||||
onError: (_error, _enabled, context) => {
|
||||
if (context?.previous !== undefined) {
|
||||
queryClient.setQueryData(["cloudStatus"], context.previous);
|
||||
}
|
||||
},
|
||||
onSuccess: (status) => {
|
||||
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||
["cloudStatus"],
|
||||
status,
|
||||
);
|
||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["cloudStatus"] });
|
||||
|
||||
setShowSaved(true);
|
||||
setTimeout(() => setShowSaved(false), 1500);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
refetchUser();
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
@@ -148,13 +212,17 @@ export default function Settings() {
|
||||
Models: "",
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 4096,
|
||||
AirplaneMode: false,
|
||||
ContextLength: 0,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
};
|
||||
|
||||
const cloudOverriddenByEnv =
|
||||
cloudStatus?.source === "env" || cloudStatus?.source === "both";
|
||||
const cloudToggleDisabled =
|
||||
cloudStatusLoading || updateCloudMutation.isPending || cloudOverriddenByEnv;
|
||||
|
||||
const handleConnectOllamaAccount = async () => {
|
||||
setConnectionError(null);
|
||||
|
||||
@@ -237,7 +305,7 @@ export default function Settings() {
|
||||
<div className="space-y-4 max-w-2xl mx-auto">
|
||||
{/* Connect Ollama Account */}
|
||||
<div className="overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="p-4 border-b border-neutral-200 dark:border-neutral-800">
|
||||
<div className="p-4">
|
||||
<Field>
|
||||
{isLoading ? (
|
||||
// Loading skeleton, this will only happen if the app started recently
|
||||
@@ -344,6 +412,34 @@ export default function Settings() {
|
||||
{/* Local Configuration */}
|
||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="space-y-4 p-4">
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<CloudIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div>
|
||||
<Label>Cloud</Label>
|
||||
<Description>
|
||||
{cloudOverriddenByEnv
|
||||
? "The OLLAMA_NO_CLOUD environment variable is currently forcing cloud off."
|
||||
: "Enable cloud models and web search."}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={!cloudDisabled}
|
||||
disabled={cloudToggleDisabled}
|
||||
onChange={(checked) => {
|
||||
if (cloudOverriddenByEnv) {
|
||||
return;
|
||||
}
|
||||
updateCloudMutation.mutate(checked);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
@@ -419,13 +515,11 @@ export default function Settings() {
|
||||
</Description>
|
||||
<div className="mt-3">
|
||||
<Slider
|
||||
value={(() => {
|
||||
// Otherwise use the settings value
|
||||
return settings.ContextLength || 4096;
|
||||
})()}
|
||||
value={settings.ContextLength || defaultContextLength || 0}
|
||||
onChange={(value) => {
|
||||
handleChange("ContextLength", value);
|
||||
}}
|
||||
disabled={!defaultContextLength}
|
||||
options={[
|
||||
{ value: 4096, label: "4k" },
|
||||
{ value: 8192, label: "8k" },
|
||||
@@ -440,35 +534,6 @@ export default function Settings() {
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
{/* Airplane Mode */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<svg
|
||||
className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100"
|
||||
viewBox="0 0 21.5508 17.9033"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M21.5508 8.94727C21.542 7.91895 20.1445 7.17188 18.4658 7.17188L14.9238 7.17188C14.4316 7.17188 14.2471 7.09277 13.957 6.75879L8.05078 0.316406C7.86621 0.105469 7.6377 0 7.37402 0L6.35449 0C6.12598 0 5.99414 0.202148 6.1084 0.448242L9.14941 7.17188L4.68457 7.68164L3.09375 4.76367C2.97949 4.54395 2.78613 4.44727 2.49609 4.44727L2.11816 4.44727C1.88965 4.44727 1.74023 4.59668 1.74023 4.8252L1.74023 13.0693C1.74023 13.2979 1.88965 13.4385 2.11816 13.4385L2.49609 13.4385C2.78613 13.4385 2.97949 13.3418 3.09375 13.1309L4.68457 10.2129L9.14941 10.7227L6.1084 17.4463C5.99414 17.6836 6.12598 17.8945 6.35449 17.8945L7.37402 17.8945C7.6377 17.8945 7.86621 17.7803 8.05078 17.5781L13.957 11.127C14.2471 10.8018 14.4316 10.7227 14.9238 10.7227L18.4658 10.7227C20.1445 10.7227 21.542 9.9668 21.5508 8.94727Z" />
|
||||
</svg>
|
||||
<div>
|
||||
<Label>Airplane mode</Label>
|
||||
<Description>
|
||||
Airplane mode keeps data local, disabling cloud models
|
||||
and web search.
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AirplaneMode}
|
||||
onChange={(checked) =>
|
||||
handleChange("AirplaneMode", checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ export interface SliderProps {
|
||||
value?: number;
|
||||
onChange?: (value: number) => void;
|
||||
className?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
({ label, options, value = 0, onChange }, ref) => {
|
||||
({ label, options, value = 0, onChange, disabled = false }, ref) => {
|
||||
const [selectedValue, setSelectedValue] = React.useState(value);
|
||||
const [isDragging, setIsDragging] = React.useState(false);
|
||||
const containerRef = React.useRef<HTMLDivElement>(null);
|
||||
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}, [value]);
|
||||
|
||||
const handleClick = (optionValue: number) => {
|
||||
if (disabled) return;
|
||||
setSelectedValue(optionValue);
|
||||
onChange?.(optionValue);
|
||||
};
|
||||
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
};
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent) => {
|
||||
if (disabled) return;
|
||||
setIsDragging(true);
|
||||
e.preventDefault();
|
||||
};
|
||||
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2" ref={ref}>
|
||||
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
|
||||
{label && <label className="text-sm font-medium">{label}</label>}
|
||||
<div className="relative">
|
||||
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
||||
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
<button
|
||||
onClick={() => handleClick(option.value)}
|
||||
onMouseDown={handleMouseDown}
|
||||
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
|
||||
disabled={disabled}
|
||||
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
|
||||
>
|
||||
<div className="relative w-5 h-5 flex items-center justify-center">
|
||||
{selectedValue === option.value && (
|
||||
{selectedValue === option.value && !disabled && (
|
||||
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useSelectedModel } from "./useSelectedModel";
|
||||
import { createQueryBatcher } from "./useQueryBatcher";
|
||||
import { useRefetchModels } from "./useModels";
|
||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { getModelCapabilities } from "@/api";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export const useChats = () => {
|
||||
return useQuery({
|
||||
@@ -116,11 +116,9 @@ export const useIsModelStale = (modelName: string) => {
|
||||
export const useShouldShowStaleDisplay = (model: Model | null) => {
|
||||
const isStale = useIsModelStale(model?.model || "");
|
||||
const { data: dismissedModels } = useDismissedStaleModels();
|
||||
const {
|
||||
settings: { airplaneMode },
|
||||
} = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
|
||||
if (model?.isCloud() && !airplaneMode) {
|
||||
if (model?.isCloud() && !cloudDisabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { getCloudStatus, type CloudStatusResponse } from "@/api";
|
||||
|
||||
export function useCloudStatus() {
|
||||
const cloudQuery = useQuery<CloudStatusResponse | null>({
|
||||
queryKey: ["cloudStatus"],
|
||||
queryFn: getCloudStatus,
|
||||
retry: false,
|
||||
staleTime: 60 * 1000,
|
||||
});
|
||||
|
||||
return {
|
||||
cloudStatus: cloudQuery.data,
|
||||
cloudDisabled: cloudQuery.data?.disabled ?? false,
|
||||
isKnown: cloudQuery.data !== null && cloudQuery.data !== undefined,
|
||||
isLoading: cloudQuery.isLoading,
|
||||
isError: cloudQuery.isError,
|
||||
error: cloudQuery.error,
|
||||
};
|
||||
}
|
||||
@@ -2,11 +2,11 @@ import { useQuery } from "@tanstack/react-query";
|
||||
import { Model } from "@/gotypes";
|
||||
import { getModels } from "@/api";
|
||||
import { mergeModels } from "@/utils/mergeModels";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { useMemo } from "react";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export function useModels(searchQuery = "") {
|
||||
const { settings } = useSettings();
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const localQuery = useQuery<Model[], Error>({
|
||||
queryKey: ["models", searchQuery],
|
||||
queryFn: () => getModels(searchQuery),
|
||||
@@ -20,7 +20,7 @@ export function useModels(searchQuery = "") {
|
||||
});
|
||||
|
||||
const allModels = useMemo(() => {
|
||||
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
|
||||
const models = mergeModels(localQuery.data || [], cloudDisabled);
|
||||
|
||||
if (searchQuery && searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase().trim();
|
||||
@@ -40,7 +40,7 @@ export function useModels(searchQuery = "") {
|
||||
}
|
||||
|
||||
return models;
|
||||
}, [localQuery.data, searchQuery, settings.airplaneMode]);
|
||||
}, [localQuery.data, searchQuery, cloudDisabled]);
|
||||
|
||||
return {
|
||||
...localQuery,
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Model } from "@/gotypes";
|
||||
import { FEATURED_MODELS } from "@/utils/mergeModels";
|
||||
import { getTotalVRAM } from "@/utils/vram.ts";
|
||||
import { getInferenceCompute } from "@/api";
|
||||
import { useCloudStatus } from "./useCloudStatus";
|
||||
|
||||
export function recommendDefaultModel(totalVRAM: number): string {
|
||||
const vram = Math.max(0, Number(totalVRAM) || 0);
|
||||
@@ -22,16 +23,19 @@ export function recommendDefaultModel(totalVRAM: number): string {
|
||||
export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
const { settings, setSettings } = useSettings();
|
||||
const { data: models = [], isLoading } = useModels(searchQuery || "");
|
||||
const { cloudDisabled } = useCloudStatus();
|
||||
const { data: chatData, isLoading: isChatLoading } = useChat(
|
||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||
);
|
||||
|
||||
const { data: inferenceComputes = [] } = useQuery({
|
||||
queryKey: ["inference-compute"],
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
queryFn: getInferenceCompute,
|
||||
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
||||
});
|
||||
|
||||
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
|
||||
|
||||
const totalVRAM = useMemo(
|
||||
() => getTotalVRAM(inferenceComputes),
|
||||
[inferenceComputes],
|
||||
@@ -46,12 +50,11 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
const restoredChatRef = useRef<string | null>(null);
|
||||
|
||||
const selectedModel: Model | null = useMemo(() => {
|
||||
// if airplane mode is on and selected model ends with cloud,
|
||||
// switch to recommended default model
|
||||
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
|
||||
// If cloud is disabled and selected model ends with cloud, switch to a local default.
|
||||
if (cloudDisabled && settings.selectedModel?.endsWith("cloud")) {
|
||||
return (
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud) ||
|
||||
models.find((m) => !m.isCloud()) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0] ||
|
||||
null
|
||||
@@ -68,7 +71,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
"qwen3-coder:480b",
|
||||
];
|
||||
const shouldMigrate =
|
||||
!settings.airplaneMode &&
|
||||
!cloudDisabled &&
|
||||
settings.turboEnabled &&
|
||||
baseModelsToMigrate.includes(settings.selectedModel);
|
||||
|
||||
@@ -96,13 +99,18 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
})) ||
|
||||
null
|
||||
);
|
||||
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
|
||||
}, [
|
||||
models,
|
||||
settings.selectedModel,
|
||||
cloudDisabled,
|
||||
recommendedModel,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedModel) return;
|
||||
|
||||
if (
|
||||
settings.airplaneMode &&
|
||||
cloudDisabled &&
|
||||
settings.selectedModel?.endsWith("cloud") &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
@@ -110,13 +118,17 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
}
|
||||
|
||||
if (
|
||||
!settings.airplaneMode &&
|
||||
!cloudDisabled &&
|
||||
settings.turboEnabled &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
|
||||
}
|
||||
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
|
||||
}, [
|
||||
selectedModel,
|
||||
cloudDisabled,
|
||||
settings.selectedModel,
|
||||
]);
|
||||
|
||||
// Set model from chat history when chat data loads
|
||||
useEffect(() => {
|
||||
@@ -169,7 +181,9 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
|
||||
const defaultModel =
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud()) ||
|
||||
(cloudDisabled
|
||||
? models.find((m) => !m.isCloud())
|
||||
: models.find((m) => m.isCloud())) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0];
|
||||
|
||||
@@ -181,6 +195,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
inferenceComputes.length,
|
||||
models.length,
|
||||
settings.selectedModel,
|
||||
cloudDisabled,
|
||||
]);
|
||||
|
||||
// Add the selected model to the models list if it's not already there
|
||||
|
||||
@@ -9,7 +9,6 @@ interface SettingsState {
|
||||
webSearchEnabled: boolean;
|
||||
selectedModel: string;
|
||||
sidebarOpen: boolean;
|
||||
airplaneMode: boolean;
|
||||
thinkEnabled: boolean;
|
||||
thinkLevel: string;
|
||||
}
|
||||
@@ -51,7 +50,6 @@ export function useSettings() {
|
||||
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
||||
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
||||
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
||||
airplaneMode: settingsData?.settings?.AirplaneMode ?? false,
|
||||
}),
|
||||
[settingsData?.settings],
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { QueryClient } from "@tanstack/react-query";
|
||||
import { createRootRouteWithContext, Outlet } from "@tanstack/react-router";
|
||||
import { getSettings } from "@/api";
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||
|
||||
function RootComponent() {
|
||||
// This hook ensures settings are fetched on app startup
|
||||
@@ -9,6 +10,8 @@ function RootComponent() {
|
||||
queryKey: ["settings"],
|
||||
queryFn: getSettings,
|
||||
});
|
||||
// Fetch cloud status on startup (best-effort)
|
||||
useCloudStatus();
|
||||
|
||||
return (
|
||||
<div>
|
||||
|
||||
@@ -41,14 +41,14 @@ describe("Model merging logic", () => {
|
||||
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
|
||||
});
|
||||
|
||||
it("should hide cloud models in airplane mode", () => {
|
||||
it("should hide cloud models when cloud is disabled", () => {
|
||||
const localModels: Model[] = [
|
||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||
new Model({ model: "llama3:latest" }),
|
||||
new Model({ model: "mistral:latest" }),
|
||||
];
|
||||
|
||||
const merged = mergeModels(localModels, true); // airplane mode = true
|
||||
const merged = mergeModels(localModels, true); // cloud disabled = true
|
||||
|
||||
// No cloud models should be present
|
||||
const cloudModels = merged.filter((m) => m.isCloud());
|
||||
|
||||
@@ -32,7 +32,7 @@ function alphabeticalSort(a: Model, b: Model): number {
|
||||
//Merges models, sorting cloud models first, then other models
|
||||
export function mergeModels(
|
||||
localModels: Model[],
|
||||
airplaneMode: boolean = false,
|
||||
hideCloudModels: boolean = false,
|
||||
): Model[] {
|
||||
const allModels = (localModels || []).map((model) => model);
|
||||
|
||||
@@ -95,7 +95,7 @@ export function mergeModels(
|
||||
|
||||
remainingModels.sort(alphabeticalSort);
|
||||
|
||||
return airplaneMode
|
||||
return hideCloudModels
|
||||
? [...featuredModels, ...remainingModels]
|
||||
: [...cloudModels, ...featuredModels, ...remainingModels];
|
||||
}
|
||||
|
||||
@@ -45,7 +45,8 @@ type InferenceCompute struct {
|
||||
}
|
||||
|
||||
type InferenceComputeResponse struct {
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
DefaultContextLength int `json:"defaultContextLength"`
|
||||
}
|
||||
|
||||
type ModelCapabilitiesResponse struct {
|
||||
|
||||
55
app/ui/ui.go
55
app/ui/ui.go
@@ -284,12 +284,15 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||
mux.Handle("GET /api/v1/cloud", handle(s.getCloudSetting))
|
||||
mux.Handle("POST /api/v1/cloud", handle(s.cloudSetting))
|
||||
|
||||
// Ollama proxy endpoints
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
mux.Handle("GET /api/tags", ollamaProxy)
|
||||
mux.Handle("POST /api/show", ollamaProxy)
|
||||
mux.Handle("GET /api/version", ollamaProxy)
|
||||
mux.Handle("GET /api/status", ollamaProxy)
|
||||
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||
mux.Handle("POST /api/me", ollamaProxy)
|
||||
mux.Handle("POST /api/signout", ollamaProxy)
|
||||
@@ -1417,11 +1420,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
|
||||
settings.Models = envconfig.Models()
|
||||
}
|
||||
|
||||
// set default context length if not set
|
||||
if settings.ContextLength == 0 {
|
||||
settings.ContextLength = 4096
|
||||
}
|
||||
|
||||
// Include current runtime settings
|
||||
settings.Agent = s.Agent
|
||||
settings.Tools = s.Tools
|
||||
@@ -1460,17 +1458,51 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) cloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return fmt.Errorf("invalid request body: %w", err)
|
||||
}
|
||||
|
||||
if err := s.Store.SetCloudEnabled(req.Enabled); err != nil {
|
||||
return fmt.Errorf("failed to persist cloud setting: %w", err)
|
||||
}
|
||||
|
||||
s.Restart()
|
||||
|
||||
return s.writeCloudStatus(w)
|
||||
}
|
||||
|
||||
func (s *Server) getCloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.writeCloudStatus(w)
|
||||
}
|
||||
|
||||
func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
||||
disabled, source, err := s.Store.CloudStatus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load cloud status: %w", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(map[string]any{
|
||||
"disabled": disabled,
|
||||
"source": source,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
|
||||
info, err := server.GetInferenceInfo(ctx)
|
||||
if err != nil {
|
||||
s.log().Error("failed to get inference compute", "error", err)
|
||||
return fmt.Errorf("failed to get inference compute: %w", err)
|
||||
s.log().Error("failed to get inference info", "error", err)
|
||||
return fmt.Errorf("failed to get inference info: %w", err)
|
||||
}
|
||||
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
|
||||
for i, ic := range serverInferenceComputes {
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
|
||||
for i, ic := range info.Computes {
|
||||
inferenceComputes[i] = responses.InferenceCompute{
|
||||
Library: ic.Library,
|
||||
Variant: ic.Variant,
|
||||
@@ -1482,7 +1514,8 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
||||
}
|
||||
|
||||
response := responses.InferenceComputeResponse{
|
||||
InferenceComputes: inferenceComputes,
|
||||
InferenceComputes: inferenceComputes,
|
||||
DefaultContextLength: info.DefaultContextLength,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -115,6 +115,107 @@ func TestHandlePostApiSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePostApiCloudSetting(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
restartCount := 0
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {
|
||||
restartCount++
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
body string
|
||||
wantEnabled bool
|
||||
}{
|
||||
{name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
|
||||
{name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.cloudSetting(rr, req); err != nil {
|
||||
t.Fatalf("cloudSetting() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("cloudSetting() invalid response JSON: %v", err)
|
||||
}
|
||||
if got["disabled"] != !tc.wantEnabled {
|
||||
t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
|
||||
}
|
||||
|
||||
disabled, err := testStore.CloudDisabled()
|
||||
if err != nil {
|
||||
t.Fatalf("CloudDisabled() error = %v", err)
|
||||
}
|
||||
if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
|
||||
t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if restartCount != 2 {
|
||||
t.Fatalf("Restart called %d times, want 2", restartCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetApiCloudSetting(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
if err := testStore.SetCloudEnabled(false); err != nil {
|
||||
t.Fatalf("SetCloudEnabled(false) error = %v", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
if err := server.getCloudSetting(rr, req); err != nil {
|
||||
t.Fatalf("getCloudSetting() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
|
||||
}
|
||||
if got["disabled"] != true {
|
||||
t.Fatalf("response disabled = %v, want true", got["disabled"])
|
||||
}
|
||||
if got["source"] != "config" {
|
||||
t.Fatalf("response source = %v, want config", got["source"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
64
cmd/cmd.go
64
cmd/cmd.go
@@ -57,9 +57,9 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -1886,12 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
result, err := tui.SelectSingle(title, tuiItems)
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
@@ -1899,10 +1911,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := make([]tui.SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
tuiItems[i] = tui.SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
@@ -1947,9 +1956,13 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) {
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
@@ -1971,7 +1984,7 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" {
|
||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
@@ -1999,6 +2012,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
@@ -2008,6 +2024,17 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if len(result.Models) > 0 {
|
||||
// Filter out cloud-disabled models
|
||||
var filtered []string
|
||||
for _, m := range result.Models {
|
||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
result.Models = filtered
|
||||
// Multi-select from modal (Editor integrations)
|
||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
@@ -2017,8 +2044,11 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else if result.Model != "" {
|
||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
||||
continue
|
||||
}
|
||||
// Single-select from modal - save and launch
|
||||
if err := config.SaveIntegrationModel(result.Integration, result.Model); err != nil {
|
||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
@@ -2130,6 +2160,9 @@ func NewCLI() *cobra.Command {
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
|
||||
runCmd.Flags().MarkHidden("imagegen")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
Short: "Stop a running model",
|
||||
@@ -2273,6 +2306,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_MAX_QUEUE"],
|
||||
envVars["OLLAMA_MODELS"],
|
||||
envVars["OLLAMA_NUM_PARALLEL"],
|
||||
envVars["OLLAMA_NO_CLOUD"],
|
||||
envVars["OLLAMA_NOPRUNE"],
|
||||
envVars["OLLAMA_ORIGINS"],
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
|
||||
@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items)
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"qwen3:8b"})
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
@@ -162,7 +162,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"llama3.2:70b"})
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
@@ -187,7 +187,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
saveIntegration("claude", []string{"saved-model"})
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
|
||||
123
cmd/config/cline.go
Normal file
123
cmd/config/cline.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0]
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0]
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
204
cmd/config/cline_test.go
Normal file
204
cmd/config/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -56,8 +57,8 @@ func migrateConfig() (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
// Ignore legacy files with invalid JSON and continue startup.
|
||||
if !json.Valid(oldData) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -126,7 +127,7 @@ func save(cfg *config) error {
|
||||
return writeWithBackup(path, data)
|
||||
}
|
||||
|
||||
func saveIntegration(appName string, models []string) error {
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
@@ -139,34 +140,54 @@ func saveIntegration(appName string, models []string) error {
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
var onboarded bool
|
||||
if existing != nil {
|
||||
aliases = existing.Aliases
|
||||
onboarded = existing.Onboarded
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Onboarded: onboarded,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
existing.Onboarded = true
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
return integrationConfig.Models[0]
|
||||
}
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ic.Models
|
||||
return integrationConfig.Models
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
@@ -234,12 +255,12 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
@@ -272,8 +293,8 @@ func listIntegrations() ([]integration, error) {
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
for _, integrationConfig := range cfg.Integrations {
|
||||
result = append(result, *integrationConfig)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
@@ -604,7 +604,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -619,7 +619,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
@@ -634,7 +634,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -650,7 +650,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
@@ -662,7 +662,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
@@ -74,14 +74,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
@@ -94,7 +94,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
defaultModel := ""
|
||||
@@ -118,7 +118,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
saveIntegration("Claude", []string{"model-x"})
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
@@ -134,8 +134,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
saveIntegration("app1", []string{"model-1"})
|
||||
saveIntegration("app2", []string{"model-2"})
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
@@ -172,8 +172,8 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
saveIntegration("claude", []string{"model-1"})
|
||||
saveIntegration("droid", []string{"model-2"})
|
||||
SaveIntegration("claude", []string{"model-1"})
|
||||
SaveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
@@ -261,7 +261,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveIntegration("test", nil); err != nil {
|
||||
if err := SaveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := saveIntegration("", []string{"model"})
|
||||
err := SaveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
@@ -511,7 +511,7 @@ func TestMigrateConfig(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -51,6 +52,16 @@ func (d *Droid) Run(model string, args []string) error {
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -52,6 +53,7 @@ type AliasConfigurer interface {
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"cline": &Cline{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
@@ -63,12 +65,33 @@ var integrations = map[string]Runner{
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3:8b", Description: "Efficient all-purpose assistant", Recommended: true},
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.5": {Context: 204_800, Output: 128_000},
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
||||
var recommendedVRAM = map[string]string{
|
||||
"glm-4.7-flash": "~25GB",
|
||||
@@ -79,16 +102,17 @@ var recommendedVRAM = map[string]string{
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
"pi": true,
|
||||
}
|
||||
|
||||
// integrationInstallHints maps integration names to install URLs.
|
||||
var integrationInstallHints = map[string]string{
|
||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||
"cline": "https://cline.bot/cli",
|
||||
"openclaw": "https://docs.openclaw.ai",
|
||||
"codex": "https://developers.openai.com/codex/cli/",
|
||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
"opencode": "https://opencode.ai",
|
||||
"pi": "https://github.com/badlogic/pi-mono",
|
||||
}
|
||||
|
||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||
@@ -106,13 +130,21 @@ type IntegrationInfo struct {
|
||||
// integrationDescriptions maps integration names to short descriptions.
|
||||
var integrationDescriptions = map[string]string{
|
||||
"claude": "Anthropic's coding tool with subagents",
|
||||
"cline": "Autonomous coding agent with parallel execution",
|
||||
"codex": "OpenAI's open-source coding agent",
|
||||
"openclaw": "Personal AI with 100+ skills",
|
||||
"droid": "Factory's coding agent across terminal and IDEs",
|
||||
"opencode": "Anomaly's open-source coding agent",
|
||||
"pi": "Minimal AI agent toolkit with plugin support",
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
||||
// integrationOrder defines a custom display order for integrations.
|
||||
// Integrations listed here are placed at the end in the given order;
|
||||
// all others appear first, sorted alphabetically.
|
||||
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
|
||||
// with integrationOrder entries placed at the end.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
var result []IntegrationInfo
|
||||
for name, r := range integrations {
|
||||
@@ -125,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
|
||||
Description: integrationDescriptions[name],
|
||||
})
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
|
||||
}
|
||||
|
||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
// Both have custom order: sort by their rank
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
// Only one has custom order: it goes last
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
// Neither has custom order: alphabetical
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return result
|
||||
@@ -163,14 +214,45 @@ func IsIntegrationInstalled(name string) bool {
|
||||
case "droid":
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
case "cline":
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
case "opencode":
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
case "pi":
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
default:
|
||||
return true // Assume installed for unknown integrations
|
||||
}
|
||||
}
|
||||
|
||||
// AutoInstallable returns true if the integration can be automatically
|
||||
// installed when not found (e.g. via npm).
|
||||
func AutoInstallable(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "openclaw", "clawdbot", "moltbot":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureInstalled checks if an auto-installable integration is present and
|
||||
// offers to install it if missing. Returns nil for non-auto-installable
|
||||
// integrations or when the binary is already on PATH.
|
||||
func EnsureInstalled(name string) error {
|
||||
if !AutoInstallable(name) {
|
||||
return nil
|
||||
}
|
||||
if IsIntegrationInstalled(name) {
|
||||
return nil
|
||||
}
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
}
|
||||
|
||||
// IsEditorIntegration returns true if the named integration uses multi-model
|
||||
// selection (implements the Editor interface).
|
||||
func IsEditorIntegration(name string) bool {
|
||||
@@ -191,7 +273,8 @@ type ModelItem struct {
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
@@ -213,6 +296,11 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
@@ -221,11 +309,15 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
selected, err := selector("Select model to run:", items)
|
||||
selected, err := selector("Select model to run:", items, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -335,13 +427,11 @@ func selectIntegration() (string, error) {
|
||||
return "", fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []ModelItem
|
||||
for _, name := range names {
|
||||
for name, r := range integrations {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||
@@ -349,7 +439,25 @@ func selectIntegration() (string, error) {
|
||||
items = append(items, ModelItem{Name: name, Description: description})
|
||||
}
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items)
|
||||
orderRank := make(map[string]int, len(integrationOrder))
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
slices.SortFunc(items, func(a, b ModelItem) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return DefaultSingleSelector("Select integration:", items, "")
|
||||
}
|
||||
|
||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||
@@ -374,6 +482,11 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
@@ -383,6 +496,10 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
@@ -398,7 +515,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := single(prompt, items)
|
||||
model, err := single(prompt, items, current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -489,8 +606,17 @@ func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]b
|
||||
})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
@@ -519,6 +645,9 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
@@ -651,25 +780,6 @@ func LaunchIntegrationWithModel(name, modelName string) error {
|
||||
return runIntegration(name, modelName, nil)
|
||||
}
|
||||
|
||||
// SaveIntegrationModel saves the model for an integration.
|
||||
func SaveIntegrationModel(name, modelName string) error {
|
||||
// Load existing models and prepend the new one
|
||||
var models []string
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
models = existing.Models
|
||||
// Remove the model if it already exists
|
||||
for i, m := range models {
|
||||
if m == modelName {
|
||||
models = append(models[:i], models[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Prepend the new model
|
||||
models = append([]string{modelName}, models...)
|
||||
return saveIntegration(name, models)
|
||||
}
|
||||
|
||||
// SaveAndEditIntegration saves the models for an Editor integration and runs its Edit method
|
||||
// to write the integration's config files.
|
||||
func SaveAndEditIntegration(name string, models []string) error {
|
||||
@@ -677,7 +787,7 @@ func SaveAndEditIntegration(name string, models []string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
if editor, isEditor := r.(Editor); isEditor {
|
||||
@@ -688,6 +798,29 @@ func SaveAndEditIntegration(name string, models []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveEditorModels filters out cloud-disabled models before editor launch.
|
||||
// If no models remain, it invokes picker to collect a valid replacement list.
|
||||
func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) {
|
||||
filtered := filterDisabledCloudModels(models)
|
||||
if len(filtered) != len(models) {
|
||||
if err := SaveIntegration(name, filtered); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
selected, err := picker()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := SaveIntegration(name, selected); err != nil {
|
||||
return nil, fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
r, ok := integrations[name]
|
||||
@@ -722,7 +855,7 @@ func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -755,10 +888,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
@@ -816,6 +951,14 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if err := EnsureInstalled(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
|
||||
modelFlag = ""
|
||||
}
|
||||
|
||||
// Handle AliasConfigurer integrations (claude, codex)
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -843,7 +986,7 @@ Examples:
|
||||
model = cfg.Models[0]
|
||||
// AliasConfigurer integrations use single model; sanitize if multiple
|
||||
if len(cfg.Models) > 1 {
|
||||
_ = saveIntegration(name, []string{model})
|
||||
_ = SaveIntegration(name, []string{model})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -855,7 +998,9 @@ Examples:
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
||||
model = ""
|
||||
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||
model = ""
|
||||
@@ -863,18 +1008,16 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid model or --config flag, show picker
|
||||
if model == "" || configFlag {
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
// Show picker so user can change model (skip when --model flag provided)
|
||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model = aliases["primary"]
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
@@ -887,7 +1030,7 @@ Examples:
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
if err := saveIntegration(name, []string{model}); err != nil {
|
||||
if err := SaveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -925,11 +1068,24 @@ Examples:
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
||||
return runIntegration(name, saved.Models[0], passArgs)
|
||||
models = filterDisabledCloudModels(models)
|
||||
if len(models) == 0 {
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
current := ""
|
||||
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
|
||||
current = saved.Models[0]
|
||||
}
|
||||
var err error
|
||||
models, err = selectModels(cmd.Context(), name, "")
|
||||
models, err = selectModels(cmd.Context(), name, current)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
@@ -953,7 +1109,7 @@ Examples:
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveIntegration(name, models); err != nil {
|
||||
if err := SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
@@ -1027,7 +1183,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
if isCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -1062,7 +1218,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "install?")
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
@@ -1132,7 +1288,55 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using the Show API.
|
||||
// IsCloudModelDisabled reports whether the given model name looks like a cloud
|
||||
// model and cloud features are currently disabled on the server.
|
||||
func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
||||
if !isCloudModelName(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
disabled, _ := cloudStatusDisabled(ctx, client)
|
||||
return disabled
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterDisabledCloudModels removes cloud models from a list when cloud is disabled.
|
||||
func filterDisabledCloudModels(models []string) []string {
|
||||
var filtered []string
|
||||
for _, m := range models {
|
||||
if !IsCloudModelDisabled(context.Background(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !isCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
@@ -1162,6 +1366,11 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, client)
|
||||
if cloudDisabled {
|
||||
existing = filterCloudModels(existing)
|
||||
}
|
||||
|
||||
lastModel := LastModel()
|
||||
var preChecked []string
|
||||
if lastModel != "" {
|
||||
@@ -1170,9 +1379,25 @@ func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
|
||||
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
}
|
||||
|
||||
return items, existingModels
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@@ -16,6 +16,28 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
edited [][]string
|
||||
ranModel string
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) Run(model string, args []string) error {
|
||||
s.ranModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) String() string { return "StubEditor" }
|
||||
|
||||
func (s *stubEditorRunner) Paths() []string { return nil }
|
||||
|
||||
func (s *stubEditorRunner) Edit(models []string) error {
|
||||
cloned := append([]string(nil), models...)
|
||||
s.edited = append(s.edited, cloned)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) Models() []string { return nil }
|
||||
|
||||
func TestIntegrationLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -149,6 +171,10 @@ func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
@@ -394,14 +420,14 @@ func names(items []ModelItem) []string {
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,7 +442,7 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
got := names(items)
|
||||
|
||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-5:cloud", "kimi-k2.5:cloud", "llama3.2", "qwen2.5"}
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "llama3.2", "qwen2.5"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -432,7 +458,7 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
got := names(items)
|
||||
|
||||
// All recs pinned at top (cloud before local in mixed case), then non-recs
|
||||
want := []string{"glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -463,12 +489,12 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "glm-5:cloud":
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -486,7 +512,7 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
// glm-4.7-flash and glm-5:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
// All recs: cloud first in mixed case, then local, in rec order within each
|
||||
want := []string{"glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -504,15 +530,15 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
// All recs pinned at top (cloud first in mixed case), then non-recs
|
||||
want := []string{"glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -648,7 +674,7 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
|
||||
lastRecIdx := -1
|
||||
firstNonRecIdx := len(got)
|
||||
for i, name := range got {
|
||||
isRec := name == "glm-4.7-flash" || name == "qwen3:8b" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud"
|
||||
isRec := name == "glm-4.7-flash" || name == "qwen3:8b" || name == "minimax-m2.5:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud"
|
||||
if isRec && i > lastRecIdx {
|
||||
lastRecIdx = i
|
||||
}
|
||||
@@ -680,7 +706,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save a config for opencode so it looks like a previous launch
|
||||
if err := saveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
if err := SaveIntegration("opencode", []string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -697,6 +723,137 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if !pickerCalled {
|
||||
t.Fatal("expected model picker to be called when all models are filtered")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"qwen3:8b"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if pickerCalled {
|
||||
t.Fatal("picker should not be called when a local model remains")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &stubEditorRunner{}
|
||||
old, existed := integrations["stubeditor"]
|
||||
integrations["stubeditor"] = stub
|
||||
defer func() {
|
||||
if existed {
|
||||
integrations["stubeditor"] = old
|
||||
} else {
|
||||
delete(integrations, "stubeditor")
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
@@ -1091,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sorted by name", func(t *testing.T) {
|
||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||
// All other entries should be sorted alphabetically before them.
|
||||
orderRank := make(map[string]int)
|
||||
for i, name := range integrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
for i := 1; i < len(infos); i++ {
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||
switch {
|
||||
case aRank == 0 && bRank == 0:
|
||||
if infos[i-1].Name >= infos[i].Name {
|
||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||
}
|
||||
case aRank > 0 && bRank == 0:
|
||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||
case aRank > 0 && bRank > 0:
|
||||
if aRank >= bRank {
|
||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1234,7 +1407,7 @@ func TestIntegrationModels(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns all saved models", func(t *testing.T) {
|
||||
if err := saveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
|
||||
if err := SaveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := IntegrationModels("droid")
|
||||
|
||||
@@ -1,69 +1,287 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const defaultGatewayPort = 18789
|
||||
|
||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||
var openclawModelShowTimeout = 5 * time.Second
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
bin, err := ensureOpenclawInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
"--install-daemon",
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
|
||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
||||
// that Edit() wrote before Run() was called.
|
||||
if err := c.Edit([]string{model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
if len(args) > 0 {
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
if err := restart.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
if !waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// If the gateway isn't running, start it as a background child process.
|
||||
if !portOpen(addr) {
|
||||
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||
gw.Env = openclawEnv()
|
||||
if err := gw.Start(); err != nil {
|
||||
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||
}
|
||||
defer func() {
|
||||
if gw.Process != nil {
|
||||
_ = gw.Process.Kill()
|
||||
_ = gw.Wait()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
if !waitForPort(addr, 30*time.Second) {
|
||||
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||
}
|
||||
|
||||
printOpenclawReady(bin, token, port, firstLaunch)
|
||||
|
||||
tuiArgs := []string{"tui"}
|
||||
if firstLaunch {
|
||||
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||
}
|
||||
tui := exec.Command(bin, tuiArgs...)
|
||||
tui.Env = openclawEnv()
|
||||
tui.Stdin = os.Stdin
|
||||
tui.Stdout = os.Stdout
|
||||
tui.Stderr = os.Stderr
|
||||
if err := tui.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||
port = defaultGatewayPort
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", port
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
continue
|
||||
}
|
||||
gw, _ := config["gateway"].(map[string]any)
|
||||
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||
port = int(p)
|
||||
}
|
||||
auth, _ := gw["auth"].(map[string]any)
|
||||
if t, _ := auth["token"].(string); t != "" {
|
||||
token = t
|
||||
}
|
||||
return token, port
|
||||
}
|
||||
return "", port
|
||||
}
|
||||
|
||||
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||
u := fmt.Sprintf("http://localhost:%d", port)
|
||||
if token != "" {
|
||||
u += "/#token=" + url.QueryEscape(token)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// openclawEnv returns the current environment with provider API keys cleared
|
||||
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||
func openclawEnv() []string {
|
||||
clear := map[string]bool{
|
||||
"ANTHROPIC_API_KEY": true,
|
||||
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||
"OPENAI_API_KEY": true,
|
||||
"GEMINI_API_KEY": true,
|
||||
"MISTRAL_API_KEY": true,
|
||||
"GROQ_API_KEY": true,
|
||||
"XAI_API_KEY": true,
|
||||
"OPENROUTER_API_KEY": true,
|
||||
}
|
||||
var env []string
|
||||
for _, e := range os.Environ() {
|
||||
key, _, _ := strings.Cut(e, "=")
|
||||
if !clear[key] {
|
||||
env = append(env, e)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// portOpen checks if a TCP port is currently accepting connections.
|
||||
func portOpen(addr string) bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func waitForPort(addr string, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func windowsHint(err error) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\n"+
|
||||
"OpenClaw runs best on WSL2.\n"+
|
||||
"Quick setup: wsl --install\n"+
|
||||
"Guide: https://docs.openclaw.ai/windows", err)
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
@@ -95,6 +313,144 @@ func (c *Openclaw) onboarded() bool {
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||
// operator.admin. Only patches the local device, not remote ones.
|
||||
// Best-effort: silently returns on any error.
|
||||
func patchDeviceScopes() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := readLocalDeviceID(home)
|
||||
if deviceID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var devices map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &devices); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, ok := devices[deviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"operator.read",
|
||||
"operator.admin",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
changed := patchScopes(dev, "scopes", required)
|
||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||
for _, tok := range tokens {
|
||||
if tokenMap, ok := tok.(map[string]any); ok {
|
||||
if patchScopes(tokenMap, "scopes", required) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(devices, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||
func readLocalDeviceID(home string) string {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var auth map[string]any
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return ""
|
||||
}
|
||||
id, _ := auth["deviceId"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||
// any scopes were added.
|
||||
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
existing, _ := obj[key].([]any)
|
||||
have := make(map[string]bool, len(existing))
|
||||
for _, s := range existing {
|
||||
if str, ok := s.(string); ok {
|
||||
have[str] = true
|
||||
}
|
||||
}
|
||||
added := false
|
||||
for _, s := range required {
|
||||
if !have[s] {
|
||||
existing = append(existing, s)
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if added {
|
||||
obj[key] = existing
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("openclaw installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
@@ -149,8 +505,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
ollama["api"] = "ollama"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
@@ -163,25 +518,13 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
for _, m := range models {
|
||||
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
if existing, ok := existingByID[m]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
@@ -218,7 +561,213 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear any per-session model overrides so the new primary takes effect
|
||||
// immediately rather than being shadowed by a cached modelOverride.
|
||||
clearSessionModelOverride(models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearSessionModelOverride removes per-session model overrides from the main
|
||||
// agent session so the global primary model takes effect on the next TUI launch.
|
||||
func clearSessionModelOverride(primary string) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if json.Unmarshal(data, &sessions) != nil {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, sess := range sessions {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
out, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
|
||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||
// present. Returns true if the extension is available.
|
||||
func ensureWebSearchPlugin() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
||||
return true // already installed
|
||||
}
|
||||
|
||||
npmBin, err := exec.LookPath("npm")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||
out, err := pack.Output()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
tgzName := strings.TrimSpace(string(out))
|
||||
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||
defer os.Remove(tgzPath)
|
||||
|
||||
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||
if err := tar.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||
return true
|
||||
}
|
||||
|
||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||
// on any error.
|
||||
func registerWebSearchPlugin() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
plugins = make(map[string]any)
|
||||
}
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
return // already registered
|
||||
}
|
||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||
plugins["entries"] = entries
|
||||
config["plugins"] = plugins
|
||||
|
||||
// Disable the built-in web search since our plugin replaces it.
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
web["search"] = map[string]any{"enabled": false}
|
||||
tools["web"] = web
|
||||
config["tools"] = tools
|
||||
|
||||
out, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(configPath, out, 0o600)
|
||||
}
|
||||
|
||||
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||
// The second return value indicates whether the model is a cloud (remote) model.
|
||||
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||
entry := map[string]any{
|
||||
"id": modelID,
|
||||
"name": modelID,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
showCtx := ctx
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
entry["input"] = []any{"text", "image"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
entry["reasoning"] = true
|
||||
}
|
||||
|
||||
// Cloud models: use hardcoded limits for context/output tokens.
|
||||
// Capability detection above still applies (vision, thinking).
|
||||
if resp.RemoteModel != "" {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
entry["contextWindow"] = l.Context
|
||||
entry["maxTokens"] = l.Output
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo (local models only)
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
entry["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return entry, false
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
@@ -26,6 +36,124 @@ func TestOpenclawIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawRunPassthroughArgs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bin := filepath.Join(tmpDir, "openclaw")
|
||||
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "gateway --someflag" {
|
||||
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
t.Run("success persists onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
@@ -359,19 +487,16 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
// Verify base schema fields (always set regardless of API availability)
|
||||
if entry["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", entry["id"])
|
||||
}
|
||||
if entry["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", entry["name"])
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
@@ -876,3 +1001,589 @@ func TestOpenclawOnboarded(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawGatewayInfo(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns defaults when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reads token and port from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {
|
||||
"port": 9999,
|
||||
"auth": {"mode": "token", "token": "my-secret"}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "my-secret" {
|
||||
t.Errorf("expected token %q, got %q", "my-secret", token)
|
||||
}
|
||||
if port != 9999 {
|
||||
t.Errorf("expected port 9999, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses default port when not in config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {"auth": {"token": "tok"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "tok" {
|
||||
t.Errorf("expected token %q, got %q", "tok", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "legacy-token" {
|
||||
t.Errorf("expected token %q, got %q", "legacy-token", token)
|
||||
}
|
||||
if port != 12345 {
|
||||
t.Errorf("expected port 12345, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles missing gateway section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrintOpenclawReady(t *testing.T) {
|
||||
t.Run("includes port in URL", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", 9999, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "localhost:9999") {
|
||||
t.Errorf("expected port 9999 in output, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "#token=") {
|
||||
t.Error("should not include token fragment when token is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("URL-escapes token", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
escaped := url.QueryEscape("my token&special=chars")
|
||||
if !strings.Contains(output, "#token="+escaped) {
|
||||
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple token is not mangled", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "#token=ollama") {
|
||||
t.Errorf("expected #token=ollama in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes web UI hint", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Open the Web UI") {
|
||||
t.Errorf("expected web UI hint in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("first launch shows quick start tips", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subsequent launch shows single tip", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Tip:") {
|
||||
t.Errorf("expected single tip line, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "Quick start") {
|
||||
t.Errorf("should not show quick start on subsequent launch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModelConfig(t *testing.T) {
|
||||
t.Run("nil client returns base config", func(t *testing.T) {
|
||||
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
|
||||
|
||||
if cfg["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", cfg["id"])
|
||||
}
|
||||
if cfg["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", cfg["name"])
|
||||
}
|
||||
if cfg["cost"] == nil {
|
||||
t.Error("cost should be set")
|
||||
}
|
||||
// Should not have capability fields without API
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set without API")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set without API")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 1 {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns base config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
// Should still have input (default)
|
||||
if cfg["input"] == nil {
|
||||
t.Error("input should always be set")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("times out slow show and returns base config", func(t *testing.T) {
|
||||
oldTimeout := openclawModelShowTimeout
|
||||
openclawModelShowTimeout = 50 * time.Millisecond
|
||||
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
start := time.Now()
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
|
||||
elapsed := time.Since(start)
|
||||
if elapsed >= 250*time.Millisecond {
|
||||
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
|
||||
}
|
||||
if cfg["id"] != "slow-model" {
|
||||
t.Errorf("id = %v, want slow-model", cfg["id"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set on timeout")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set on timeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
|
||||
// Use a model name that's in cloudModelLimits and make the server
|
||||
// report it as a remote/cloud model
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud model")
|
||||
}
|
||||
if cfg["contextWindow"] != 204_800 {
|
||||
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
|
||||
}
|
||||
if cfg["maxTokens"] != 128_000 {
|
||||
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports vision capability should have input: [text, image].
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud vision model")
|
||||
}
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports thinking capability should have reasoning: true.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud thinking model")
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for cloud thinking model")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationOnboarded(t *testing.T) {
|
||||
t.Run("returns false when not set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("is case insensitive", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing integration data", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
model := IntegrationModel("openclaw")
|
||||
if model != "llama3.2" {
|
||||
t.Errorf("expected first model llama3.2, got %q", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
@@ -24,25 +25,6 @@ type cloudModelLimit struct {
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
@@ -70,6 +52,16 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
|
||||
@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if len(m.checkOrder) > 0 {
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
m.toggleItem()
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -576,15 +616,36 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -596,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -618,6 +679,11 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -642,7 +708,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -665,7 +731,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,7 +751,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -697,15 +763,18 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -728,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled {
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
for _, idx := range fm.checkOrder {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- cursorForCurrent ---
|
||||
|
||||
func TestCursorForCurrent(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "llama3.2", Recommended: true},
|
||||
{Name: "qwen3:8b", Recommended: true},
|
||||
{Name: "gemma3:latest"},
|
||||
{Name: "deepseek-r1"},
|
||||
{Name: "glm-5:cloud"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current string
|
||||
want int
|
||||
}{
|
||||
{"empty current", "", 0},
|
||||
{"exact match", "qwen3:8b", 1},
|
||||
{"no match returns 0", "nonexistent", 0},
|
||||
{"bare name matches with :latest suffix", "gemma3", 2},
|
||||
{"full tag matches bare item", "llama3.2:latest", 0},
|
||||
{"cloud model exact match", "glm-5:cloud", 4},
|
||||
{"cloud model bare name", "glm-5", 4},
|
||||
{"recommended item exact match", "llama3.2", 0},
|
||||
{"recommended item with tag", "qwen3", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReorderItems ---
|
||||
|
||||
func TestReorderItems(t *testing.T) {
|
||||
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("first checked item should have (default) tag")
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,6 +589,200 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-select space toggle (including KeyRunes fallback for Windows PowerShell) ---
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if !m.checked[1] {
|
||||
t.Error("space (KeySpace) should toggle the item at cursor")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space should not modify filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if !m.checked[1] {
|
||||
t.Error("space (KeyRunes) should toggle the item at cursor")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune should not be added to filter")
|
||||
}
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should stay at 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -131,7 +131,7 @@ type model struct {
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
@@ -209,7 +209,26 @@ func (m *model) openMultiModelModal(integration string) {
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud")
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
status, err := client.CloudStatusExperimental(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return status.Cloud.Disabled
|
||||
}
|
||||
|
||||
func cloudModelDisabled(name string) bool {
|
||||
if !isCloudModel(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cloudStatusDisabled(client)
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
@@ -222,6 +241,9 @@ func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if cloudStatusDisabled(client) {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
@@ -272,7 +294,11 @@ func (m *model) loadAvailableModels() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cloudDisabled := cloudStatusDisabled(client)
|
||||
for _, mdl := range models.Models {
|
||||
if cloudDisabled && mdl.RemoteModel != "" {
|
||||
continue
|
||||
}
|
||||
m.availableModels[mdl.Name] = true
|
||||
}
|
||||
}
|
||||
@@ -403,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
@@ -482,7 +524,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -496,6 +538,15 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
m.openModelModal(configuredModel)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
@@ -504,6 +555,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
@@ -567,7 +624,11 @@ func (m model) View() string {
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
@@ -583,7 +644,9 @@ func (m model) View() string {
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
|
||||
@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bts = sanitizeNonFiniteJSON(bts)
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
case "Lfm2VlForConditionalGeneration":
|
||||
conv = &lfm2VLTextModel{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -13,42 +15,149 @@ type lfm2Model struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
BlockFFDim uint32 `json:"block_ff_dim"`
|
||||
BlockMultipleOf uint32 `json:"block_multiple_of"`
|
||||
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
|
||||
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
ConvLCache uint32 `json:"conv_L_cache"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NumDenseLayers uint32 `json:"num_dense_layers"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieEmbedding bool `json:"tie_embedding"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*lfm2Model)(nil)
|
||||
|
||||
const (
|
||||
defaultMaxPositionEmbeddings = uint32(128_000)
|
||||
fallbackContextLength = uint32(32_768)
|
||||
)
|
||||
|
||||
func (p *lfm2Model) isMoE() bool {
|
||||
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
|
||||
}
|
||||
|
||||
func (p *lfm2Model) ropeFreqBase() float32 {
|
||||
if p.RopeTheta != 0 {
|
||||
return p.RopeTheta
|
||||
}
|
||||
|
||||
return p.RopeParameters.RopeTheta
|
||||
}
|
||||
|
||||
func (p *lfm2Model) expertCount() uint32 {
|
||||
if p.NumLocalExperts > 0 {
|
||||
return p.NumLocalExperts
|
||||
}
|
||||
return p.NumExperts
|
||||
}
|
||||
|
||||
func (p *lfm2Model) feedForwardLength() uint32 {
|
||||
ff := p.IntermediateSize
|
||||
if p.BlockFFDim != 0 {
|
||||
ff = p.BlockFFDim
|
||||
}
|
||||
|
||||
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
|
||||
return ff
|
||||
}
|
||||
|
||||
ff = (2 * ff) / 3
|
||||
|
||||
// Keep default multiplier behavior consistent with llama.cpp conversion.
|
||||
if p.BlockFFNDimMultiplier != 0 {
|
||||
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
|
||||
}
|
||||
|
||||
m := p.BlockMultipleOf
|
||||
return m * ((ff + m - 1) / m)
|
||||
}
|
||||
|
||||
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
|
||||
return p.isMoE() &&
|
||||
p.VocabSize == 65536 &&
|
||||
p.HiddenSize == 2048 &&
|
||||
p.NumHiddenLayers == 40 &&
|
||||
p.IntermediateSize == 11776 &&
|
||||
p.NumAttentionHeads == 32 &&
|
||||
p.NumKeyValueHeads == 8 &&
|
||||
p.NumDenseLayers == 2 &&
|
||||
p.expertCount() == 64 &&
|
||||
p.NumExpertsPerToken == 4 &&
|
||||
p.MoEIntermediateSize == 1536
|
||||
}
|
||||
|
||||
func (p *lfm2Model) contextLength() uint32 {
|
||||
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
|
||||
return fallbackContextLength
|
||||
}
|
||||
|
||||
return p.MaxPositionEmbeddings
|
||||
}
|
||||
|
||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
architecture := "lfm2"
|
||||
if p.isMoE() {
|
||||
architecture = "lfm2moe"
|
||||
}
|
||||
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "lfm2"
|
||||
kv["lfm2.vocab_size"] = p.VocabSize
|
||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["general.architecture"] = architecture
|
||||
kv["tokenizer.ggml.pre"] = "lfm2"
|
||||
kv["vocab_size"] = p.VocabSize
|
||||
kv["block_count"] = p.NumHiddenLayers
|
||||
kv["embedding_length"] = p.HiddenSize
|
||||
kv["feed_forward_length"] = p.feedForwardLength()
|
||||
kv["context_length"] = p.contextLength()
|
||||
|
||||
// Build per-layer KV head count array based on layer_types
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads).
|
||||
//
|
||||
// Dense LFM2 in HF defaults to all attention layers when layer_types is absent.
|
||||
// Preserve that behavior to avoid accidentally emitting all-conv metadata.
|
||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
if len(p.LayerTypes) == 0 {
|
||||
for i := range p.NumHiddenLayers {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
} else {
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
||||
kv["attention.head_count"] = p.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = kvHeadCounts
|
||||
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["shortconv.l_cache"] = p.ConvLCache
|
||||
|
||||
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
|
||||
kv["rope.freq_base"] = ropeFreqBase
|
||||
}
|
||||
|
||||
if p.isMoE() {
|
||||
kv["expert_count"] = p.expertCount()
|
||||
kv["expert_used_count"] = p.NumExpertsPerToken
|
||||
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
|
||||
kv["leading_dense_block_count"] = p.NumDenseLayers
|
||||
kv["expert_gating_func"] = uint32(2) // sigmoid
|
||||
kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0))
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
@@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
if p.isMoE() {
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*3)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if i < p.NumDenseLayers {
|
||||
continue
|
||||
}
|
||||
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
})
|
||||
}
|
||||
|
||||
merged, remaining := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
ts = remaining
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
|
||||
@@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
func (p *lfm2Model) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.embedding_norm", "output_norm",
|
||||
"model.embedding_norm", "token_embd_norm",
|
||||
"model.layers", "blk",
|
||||
"operator_norm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
@@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string {
|
||||
"conv.conv", "shortconv.conv",
|
||||
"conv.in_proj", "shortconv.in_proj",
|
||||
"conv.out_proj", "shortconv.out_proj",
|
||||
"feed_forward.gate", "ffn_gate_inp",
|
||||
"feed_forward.expert_bias", "exp_probs_b.bias",
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
|
||||
271
convert/convert_lfm2_test.go
Normal file
271
convert/convert_lfm2_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type lfm2StubTensor struct {
|
||||
tensorBase
|
||||
}
|
||||
|
||||
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
|
||||
return &lfm2StubTensor{
|
||||
tensorBase: tensorBase{
|
||||
name: name,
|
||||
shape: shape,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *lfm2StubTensor) Clone() Tensor {
|
||||
return &lfm2StubTensor{
|
||||
tensorBase: tensorBase{
|
||||
name: t.name,
|
||||
shape: slices.Clone(t.shape),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoEKV(t *testing.T) {
|
||||
var p lfm2Model
|
||||
p.ModelParameters.ModelType = "lfm2_moe"
|
||||
p.VocabSize = 65536
|
||||
p.HiddenSize = 2048
|
||||
p.NumHiddenLayers = 4
|
||||
p.MaxPositionEmbeddings = 128000
|
||||
p.IntermediateSize = 11776
|
||||
p.NumAttentionHeads = 32
|
||||
p.NumKeyValueHeads = 8
|
||||
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
|
||||
p.NormEps = 1e-5
|
||||
p.ConvLCache = 3
|
||||
p.MoEIntermediateSize = 1536
|
||||
p.NumExperts = 64
|
||||
p.NumExpertsPerToken = 4
|
||||
p.NumDenseLayers = 2
|
||||
p.RopeParameters.RopeTheta = 1_000_000
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_count"], uint32(64); got != want {
|
||||
t.Fatalf("expert_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_used_count"], uint32(4); got != want {
|
||||
t.Fatalf("expert_used_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
|
||||
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
|
||||
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["expert_gating_func"], uint32(2); got != want {
|
||||
t.Fatalf("expert_gating_func = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
|
||||
wantHeadCounts := []uint32{0, 8, 0, 8}
|
||||
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
|
||||
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
|
||||
}
|
||||
|
||||
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
|
||||
t.Fatalf("rope.freq_base = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2DenseKV(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
|
||||
HiddenSize: 1024,
|
||||
NumHiddenLayers: 2,
|
||||
MaxPositionEmbeddings: 32768,
|
||||
IntermediateSize: 4096,
|
||||
NumAttentionHeads: 16,
|
||||
NumKeyValueHeads: 4,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
RopeTheta: 10000,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if _, ok := kv["expert_count"]; ok {
|
||||
t.Fatalf("expert_count should not be set for dense lfm2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoETensors(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
|
||||
NumHiddenLayers: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
|
||||
in := []Tensor{
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
|
||||
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
|
||||
}
|
||||
|
||||
out := p.Tensors(in)
|
||||
|
||||
byName := make(map[string][]uint64, len(out))
|
||||
for _, tns := range out {
|
||||
byName[tns.Name] = tns.Shape
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
|
||||
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
|
||||
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
|
||||
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||
}
|
||||
|
||||
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
|
||||
t.Fatalf("missing shortconv tensor")
|
||||
} else if !slices.Equal(got, []uint64{2048, 3}) {
|
||||
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
|
||||
}
|
||||
|
||||
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
|
||||
t.Fatalf("unmerged expert tensor should not be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2MoEReplacements(t *testing.T) {
|
||||
p := lfm2Model{}
|
||||
replacer := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
|
||||
t.Fatalf("expert bias replacement = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
|
||||
t.Fatalf("gate replacement = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 40,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 11776,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: make([]string, 40),
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
MoEIntermediateSize: 1536,
|
||||
NumExperts: 64,
|
||||
NumExpertsPerToken: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
for i := 0; i < len(p.LayerTypes); i++ {
|
||||
p.LayerTypes[i] = "conv"
|
||||
}
|
||||
p.LayerTypes[2] = "full_attention"
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["context_length"], uint32(32768); got != want {
|
||||
t.Fatalf("context_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 39, // mismatch: should not trigger edge case
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 11776,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
MoEIntermediateSize: 1536,
|
||||
NumExperts: 64,
|
||||
NumExpertsPerToken: 4,
|
||||
NumDenseLayers: 2,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["context_length"], uint32(128000); got != want {
|
||||
t.Fatalf("context_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
|
||||
p := lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 16,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
|
||||
BlockFFDim: 12288,
|
||||
BlockAutoAdjustFFDim: true,
|
||||
BlockMultipleOf: 256,
|
||||
BlockFFNDimMultiplier: 1.0,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
}
|
||||
|
||||
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||
|
||||
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
417
convert/convert_lfm2_vl.go
Normal file
417
convert/convert_lfm2_vl.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
|
||||
type lfm2VLTextModel struct {
|
||||
TextConfig lfm2Model `json:"text_config"`
|
||||
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
EncoderPatchSize uint32 `json:"encoder_patch_size"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||
MaxTiles uint32 `json:"max_tiles"`
|
||||
MinTiles uint32 `json:"min_tiles"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
|
||||
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
|
||||
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
Processor struct {
|
||||
ImageProcessor struct {
|
||||
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||
MaxTiles uint32 `json:"max_tiles"`
|
||||
MinTiles uint32 `json:"min_tiles"`
|
||||
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
UseThumbnail *bool `json:"use_thumbnail"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
Size struct {
|
||||
Height uint32 `json:"height"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"size"`
|
||||
} `json:"image_processor"`
|
||||
}
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) textModel() *lfm2Model {
|
||||
return &p.TextConfig
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) specialTokenTypes() []string {
|
||||
return p.textModel().specialTokenTypes()
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &p.Processor)
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) visionImageSize() uint32 {
|
||||
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
|
||||
// before projection. Keep a fixed square image size compatible with position
|
||||
// embeddings and the simplified runtime image pipeline.
|
||||
tile := cmp.Or(
|
||||
p.Processor.ImageProcessor.TileSize,
|
||||
p.Processor.ImageProcessor.Size.Height,
|
||||
p.Processor.ImageProcessor.Size.Width,
|
||||
uint32(512),
|
||||
)
|
||||
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
if downsample == 0 {
|
||||
return tile
|
||||
}
|
||||
|
||||
return max(uint32(1), tile/downsample)
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
|
||||
kv := p.textModel().KV(t)
|
||||
|
||||
boolOr := func(defaultValue bool, values ...*bool) bool {
|
||||
for _, v := range values {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
|
||||
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
|
||||
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
|
||||
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
|
||||
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
|
||||
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
|
||||
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
|
||||
kv["vision.image_size"] = p.visionImageSize()
|
||||
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
|
||||
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
|
||||
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
|
||||
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
|
||||
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
|
||||
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
|
||||
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
|
||||
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
|
||||
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
|
||||
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
|
||||
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
|
||||
|
||||
setVisionTokenID := func(k, token string) {
|
||||
if t == nil || t.Vocabulary == nil {
|
||||
return
|
||||
}
|
||||
for i, v := range t.Vocabulary.Tokens {
|
||||
if v == token {
|
||||
kv[k] = uint32(i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
|
||||
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
|
||||
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
|
||||
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
|
||||
|
||||
for _, t := range ts {
|
||||
if t.Name() == "v.patch_embd.weight" {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 2 {
|
||||
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||
if shape[1] == inputDim {
|
||||
channels := numChannels
|
||||
patch := patchSize
|
||||
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := p.textModel().Tensors(ts)
|
||||
for _, t := range out {
|
||||
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
|
||||
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2VLTextModel) Replacements() []string {
|
||||
out := make([]string, 0, 96)
|
||||
|
||||
addText := func(from, to string) {
|
||||
out = append(out, from, to)
|
||||
if strings.HasPrefix(from, "model.") {
|
||||
suffix := strings.TrimPrefix(from, "model.")
|
||||
out = append(out,
|
||||
"model.language_model."+suffix, to,
|
||||
"model.language_model.model."+suffix, to,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
base := p.textModel().Replacements()
|
||||
for i := 0; i+1 < len(base); i += 2 {
|
||||
addText(base[i], base[i+1])
|
||||
}
|
||||
|
||||
// Vision tower + multimodal projector tensors (single-file conversion).
|
||||
out = append(out,
|
||||
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
|
||||
"model.multi_modal_projector.linear_1", "mm.1",
|
||||
"model.multi_modal_projector.linear_2", "mm.2",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_out",
|
||||
"layer_norm1", "ln1",
|
||||
"layer_norm2", "ln2",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
|
||||
type lfm2VLProjectorModel struct {
|
||||
ModelParameters
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
} `json:"vision_config"`
|
||||
Processor struct {
|
||||
ImageProcessor struct {
|
||||
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||
TileSize uint32 `json:"tile_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
Size struct {
|
||||
Height uint32 `json:"height"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"size"`
|
||||
} `json:"image_processor"`
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ ModelConverter = (*lfm2VLTextModel)(nil)
|
||||
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
|
||||
_ moreParser = (*lfm2VLTextModel)(nil)
|
||||
_ moreParser = (*lfm2VLProjectorModel)(nil)
|
||||
)
|
||||
|
||||
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &p.Processor)
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) imageSize() uint32 {
|
||||
if p.VisionModel.ImageSize > 0 {
|
||||
return p.VisionModel.ImageSize
|
||||
}
|
||||
|
||||
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
baseSize := cmp.Or(
|
||||
p.Processor.ImageProcessor.TileSize,
|
||||
p.Processor.ImageProcessor.Size.Height,
|
||||
p.Processor.ImageProcessor.Size.Width,
|
||||
uint32(256),
|
||||
)
|
||||
if downsample == 0 {
|
||||
return baseSize
|
||||
}
|
||||
|
||||
return max(uint32(1), baseSize/downsample)
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
|
||||
kv := KV{
|
||||
"general.architecture": "clip",
|
||||
"general.type": "mmproj",
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"clip.has_vision_encoder": true,
|
||||
"clip.projector_type": "lfm2",
|
||||
"clip.use_gelu": true,
|
||||
}
|
||||
|
||||
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
|
||||
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
|
||||
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
|
||||
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
|
||||
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
|
||||
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||
kv["clip.vision.image_size"] = p.imageSize()
|
||||
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
|
||||
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func defaultFloat32Slice(v, fallback []float32) []float32 {
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
|
||||
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
|
||||
continue
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
if name == "v.patch_embd.weight" && len(shape) == 2 {
|
||||
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||
if shape[1] == inputDim {
|
||||
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||
channels := int(numChannels)
|
||||
patch := int(patchSize)
|
||||
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2VLProjectorModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.multi_modal_projector.linear_1", "mm.1",
|
||||
"model.multi_modal_projector.linear_2", "mm.2",
|
||||
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_out",
|
||||
"layer_norm1", "ln1",
|
||||
"layer_norm2", "ln2",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||
}
|
||||
}
|
||||
|
||||
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
|
||||
if len(srcShape) != 2 {
|
||||
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
|
||||
}
|
||||
|
||||
outDim := int(srcShape[0])
|
||||
flatInputDim := int(srcShape[1])
|
||||
expectedInputDim := channels * patch * patch
|
||||
if flatInputDim != expectedInputDim {
|
||||
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
|
||||
}
|
||||
|
||||
expectedSize := outDim * flatInputDim
|
||||
if len(data) != expectedSize {
|
||||
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
|
||||
}
|
||||
|
||||
repacked := make([]float32, len(data))
|
||||
perChannel := patch * patch
|
||||
|
||||
for o := range outDim {
|
||||
inBase := o * flatInputDim
|
||||
outBase := o * flatInputDim
|
||||
|
||||
for y := range patch {
|
||||
for x := range patch {
|
||||
inPixelBase := inBase + (y*patch+x)*channels
|
||||
for c := range channels {
|
||||
src := inPixelBase + c
|
||||
dst := outBase + c*perChannel + y*patch + x
|
||||
repacked[dst] = data[src]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return repacked, nil
|
||||
}
|
||||
249
convert/convert_lfm2_vl_test.go
Normal file
249
convert/convert_lfm2_vl_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
|
||||
p := lfm2VLTextModel{
|
||||
TextConfig: lfm2Model{
|
||||
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||
HiddenSize: 2048,
|
||||
NumHiddenLayers: 16,
|
||||
MaxPositionEmbeddings: 128000,
|
||||
IntermediateSize: 12288,
|
||||
BlockFFDim: 12288,
|
||||
BlockAutoAdjustFFDim: true,
|
||||
BlockMultipleOf: 256,
|
||||
BlockFFNDimMultiplier: 1.0,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
LayerTypes: []string{"conv", "full_attention"},
|
||||
NormEps: 1e-5,
|
||||
ConvLCache: 3,
|
||||
},
|
||||
DownsampleFactor: 2,
|
||||
VisionConfig: struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||
}{
|
||||
HiddenSize: 1152,
|
||||
IntermediateSize: 4304,
|
||||
NumAttentionHeads: 16,
|
||||
NumHiddenLayers: 27,
|
||||
NumChannels: 3,
|
||||
PatchSize: 16,
|
||||
LayerNormEpsilon: 1e-6,
|
||||
},
|
||||
}
|
||||
p.Processor.ImageProcessor.TileSize = 512
|
||||
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
|
||||
kv := p.KV(&Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
|
||||
},
|
||||
})
|
||||
|
||||
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.block_count"], uint32(27); got != want {
|
||||
t.Fatalf("vision.block_count = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_size"], uint32(256); got != want {
|
||||
t.Fatalf("vision.image_size = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
|
||||
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
|
||||
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if got, want := kv["vision.do_image_splitting"], true; got != want {
|
||||
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
|
||||
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
|
||||
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.tile_size"], uint32(512); got != want {
|
||||
t.Fatalf("vision.tile_size = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.use_thumbnail"], true; got != want {
|
||||
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
|
||||
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
|
||||
p := lfm2VLTextModel{}
|
||||
p.VisionConfig.PatchSize = 16
|
||||
p.VisionConfig.NumChannels = 3
|
||||
input := []Tensor{
|
||||
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
|
||||
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
|
||||
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||
}
|
||||
|
||||
out := p.Tensors(input)
|
||||
if len(out) == 0 {
|
||||
t.Fatal("expected non-empty tensor list")
|
||||
}
|
||||
|
||||
foundPatch := false
|
||||
foundVision := false
|
||||
for _, tns := range out {
|
||||
if tns.Name == "v.patch_embd.weight" {
|
||||
foundPatch = true
|
||||
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
|
||||
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
|
||||
foundVision = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundPatch {
|
||||
t.Fatal("expected v.patch_embd.weight in output tensors")
|
||||
}
|
||||
if !foundVision {
|
||||
t.Fatal("expected at least one vision/projector tensor in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLTextModelReplacements(t *testing.T) {
|
||||
p := lfm2VLTextModel{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "language_model_embed_tokens",
|
||||
in: "model.language_model.embed_tokens.weight",
|
||||
want: "token_embd.weight",
|
||||
},
|
||||
{
|
||||
name: "language_model_layers",
|
||||
in: "model.language_model.layers.2.self_attn.q_proj.weight",
|
||||
want: "blk.2.attn_q.weight",
|
||||
},
|
||||
{
|
||||
name: "nested_language_model_prefix",
|
||||
in: "model.language_model.model.embedding_norm.weight",
|
||||
want: "token_embd_norm.weight",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := r.Replace(tt.in); got != tt.want {
|
||||
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLProjectorKV(t *testing.T) {
|
||||
p := lfm2VLProjectorModel{
|
||||
DownsampleFactor: 2,
|
||||
ProjectorHiddenDim: 2048,
|
||||
}
|
||||
p.VisionModel.NumHiddenLayers = 27
|
||||
p.VisionModel.HiddenSize = 1152
|
||||
p.VisionModel.IntermediateSize = 4304
|
||||
p.VisionModel.NumAttentionHeads = 16
|
||||
p.VisionModel.PatchSize = 16
|
||||
p.VisionModel.LayerNormEpsilon = 1e-6
|
||||
p.Processor.ImageProcessor.TileSize = 512
|
||||
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
|
||||
kv := p.KV(nil)
|
||||
|
||||
if got, want := kv["general.architecture"], "clip"; got != want {
|
||||
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
|
||||
t.Fatalf("clip.projector_type = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
|
||||
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
|
||||
p := lfm2VLProjectorModel{}
|
||||
p.VisionModel.NumChannels = 3
|
||||
p.VisionModel.PatchSize = 16
|
||||
|
||||
input := []Tensor{
|
||||
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||
}
|
||||
|
||||
out := p.Tensors(input)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 tensors, got %d", len(out))
|
||||
}
|
||||
|
||||
var patchShape []uint64
|
||||
for _, tns := range out {
|
||||
if tns.Name == "v.patch_embd.weight" {
|
||||
patchShape = tns.Shape
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
|
||||
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepackPatchEmbeddingWeight(t *testing.T) {
|
||||
data := []float32{
|
||||
0, 1, // y=0,x=0
|
||||
2, 3, // y=0,x=1
|
||||
4, 5, // y=1,x=0
|
||||
6, 7, // y=1,x=1
|
||||
}
|
||||
|
||||
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("repacked data = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
385
convert/convert_nemotron_h.go
Normal file
385
convert/convert_nemotron_h.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type hybridPattern string
|
||||
|
||||
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
*p = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
*p = hybridPattern(strings.TrimSpace(single))
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if err := json.Unmarshal(data, &parts); err == nil {
|
||||
*p = hybridPattern(strings.Join(parts, ""))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
|
||||
}
|
||||
|
||||
type nemotronHModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
ConvKernel uint32 `json:"conv_kernel"`
|
||||
SSMStateSize uint32 `json:"ssm_state_size"`
|
||||
MambaNumHeads uint32 `json:"mamba_num_heads"`
|
||||
MambaHeadDim uint32 `json:"mamba_head_dim"`
|
||||
NGroups uint32 `json:"n_groups"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||
|
||||
// MoE
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||
NRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
ExpertGroupCount uint32 `json:"n_group"`
|
||||
ExpertGroupUsedCount uint32 `json:"topk_group"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*nemotronHModel)(nil)
|
||||
|
||||
func (n *nemotronHModel) parseMore(_ fs.FS) error {
|
||||
if n.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
|
||||
}
|
||||
if n.HiddenSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size must be set")
|
||||
}
|
||||
if n.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
|
||||
}
|
||||
if n.HeadDim == 0 {
|
||||
if n.HiddenSize%n.NumAttentionHeads != 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
|
||||
}
|
||||
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
|
||||
}
|
||||
if n.NumKeyValueHeads == 0 {
|
||||
n.NumKeyValueHeads = n.NumAttentionHeads
|
||||
}
|
||||
if n.ConvKernel == 0 {
|
||||
return fmt.Errorf("nemotron_h: conv_kernel must be set")
|
||||
}
|
||||
if n.SSMStateSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
|
||||
}
|
||||
if n.ssmHeadCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
|
||||
}
|
||||
if n.MambaHeadDim == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
|
||||
}
|
||||
if n.NGroups == 0 {
|
||||
n.NGroups = 1
|
||||
}
|
||||
|
||||
if _, _, err := n.layerArrays(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n.isMoE() {
|
||||
if n.routedExpertCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok > n.routedExpertCount() {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
|
||||
}
|
||||
if n.moeIntermediateSize() == 0 {
|
||||
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) isMoE() bool {
|
||||
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) routedExpertCount() uint32 {
|
||||
return cmp.Or(n.NRoutedExperts, n.NumExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) sharedExpertCount() uint32 {
|
||||
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmHeadCount() uint32 {
|
||||
return n.MambaNumHeads
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmInnerSize() uint32 {
|
||||
return n.MambaHeadDim * n.ssmHeadCount()
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) epsilon() float32 {
|
||||
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) moeIntermediateSize() uint32 {
|
||||
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
||||
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||
if pattern == "" {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
||||
}
|
||||
|
||||
runes := []rune(pattern)
|
||||
if len(runes) != int(n.NumHiddenLayers) {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
|
||||
}
|
||||
|
||||
headCountKV = make([]uint32, n.NumHiddenLayers)
|
||||
ffnLengths = make([]uint32, n.NumHiddenLayers)
|
||||
|
||||
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
|
||||
moeFFN := n.moeIntermediateSize()
|
||||
denseFFN := n.denseIntermediateSize()
|
||||
|
||||
for i, layerType := range runes {
|
||||
switch layerType {
|
||||
case 'M':
|
||||
// Recurrent layer: no KV heads and no FFN.
|
||||
case '*', 'A':
|
||||
// Attention-only layer.
|
||||
headCountKV[i] = attnKVHeads
|
||||
case 'E':
|
||||
// MoE layer.
|
||||
if moeFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = moeFFN
|
||||
case '-':
|
||||
// Dense FFN layer.
|
||||
if denseFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = denseFFN
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
|
||||
}
|
||||
}
|
||||
|
||||
return headCountKV, ffnLengths, nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) KV(t *Tokenizer) KV {
|
||||
kv := n.ModelParameters.KV(t)
|
||||
|
||||
arch := "nemotron_h"
|
||||
if n.isMoE() {
|
||||
arch = "nemotron_h_moe"
|
||||
}
|
||||
kv["general.architecture"] = arch
|
||||
kv["block_count"] = n.NumHiddenLayers
|
||||
kv["context_length"] = n.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = n.HiddenSize
|
||||
kv["attention.head_count"] = n.NumAttentionHeads
|
||||
kv["attention.key_length"] = n.HeadDim
|
||||
kv["attention.value_length"] = n.HeadDim
|
||||
kv["attention.layer_norm_epsilon"] = n.epsilon()
|
||||
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
|
||||
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
|
||||
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
|
||||
}
|
||||
|
||||
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
|
||||
kv["attention.head_count_kv"] = headCountKV
|
||||
kv["feed_forward_length"] = ffnLengths
|
||||
}
|
||||
|
||||
kv["ssm.conv_kernel"] = n.ConvKernel
|
||||
kv["ssm.inner_size"] = n.ssmInnerSize()
|
||||
kv["ssm.state_size"] = n.SSMStateSize
|
||||
kv["ssm.group_count"] = n.NGroups
|
||||
kv["ssm.time_step_rank"] = n.ssmHeadCount()
|
||||
|
||||
if n.isMoE() {
|
||||
kv["expert_count"] = n.routedExpertCount()
|
||||
kv["expert_used_count"] = n.NumExpertsPerTok
|
||||
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
|
||||
if n.sharedExpertCount() > 0 {
|
||||
kv["expert_shared_count"] = n.sharedExpertCount()
|
||||
}
|
||||
if n.MoESharedExpertIntermediate > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
|
||||
}
|
||||
kv["expert_weights_norm"] = n.NormTopKProb
|
||||
kv["expert_weights_scale"] = n.RoutedScalingFactor
|
||||
if n.ExpertGroupCount > 0 {
|
||||
kv["expert_group_count"] = n.ExpertGroupCount
|
||||
}
|
||||
if n.ExpertGroupUsedCount > 0 {
|
||||
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return []uint64{shape[0], 1}
|
||||
case 2:
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
return []uint64{shape[1], 1}
|
||||
}
|
||||
if shape[1] == 1 && shape[0] > 1 {
|
||||
return []uint64{shape[0], 1}
|
||||
}
|
||||
}
|
||||
|
||||
return slices.Clone(shape)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
remaining := ts
|
||||
if n.isMoE() {
|
||||
merges := make([]merge, 0, n.NumHiddenLayers*2)
|
||||
for i := range n.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
})
|
||||
}
|
||||
|
||||
merged, rest := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
remaining = rest
|
||||
}
|
||||
|
||||
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := slices.Clone(t.Shape())
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
out := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
out[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
case strings.HasSuffix(name, ".ssm_d"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
case strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
if nGroups > 0 && shape[0]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[0] / nGroups}
|
||||
}
|
||||
case 2:
|
||||
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[1] / nGroups}
|
||||
}
|
||||
}
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
shape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embedding and output
|
||||
"lm_head", "output",
|
||||
"backbone.embeddings", "token_embd",
|
||||
"backbone.norm_f", "output_norm",
|
||||
"backbone.layers", "blk",
|
||||
|
||||
// Recurrent (Mamba2) tensors
|
||||
"mixer.in_proj", "ssm_in",
|
||||
"mixer.out_proj", "ssm_out",
|
||||
"mixer.dt_bias", "ssm_dt.bias",
|
||||
"mixer.A_log", "ssm_a",
|
||||
"mixer.D", "ssm_d",
|
||||
"mixer.conv1d", "ssm_conv1d",
|
||||
"mixer.norm.weight", "ssm_norm.weight",
|
||||
|
||||
// Attention tensors
|
||||
"mixer.q_proj", "attn_q",
|
||||
"mixer.k_proj", "attn_k",
|
||||
"mixer.v_proj", "attn_v",
|
||||
"mixer.o_proj", "attn_output",
|
||||
|
||||
// FFN / MoE tensors
|
||||
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mixer.gate", "ffn_gate_inp",
|
||||
"mixer.fc1_latent_proj", "ffn_latent_in",
|
||||
"mixer.fc2_latent_proj", "ffn_latent_out",
|
||||
"mixer.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mixer.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mixer.up_proj", "ffn_up",
|
||||
"mixer.down_proj", "ffn_down",
|
||||
|
||||
// Per-layer pre-norm
|
||||
".norm.weight", ".attn_norm.weight",
|
||||
}
|
||||
}
|
||||
230
convert/convert_nemotron_h_test.go
Normal file
230
convert/convert_nemotron_h_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHybridPatternUnmarshal(t *testing.T) {
|
||||
t.Run("string", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNemotronHLayerArrays(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
}
|
||||
|
||||
headsKV, ffn, err := m.layerArrays()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHKV(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
MaxPositionEmbeddings: 1048576,
|
||||
HiddenSize: 2688,
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 128,
|
||||
LayerNormEpsilon: 1e-5,
|
||||
RopeTheta: 10000,
|
||||
PartialRotaryFactor: 0.5,
|
||||
ConvKernel: 4,
|
||||
SSMStateSize: 128,
|
||||
MambaNumHeads: 64,
|
||||
MambaHeadDim: 64,
|
||||
NGroups: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NSharedExperts: 1,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
MoESharedExpertIntermediate: 3712,
|
||||
NormTopKProb: true,
|
||||
RoutedScalingFactor: 2.5,
|
||||
}
|
||||
if err := m.parseMore(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
ffnLength, ok := kv["feed_forward_length"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
|
||||
}
|
||||
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHTensorsTransforms(t *testing.T) {
|
||||
m := &nemotronHModel{NGroups: 8}
|
||||
in := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_a",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_d",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_norm.weight",
|
||||
shape: []uint64{16},
|
||||
data: make([]float32, 16),
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_conv1d.weight",
|
||||
shape: []uint64{10, 1, 4},
|
||||
data: make([]float32, 40),
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors(in)
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
|
||||
}
|
||||
|
||||
got := map[string]struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{}
|
||||
for _, t := range out {
|
||||
got[t.Name] = struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{shape: t.Shape, writer: t.WriterTo}
|
||||
}
|
||||
|
||||
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_a shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_d shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
|
||||
t.Fatalf("unexpected ssm_norm shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
|
||||
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
values := make([]float32, 4)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 0 -> -exp(0) == -1
|
||||
if values[0] != -1 {
|
||||
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHLoadModelMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := `{
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"model_type": "nemotron_h",
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"max_position_embeddings": 32768,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 64,
|
||||
"layer_norm_epsilon": 1e-5,
|
||||
"conv_kernel": 4,
|
||||
"ssm_state_size": 128,
|
||||
"mamba_num_heads": 16,
|
||||
"mamba_head_dim": 32,
|
||||
"n_groups": 8,
|
||||
"hybrid_override_pattern": "ME*M",
|
||||
"n_routed_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 256
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := kv.(*nemotronHModel); !ok {
|
||||
t.Fatalf("unexpected converter type: %T", kv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
|
||||
m := &nemotronHModel{}
|
||||
r := strings.NewReplacer(m.Replacements()...)
|
||||
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
|
||||
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
|
||||
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
97
convert/json_compat.go
Normal file
97
convert/json_compat.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package convert
|
||||
|
||||
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
|
||||
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
|
||||
//
|
||||
// This is intentionally conservative:
|
||||
// - only runs outside quoted strings
|
||||
// - only rewrites full tokens
|
||||
//
|
||||
// We map these values to 0 because encoding/json rejects non-finite values,
|
||||
// and these fields are typically model-side metadata not consumed by the
|
||||
// converter.
|
||||
func sanitizeNonFiniteJSON(in []byte) []byte {
|
||||
if len(in) == 0 {
|
||||
return in
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(in))
|
||||
inString := false
|
||||
escape := false
|
||||
|
||||
for i := 0; i < len(in); {
|
||||
c := in[i]
|
||||
|
||||
if inString {
|
||||
out = append(out, c)
|
||||
if escape {
|
||||
escape = false
|
||||
} else if c == '\\' {
|
||||
escape = true
|
||||
} else if c == '"' {
|
||||
inString = false
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = true
|
||||
out = append(out, c)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "-Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("-Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "NaN") {
|
||||
out = append(out, '0')
|
||||
i += len("NaN")
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, c)
|
||||
i++
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func hasToken(in []byte, at int, tok string) bool {
|
||||
end := at + len(tok)
|
||||
if at < 0 || end > len(in) {
|
||||
return false
|
||||
}
|
||||
if string(in[at:end]) != tok {
|
||||
return false
|
||||
}
|
||||
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
|
||||
return false
|
||||
}
|
||||
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isJSONWhitespace(b byte) bool {
|
||||
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
|
||||
}
|
||||
|
||||
func isJSONValuePrefixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
|
||||
}
|
||||
|
||||
func isJSONValueSuffixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
|
||||
}
|
||||
46
convert/json_compat_test.go
Normal file
46
convert/json_compat_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package convert
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeNonFiniteJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "infinity token",
|
||||
in: `{"a":[0,Infinity,1]}`,
|
||||
want: `{"a":[0,0,1]}`,
|
||||
},
|
||||
{
|
||||
name: "negative infinity token",
|
||||
in: `{"a":-Infinity}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "nan token",
|
||||
in: `{"a":NaN}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "tokens inside strings untouched",
|
||||
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
|
||||
want: `{"a":"Infinity -Infinity NaN","b":0}`,
|
||||
},
|
||||
{
|
||||
name: "identifier-like token untouched",
|
||||
in: `{"a":InfinityValue}`,
|
||||
want: `{"a":InfinityValue}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
|
||||
if got != tt.want {
|
||||
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -212,8 +212,13 @@ type tokenizer struct {
|
||||
|
||||
PreTokenizer struct {
|
||||
PreTokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Type string `json:"type"`
|
||||
Behavior string `json:"behavior"`
|
||||
Invert bool `json:"invert"`
|
||||
AddPrefixSpace bool `json:"add_prefix_space"`
|
||||
TrimOffsets bool `json:"trim_offsets"`
|
||||
UseRegex bool `json:"use_regex"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
|
||||
@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "llama-bpe pretokenizer and control tokens",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"added_tokens": [
|
||||
{"id": 1, "content": "<|startoftext|>", "special": true},
|
||||
{"id": 6, "content": "<|im_start|>", "special": true},
|
||||
{"id": 7, "content": "<|im_end|>", "special": true},
|
||||
{"id": 8, "content": "<|tool_list_start|>", "special": true},
|
||||
{"id": 9, "content": "<|tool_list_end|>", "special": true},
|
||||
{"id": 10, "content": "<|tool_call_start|>", "special": true},
|
||||
{"id": 11, "content": "<|tool_call_end|>", "special": true},
|
||||
{"id": 12, "content": "<|tool_response_start|>", "special": true},
|
||||
{"id": 13, "content": "<|tool_response_end|>", "special": true},
|
||||
{"id": 396, "content": "<image>", "special": true},
|
||||
{"id": 64400, "content": "<think>", "special": true},
|
||||
{"id": 64401, "content": "</think>", "special": true}
|
||||
],
|
||||
"model": {
|
||||
"vocab": {
|
||||
"<|startoftext|>": 1,
|
||||
"<|im_start|>": 6,
|
||||
"<|im_end|>": 7,
|
||||
"<|tool_list_start|>": 8,
|
||||
"<|tool_list_end|>": 9,
|
||||
"<|tool_call_start|>": 10,
|
||||
"<|tool_call_end|>": 11,
|
||||
"<|tool_response_start|>": 12,
|
||||
"<|tool_response_end|>": 13,
|
||||
"<image>": 396,
|
||||
"<think>": 64400,
|
||||
"</think>": 64401
|
||||
}
|
||||
},
|
||||
"pre_tokenizer": {
|
||||
"type": "Sequence",
|
||||
"pretokenizers": [
|
||||
{
|
||||
"type": "Split",
|
||||
"pattern": {
|
||||
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
},
|
||||
"behavior": "Isolated",
|
||||
"invert": false
|
||||
},
|
||||
{
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": false,
|
||||
"trim_offsets": true,
|
||||
"use_regex": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{
|
||||
"<|startoftext|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|tool_list_start|>",
|
||||
"<|tool_list_end|>",
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<|tool_response_start|>",
|
||||
"<|tool_response_end|>",
|
||||
"<image>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
},
|
||||
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
|
||||
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
|
||||
},
|
||||
Pre: "llama-bpe",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
|
||||
@@ -226,3 +226,7 @@ curl https://ollama.com/api/chat \
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Local only
|
||||
|
||||
Ollama can run in local-only mode by [disabling Ollama's cloud](./faq#how-do-i-disable-ollama-cloud) features.
|
||||
@@ -106,20 +106,23 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/index",
|
||||
{
|
||||
"group": "Assistants",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Coding",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "Assistants",
|
||||
"pages": [
|
||||
"/integrations/openclaw"
|
||||
"/integrations/goose",
|
||||
"/integrations/pi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
20
docs/faq.mdx
20
docs/faq.mdx
@@ -160,6 +160,26 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
|
||||
|
||||
Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
|
||||
|
||||
## How do I disable Ollama's cloud features?
|
||||
|
||||
Ollama can run in local only mode by disabling Ollama's cloud features. By turning off Ollama's cloud features, you will lose the ability to use Ollama's cloud models and web search.
|
||||
|
||||
Set `disable_ollama_cloud` in `~/.ollama/server.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"disable_ollama_cloud": true
|
||||
}
|
||||
```
|
||||
|
||||
You can also set the environment variable:
|
||||
|
||||
```shell
|
||||
OLLAMA_NO_CLOUD=1
|
||||
```
|
||||
|
||||
Restart Ollama after changing configuration. Once disabled, Ollama's logs will show `Ollama cloud disabled: true`.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||
|
||||
@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
- [Pi](/integrations/pi)
|
||||
|
||||
## Assistants
|
||||
|
||||
|
||||
@@ -4,47 +4,65 @@ title: OpenClaw
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
Ollama handles everything automatically:
|
||||
|
||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||
3. **Model** — Pick a model from the selector (local or cloud)
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
## Configure without launching
|
||||
|
||||
To change the model without starting the gateway and TUI:
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
To use a specific model directly:
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
```bash
|
||||
ollama launch openclaw --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
If the gateway is already running, it restarts automatically to pick up the new model.
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `glm-5:cloud` — Reasoning and code generation
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
```bash
|
||||
openclaw configure --section channels
|
||||
```
|
||||
|
||||
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
|
||||
|
||||
## Stopping the gateway
|
||||
|
||||
```bash
|
||||
openclaw gateway stop
|
||||
```
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
57
docs/integrations/pi.mdx
Normal file
57
docs/integrations/pi.mdx
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Pi
|
||||
---
|
||||
|
||||
Pi is a minimal AI agent toolkit with plugin support.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Pi](https://github.com/badlogic/pi-mono):
|
||||
|
||||
```bash
|
||||
npm install -g @mariozechner/pi-coding-agent
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch pi
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch pi --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.pi/agent/models.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{
|
||||
"id": "qwen3-coder"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Update `~/.pi/agent/settings.json` to set the default provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultProvider": "ollama",
|
||||
"defaultModel": "qwen3-coder"
|
||||
}
|
||||
```
|
||||
@@ -2,7 +2,7 @@
|
||||
title: Quickstart
|
||||
---
|
||||
|
||||
This quickstart will walk your through running your first model with Ollama. To get started, download Ollama on macOS, Windows or Linux.
|
||||
Ollama is available on macOS, Windows, and Linux.
|
||||
|
||||
<a
|
||||
href="https://ollama.com/download"
|
||||
@@ -12,131 +12,56 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
Download Ollama
|
||||
</a>
|
||||
|
||||
## Run a model
|
||||
## Get Started
|
||||
|
||||
<Tabs>
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello there!"
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```python
|
||||
from ollama import chat
|
||||
from ollama import ChatResponse
|
||||
|
||||
response: ChatResponse = chat(model='gemma3', messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
])
|
||||
print(response['message']['content'])
|
||||
# or access fields directly from the response object
|
||||
print(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="JavaScript">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install the Ollama JavaScript library:
|
||||
```
|
||||
npm i ollama
|
||||
```
|
||||
|
||||
Lastly, chat with the model:
|
||||
|
||||
```shell
|
||||
import ollama from 'ollama'
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'gemma3',
|
||||
messages: [{ role: 'user', content: 'Why is the sky blue?' }],
|
||||
})
|
||||
console.log(response.message.content)
|
||||
```
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
Run `ollama` in your terminal to open the interactive menu:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
ollama
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
Navigate with `↑/↓`, press `enter` to launch, `→` to change model, and `esc` to quit.
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
The menu provides quick access to:
|
||||
- **Run a model** - Start an interactive chat
|
||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||
- **Additional integrations** - Available under "More..."
|
||||
|
||||
### Launch with a specific model
|
||||
## Assistants
|
||||
|
||||
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
## Coding
|
||||
|
||||
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
```sh
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
```sh
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
See [integrations](/integrations) for all supported tools.
|
||||
|
||||
## API
|
||||
|
||||
Use the [API](/api) to integrate Ollama into your applications:
|
||||
|
||||
```sh
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "gemma3",
|
||||
"messages": [{ "role": "user", "content": "Hello!" }]
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](/api) for Python, JavaScript, and other integrations.
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -206,6 +209,8 @@ var (
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
// NoCloudEnv checks the OLLAMA_NO_CLOUD environment variable.
|
||||
NoCloudEnv = Bool("OLLAMA_NO_CLOUD")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -285,6 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
@@ -334,3 +340,91 @@ func Values() map[string]string {
|
||||
func Var(key string) string {
|
||||
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
|
||||
}
|
||||
|
||||
// serverConfigData holds the parsed fields from ~/.ollama/server.json.
|
||||
type serverConfigData struct {
|
||||
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
serverCfgMu sync.RWMutex
|
||||
serverCfgLoaded bool
|
||||
serverCfg serverConfigData
|
||||
)
|
||||
|
||||
func loadServerConfig() {
|
||||
serverCfgMu.RLock()
|
||||
if serverCfgLoaded {
|
||||
serverCfgMu.RUnlock()
|
||||
return
|
||||
}
|
||||
serverCfgMu.RUnlock()
|
||||
|
||||
cfg := serverConfigData{}
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
path := filepath.Join(home, ".ollama", "server.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("envconfig: could not read server config", "error", err)
|
||||
}
|
||||
} else if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
slog.Debug("envconfig: could not parse server config", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
serverCfgMu.Lock()
|
||||
defer serverCfgMu.Unlock()
|
||||
if serverCfgLoaded {
|
||||
return
|
||||
}
|
||||
serverCfg = cfg
|
||||
serverCfgLoaded = true
|
||||
}
|
||||
|
||||
func cachedServerConfig() serverConfigData {
|
||||
serverCfgMu.RLock()
|
||||
defer serverCfgMu.RUnlock()
|
||||
return serverCfg
|
||||
}
|
||||
|
||||
// ReloadServerConfig refreshes the cached ~/.ollama/server.json settings.
|
||||
func ReloadServerConfig() {
|
||||
serverCfgMu.Lock()
|
||||
serverCfgLoaded = false
|
||||
serverCfg = serverConfigData{}
|
||||
serverCfgMu.Unlock()
|
||||
|
||||
loadServerConfig()
|
||||
}
|
||||
|
||||
// NoCloud returns true if Ollama cloud features are disabled,
|
||||
// checking both the OLLAMA_NO_CLOUD environment variable and
|
||||
// the disable_ollama_cloud field in ~/.ollama/server.json.
|
||||
func NoCloud() bool {
|
||||
if NoCloudEnv() {
|
||||
return true
|
||||
}
|
||||
loadServerConfig()
|
||||
return cachedServerConfig().DisableOllamaCloud
|
||||
}
|
||||
|
||||
// NoCloudSource returns the source of the cloud-disabled decision.
|
||||
// Returns "none", "env", "config", or "both".
|
||||
func NoCloudSource() string {
|
||||
envDisabled := NoCloudEnv()
|
||||
loadServerConfig()
|
||||
configDisabled := cachedServerConfig().DisableOllamaCloud
|
||||
|
||||
switch {
|
||||
case envDisabled && configDisabled:
|
||||
return "both"
|
||||
case envDisabled:
|
||||
return "env"
|
||||
case configDisabled:
|
||||
return "config"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package envconfig
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -326,3 +328,81 @@ func TestLogLevel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoCloud(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
wantDisabled bool
|
||||
wantSource string
|
||||
}{
|
||||
{
|
||||
name: "neither env nor config",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
envValue: "1",
|
||||
wantDisabled: true,
|
||||
wantSource: "env",
|
||||
},
|
||||
{
|
||||
name: "config only",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "config",
|
||||
},
|
||||
{
|
||||
name: "both env and config",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
wantDisabled: true,
|
||||
wantSource: "both",
|
||||
},
|
||||
{
|
||||
name: "config false",
|
||||
configContent: `{"disable_ollama_cloud": false}`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "invalid config ignored",
|
||||
configContent: `{invalid json`,
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
{
|
||||
name: "no config file",
|
||||
wantDisabled: false,
|
||||
wantSource: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
if tt.configContent != "" {
|
||||
configDir := filepath.Join(home, ".ollama")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "server.json"), []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
setTestHome(t, home)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
if got := NoCloud(); got != tt.wantDisabled {
|
||||
t.Errorf("NoCloud() = %v, want %v", got, tt.wantDisabled)
|
||||
}
|
||||
|
||||
if got := NoCloudSource(); got != tt.wantSource {
|
||||
t.Errorf("NoCloudSource() = %q, want %q", got, tt.wantSource)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
10
envconfig/test_home_test.go
Normal file
10
envconfig/test_home_test.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package envconfig
|
||||
|
||||
import "testing"
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
ReloadServerConfig()
|
||||
}
|
||||
@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
func (kv KV) FFNLength() []uint64 {
|
||||
ffnLengthDefault := uint32(0)
|
||||
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
|
||||
if len(ffnLength) == 1 {
|
||||
ffnLengthDefault = ffnLength[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(ffnLength) > nLayers {
|
||||
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(ffnLength) {
|
||||
out[i] = uint64(ffnLengthDefault)
|
||||
} else {
|
||||
out[i] = uint64(ffnLength[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
@@ -273,6 +295,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
"lfm2moe",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -864,7 +887,9 @@ func (f GGML) FlashAttention() bool {
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"lfm2moe",
|
||||
"mistral3",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
|
||||
1
go.mod
1
go.mod
@@ -26,6 +26,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -122,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -150,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
25
internal/cloud/policy.go
Normal file
25
internal/cloud/policy.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const DisabledMessagePrefix = "ollama cloud is disabled"
|
||||
|
||||
// Status returns whether cloud is disabled and the source of the decision.
|
||||
// Source is one of: "none", "env", "config", "both".
|
||||
func Status() (disabled bool, source string) {
|
||||
return envconfig.NoCloud(), envconfig.NoCloudSource()
|
||||
}
|
||||
|
||||
func Disabled() bool {
|
||||
return envconfig.NoCloud()
|
||||
}
|
||||
|
||||
func DisabledError(operation string) string {
|
||||
if operation == "" {
|
||||
return DisabledMessagePrefix
|
||||
}
|
||||
|
||||
return DisabledMessagePrefix + ": " + operation
|
||||
}
|
||||
85
internal/cloud/policy_test.go
Normal file
85
internal/cloud/policy_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
disabled bool
|
||||
source string
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
envValue: "1",
|
||||
disabled: true,
|
||||
source: "env",
|
||||
},
|
||||
{
|
||||
name: "config only",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "config",
|
||||
},
|
||||
{
|
||||
name: "both",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "both",
|
||||
},
|
||||
{
|
||||
name: "invalid config ignored",
|
||||
configContent: `{invalid json`,
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
if tt.configContent != "" {
|
||||
configPath := filepath.Join(home, ".ollama", "server.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
setTestHome(t, home)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
disabled, source := Status()
|
||||
if disabled != tt.disabled {
|
||||
t.Fatalf("disabled: expected %v, got %v", tt.disabled, disabled)
|
||||
}
|
||||
if source != tt.source {
|
||||
t.Fatalf("source: expected %q, got %q", tt.source, source)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledError(t *testing.T) {
|
||||
if got := DisabledError(""); got != DisabledMessagePrefix {
|
||||
t.Fatalf("expected %q, got %q", DisabledMessagePrefix, got)
|
||||
}
|
||||
|
||||
want := DisabledMessagePrefix + ": remote inference is unavailable"
|
||||
if got := DisabledError("remote inference is unavailable"); got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
14
internal/cloud/test_home_test.go
Normal file
14
internal/cloud/test_home_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
752
kvcache/recurrent.go
Normal file
752
kvcache/recurrent.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
// Config configures a shared hybrid recurrent cache.
|
||||
type RecurrentConfig struct {
|
||||
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
ConvDim int
|
||||
ConvChannels int
|
||||
RecurrentStateSize int
|
||||
CheckpointLogPrefix string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Cache = (*Recurrent)(nil)
|
||||
_ CheckpointCache = (*Recurrent)(nil)
|
||||
)
|
||||
|
||||
// Cache stores:
|
||||
// - a standard causal KV cache
|
||||
// - per-sequence conv state for recurrent operators
|
||||
// - per-sequence recurrent state for recurrent operators
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int
|
||||
convChannels int
|
||||
|
||||
// Recurrent state dimensions
|
||||
recurrentStateSize int
|
||||
|
||||
logPrefix string
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
seqCounts map[int]int
|
||||
slotScratch [1]int32
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer recurrent state buffers (allocated lazily)
|
||||
recurrentCtxs map[int]ml.Context
|
||||
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointRecurCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
convDim: config.ConvDim,
|
||||
convChannels: config.ConvChannels,
|
||||
recurrentStateSize: config.RecurrentStateSize,
|
||||
logPrefix: config.CheckpointLogPrefix,
|
||||
slotForSeq: make(map[int]int),
|
||||
seqCounts: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
recurrentCtxs: make(map[int]ml.Context),
|
||||
recurrentStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: DefaultCheckpointCount,
|
||||
checkpointMinPos: DefaultCheckpointMinPos,
|
||||
checkpointInterval: DefaultCheckpointInterval,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointRecurCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.recurrentCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointRecurCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
if nTokens == 0 {
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
c.curSlots = c.curSlots[:0]
|
||||
c.curSlotsInput = nil
|
||||
c.curSeqTokens = 0
|
||||
c.reserveCheckpoints = false
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path for single-sequence batches (common during decode and prefill).
|
||||
firstSeq := batch.Sequences[0]
|
||||
singleSeq := true
|
||||
for _, s := range batch.Sequences[1:] {
|
||||
if s != firstSeq {
|
||||
singleSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if singleSeq {
|
||||
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers.
|
||||
seqCounts := c.seqCounts
|
||||
for s := range seqCounts {
|
||||
delete(seqCounts, s)
|
||||
}
|
||||
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if seqCounts[s] == 0 {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
}
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch.
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
||||
c.curSeqs = append(c.curSeqs[:0], seq)
|
||||
c.curSeqTokens = seqTokens
|
||||
|
||||
if reserve {
|
||||
c.curSlots = append(c.curSlots[:0], 0)
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.slotForSeq[seq] = slot
|
||||
c.refCount[slot] = 1
|
||||
slotList := [1]int{slot}
|
||||
c.zeroSlots(ctx, slotList[:])
|
||||
}
|
||||
|
||||
c.curSlots = append(c.curSlots[:0], slot)
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
||||
c.setCurSlotsInput(ctx)
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = reserve
|
||||
c.planCheckpoints(batch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
||||
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
||||
}
|
||||
|
||||
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
||||
switch len(slots) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
c.slotScratch[0] = int32(slots[0])
|
||||
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
||||
default:
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, v := range slots {
|
||||
slotIndices[i] = int32(v)
|
||||
}
|
||||
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
||||
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
slotsTensor := c.slotsInput(ctx, slots)
|
||||
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.recurrentStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
||||
for _, buf := range c.recurrentStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
c.setCurSlotsInput(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.recurrentStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.pendingRestore, seq)
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok || !c.validSlot(slot) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
||||
if c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
slot = newSlot
|
||||
}
|
||||
|
||||
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *Recurrent) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) SeqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *Recurrent) NumSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.recurrentStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.recurrentCtxs[layer]; !ok {
|
||||
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
||||
c.recurrentStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
||||
c.ensureWritableOnce(ctx)
|
||||
return c.writableError
|
||||
}
|
||||
|
||||
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
}
|
||||
|
||||
return buf.Rows(ctx, c.SlotsTensor())
|
||||
}
|
||||
|
||||
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
ctx.Forward(src.Copy(ctx, view))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := c.convBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes new conv state for current batch sequences.
|
||||
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
||||
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := 1
|
||||
for _, d := range dims {
|
||||
if d <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
size *= d
|
||||
}
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
shape := make([]int, 0, len(dims)+1)
|
||||
shape = append(shape, dims...)
|
||||
shape = append(shape, c.NumSeqs())
|
||||
return cur.Reshape(ctx, shape...), nil
|
||||
}
|
||||
|
||||
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
||||
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := dim0 * dim1 * dim2
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
||||
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
||||
|
||||
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *Recurrent) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *Recurrent) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
561
kvcache/recurrent_checkpoints.go
Normal file
561
kvcache/recurrent_checkpoints.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
recurrent map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
offset := beginIndex - endIndex
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
maxPos := int32(-1)
|
||||
minIdx := 0
|
||||
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
if pos >= beginIndex && pos < endIndex {
|
||||
s.entries[i].pos = -1
|
||||
} else if pos >= endIndex {
|
||||
s.entries[i].pos = pos + offset
|
||||
}
|
||||
}
|
||||
|
||||
pos = s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
minIdx = i
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = maxPos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointTag() string {
|
||||
if c.logPrefix == "" {
|
||||
return "kvcache.recurrent"
|
||||
}
|
||||
return c.logPrefix
|
||||
}
|
||||
|
||||
func (c *Recurrent) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.recurrent {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.recurrentStates {
|
||||
if entry.recurrent == nil || entry.recurrent[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Recurrent) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.shiftRange(beginIndex, endIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.recurrent != nil {
|
||||
if dstEntry.recurrent == nil {
|
||||
dstEntry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.recurrent {
|
||||
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointRecurrent(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointRecurrent(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.recurrent == nil {
|
||||
entry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.recurrent[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointRecurCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointRecurCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
|
||||
entry.recurrent[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointRecurrent(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
288
kvcache/recurrent_checkpoints_test.go
Normal file
288
kvcache/recurrent_checkpoints_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestCache() *Recurrent {
|
||||
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachePrepareRestore(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate layer 0 requires both conv and recurrent checkpoints.
|
||||
cache.convStates[0] = nil
|
||||
cache.recurrentStates[0] = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.recurrent intentionally missing
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
_, err := cache.RecurrentState(nil, 0, 3)
|
||||
if !errors.Is(err, ErrInvalidRecurrentShape) {
|
||||
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
|
||||
store := newSlotCheckpointStore(5)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
store.shiftRange(2, 6)
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
|
||||
}
|
||||
if store.lastPos != 6 {
|
||||
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
if err := cache.Remove(1, 2, 6); err != nil {
|
||||
t.Fatalf("expected middle remove to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared after middle remove")
|
||||
}
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil
|
||||
store.entries[0].recurrent = make(map[int]ml.Tensor)
|
||||
store.entries[0].recurrent[0] = nil
|
||||
|
||||
store.record(40)
|
||||
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].recurrent == nil {
|
||||
t.Fatalf("expected recurrent map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 22 Feb 2026 14:12:30 -0800
|
||||
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
|
||||
specialization
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
2 files changed, 3 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 4ac135603..ac5ad53db 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c37447a10..4f338aa13 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
@@ -2,15 +2,22 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
@@ -18,7 +25,6 @@ type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
@@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
return writeSSE(w.ResponseWriter, eventType, data)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
@@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
@@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
|
||||
// executes the search, re-invokes the model with results, and assembles the
|
||||
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
|
||||
type WebSearchAnthropicWriter struct {
|
||||
BaseWriter
|
||||
newLoopContext func() (context.Context, context.CancelFunc)
|
||||
inner *AnthropicWriter
|
||||
req anthropic.MessagesRequest // original Anthropic request
|
||||
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
|
||||
stream bool
|
||||
|
||||
estimatedInputTokens int
|
||||
|
||||
terminalSent bool
|
||||
|
||||
observedPromptEvalCount int
|
||||
observedEvalCount int
|
||||
|
||||
loopInFlight bool
|
||||
loopBaseInputTok int
|
||||
loopBaseOutputTok int
|
||||
loopResultCh chan webSearchLoopResult
|
||||
|
||||
streamMessageStarted bool
|
||||
streamHasOpenBlock bool
|
||||
streamOpenBlockIndex int
|
||||
streamNextIndex int
|
||||
}
|
||||
|
||||
const maxWebSearchLoops = 3
|
||||
|
||||
type webSearchLoopResult struct {
|
||||
response anthropic.MessagesResponse
|
||||
loopErr *webSearchLoopError
|
||||
}
|
||||
|
||||
type webSearchLoopError struct {
|
||||
code string
|
||||
query string
|
||||
usage anthropic.Usage
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *webSearchLoopError) Error() string {
|
||||
if e.err == nil {
|
||||
return e.code
|
||||
}
|
||||
return fmt.Sprintf("%s: %v", e.code, e.err)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
|
||||
if w.terminalSent {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.inner.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.recordObservedUsage(chatResponse.Metrics)
|
||||
|
||||
if w.stream && w.loopInFlight {
|
||||
if !chatResponse.Done {
|
||||
return len(data), nil
|
||||
}
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
|
||||
logutil.Trace("anthropic middleware: upstream chunk",
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"web_search", hasWebSearch,
|
||||
"other_tools", hasOtherTools,
|
||||
)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
if w.stream {
|
||||
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
return w.inner.writeResponse(data)
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
// Let the original generation continue to completion while web search runs in parallel.
|
||||
logutil.Trace("anthropic middleware: starting async web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
)
|
||||
w.startLoopWorker(chatResponse, webSearchCall)
|
||||
if chatResponse.Done {
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
loopCtx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
|
||||
}
|
||||
logutil.Trace("anthropic middleware: starting sync web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"usage", initialUsage,
|
||||
)
|
||||
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
|
||||
if loopErr != nil {
|
||||
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
|
||||
}
|
||||
|
||||
if err := w.writeTerminalResponse(response); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
|
||||
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
|
||||
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
|
||||
|
||||
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
|
||||
usage := initialUsage
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
|
||||
"model", w.req.Model,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
"messages", len(followUpMessages),
|
||||
"tools", len(followUpTools),
|
||||
"max_loops", maxWebSearchLoops,
|
||||
)
|
||||
|
||||
currentResponse := initialResponse
|
||||
currentToolCall := initialToolCall
|
||||
|
||||
var serverContent []anthropic.ContentBlock
|
||||
|
||||
if !isCloudModelName(w.req.Model) {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model")
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "web_search_not_supported_for_local_models",
|
||||
query: extractQueryFromToolCall(&initialToolCall),
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
for loop := 1; loop <= maxWebSearchLoops; loop++ {
|
||||
query := extractQueryFromToolCall(¤tToolCall)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
|
||||
"loop", loop,
|
||||
"query", anthropic.TraceTruncateString(query),
|
||||
"messages", len(followUpMessages),
|
||||
)
|
||||
if query == "" {
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "invalid_request",
|
||||
query: "",
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "unavailable",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
|
||||
"loop", loop,
|
||||
"results", len(searchResp.Results),
|
||||
)
|
||||
|
||||
toolUseID := loopServerToolUseID(w.inner.id, loop)
|
||||
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: searchResults,
|
||||
},
|
||||
)
|
||||
|
||||
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
|
||||
toolResultMsg := api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
|
||||
ToolCallID: currentToolCall.ID,
|
||||
}
|
||||
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
|
||||
|
||||
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "api_error",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup response",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceChatResponse(followUpResponse),
|
||||
)
|
||||
|
||||
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
|
||||
usage.OutputTokens += followUpResponse.Metrics.EvalCount
|
||||
|
||||
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceMessagesResponse(finalResponse),
|
||||
)
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
currentResponse = followUpResponse
|
||||
currentToolCall = nextToolCall
|
||||
}
|
||||
|
||||
maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
|
||||
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: maxLoopToolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": maxLoopQuery},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: maxLoopToolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
maxResponse := anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: serverContent,
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
|
||||
"resp", anthropic.TraceMessagesResponse(maxResponse),
|
||||
)
|
||||
return maxResponse, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
|
||||
if w.loopInFlight {
|
||||
return
|
||||
}
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
|
||||
}
|
||||
w.loopBaseInputTok = initialUsage.InputTokens
|
||||
w.loopBaseOutputTok = initialUsage.OutputTokens
|
||||
w.loopResultCh = make(chan webSearchLoopResult, 1)
|
||||
w.loopInFlight = true
|
||||
logutil.Trace("anthropic middleware: loop worker started",
|
||||
"usage", initialUsage,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
)
|
||||
|
||||
go func() {
|
||||
ctx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
|
||||
w.loopResultCh <- webSearchLoopResult{
|
||||
response: response,
|
||||
loopErr: loopErr,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
|
||||
if w.loopResultCh == nil {
|
||||
return w.sendError("api_error", "", w.currentObservedUsage())
|
||||
}
|
||||
|
||||
result := <-w.loopResultCh
|
||||
w.loopResultCh = nil
|
||||
w.loopInFlight = false
|
||||
if result.loopErr != nil {
|
||||
logutil.Trace("anthropic middleware: loop worker returned error",
|
||||
"code", result.loopErr.code,
|
||||
"query", result.loopErr.query,
|
||||
"usage", result.loopErr.usage,
|
||||
"error", result.loopErr.err,
|
||||
)
|
||||
usage := result.loopErr.usage
|
||||
w.applyObservedUsageDeltaToUsage(&usage)
|
||||
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
|
||||
}
|
||||
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
|
||||
|
||||
w.applyObservedUsageDelta(&result.response)
|
||||
return w.writeTerminalResponse(result.response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
|
||||
w.applyObservedUsageDeltaToUsage(&response.Usage)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
|
||||
if metrics.PromptEvalCount > w.observedPromptEvalCount {
|
||||
w.observedPromptEvalCount = metrics.PromptEvalCount
|
||||
}
|
||||
if metrics.EvalCount > w.observedEvalCount {
|
||||
w.observedEvalCount = metrics.EvalCount
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
|
||||
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
|
||||
usage.InputTokens += deltaIn
|
||||
}
|
||||
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
|
||||
usage.OutputTokens += deltaOut
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
|
||||
return anthropic.Usage{
|
||||
InputTokens: w.observedPromptEvalCount,
|
||||
OutputTokens: w.observedEvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
|
||||
if w.newLoopContext != nil {
|
||||
return w.newLoopContext()
|
||||
}
|
||||
return context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
|
||||
|
||||
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
|
||||
content = append(content, serverContent...)
|
||||
content = append(content, converted.Content...)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: content,
|
||||
StopReason: converted.StopReason,
|
||||
StopSequence: converted.StopSequence,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{webSearchCall},
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
assistantMsg.Content = response.Message.Content
|
||||
}
|
||||
if response.Message.Thinking != "" {
|
||||
assistantMsg.Thinking = response.Message.Thinking
|
||||
}
|
||||
return assistantMsg
|
||||
}
|
||||
|
||||
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
|
||||
var resultText strings.Builder
|
||||
for _, r := range results {
|
||||
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
|
||||
if r.Content != "" {
|
||||
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
return resultText.String()
|
||||
}
|
||||
|
||||
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
|
||||
var webSearchCall api.ToolCall
|
||||
hasWebSearch := false
|
||||
hasOtherTools := false
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == "web_search" {
|
||||
if !hasWebSearch {
|
||||
webSearchCall = toolCall
|
||||
hasWebSearch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
hasOtherTools = true
|
||||
}
|
||||
|
||||
return webSearchCall, hasWebSearch, hasOtherTools
|
||||
}
|
||||
|
||||
func loopServerToolUseID(messageID string, loop int) string {
|
||||
base := serverToolUseID(messageID)
|
||||
if loop <= 1 {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", base, loop)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
|
||||
streaming := false
|
||||
followUp := api.ChatRequest{
|
||||
Model: w.chatReq.Model,
|
||||
Messages: messages,
|
||||
Stream: &streaming,
|
||||
Tools: tools,
|
||||
Options: w.chatReq.Options,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(followUp)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
|
||||
chatURL := envconfig.Host().String() + "/api/chat"
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup request",
|
||||
"url", chatURL,
|
||||
"req", anthropic.TraceChatRequest(&followUp),
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
|
||||
"status", resp.StatusCode,
|
||||
"response", strings.TrimSpace(string(respBody)),
|
||||
)
|
||||
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var chatResp api.ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
|
||||
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
|
||||
events := w.inner.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
switch e := event.Data.(type) {
|
||||
case anthropic.MessageStartEvent:
|
||||
w.streamMessageStarted = true
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
w.streamHasOpenBlock = true
|
||||
w.streamOpenBlockIndex = e.Index
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
|
||||
w.streamHasOpenBlock = false
|
||||
}
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.MessageStopEvent:
|
||||
w.terminalSent = true
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
|
||||
if w.streamMessageStarted {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = w.estimatedInputTokens
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: inputTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamMessageStarted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
|
||||
if !w.streamHasOpenBlock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: w.streamOpenBlockIndex,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = w.streamOpenBlockIndex + 1
|
||||
}
|
||||
w.streamHasOpenBlock = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
|
||||
for _, block := range content {
|
||||
index := w.streamNextIndex
|
||||
if block.Type == "text" {
|
||||
emptyText := ""
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: anthropic.ContentBlock{
|
||||
Type: "text",
|
||||
Text: &emptyText,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
text := ""
|
||||
if block.Text != nil {
|
||||
text = *block.Text
|
||||
}
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: anthropic.Delta{
|
||||
Type: "text_delta",
|
||||
Text: text,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: block,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: index,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamNextIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
|
||||
if w.terminalSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.closeOpenStreamBlock(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.writeStreamContentBlocks(response.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: anthropic.MessageDelta{
|
||||
StopReason: response.StopReason,
|
||||
},
|
||||
Usage: anthropic.DeltaUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamResponse emits a complete MessagesResponse as SSE events.
|
||||
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
toolUseID := serverToolUseID(w.inner.id)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: errorCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// sendError sends a web search error response.
|
||||
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
|
||||
response := w.webSearchErrorResponse(errorCode, query, usage)
|
||||
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestCtx := c.Request.Context()
|
||||
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
@@ -134,11 +865,10 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
innerWriter := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
@@ -148,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
if hasWebSearchTool(req.Tools) {
|
||||
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
|
||||
// for cloud models. Local models may still receive web_search tool definitions;
|
||||
// execution is validated when the model actually emits a web_search tool call.
|
||||
if isCloudModelName(req.Model) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer = &WebSearchAnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
newLoopContext: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(requestCtx, 5*time.Minute)
|
||||
},
|
||||
inner: innerWriter,
|
||||
req: req,
|
||||
chatReq: chatReq,
|
||||
stream: req.Stream,
|
||||
estimatedInputTokens: estimatedTokens,
|
||||
}
|
||||
} else {
|
||||
c.Writer = innerWriter
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// hasWebSearchTool checks if the request tools include a web_search tool
|
||||
func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Type, "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
func extractQueryFromToolCall(tc *api.ToolCall) string {
|
||||
q, ok := tc.Function.Arguments.Get("query")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := q.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeSSE writes a Server-Sent Event
|
||||
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverToolUseID derives a server tool use ID from a message ID
|
||||
func serverToolUseID(messageID string) string {
|
||||
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
22
middleware/test_home_test.go
Normal file
22
middleware/test_home_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
|
||||
// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
|
||||
// so that cloud features are enabled for the duration of the test.
|
||||
func enableCloudForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
setTestHome(t, t.TempDir())
|
||||
}
|
||||
@@ -163,6 +163,7 @@ type Tensor interface {
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
@@ -1,410 +1,44 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
// HybridCache adapts the shared recurrent cache for LFM2:
|
||||
// - KV attention cache is handled by the embedded causal cache
|
||||
// - shortconv recurrent state uses conv slots [dConv, hiddenSize]
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
// This reuses shared checkpoint/restore logic for prefix mismatch recovery.
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: shift,
|
||||
ConvDim: dConv,
|
||||
ConvChannels: hiddenSize,
|
||||
RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size.
|
||||
CheckpointLogPrefix: "lfm2",
|
||||
})
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
return c.NumSeqs()
|
||||
}
|
||||
|
||||
@@ -4,441 +4,39 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
func TestHybridCache_New(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
if cache == nil {
|
||||
t.Fatal("expected cache to be created")
|
||||
}
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
if cache.Recurrent == nil {
|
||||
t.Fatal("expected embedded recurrent cache to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
func TestHybridCache_ImplementsCheckpointCache(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
if _, ok := any(cache).(kvcache.CheckpointCache); !ok {
|
||||
t.Fatal("expected HybridCache to implement CheckpointCache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
func TestHybridCache_DefaultBatchState(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
if got := cache.numSeqs(); got != 0 {
|
||||
t.Fatalf("expected 0 sequences before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
if got := cache.seqTokens(); got != 0 {
|
||||
t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
t.Fatal("expected unsupported batch layout before StartForward")
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user