Compare commits

..

2 Commits

Author SHA1 Message Date
ParthSareen
23e8ac9428 wip? 2025-05-07 19:00:44 -07:00
ParthSareen
611d3a17ed server: add python tool parsing logic 2025-05-02 16:23:54 -07:00
333 changed files with 15004 additions and 24680 deletions

View File

@@ -432,22 +432,6 @@ jobs:
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }}
# Trigger downstream release process
trigger:
runs-on: ubuntu-latest
environment: release
needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]

View File

@@ -19,8 +19,8 @@ linters:
- nolintlint
- nosprintfhostport
- staticcheck
- tenv
- unconvert
- usetesting
- wastedassign
- whitespace
disable:

View File

@@ -51,8 +51,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
add_compile_definitions(NDEBUG)
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
FETCH_HEAD=2016f07bd106c73699ecbaace80f55db5ed95dac
.PHONY: help
help:
@@ -15,13 +15,11 @@ help:
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
.PHONY: sync
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
go generate ./$(@D)
.PHONY: llama/build-info.cpp
llama/build-info.cpp: llama/build-info.cpp.in
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
.PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/
@@ -32,13 +30,12 @@ ml/backend/ggml/ggml: llama/vendor/ggml/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch)
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
.PHONY: apply-patches
.NOTPARALLEL:
apply-patches: $(PATCHED)
apply-patches: $(addsuffix ed, $(PATCHES))
llama/patches/.%.patched: llama/patches/%.patch
%.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout
@@ -60,4 +57,4 @@ format-patches: llama/patches
.PHONE: clean
clean: checkout
$(RM) llama/patches/.*.patched
$(RM) $(addsuffix ed, $(PATCHES))

View File

@@ -61,8 +61,6 @@ Here are some example models that can be downloaded:
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
@@ -79,7 +77,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -287,7 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
@@ -314,8 +312,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
@@ -329,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
@@ -345,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
@@ -372,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
@@ -390,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
@@ -398,15 +394,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
### Cloud
@@ -448,9 +440,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
### Apple Vision Pro
@@ -477,7 +468,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Libraries
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
- [crewAI](https://github.com/crewAIInc/crewAI)
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
@@ -524,21 +515,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
@@ -562,7 +552,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
@@ -572,8 +562,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
@@ -587,8 +577,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
### Supported backends

View File

@@ -24,10 +24,7 @@ import (
"net/http"
"net/url"
"runtime"
"strconv"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@@ -79,14 +76,6 @@ func NewClient(base *url.URL, http *http.Client) *Client {
}
}
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
token, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return "", err
}
return token, nil
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte
@@ -108,21 +97,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
}
requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil {
return err
@@ -132,10 +106,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
respObj, err := c.http.Do(request)
if err != nil {
return err
@@ -173,22 +143,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
var err error
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
@@ -198,10 +152,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/x-ndjson")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
response, err := c.http.Do(request)
if err != nil {
return err

View File

@@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -136,7 +137,7 @@ func TestClientStream(t *testing.T) {
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
@@ -222,7 +223,7 @@ func TestClientDo(t *testing.T) {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {

View File

@@ -83,12 +83,6 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -114,10 +108,6 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
}
type Tools []Tool
@@ -136,11 +126,8 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
@@ -284,6 +271,9 @@ type Options struct {
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"`
}
@@ -293,7 +283,12 @@ type Runner struct {
NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap *bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
@@ -457,13 +452,12 @@ type ProcessResponse struct {
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}
// ProcessModelResponse is a single model description in [ProcessResponse].
@@ -477,6 +471,13 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"`
}
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct {
Token string `json:"token"`
}
@@ -492,10 +493,6 @@ type GenerateResponse struct {
// Response is the textual response itself.
Response string `json:"response"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
// Done specifies if the response is complete.
Done bool `json:"done"`
@@ -663,6 +660,9 @@ func DefaultOptions() Options {
RepeatPenalty: 1.1,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Mirostat: 0,
MirostatTau: 5.0,
MirostatEta: 0.1,
Seed: -1,
Runner: Runner{
@@ -671,6 +671,8 @@ func DefaultOptions() Options {
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide
LowVRAM: false,
UseMLock: false,
UseMMap: nil,
},
}

View File

@@ -372,50 +372,3 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
})
}
}
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@@ -4,14 +4,20 @@ import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
)
func InitLogging() {
level := slog.LevelInfo
if envconfig.Debug() {
level = slog.LevelDebug
}
var logFile *os.File
var err error
// Detect if we're a GUI app on windows, and if not, send logs to console
@@ -27,8 +33,20 @@ func InitLogging() {
return
}
}
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
slog.Info("ollama app started")
}

View File

@@ -78,7 +78,7 @@ func BenchmarkColdStart(b *testing.B) {
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
@@ -113,7 +113,7 @@ func BenchmarkWarmStart(b *testing.B) {
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
@@ -140,7 +140,7 @@ func setup(b *testing.B) *api.Client {
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}

View File

@@ -31,7 +31,6 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
@@ -39,31 +38,12 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
)
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
@@ -126,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -137,54 +117,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -253,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
@@ -283,9 +243,6 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
req := &api.GenerateRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@@ -320,22 +277,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
hidethinking, err := cmd.Flags().GetBool("hidethinking")
if err != nil {
return err
}
opts.HideThinking = hidethinking
keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil {
return err
@@ -399,11 +340,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
@@ -789,38 +725,11 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
targetWidth := 10 // Small width where we are displaying the data in a column
var itemsToShow int
totalWidth := 1 // Start with 1 for opening bracket
// Find how many we can fit
for i := range vData {
itemStr := fmt.Sprintf("%v", vData[i])
width := runewidth.StringWidth(itemStr)
// Add separator width (", ") for all items except the first
if i > 0 {
width += 2
}
// Check if adding this item would exceed our width limit
if totalWidth+width > targetWidth && i > 0 {
break
}
totalWidth += width
itemsToShow++
}
// Format the output
if itemsToShow < len(vData) {
v = fmt.Sprintf("%v", vData[:itemsToShow])
v = strings.TrimSuffix(v, "]")
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
} else {
v = fmt.Sprintf("%v", vData)
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
@@ -841,19 +750,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" {
continue
for scanner.Scan() && (len(rows) < n || n < 0) {
if text := scanner.Text(); text != "" {
rows = append(rows, []string{"", strings.TrimSpace(text)})
}
count++
if n < 0 || count <= n {
rows = append(rows, []string{"", text})
}
}
if n >= 0 && count > n {
rows = append(rows, []string{"", "..."})
}
return
}
@@ -965,19 +865,17 @@ func PullHandler(cmd *cobra.Command, args []string) error {
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Think *bool
HideThinking bool
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
}
type displayResponseState struct {
@@ -1033,26 +931,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
}
}
func thinkingOutputOpeningText(plainText bool) string {
text := "Thinking...\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
}
func thinkingOutputClosingText(plainText bool) string {
text := "...done thinking.\n\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -1080,34 +958,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error {
if response.Message.Content != "" || !opts.HideThinking {
p.StopAndClear()
}
p.StopAndClear()
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(false))
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(false))
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state)
@@ -1124,7 +982,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
if opts.KeepAlive != nil {
@@ -1186,32 +1043,13 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}()
var state *displayResponseState = &displayResponseState{}
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
}
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagClosed = true
}
displayResponse(content, opts.WordWrap, state)
return nil
@@ -1237,7 +1075,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
Think: opts.Think,
}
if err := client.Generate(ctx, &request, fn); err != nil {
@@ -1341,11 +1178,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
if !strings.Contains(err.Error(), " refused") {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err)
return errors.New("could not connect to ollama app, is it running?")
}
}
return nil
@@ -1423,7 +1260,7 @@ func NewCLI() *cobra.Command {
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -1453,8 +1290,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
@@ -1506,6 +1341,7 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler,
}
copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION",
Short: "Copy a model",
@@ -1571,6 +1407,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"],
envVars["OLLAMA_CONTEXT_LENGTH"],
})
default:
appendEnvDocs(cmd, envs)
@@ -1594,45 +1431,3 @@ func NewCLI() *cobra.Command {
return rootCmd
}
// If the user has explicitly set thinking options, either through the CLI or
// through the `/set think` or `set nothink` interactive options, then we
// respect them. Otherwise, we check model capabilities to see if the model
// supports thinking. If the model does support thinking, we enable it.
// Otherwise, we unset the thinking option (which is different than setting it
// to false).
//
// If capabilities are not provided, we fetch them from the server.
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
if explicitlySetByUser {
return runOpts.Think, nil
}
if caps == nil {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
ret, err := client.Show(context.Background(), &api.ShowRequest{
Model: runOpts.Model,
})
if err != nil {
return nil, err
}
caps = &ret.Capabilities
}
thinkingSupported := false
for _, cap := range *caps {
if cap == model.CapabilityThinking {
thinkingSupported = true
}
}
if thinkingSupported {
thinking := true
return &thinking, nil
}
return nil, nil
}

View File

@@ -2,6 +2,7 @@ package cmd
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -225,7 +226,6 @@ Weigh anchor!
System
You are a pirate!
Ahoy, matey!
...
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
@@ -337,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err)
}
@@ -399,6 +399,11 @@ func TestGetModelfileName(t *testing.T) {
var expectedFilename string
if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
@@ -406,7 +411,7 @@ func TestGetModelfileName(t *testing.T) {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(t.TempDir(), fn)
tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
@@ -525,7 +530,7 @@ func TestPushHandler(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
@@ -630,7 +635,7 @@ func TestListHandler(t *testing.T) {
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
@@ -685,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return
}
if req.Model != "test-model" {
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
@@ -725,7 +730,7 @@ func TestCreateHandler(t *testing.T) {
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
tempFile, err := os.CreateTemp("", "modelfile")
if err != nil {
t.Fatal(err)
}
@@ -745,7 +750,7 @@ func TestCreateHandler(t *testing.T) {
}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr

View File

@@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
@@ -62,8 +62,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "")
}
@@ -130,7 +128,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
for {
line, err := scanner.Readline()
@@ -198,19 +195,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
continue
@@ -271,22 +260,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
@@ -475,11 +448,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
sb.Reset()
continue
}
return err
}
if assistant != nil {
@@ -543,7 +511,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
@@ -563,8 +531,6 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return "", imgs, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
@@ -585,7 +551,7 @@ func getImageData(filePath string) ([]byte, error) {
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}

View File

@@ -1,8 +1,6 @@
package cmd
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,17 +10,14 @@ func TestExtractFilenames(t *testing.T) {
// Unix style paths
input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input)
assert.Len(t, res, 7)
assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.webp")
assert.Contains(t, res[6], "seven.WEBP")
assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
@@ -33,12 +28,10 @@ func TestExtractFilenames(t *testing.T) {
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
d:\path with\spaces\thirteen.WEBP some ending
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
`
res = extractFileNames(input)
assert.Len(t, res, 13)
assert.Len(t, res, 10)
assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:")
@@ -56,31 +49,4 @@ d:\path with\spaces\thirteen.WEBP some ending
assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:")
assert.Contains(t, res[10], "eleven.webp")
assert.Contains(t, res[10], "c:")
assert.Contains(t, res[11], "twelve.WebP")
assert.Contains(t, res[11], "c:")
assert.Contains(t, res[12], "thirteen.WEBP")
assert.Contains(t, res[12], "d:")
}
// Ensure that file paths wrapped in single quotes are removed with the quotes.
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "img.jpg")
data := make([]byte, 600)
copy(data, []byte{
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xff, 0xd9,
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test image: %v", err)
}
input := "before '" + fp + "' after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}

View File

@@ -23,7 +23,7 @@ func startApp(ctx context.Context, client *api.Client) error {
return errors.New("could not find ollama app")
}
path := strings.Split(link, "Ollama.app")
if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil {
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
return err
}
return waitForServer(ctx, client)

View File

@@ -4,27 +4,17 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"syscall"
"unsafe"
"github.com/ollama/ollama/api"
"golang.org/x/sys/windows"
)
const (
Installer = "OllamaSetup.exe"
)
func startApp(ctx context.Context, client *api.Client) error {
if len(isProcRunning(Installer)) > 0 {
return fmt.Errorf("upgrade in progress...")
}
// log.Printf("XXX Attempting to find and start ollama app")
AppName := "ollama app.exe"
exe, err := os.Executable()
if err != nil {
@@ -45,11 +35,14 @@ func startApp(ctx context.Context, client *api.Client) error {
}
}
}
// log.Printf("XXX attempting to start app %s", appExe)
cmd_path := "c:\\Windows\\system32\\cmd.exe"
cmd := exec.Command(cmd_path, "/c", appExe, "hidden")
cmd := exec.Command(cmd_path, "/c", appExe)
// TODO - these hide flags aren't working - still pops up a command window for some reason
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
// TODO this didn't help either...
cmd.Stdin = strings.NewReader("")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -63,50 +56,3 @@ func startApp(ctx context.Context, client *api.Client) error {
}
return waitForServer(ctx, client)
}
func isProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
if ret > uint32(len(pids)) {
pids = make([]uint32, ret+10)
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
}
if ret < uint32(len(pids)) {
pids = pids[:ret]
}
var matches []uint32
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}

View File

@@ -1,63 +0,0 @@
package cmd
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
ensureThinkingSupport(t.Context(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@@ -1,13 +1,12 @@
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"slices"
"strings"
@@ -15,12 +14,13 @@ import (
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
@@ -53,11 +53,8 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
}
for _, sv := range t.SpecialVocabulary {
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
if len(sv.IDs) > 0 {
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
}
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
}
return kv
@@ -92,7 +89,7 @@ type ModelConverter interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
@@ -109,13 +106,13 @@ type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -150,14 +147,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
@@ -176,8 +173,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] {
case "LlamaForCausalLM":
conv = &llamaModel{}
case "MllamaForConditionalGeneration":
conv = &mllamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
@@ -194,8 +189,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &phi3Model{}
case "Qwen2ForCausalLM":
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":
@@ -219,22 +212,24 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch {
case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
}
case vocabSize < len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
@@ -244,13 +239,13 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err
}
return writeFile(f, conv.KV(t), conv.Tensors(ts))
return writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(f, kv, ts)
return ggml.WriteGGUF(ws, kv, ts)
}

View File

@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
continue
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
return kv
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -126,11 +126,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
@@ -139,14 +139,13 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") ||
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
@@ -182,9 +181,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") {
if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") {
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)

View File

@@ -88,13 +88,13 @@ func (p *llama4Model) Replacements() []string {
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
@@ -112,7 +112,7 @@ func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
@@ -125,7 +125,7 @@ func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,

View File

@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
return kv
}
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,

View File

@@ -89,8 +89,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
@@ -100,7 +100,7 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
@@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
return true
})
var out []*ggml.Tensor
var out []ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),

View File

@@ -1,179 +0,0 @@
package convert
import (
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type mllamaModel struct {
ModelParameters
TextModel struct {
llamaModel
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumGlobalLayers uint32 `json:"num_global_layers"`
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
MaxNumTiles uint32 `json:"max_num_tiles"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope.freq_base"`
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"
for k, v := range m.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
}
}
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
return kv
}
func (m *mllamaModel) Replacements() []string {
return append(
m.TextModel.Replacements(),
"language_model.", "",
"gate_attn", "attn_gate",
"gate_ffn", "ffn_gate",
"cross_attn.", "cross_attn_",
"vision_model", "v",
"class_embedding", "class_embd",
"patch_embedding", "patch_embd",
"gated_positional_embedding.tile_embedding", "tile_position_embd",
"gated_positional_embedding.embedding", "position_embd.weight",
"gated_positional_embedding", "position_embd",
"embedding.weight", "weight",
"pre_tile_positional_embedding", "pre_tile_position_embd",
"post_tile_positional_embedding", "post_tile_position_embd",
"layernorm_pre", "pre_ln",
"layernorm_post", "post_ln",
"global_transformer.layers", "global.blk",
"transformer.layers", "blk",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"multi_modal_projector", "mm.0",
)
}
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var text []Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
text = append(text, t)
} else if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone()
tt.SetRepacker(m.repack(name))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: tt,
})
}
} else {
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
t.SetRepacker(m.repack(t.Name()))
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return append(out, m.TextModel.Tensors(text)...)
}
func (m *mllamaModel) repack(name string) Repacker {
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
heads := m.VisionModel.AttentionHeads
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := t.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := t.Reshape(dims...); err != nil {
return nil, err
}
if err := t.Transpose(); err != nil {
return nil, err
}
} else {
t, err = tensor.Tanh(t)
if err != nil {
return nil, err
}
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
var addRopeFactors sync.Once
out := make([]*ggml.Tensor, 0, len(ts)+2)
out := make([]ggml.Tensor, 0, len(ts)+2)
for _, t := range ts {
if strings.HasPrefix(t.Name(), "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: "rope_factors_long.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor,
}, &ggml.Tensor{
}, ggml.Tensor{
Name: "rope_factors_short.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
})
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -15,7 +15,6 @@ type qwen2Model struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
@@ -40,18 +39,16 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}
return kv
}
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@@ -1,102 +0,0 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
strings.NewReplacer("attn.qkv", "attn_q"),
strings.NewReplacer("attn.qkv", "attn_k"),
strings.NewReplacer("attn.qkv", "attn_v"),
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

View File

@@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}
@@ -130,7 +130,6 @@ func TestConvertModel(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer expectFile.Close()
var expect map[string]string
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}

58
convert/fs.go Normal file
View File

@@ -0,0 +1,58 @@
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}

View File

@@ -38,10 +38,7 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
t.name == "v.positional_embedding_vlm" {
// these tensors are always F32
return 0
}

View File

@@ -1,56 +0,0 @@
package convert
import (
"iter"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
for i, replacer := range replacers {
shape := slices.Clone(t.Shape())
shape[dim] = shape[dim] / uint64(len(replacers))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
tt := t.Clone()
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be written as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: replacer.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: tt,
}) {
break
}
}
}
}

View File

@@ -110,7 +110,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
return nil, err
} else {
@@ -172,34 +171,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
}
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
defer f.Close()
var p map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, err
}
for _, st := range specialTokenTypes {
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
var ids []int32
if err := json.Unmarshal(bts, &ids); err != nil {
// value is not a list so the existing ID is used
continue
}
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
return sv.Type == st
}); i >= 0 {
t.SpecialVocabulary[i].IDs = ids
}
}
}
}
return t, nil
}
@@ -309,9 +280,6 @@ type SpecialVocabulary struct {
ID int
Content string
AddToken bool
// IDs is populated by generation_config.json
IDs []int32
}
func (sv SpecialVocabulary) Key() string {

View File

@@ -247,67 +247,6 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "generation config eos token ids",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<bos>",
"special": true
},
{
"id": 1,
"content": "<eos>",
"special": true
},
{
"id": 2,
"content": "<eot>",
"special": true
},
{
"id": 3,
"content": "<eom>",
"special": true
}
],
"model": {
"vocab": {
"<bos>": 0,
"<eos>": 1,
"<eot>": 2,
"<eom>": 3
}
}
}`),
"tokenizer_config.json": strings.NewReader(`{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": "<bos>",
"eos_token": "<eos>"
}`),
"generation_config.json": strings.NewReader(`{
"bos_token_id": 0,
"eos_token_id": [1, 2, 3]
}`),
}),
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
Scores: []float32{0, 1, 2, 3},
Types: []int32{3, 3, 3, 3},
},
SpecialVocabulary: []*SpecialVocabulary{
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
},
Pre: "default",
},
},
}
for _, tt := range cases {

View File

@@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
}
func getVerboseState() C.uint16_t {
if envconfig.LogLevel() < slog.LevelInfo {
if envconfig.Debug() {
return C.uint16_t(1)
}
return C.uint16_t(0)

View File

@@ -27,14 +27,12 @@
#endif
#ifndef LOG
#define LOG(verbose, ...) \
do { \
if (verbose) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
#endif
#ifdef __cplusplus
extern "C" {

View File

@@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include <inttypes.h>
#include "gpu_info_cudart.h"
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
@@ -59,7 +58,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
@@ -169,9 +168,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
resp->free = memInfo.free;
resp->used = memInfo.used;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
@@ -181,4 +180,4 @@ void cudart_release(cudart_handle_t h) {
h.handle = NULL;
}
#endif // __APPLE__
#endif // __APPLE__

View File

@@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include <inttypes.h>
#include "gpu_info_nvcuda.h"
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
@@ -194,8 +193,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
resp->total = memInfo.total;
resp->free = memInfo.free;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
@@ -248,4 +247,4 @@ void nvcuda_release(nvcuda_handle_t h) {
h.handle = NULL;
}
#endif // __APPLE__
#endif // __APPLE__

View File

@@ -19,7 +19,7 @@
### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations
@@ -43,7 +43,6 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
- `think`: (for thinking models) should the model think before responding?
Advanced parameters (optional):
@@ -395,6 +394,9 @@ curl http://localhost:11434/api/generate -d '{
"repeat_penalty": 1.2,
"presence_penalty": 1.5,
"frequency_penalty": 1.0,
"mirostat": 1,
"mirostat_tau": 0.8,
"mirostat_eta": 0.6,
"penalize_newline": true,
"stop": ["\n", "user:"],
"numa": false,
@@ -402,7 +404,10 @@ curl http://localhost:11434/api/generate -d '{
"num_batch": 2,
"num_gpu": 1,
"main_gpu": 0,
"low_vram": false,
"vocab_only": false,
"use_mmap": true,
"use_mlock": false,
"num_thread": 8
}
}'
@@ -491,13 +496,11 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported
- `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
@@ -955,8 +958,19 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
| Type | Recommended |
| --- | :-: |
| q2_K | |
| q3_K_L | |
| q3_K_M | |
| q3_K_S | |
| q4_0 | |
| q4_1 | |
| q4_K_M | * |
| q4_K_S | |
| q5_0 | |
| q5_1 | |
| q5_K_M | |
| q5_K_S | |
| q6_K | |
| q8_0 | * |
### Examples
@@ -1001,8 +1015,8 @@ Quantize a non-quantized model.
```shell
curl http://localhost:11434/api/create -d '{
"model": "llama3.2:quantized",
"from": "llama3.2:3b-instruct-fp16",
"model": "llama3.1:quantized",
"from": "llama3.1:8b-instruct-fp16",
"quantize": "q4_K_M"
}'
```
@@ -1012,14 +1026,12 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned:
```json
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302}
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552}
{"status":"verifying conversion"}
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"}
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
{"status":"quantizing F16 model to Q4_K_M"}
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
{"status":"writing manifest"}
{"status":"success"}
```
@@ -1157,46 +1169,29 @@ A single JSON object will be returned.
{
"models": [
{
"model": "codellama:13b",
"name": "codellama:13b",
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
"size": 7365960935,
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
"capabilities": [
"completion"
],
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": [
"qwen2"
],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M"
"family": "llama",
"families": null,
"parameter_size": "13B",
"quantization_level": "Q4_0"
}
},
{
"model": "llama4:latest",
"name": "llama3:latest",
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
"size": 3825819519,
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
"capabilities": [
"completion",
"vision"
],
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
"families": null,
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]

View File

@@ -118,7 +118,7 @@ To run tests, use `go test`:
go test ./...
```
> NOTE: In rare cirumstances, you may need to change a package using the new
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens.
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
@@ -31,7 +31,7 @@ OLLAMA_CONTEXT_LENGTH=8192 ollama serve
To change this when using `ollama run`, use `/set parameter`:
```shell
/set parameter num_ctx 4096
/set parameter num_ctx 8192
```
When using the API, specify the `num_ctx` parameter:
@@ -41,7 +41,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama3.2",
"prompt": "Why is the sky blue?",
"options": {
"num_ctx": 4096
"num_ctx": 8192
}
}'
```

View File

@@ -132,12 +132,22 @@ success
### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0`
#### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com

View File

@@ -150,6 +150,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -149,22 +149,9 @@ func Bool(k string) func() bool {
}
}
// LogLevel returns the log level for the application.
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
func LogLevel() slog.Level {
level := slog.LevelInfo
if s := Var("OLLAMA_DEBUG"); s != "" {
if b, _ := strconv.ParseBool(s); b {
level = slog.LevelDebug
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
level = slog.Level(i * -4)
}
}
return level
}
var (
// Debug enabled additional debug information.
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// KvCacheType is the quantization type for the K/V cache.
@@ -182,9 +169,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1)
)
func String(s string) func() string {
@@ -224,6 +209,8 @@ var (
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
)
func Uint64(key string, defaultValue uint64) func() uint64 {
@@ -240,6 +227,20 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
}
}
func Int64(key string, defaultValue int64) func() int64 {
return func() int64 {
if s := Var(key); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return n
}
}
return defaultValue
}
}
// Set aside VRAM per GPU
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
@@ -251,7 +252,7 @@ type EnvVar struct {
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
@@ -268,7 +269,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational

View File

@@ -1,13 +1,11 @@
package envconfig
import (
"log/slog"
"math"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/logutil"
)
func TestHost(t *testing.T) {
@@ -280,9 +278,9 @@ func TestVar(t *testing.T) {
}
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 4096,
"2048": 2048,
cases := map[string]int64{
"": -1,
"4096": 4096,
}
for k, v := range cases {
@@ -294,34 +292,3 @@ func TestContextLength(t *testing.T) {
})
}
}
func TestLogLevel(t *testing.T) {
cases := map[string]slog.Level{
// Default to INFO
"": slog.LevelInfo,
"false": slog.LevelInfo,
"f": slog.LevelInfo,
"0": slog.LevelInfo,
// True values enable Debug
"true": slog.LevelDebug,
"t": slog.LevelDebug,
// Positive values increase verbosity
"1": slog.LevelDebug,
"2": logutil.LevelTrace,
// Negative values decrease verbosity
"-1": slog.LevelWarn,
"-2": slog.LevelError,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", k)
if i := LogLevel(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -15,7 +15,6 @@ import (
type GGML struct {
container
model
Length int64
}
type model interface {
@@ -37,12 +36,12 @@ func (kv KV) ParameterCount() uint64 {
return keyValue(kv, "general.parameter_count", uint64(0))
}
func (kv KV) FileType() FileType {
func (kv KV) FileType() fileType {
if t := kv.Uint("general.file_type"); t > 0 {
return FileType(t)
return fileType(t)
}
return FileTypeUnknown
return fileTypeUnknown
}
func (kv KV) BlockCount() uint64 {
@@ -126,8 +125,6 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3",
"mistral3",
"llama4",
"mllama",
"qwen25vl",
}, kv.Architecture())
}
@@ -152,7 +149,7 @@ func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ..
return val.(T)
}
slog.Debug("key not found", "key", key, "default", defaultValue[0])
slog.Warn("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0]
}
@@ -229,11 +226,7 @@ func (t Tensor) block() (n int) {
}
func (t Tensor) blockSize() uint64 {
return (TensorType)(t.Kind).BlockSize()
}
func (t TensorType) BlockSize() uint64 {
switch t {
switch t.Kind {
case
0, // F32
1, // F16
@@ -259,77 +252,73 @@ func (t TensorType) BlockSize() uint64 {
}
func (t Tensor) typeSize() uint64 {
return TensorType(t.Kind).TypeSize()
}
blockSize := t.blockSize()
func (t TensorType) TypeSize() uint64 {
blockSize := t.BlockSize()
switch t {
case TensorTypeF32:
switch t.Kind {
case 0: // FP32
return 4
case TensorTypeF16:
case 1: // FP16
return 2
case TensorTypeQ4_0:
case 2: // Q4_0
return 2 + blockSize/2
case TensorTypeQ4_1:
case 3: // Q4_1
return 2 + 2 + blockSize/2
case TensorTypeQ5_0:
case 6: // Q5_0
return 2 + 4 + blockSize/2
case TensorTypeQ5_1:
case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2
case TensorTypeQ8_0:
case 8: // Q8_0
return 2 + blockSize
case TensorTypeQ8_1:
case 9: // Q8_1
return 2 + 2 + blockSize
case TensorTypeQ2_K:
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case TensorTypeQ3_K:
case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2
case TensorTypeQ4_K:
case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2
case TensorTypeQ5_K:
case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case TensorTypeQ6_K:
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
case TensorTypeQ8_K:
case 15: // Q8_K
return 4 + blockSize + 2*blockSize/16
case tensorTypeIQ2_XXS:
case 16: // IQ2_XXS
return 2 + 2*blockSize/8
case tensorTypeIQ2_XS:
case 17: // IQ2_XS
return 2 + 2*blockSize/8 + blockSize/32
case tensorTypeIQ3_XXS:
case 18: // IQ3_XXS
return 2 + blockSize/4 + blockSize/8
case tensorTypeIQ1_S:
case 19: // IQ1_S
return 2 + blockSize/8 + blockSize/16
case tensorTypeIQ4_NL:
case 20: // IQ4_NL
return 2 + blockSize/2
case tensorTypeIQ3_S:
case 21: // IQ3_S
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
case tensorTypeIQ2_S:
case 22: // IQ2_S
return 2 + blockSize/4 + blockSize/16
case tensorTypeIQ4_XS:
case 23: // IQ4_XS
return 2 + 2 + blockSize/2 + blockSize/64
case TensorTypeI8:
case 24: // I8
return 1
case TensorTypeI16:
case 25: // I16
return 2
case TensorTypeI32:
case 26: // I32
return 4
case TensorTypeI64:
case 27: // I64
return 8
case TensorTypeF64:
case 28: // F64
return 8
case tensorTypeIQ1_M:
case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16:
case 30: // BF16
return 2
default:
return 0
}
}
func (t Tensor) Elements() uint64 {
func (t Tensor) parameters() uint64 {
var count uint64 = 1
for _, n := range t.Shape {
count *= n
@@ -338,11 +327,11 @@ func (t Tensor) Elements() uint64 {
}
func (t Tensor) Size() uint64 {
return t.Elements() * t.typeSize() / t.blockSize()
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return TensorType(t.Kind).String()
return fileType(t.Kind).String()
}
type container interface {
@@ -387,12 +376,12 @@ func DetectContentType(b []byte) string {
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, err
return nil, 0, err
}
var c container
@@ -402,25 +391,24 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default:
return nil, errors.New("invalid file magic")
return nil, 0, errors.New("invalid file magic")
}
model, err := c.Decode(rs)
if err != nil {
return nil, err
return nil, 0, err
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
return nil, 0, err
}
// final model type
return &GGML{
container: c,
model: model,
Length: offset,
}, nil
}, offset, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
@@ -492,7 +480,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
var ropeFreqsCount uint64
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
ropeFreqsCount = ropeFreqsWeights.Elements()
ropeFreqsCount = ropeFreqsWeights.parameters()
}
}
@@ -652,20 +640,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph

View File

@@ -9,12 +9,8 @@ import (
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"golang.org/x/sync/errgroup"
)
type containerGGUF struct {
@@ -229,7 +225,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
}
llm.tensors = append(llm.tensors, &tensor)
llm.parameters += tensor.Elements()
llm.parameters += tensor.parameters()
}
// patch KV with parameter count
@@ -492,38 +488,25 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return err
}
if t == ggufTypeString {
for _, e := range any(s).([]string) {
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
return nil
}
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
@@ -531,12 +514,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err
}
}
slices.SortStableFunc(ts, func(a, b *Tensor) int {
slices.SortStableFunc(ts, func(a, b Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
@@ -547,34 +530,21 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
})
var s uint64
for i := range ts {
ts[i].Offset = s
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
for _, t := range ts {
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
if err := ggufWriteTensorInfo(ws, t); err != nil {
return err
}
s += ts[i].Size()
s += uint64(ggufPadding(int64(s), int64(alignment)))
s += t.Size()
}
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
offset += ggufPadding(offset, int64(alignment))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
return err
})
}
}
return g.Wait()
return nil
}
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
@@ -589,10 +559,8 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
var err error
switch v := v.(type) {
case uint32, FileType:
case uint32:
err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:
err = writeGGUF(ws, ggufTypeUint64, v)
case float32:
err = writeGGUF(ws, ggufTypeFloat32, v)
case bool:
@@ -601,20 +569,32 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFString(ws, v)
case []int32:
err = writeGGUFArray(ws, ggufTypeInt32, v)
case *array[int32]:
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v)
case *array[uint32]:
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
case []float32:
err = writeGGUFArray(ws, ggufTypeFloat32, v)
case *array[float32]:
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
case []string:
err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values)
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, e := range v {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
default:
return fmt.Errorf("improper type for '%s'", k)
}
@@ -622,7 +602,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
return err
}
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
return err
@@ -649,6 +629,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
return binary.Write(ws, binary.LittleEndian, t.Offset)
}
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align
}

View File

@@ -1,63 +0,0 @@
package ggml
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, []*Tensor{
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(ff.KV(), KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(36),
}); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(ff.Tensors(), Tensors{
Offset: 336,
items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
},
}, cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -1,31 +1,26 @@
package ggml
import (
"fmt"
"log/slog"
"strings"
)
import "fmt"
// FileType is the Go equivalent to llama_ftype used for gguf file typing
type FileType uint32
type fileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16 // unused by GGML
fileTypeQ4_2 // unused by GGML
fileTypeQ4_3 // unused by GGML
FileTypeQ8_0
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
FileTypeQ4_K_S
FileTypeQ4_K_M
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
@@ -42,62 +37,93 @@ const (
fileTypeIQ2_M
fileTypeIQ4_XS
fileTypeIQ1_M
FileTypeBF16
fileTypeQ4_0_4_4 // unused by GGML
fileTypeQ4_0_4_8 // unused by GGML
fileTypeQ4_0_8_8 // unused by GGML
fileTypeTQ1_0
fileTypeTQ2_0
fileTypeBF16
FileTypeUnknown = 1024
fileTypeUnknown
)
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseFileType(s string) (FileType, error) {
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return FileTypeF32, nil
return fileTypeF32, nil
case "F16":
return FileTypeF16, nil
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return FileTypeQ8_0, nil
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return FileTypeQ4_K_S, nil
case "Q4_K_M", "Q4_K":
return FileTypeQ4_K_M, nil
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "IQ3_XS":
return fileTypeIQ3_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
case "IQ1_S":
return fileTypeIQ1_S, nil
case "IQ4_NL":
return fileTypeIQ4_NL, nil
case "IQ3_S":
return fileTypeIQ3_S, nil
case "IQ3_M":
return fileTypeIQ3_M, nil
case "IQ2_S":
return fileTypeIQ2_S, nil
case "IQ2_M":
return fileTypeIQ2_M, nil
case "IQ4_XS":
return fileTypeIQ4_XS, nil
case "IQ1_M":
return fileTypeIQ1_M, nil
case "BF16":
return FileTypeBF16, nil
return fileTypeBF16, nil
default:
supportedFileTypes := []FileType{
FileTypeF32,
FileTypeF16,
FileTypeQ4_K_S,
FileTypeQ4_K_M,
FileTypeQ8_0,
// fsggml.FileTypeBF16, // TODO
}
strs := make([]string, len(supportedFileTypes))
for i := range supportedFileTypes {
strs[i] = supportedFileTypes[i].String()
}
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t FileType) String() string {
// Note: this routine will return a broader set of file types for existing models
func (t fileType) String() string {
switch t {
case FileTypeF32:
case fileTypeF32:
return "F32"
case FileTypeF16:
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case FileTypeQ8_0:
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
@@ -111,9 +137,9 @@ func (t FileType) String() string {
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case FileTypeQ4_K_S:
case fileTypeQ4_K_S:
return "Q4_K_S"
case FileTypeQ4_K_M:
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
@@ -121,198 +147,39 @@ func (t FileType) String() string {
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case FileTypeBF16:
case fileTypeIQ3_XS:
return "IQ3_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
case fileTypeIQ1_S:
return "IQ1_S"
case fileTypeIQ4_NL:
return "IQ4_NL"
case fileTypeIQ3_S:
return "IQ3_S"
case fileTypeIQ3_M:
return "IQ3_M"
case fileTypeIQ2_S:
return "IQ2_S"
case fileTypeIQ4_XS:
return "IQ4_XS"
case fileTypeIQ2_M:
return "IQ2_M"
case fileTypeIQ1_M:
return "IQ1_M"
case fileTypeBF16:
return "BF16"
default:
return "unknown"
}
}
func (t FileType) Value() uint32 {
func (t fileType) Value() uint32 {
return uint32(t)
}
func (ftype FileType) ToTensorType() TensorType {
switch ftype {
case FileTypeF32:
return TensorTypeF32
case FileTypeF16:
return TensorTypeF16
case fileTypeQ4_0:
return TensorTypeQ4_0
case fileTypeQ4_1:
return TensorTypeQ4_1
case FileTypeQ8_0:
return TensorTypeQ8_0
case fileTypeQ5_0:
return TensorTypeQ5_0
case fileTypeQ5_1:
return TensorTypeQ5_1
case fileTypeQ2_K:
return TensorTypeQ2_K
case fileTypeQ3_K_S:
return TensorTypeQ3_K
case fileTypeQ3_K_M:
return TensorTypeQ3_K
case fileTypeQ3_K_L:
return TensorTypeQ3_K
case FileTypeQ4_K_S:
return TensorTypeQ4_K
case FileTypeQ4_K_M:
return TensorTypeQ4_K
case fileTypeQ5_K_S:
return TensorTypeQ5_K
case fileTypeQ5_K_M:
return TensorTypeQ5_K
case fileTypeQ6_K:
return TensorTypeQ6_K
case fileTypeQ2_K_S:
return TensorTypeQ2_K
case FileTypeBF16:
return TensorTypeBF16
default:
slog.Warn("unsupported file type", "type", ftype)
return 0 // F32
}
}
// TensorType is equivalent to ggml_type for individual tensor types
// Note: these are not the same as FileType
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
tensorTypeQ4_2 // unused by GGML
tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
tensorTypeIQ2_XXS // not supported by ollama
tensorTypeIQ2_XS // not supported by ollama
tensorTypeIQ3_XXS // not supported by ollama
tensorTypeIQ1_S // not supported by ollama
tensorTypeIQ4_NL // not supported by ollama
tensorTypeIQ3_S // not supported by ollama
tensorTypeIQ2_S // not supported by ollama
tensorTypeIQ4_XS // not supported by ollama
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
tensorTypeIQ1_M // not supported by ollama
TensorTypeBF16
tensorTypeQ4_0_4_4 // unused by GGML
tensorTypeQ4_0_4_8 // unused by GGML
tensorTypeQ4_0_8_8 // unused by GGML
tensorTypeTQ1_0 // not supported by ollama
tensorTypeTQ2_0 // not supported by ollama
tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML
)
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseTensorType(s string) (TensorType, error) {
switch s {
case "F32":
return TensorTypeF32, nil
case "F16":
return TensorTypeF16, nil
case "Q4_0":
return TensorTypeQ4_0, nil
case "Q4_1":
return TensorTypeQ4_1, nil
case "Q5_0":
return TensorTypeQ5_0, nil
case "Q5_1":
return TensorTypeQ5_1, nil
case "Q8_0":
return TensorTypeQ8_0, nil
case "Q8_1":
return TensorTypeQ8_1, nil
case "Q2_K":
return TensorTypeQ2_K, nil
case "Q3_K":
return TensorTypeQ3_K, nil
case "Q4_K":
return TensorTypeQ4_K, nil
case "Q5_K":
return TensorTypeQ5_K, nil
case "Q6_K":
return TensorTypeQ6_K, nil
case "Q8_K":
return TensorTypeQ8_K, nil
case "F64":
return TensorTypeF64, nil
case "BF16":
return TensorTypeBF16, nil
default:
return 0, fmt.Errorf("unsupported quantization type %s", s)
}
}
func (t TensorType) IsQuantized() bool {
switch t {
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
return false
default:
return true
}
}
func (t TensorType) RowSize(ne uint64) uint64 {
return t.TypeSize() * ne / t.BlockSize()
}
func (t TensorType) String() string {
switch t {
case TensorTypeF32:
return "F32"
case TensorTypeF16:
return "F16"
case TensorTypeQ4_0:
return "Q4_0"
case TensorTypeQ4_1:
return "Q4_1"
case TensorTypeQ5_0:
return "Q5_0"
case TensorTypeQ5_1:
return "Q5_1"
case TensorTypeQ8_0:
return "Q8_0"
case TensorTypeQ8_1:
return "Q8_1"
case TensorTypeQ2_K:
return "Q2_K"
case TensorTypeQ3_K:
return "Q3_K"
case TensorTypeQ4_K:
return "Q4_K"
case TensorTypeQ5_K:
return "Q5_K"
case TensorTypeQ6_K:
return "Q6_K"
case TensorTypeQ8_K:
return "Q8_K"
case TensorTypeF64:
return "F64"
case TensorTypeBF16:
return "BF16"
default:
return "unknown"
}
}

View File

@@ -1,350 +0,0 @@
package gguf
import (
"bytes"
"cmp"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"slices"
"strings"
)
const (
typeUint8 uint32 = iota
typeInt8
typeUint16
typeInt16
typeUint32
typeInt32
typeFloat32
typeBool
typeString
typeArray
typeUint64
typeInt64
typeFloat64
)
var ErrUnsupported = errors.New("unsupported")
type File struct {
Magic [4]byte
Version uint32
keyValues *lazy[KeyValue]
tensors *lazy[TensorInfo]
offset int64
file *os.File
reader *readSeeker
bts []byte
}
func Open(path string) (f *File, err error) {
f = &File{bts: make([]byte, 4096)}
f.file, err = os.Open(path)
if err != nil {
return nil, err
}
f.reader = newReadSeeker(f.file, 32<<10)
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
return nil, err
}
if bytes.Equal(f.Magic[:], []byte("gguf")) {
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
}
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
return nil, err
}
if f.Version != 3 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}
f.tensors, err = newLazy(f, f.readTensor)
if err != nil {
return nil, err
}
f.tensors.doneFunc = func() error {
offset, err := f.reader.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
f.offset = offset + (alignment-offset%alignment)%alignment
return nil
}
f.keyValues, err = newLazy(f, f.readKeyValue)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) readTensor() (TensorInfo, error) {
name, err := readString(f)
if err != nil {
return TensorInfo{}, err
}
dims, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
shape := make([]uint64, dims)
for i := range dims {
shape[i], err = read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
}
type_, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
offset, err := read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
return TensorInfo{
Name: name,
Offset: offset,
Shape: shape,
Type: TensorType(type_),
}, nil
}
func (f *File) readKeyValue() (KeyValue, error) {
key, err := readString(f)
if err != nil {
return KeyValue{}, err
}
t, err := read[uint32](f)
if err != nil {
return KeyValue{}, err
}
value, err := func() (any, error) {
switch t {
case typeUint8:
return read[uint8](f)
case typeInt8:
return read[int8](f)
case typeUint16:
return read[uint16](f)
case typeInt16:
return read[int16](f)
case typeUint32:
return read[uint32](f)
case typeInt32:
return read[int32](f)
case typeUint64:
return read[uint64](f)
case typeInt64:
return read[int64](f)
case typeFloat32:
return read[float32](f)
case typeFloat64:
return read[float64](f)
case typeBool:
return read[bool](f)
case typeString:
return readString(f)
case typeArray:
return readArray(f)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}()
if err != nil {
return KeyValue{}, err
}
return KeyValue{
Key: key,
Value: Value{value},
}, nil
}
func read[T any](f *File) (t T, err error) {
err = binary.Read(f.reader, binary.LittleEndian, &t)
return t, err
}
func readString(f *File) (string, error) {
n, err := read[uint64](f)
if err != nil {
return "", err
}
if int(n) > len(f.bts) {
f.bts = make([]byte, n)
}
bts := f.bts[:n]
if _, err := io.ReadFull(f.reader, bts); err != nil {
return "", err
}
defer clear(bts)
return string(bts), nil
}
func readArray(f *File) (any, error) {
t, err := read[uint32](f)
if err != nil {
return nil, err
}
n, err := read[uint64](f)
if err != nil {
return nil, err
}
switch t {
case typeUint8:
return readArrayData[uint8](f, n)
case typeInt8:
return readArrayData[int8](f, n)
case typeUint16:
return readArrayData[uint16](f, n)
case typeInt16:
return readArrayData[int16](f, n)
case typeUint32:
return readArrayData[uint32](f, n)
case typeInt32:
return readArrayData[int32](f, n)
case typeUint64:
return readArrayData[uint64](f, n)
case typeInt64:
return readArrayData[int64](f, n)
case typeFloat32:
return readArrayData[float32](f, n)
case typeFloat64:
return readArrayData[float64](f, n)
case typeBool:
return readArrayData[bool](f, n)
case typeString:
return readArrayString(f, n)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
s = make([]T, n)
for i := range n {
e, err := read[T](f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func readArrayString(f *File, n uint64) (s []string, err error) {
s = make([]string, n)
for i := range n {
e, err := readString(f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func (f *File) Close() error {
f.keyValues.stop()
f.tensors.stop()
return f.file.Close()
}
func (f *File) KeyValue(key string) KeyValue {
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
key = f.KeyValue("general.architecture").String() + "." + key
}
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
return kv.Key == key
}); index >= 0 {
return f.keyValues.values[index]
}
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
if keyValue.Key == key {
return keyValue
}
}
return KeyValue{}
}
func (f *File) NumKeyValues() int {
return int(f.keyValues.count)
}
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
return f.keyValues.All()
}
func (f *File) TensorInfo(name string) TensorInfo {
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
return t.Name == name
}); index >= 0 {
return f.tensors.values[index]
}
// fast-forward through key values if we haven't already
_ = f.keyValues.rest()
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
if tensor.Name == name {
return tensor
}
}
return TensorInfo{}
}
func (f *File) NumTensors() int {
return int(f.tensors.count)
}
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
// fast forward through key values if we haven't already
f.keyValues.rest()
return f.tensors.All()
}
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
t := f.TensorInfo(name)
if t.NumBytes() == 0 {
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
}
// fast forward through tensor info if we haven't already
_ = f.tensors.rest()
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
}

View File

@@ -1,320 +0,0 @@
package gguf
import (
"encoding/binary"
"fmt"
"os"
"path/filepath"
"slices"
"testing"
)
func TestRead(t *testing.T) {
// Setup
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test.gguf")
if err := createTestGGUFFile(tempFile, map[string]any{
"general.architecture": "llama",
"general.alignment": int64(32),
}, []testTensorInfo{
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
}); err != nil {
t.Fatal(err)
}
f, err := Open(tempFile)
if err != nil {
t.Fatal(err)
}
defer f.Close()
// Test
if got := f.NumKeyValues(); got != 2 {
t.Errorf("NumKeyValues() = %d, want %d", got, 2)
}
if got := f.NumTensors(); got != 2 {
t.Errorf("NumTensors() = %d, want %d", got, 2)
}
archKV := f.KeyValue("general.architecture")
if archKV.Key == "" {
t.Error("KeyValue(\"general.architecture\") not found")
}
if got := archKV.String(); got != "llama" {
t.Errorf("KeyValue(\"general.architecture\").String() = %q, want %q", got, "llama")
}
alignKV := f.KeyValue("general.alignment")
if alignKV.Key == "" {
t.Error("KeyValue(\"general.alignment\") not found")
}
if got := alignKV.Int(); got != 32 {
t.Errorf("KeyValue(\"general.alignment\").Int() = %d, want %d", got, 32)
}
expectedTensorNames := []string{"token_embd.weight", "output.weight"}
var gotTensorNames []string
for _, tensor := range f.TensorInfos() {
gotTensorNames = append(gotTensorNames, tensor.Name)
}
if !slices.Equal(gotTensorNames, expectedTensorNames) {
t.Errorf("tensor names = %v, want %v", gotTensorNames, expectedTensorNames)
}
tokenTensor := f.TensorInfo("token_embd.weight")
if tokenTensor.Name != "token_embd.weight" {
t.Error("TensorInfo(\"token_embd.weight\") not found")
}
if len(tokenTensor.Shape) == 0 {
t.Error("TensorInfo(\"token_embd.weight\") has empty shape")
}
outputTensor := f.TensorInfo("output.weight")
if outputTensor.Name != "output.weight" {
t.Error("TensorInfo(\"output.weight\") not found")
}
if len(outputTensor.Shape) == 0 {
t.Error("TensorInfo(\"output.weight\") has empty shape")
}
var gotKeyCount int
for _, kv := range f.KeyValues() {
gotKeyCount++
if kv.Key == "" {
t.Error("found key value with empty key")
}
}
if gotKeyCount != 2 {
t.Errorf("iterated key count = %d, want %d", gotKeyCount, 2)
}
tensorInfo, reader, err := f.TensorReader("token_embd.weight")
if err != nil {
t.Errorf("TensorReader(\"token_embd.weight\") error: %v", err)
}
if tensorInfo.Name != "token_embd.weight" {
t.Errorf("TensorReader returned wrong tensor: %q", tensorInfo.Name)
}
if reader == nil {
t.Error("TensorReader returned nil reader")
}
}
func BenchmarkRead(b *testing.B) {
// Create benchmark test file
tempDir := b.TempDir()
tempFile := filepath.Join(tempDir, "benchmark.gguf")
if err := createTestGGUFFile(tempFile, map[string]any{
"general.architecture": "llama",
"general.alignment": int64(32),
}, []testTensorInfo{
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
}); err != nil {
b.Fatal(err)
}
// Get file info for reporting
info, err := os.Stat(tempFile)
if err != nil {
b.Fatal(err)
}
b.Logf("Benchmark file size: %d bytes", info.Size())
b.ReportAllocs()
for b.Loop() {
f, err := Open(tempFile)
if err != nil {
b.Fatal(err)
}
// Access some data to ensure it's actually being read
_ = f.KeyValue("general.architecture").String()
_ = f.KeyValue("general.alignment").Int()
_ = f.NumTensors()
_ = f.NumKeyValues()
// Iterate through some tensors
count := 0
for _, tensor := range f.TensorInfos() {
_ = tensor.Name
count++
if count >= 2 {
break
}
}
f.Close()
}
}
// Helper function to create test GGUF files
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
// Write GGUF magic
if _, err := file.Write([]byte("GGUF")); err != nil {
return err
}
// Write version
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
return err
}
// Write tensor count
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
return err
}
// Write metadata count
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
return err
}
// Write metadata
for key, value := range keyValues {
if err := writeKeyValue(file, key, value); err != nil {
return err
}
}
// Write tensor info
for _, tensor := range tensors {
if err := writeTensorInfo(file, tensor); err != nil {
return err
}
}
// Write some dummy tensor data
dummyData := make([]byte, 1024)
file.Write(dummyData)
return nil
}
type testTensorInfo struct {
Name string
Shape []uint64
Type uint32
}
func writeKeyValue(file *os.File, key string, value any) error {
// Write key length and key
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
return err
}
if _, err := file.Write([]byte(key)); err != nil {
return err
}
// Write value based on type
switch v := value.(type) {
case string:
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
_, err := file.Write([]byte(v))
return err
case int64:
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
return err
}
return binary.Write(file, binary.LittleEndian, v)
case bool:
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
return err
}
return binary.Write(file, binary.LittleEndian, v)
case float64:
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
return err
}
return binary.Write(file, binary.LittleEndian, v)
case []string:
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, s := range v {
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
return err
}
if _, err := file.Write([]byte(s)); err != nil {
return err
}
}
return nil
case []int64:
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, i := range v {
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
return err
}
}
return nil
case []float64:
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
return err
}
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, f := range v {
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
return err
}
}
return nil
default:
return fmt.Errorf("unsupported value type: %T", value)
}
}
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
// Write tensor name
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
return err
}
if _, err := file.Write([]byte(tensor.Name)); err != nil {
return err
}
// Write dimensions
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
return err
}
for _, dim := range tensor.Shape {
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
return err
}
}
// Write type
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
return err
}
// Write offset (dummy value)
return binary.Write(file, binary.LittleEndian, uint64(0))
}

View File

@@ -1,102 +0,0 @@
package gguf
import (
"reflect"
"slices"
)
type KeyValue struct {
Key string
Value
}
type Value struct {
value any
}
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
vv := reflect.ValueOf(v.value)
if slices.Contains(kinds, vv.Kind()) {
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
}
return
}
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
switch vv := reflect.ValueOf(v.value); vv.Kind() {
case reflect.Slice:
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
ts = make([]T, vv.Len())
for i := range vv.Len() {
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
}
}
}
return
}
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
func (v Value) Int() int64 {
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
func (v Value) Ints() (i64s []int64) {
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
func (v Value) Uint() uint64 {
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
func (v Value) Uints() (u64s []uint64) {
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Float returns Value as a float. If it is not a float, it returns 0.
func (v Value) Float() float64 {
return value[float64](v, reflect.Float32, reflect.Float64)
}
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
func (v Value) Floats() (f64s []float64) {
return values[float64](v, reflect.Float32, reflect.Float64)
}
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
func (v Value) Bool() bool {
return value[bool](v, reflect.Bool)
}
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
func (v Value) Bools() (bools []bool) {
return values[bool](v, reflect.Bool)
}
// String returns Value as a string. If it is not a string, it returns an empty string.
func (v Value) String() string {
return value[string](v, reflect.String)
}
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
func (v Value) Strings() (strings []string) {
return values[string](v, reflect.String)
}
// IsNil checks if the Value is nil. It returns true if the value is nil or if it is a nil pointer, interface, slice, map, channel, or function.
func (v Value) IsNil() bool {
if v.value == nil {
return true
}
// Check for nil pointers, interfaces, slices, maps, channels, and functions
rv := reflect.ValueOf(v.value)
switch rv.Kind() {
case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
return rv.IsNil()
}
return false
}

View File

@@ -1,208 +0,0 @@
package gguf
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
for key, value := range values {
if key == name {
matched = value
} else {
unmatched = append(unmatched, value...)
}
}
return
}
func TestValue(t *testing.T) {
values := map[string][]any{
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
"float64": {float32(42), float64(42)},
"string": {"42", "hello"},
"bool": {true, false},
}
t.Run("int64", func(t *testing.T) {
matched, unmatched := split("int64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 42 {
t.Errorf("expected 42, got %d", i64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 0 {
t.Errorf("expected 42, got %d", i64)
}
}
})
t.Run("uint64", func(t *testing.T) {
matched, unmatched := split("uint64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 42 {
t.Errorf("expected 42, got %d", u64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 0 {
t.Errorf("expected 42, got %d", u64)
}
}
})
t.Run("float64", func(t *testing.T) {
matched, unmatched := split("float64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 42 {
t.Errorf("expected 42, got %f", f64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 0 {
t.Errorf("expected 42, got %f", f64)
}
}
})
t.Run("string", func(t *testing.T) {
matched, unmatched := split("string", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != v {
t.Errorf("expected 42, got %s", s)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != "" {
t.Errorf("expected 42, got %s", s)
}
}
})
t.Run("bool", func(t *testing.T) {
matched, unmatched := split("bool", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != v {
t.Errorf("expected true, got %v", b)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != false {
t.Errorf("expected false, got %v", b)
}
}
})
}
func TestValues(t *testing.T) {
values := map[string][]any{
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
"float64s": {[]float32{42}, []float64{42}},
"strings": {[]string{"42"}, []string{"hello"}},
"bools": {[]bool{true}, []bool{false}},
}
t.Run("int64s", func(t *testing.T) {
matched, unmatched := split("int64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64s := kv.Ints(); i64s != nil {
t.Errorf("expected nil, got %v", i64s)
}
}
})
t.Run("uint64s", func(t *testing.T) {
matched, unmatched := split("uint64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64s := kv.Uints(); u64s != nil {
t.Errorf("expected nil, got %v", u64s)
}
}
})
t.Run("float64s", func(t *testing.T) {
matched, unmatched := split("float64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64s := kv.Floats(); f64s != nil {
t.Errorf("expected nil, got %v", f64s)
}
}
})
t.Run("strings", func(t *testing.T) {
matched, unmatched := split("strings", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.Strings(); s != nil {
t.Errorf("expected nil, got %v", s)
}
}
})
t.Run("bools", func(t *testing.T) {
matched, unmatched := split("bools", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bools(); b != nil {
t.Errorf("expected nil, got %v", b)
}
}
})
}

View File

@@ -1,88 +0,0 @@
package gguf
import (
"encoding/binary"
"iter"
"log/slog"
)
type lazy[T any] struct {
count uint64
next func() (T, bool)
stop func()
values []T
doneFunc func() error
}
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
it := lazy[T]{}
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
return nil, err
}
it.values = make([]T, 0)
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
for i := range it.count {
t, err := fn()
if err != nil {
slog.Error("error reading tensor", "index", i, "error", err)
return
}
it.values = append(it.values, t)
if !yield(t) {
break
}
}
if it.doneFunc != nil {
it.doneFunc()
}
})
return &it, nil
}
func (g *lazy[T]) Values() iter.Seq[T] {
return func(yield func(T) bool) {
for _, v := range g.All() {
if !yield(v) {
break
}
}
}
}
func (g *lazy[T]) All() iter.Seq2[int, T] {
return func(yield func(int, T) bool) {
for i := range int(g.count) {
if i < len(g.values) {
if !yield(i, g.values[i]) {
break
}
} else {
t, ok := g.next()
if !ok {
break
}
if !yield(i, t) {
break
}
}
}
}
}
func (g *lazy[T]) rest() (collected bool) {
for {
_, ok := g.next()
collected = collected || ok
if !ok {
break
}
}
return collected
}

View File

@@ -1,34 +0,0 @@
package gguf
import (
"bufio"
"io"
)
type readSeeker struct {
rs io.ReadSeeker
br *bufio.Reader
}
func newReadSeeker(rs io.ReadSeeker, size int) *readSeeker {
return &readSeeker{
rs: rs,
br: bufio.NewReaderSize(rs, size),
}
}
func (b *readSeeker) Read(p []byte) (int, error) {
return b.br.Read(p)
}
func (b *readSeeker) Seek(offset int64, whence int) (int64, error) {
if whence == io.SeekCurrent {
offset -= int64(b.br.Buffered())
}
n, err := b.rs.Seek(offset, whence)
if err != nil {
return 0, err
}
b.br.Reset(b.rs)
return n, nil
}

View File

@@ -1,284 +0,0 @@
package gguf
import (
"log/slog"
"strings"
)
type TensorInfo struct {
Name string
Offset uint64
Shape []uint64
Type TensorType
}
func (t TensorInfo) NumValues() int64 {
var numItems int64 = 1
for _, dim := range t.Shape {
numItems *= int64(dim)
}
return numItems
}
// NumBytes returns the number of bytes in the tensor.
func (t TensorInfo) NumBytes() int64 {
return int64(float64(t.NumValues()) * t.Type.NumBytes())
}
func (t TensorInfo) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", t.Name),
slog.Int64("offset", int64(t.Offset)),
slog.Any("shape", t.Shape),
slog.Int64("num_values", t.NumValues()),
slog.Int64("num_bytes", t.NumBytes()),
slog.Any("type", t.Type),
)
}
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
// unexported // unused in gguf
tensorTypeQ4_2
tensorTypeQ4_3
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
// unexported // unquantizable by ollama
tensorTypeIQ2_XXS
tensorTypeIQ2_XS
tensorTypeIQ3_XXS
tensorTypeIQ1_S
tensorTypeIQ4_NL
tensorTypeIQ3_S
tensorTypeIQ2_S
tensorTypeIQ4_XS
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
// unexported // unquantizable by ollama
tensorTypeIQ1_M
TensorTypeBF16
// unexported // unused in gguf
tensorTypeQ4_0_4_4
tensorTypeQ4_0_4_8
tensorTypeQ4_0_8_8
// unexported // unquantizable by ollama
tensorTypeTQ1_0
tensorTypeTQ2_0
// unexported // unused in gguf
tensorTypeIQ4_NL_4_4
tensorTypeIQ4_NL_4_8
tensorTypeIQ4_NL_8_8
)
func (t TensorType) NumBytes() float64 {
return float64(t.typeSize()) / float64(t.blockSize())
}
func (t TensorType) typeSize() int64 {
switch t {
case TensorTypeF32:
return 4
case TensorTypeF16:
return 2
case TensorTypeQ4_0:
return 2 + t.blockSize()/2
case TensorTypeQ4_1:
return 2 + 2 + t.blockSize()/2
case TensorTypeQ5_0:
return 2 + 4 + t.blockSize()/2
case TensorTypeQ5_1:
return 2 + 2 + 4 + t.blockSize()/2
case TensorTypeQ8_0:
return 2 + t.blockSize()
case TensorTypeQ8_1:
return 2 + 2 + t.blockSize()
case TensorTypeQ2_K:
return t.blockSize()/16 + t.blockSize()/4 + 2 + 2
case TensorTypeQ3_K:
return t.blockSize()/8 + t.blockSize()/4 + 12 + 2
case TensorTypeQ4_K:
return 2 + 2 + 12 + t.blockSize()/2
case TensorTypeQ5_K:
return 2 + 2 + 12 + t.blockSize()/8 + t.blockSize()/2
case TensorTypeQ6_K:
return t.blockSize()/2 + t.blockSize()/4 + t.blockSize()/16 + 2
case TensorTypeQ8_K:
return 4 + t.blockSize() + 2*t.blockSize()/16
case tensorTypeIQ2_XXS:
return 2 + 2*t.blockSize()/8
case tensorTypeIQ2_XS:
return 2 + 2*t.blockSize()/8 + t.blockSize()/32
case tensorTypeIQ3_XXS:
return 2 + t.blockSize()/4 + t.blockSize()/8
case tensorTypeIQ1_S:
return 2 + t.blockSize()/8 + t.blockSize()/16
case tensorTypeIQ4_NL:
return 2 + t.blockSize()/2
case tensorTypeIQ3_S:
return 2 + t.blockSize()/4 + t.blockSize()/8 + t.blockSize()/32 + 4
case tensorTypeIQ2_S:
return 2 + t.blockSize()/4 + t.blockSize()/16
case tensorTypeIQ4_XS:
return 2 + 2 + t.blockSize()/2 + t.blockSize()/64
case TensorTypeI8:
return 1
case TensorTypeI16:
return 2
case TensorTypeI32:
return 4
case TensorTypeI64:
return 8
case TensorTypeF64:
return 8
case tensorTypeIQ1_M:
return t.blockSize()/8 + t.blockSize()/16 + t.blockSize()/32
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (t TensorType) blockSize() int64 {
switch t {
case TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL:
return 32
default:
return 256
}
}
func (t TensorType) String() string {
switch t {
case TensorTypeF32:
return "f32"
case TensorTypeF16:
return "f16"
case TensorTypeQ4_0:
return "q4_0"
case TensorTypeQ4_1:
return "q4_1"
case tensorTypeQ4_2:
return "q4_2"
case tensorTypeQ4_3:
return "q4_3"
case TensorTypeQ5_0:
return "q5_0"
case TensorTypeQ5_1:
return "q5_1"
case TensorTypeQ8_0:
return "q8_0"
case TensorTypeQ8_1:
return "q8_1"
case TensorTypeQ2_K:
return "q2_k"
case TensorTypeQ3_K:
return "q3_k"
case TensorTypeQ4_K:
return "q4_k"
case TensorTypeQ5_K:
return "q5_k"
case TensorTypeQ6_K:
return "q6_k"
case TensorTypeQ8_K:
return "q8_k"
case tensorTypeIQ2_XXS:
return "iq2_xxs"
case tensorTypeIQ2_XS:
return "iq2_xs"
case tensorTypeIQ3_XXS:
return "iq3_xxs"
case tensorTypeIQ1_S:
return "iq1_s"
case tensorTypeIQ4_NL:
return "iq4_nl"
case tensorTypeIQ3_S:
return "iq3_s"
case tensorTypeIQ2_S:
return "iq2_s"
case tensorTypeIQ4_XS:
return "iq4_xs"
case TensorTypeI8:
return "i8"
case TensorTypeI16:
return "i16"
case TensorTypeI32:
return "i32"
case TensorTypeI64:
return "i64"
case TensorTypeF64:
return "f64"
case tensorTypeIQ1_M:
return "iq1_m"
case TensorTypeBF16:
return "bf16"
case tensorTypeQ4_0_4_4:
return "q4_0_4_4"
case tensorTypeQ4_0_4_8:
return "q4_0_4_8"
case tensorTypeQ4_0_8_8:
return "q4_0_8_8"
case tensorTypeTQ1_0:
return "tq1_0"
case tensorTypeTQ2_0:
return "tq2_0"
case tensorTypeIQ4_NL_4_4:
return "iq4_nl_4_4"
case tensorTypeIQ4_NL_4_8:
return "iq4_nl_4_8"
case tensorTypeIQ4_NL_8_8:
return "iq4_nl_8_8"
default:
return "unknown"
}
}
func (t TensorType) LogValue() slog.Value {
return slog.GroupValue(
slog.Uint64("value", uint64(t)),
slog.String("name", strings.ToUpper(t.String())),
slog.Int64("size", t.typeSize()),
slog.Int64("block_size", t.blockSize()),
slog.Float64("num_bytes", t.NumBytes()),
)
}

12
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sync v0.11.0
)
require (
@@ -70,12 +70,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

24
go.sum
View File

@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -34,15 +34,13 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, client, t, req)
res, err := embeddingTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -64,15 +62,13 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -102,15 +98,13 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, client, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -150,8 +144,6 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false
@@ -190,7 +182,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, client, t, req.Request)
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
@@ -206,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
@@ -218,7 +210,9 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
}
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
@@ -232,7 +226,9 @@ func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T,
return response, nil
}
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}

View File

@@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
}
testCases := []testCase{
{
model: "qwen2.5vl",
model: "llava:7b",
},
{
model: "llama3.2-vision",
@@ -60,7 +60,6 @@ func TestVisionModels(t *testing.T) {
}
func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{

View File

@@ -48,6 +48,17 @@ var (
}
)
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {
return 8 * time.Minute, 10 * time.Minute
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
t.Skip("too little time")
return time.Duration(0), time.Duration(0)
}
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
}
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)

View File

@@ -1,130 +0,0 @@
//go:build integration && models
package integration
import (
"bytes"
"context"
"fmt"
"log/slog"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestQuantization(t *testing.T) {
sourceModels := []string{
"qwen2.5:0.5b-instruct-fp16",
}
quantizations := []string{
"Q8_0",
"Q4_K_S",
"Q4_K_M",
"Q4_K",
}
softTimeout, hardTimeout := getTimeouts(t)
started := time.Now()
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, base := range sourceModels {
if err := PullIfMissing(ctx, client, base); err != nil {
t.Fatalf("pull failed %s", err)
}
for _, quant := range quantizations {
newName := fmt.Sprintf("%s__%s", base, quant)
t.Run(newName, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
req := &api.CreateRequest{
Model: newName,
Quantization: quant,
From: base,
}
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
return nil
}
t.Logf("quantizing: %s -> %s", base, quant)
if err := client.Create(ctx, req, fn); err != nil {
t.Fatalf("create failed %s", err)
}
defer func() {
req := &api.DeleteRequest{
Model: newName,
}
t.Logf("deleting: %s -> %s", base, quant)
if err := client.Delete(ctx, req); err != nil {
t.Logf("failed to clean up %s: %s", req.Model, err)
}
}()
// Check metadata on the model
resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if !strings.Contains(resp.Details.QuantizationLevel, quant) {
t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
}
stream := true
genReq := api.GenerateRequest{
Model: newName,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Stream: &stream,
}
t.Logf("verifying: %s -> %s", base, quant)
// Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String())
for _, resp := range anyResp {
if strings.Contains(fullResp, resp) {
atLeastOne = true
t.Log(fullResp)
reqCancel()
break
}
}
return nil
}
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(reqCtx, &genReq, genfn)
done <- 0
}()
select {
case <-done:
if genErr != nil && !atLeastOne {
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
t.Logf("passed")
})
}
}
}

View File

File diff suppressed because one or more lines are too long

View File

@@ -217,7 +217,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
@@ -359,14 +358,3 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
}
}
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {
return 8 * time.Minute, 10 * time.Minute
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
t.Skip("too little time")
return time.Duration(0), time.Duration(0)
}
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
}

View File

@@ -30,11 +30,6 @@ type Causal struct {
// ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put
curLayer int
@@ -164,13 +159,12 @@ func (c *Causal) Close() {
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !c.curReserve {
if !reserve {
c.updateSlidingWindow()
var err error
@@ -217,9 +211,10 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curCellRange.max = len(c.cells) - 1
}
c.curMask = c.buildMask(ctx)
var err error
c.curMask, err = c.buildMask(ctx)
return nil
return err
}
func newRange() cellRange {
@@ -244,7 +239,7 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
@@ -302,7 +297,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@@ -310,11 +305,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
@@ -335,7 +325,10 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
mask[i] = float32(math.Inf(-1))
}
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
@@ -343,7 +336,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
maskTensor = out
}
return maskTensor
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
@@ -498,7 +491,12 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
c.curMask = c.buildMask(ctx)
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
}
}
@@ -654,7 +652,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
}
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys {
if key == nil {

View File

@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice(test.in, test.inShape...)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context)
@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
return t
return t, nil
}
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
f := make([]float32, len(s))
for i := range f {
f[i] = float32(s[i])
}
out := c.FromFloatSlice(f, shape...)
out, _ := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32
return out
return out, nil
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s = append(s, i)
}
out := c.FromFloatSlice(s, len(s))
out, _ := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype
return out
}
@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
char const *LLAMA_COMMIT = "2016f07bd106c73699ecbaace80f55db5ed95dac";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -10,11 +10,11 @@ include common/stb_image.*
include include/
include include/llama.*
include include/llama-*.*
include tools/
include tools/mtmd/
include tools/mtmd/clip.*
include tools/mtmd/clip-impl.*
include tools/mtmd/llava.*
include examples/
include examples/llava/
include examples/llava/clip.*
include examples/llava/clip-impl.*
include examples/llava/llava.*
include src/
include src/llama.*
include src/llama-*.*

View File

@@ -1096,6 +1096,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
@@ -1113,7 +1114,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) {
cparams.embeddings = true;
@@ -1565,20 +1565,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}

View File

@@ -66,6 +66,7 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL,
@@ -95,7 +96,6 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
};
// dimensionality reduction methods, used by cvector-generator
@@ -161,7 +161,6 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
@@ -324,6 +323,7 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
@@ -332,7 +332,6 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation
@@ -341,10 +340,8 @@ struct common_params {
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see tools/mtmd)
// multimodal models (see examples/llava)
struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)
// embedding
@@ -410,14 +407,13 @@ struct common_params {
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params
int n_pca_batch = 100;
int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill
@@ -666,9 +662,3 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
}
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

View File

@@ -16,9 +16,6 @@ using json = nlohmann::ordered_json;
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();
if (max_items == 0) {
return "";
}
if (min_items == 0 && max_items == 1) {
return item_rule + "?";
}

View File

@@ -1,7 +1,6 @@
#include "sampling.h"
#include "common.h"
#include "log.h"
#include <cmath>
#include <unordered_map>
@@ -230,48 +229,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data()));
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
if (params.top_n_sigma >= 0) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
} else {
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -473,7 +475,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
@@ -489,7 +490,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@@ -504,7 +504,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -518,7 +517,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -535,16 +533,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
continue;
} else {
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
}
}
}
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
}
return samplers;
@@ -556,7 +552,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
@@ -571,8 +566,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
}
}

View File

@@ -2,6 +2,8 @@
#include "gguf.h"
#include "clip.h"
#include "clip.h"
#include <climits>
#include <cstdarg>
#include <string>
@@ -15,29 +17,33 @@
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.vision.embedding_length"
#define KEY_N_FF "clip.vision.feed_forward_length"
#define KEY_N_BLOCK "clip.vision.block_count"
#define KEY_N_HEAD "clip.vision.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.vision.projection_dim"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.%s.projection_dim"
#define KEY_TOKENS "tokenizer.ggml.tokens"
#define KEY_N_POSITIONS "clip.text.context_length"
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
//
@@ -53,16 +59,10 @@
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_LLAVA_PROJ "mm.%d.%s"
@@ -70,14 +70,8 @@
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -93,23 +87,18 @@
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
#define TN_GLM_BOI_W "adapter.boi"
#define TN_GLM_EOI_W "adapter.eoi"
enum projector_type {
PROJECTOR_TYPE_MLP,
PROJECTOR_TYPE_MLP_NORM,
PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_MINICPMV,
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_QWEN2VL,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -117,14 +106,10 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -239,15 +224,6 @@ struct clip_image_u8_batch {
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));
}
return new_batch;
}
};
//

2927
llama/llama.cpp/examples/llava/clip.cpp vendored Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,7 +47,7 @@ CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_par
CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
@@ -59,29 +59,18 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
"use clip_n_output_tokens instead");
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
"use clip_n_output_tokens instead");
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// for M-RoPE, this will be the number of token positions in X and Y directions
// for other models, X will be the total number of tokens and Y will be 1
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// this should be equal to the embedding dimension of the text model
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init(void);
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
// nx, ny are the output image dimensions
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
@@ -125,6 +114,8 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

View File

@@ -2,7 +2,6 @@
#include "llava.h"
#include "llama.h"
#include "ggml-cpp.h"
#include <algorithm>
#include <cerrno>
@@ -113,7 +112,7 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<
}
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
struct {
struct ggml_context * ctx;
} model;
@@ -176,7 +175,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
model.ctx = ggml_init(params);
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
// fill it with the image embeddings, ignoring the base
for (size_t i = 1; i < num_images; i++) {
@@ -210,17 +209,13 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
ggml_build_forward_expand(gf, flatten);
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
ggml_backend_graph_compute(backend.get(), gf);
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
struct ggml_tensor* result = ggml_graph_node(gf, -1);
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
// append without newline tokens (default behavior in llava_arch when not using unpad ):
memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
// Debug: Test single segments
// Current findings: sending base image, sending a segment embedding all works similar to python
@@ -318,7 +313,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
image_embd_v[i],
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
}
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
@@ -347,8 +342,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding
*n_img_pos = clip_n_patches(ctx_clip);
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
*n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
if (!encoded) {
LOG_ERR("Unable to encode image\n");
@@ -386,8 +381,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
int n_img_pos_out;
clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
@@ -462,7 +456,7 @@ struct llava_embd_batch {
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
@@ -474,6 +468,7 @@ struct llava_embd_batch {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*n_embd =*/ n_embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
@@ -497,7 +492,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;

View File

@@ -1,4 +1,4 @@
package mtmd
package llava
// #cgo CXXFLAGS: -std=c++11
// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common

View File

@@ -4,7 +4,6 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h>
#include <stdint.h>
@@ -112,8 +111,6 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
};
enum llama_rope_type {
@@ -258,6 +255,7 @@ extern "C" {
llama_token * token;
float * embd;
int32_t n_embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
@@ -353,18 +351,20 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool cross_attn; // whether to use cross attention
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
};
// model quantization parameters
@@ -446,10 +446,6 @@ extern "C" {
size_t n_paths,
struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead");
@@ -464,6 +460,10 @@ extern "C" {
struct llama_context_params params),
"use llama_init_from_model instead");
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
@@ -929,19 +929,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
// Process a batch of tokens.
// In contrast to llama_decode() - this call does not use KV cache.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -1242,7 +1237,6 @@ extern "C" {
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@@ -1438,37 +1432,6 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus
}
#endif

View File

@@ -253,9 +253,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
std::vector<ggml_backend_buffer_type_t> buft_extra;
{
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
@@ -294,9 +291,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev);
break;

View File

@@ -6,6 +6,7 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_MLLAMA, "mllama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
@@ -19,7 +20,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
@@ -73,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -108,7 +109,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -144,6 +144,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
@@ -273,6 +274,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_MLLAMA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
},
},
{
LLM_ARCH_DECI,
{
@@ -476,24 +511,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
@@ -1570,6 +1587,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}
},
{
LLM_ARCH_UNKNOWN,
{
@@ -1701,6 +1734,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@@ -11,6 +11,7 @@
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_MLLAMA,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
@@ -23,7 +24,6 @@ enum llm_arch {
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
@@ -75,6 +75,7 @@ enum llm_arch {
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_MISTRAL3,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_UNKNOWN,
@@ -112,7 +113,6 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
@@ -148,6 +148,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@@ -349,6 +350,14 @@ enum llm_tensor {
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_BSKCN_TV,
LLM_TENSOR_CROSS_ATTN_K_NORM,
LLM_TENSOR_CROSS_ATTN_K_PROJ,
LLM_TENSOR_CROSS_ATTN_O_PROJ,
LLM_TENSOR_CROSS_ATTN_Q_NORM,
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
LLM_TENSOR_CROSS_ATTN_V_PROJ,
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW,
LLM_TENSOR_CONVNEXT_NORM,

View File

@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
@@ -203,7 +203,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i;
}
if (simple_split) {
seq.resize(1);
llama_sbatch_seq & s = seq[0];
@@ -213,7 +212,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
s.length = n_tokens;
return;
}
std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -241,7 +239,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
return n_seq_a > n_seq_b;
}
);
// init seq
llama_sbatch_seq * last_seq = nullptr;
@@ -265,7 +262,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
seq.push_back(new_seq);
last_seq = &seq.back();
}
// keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
@@ -320,6 +316,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@@ -332,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@@ -340,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}

View File

@@ -70,8 +70,7 @@ struct llama_sbatch {
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
// temporary allocate memory for the input batch if needed

View File

@@ -35,7 +35,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
@@ -51,8 +50,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
@@ -63,7 +62,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -83,9 +81,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
if (tmpl_contains("<|im_start|>")) {
return tmpl_contains("<|im_sep|>")
? LLM_CHAT_TEMPLATE_PHI_4
: tmpl_contains("<end_of_utterance>")
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
: LLM_CHAT_TEMPLATE_CHATML;
: LLM_CHAT_TEMPLATE_CHATML;
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
if (tmpl_contains("[SYSTEM_PROMPT]")) {
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@@ -123,12 +119,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
}
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) {
@@ -157,7 +149,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA_3;
} else if (tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
return LLM_CHAT_TEMPLATE_CHATGLM_3;
return LLM_CHAT_TEMPLATE_CHATGML_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGML_4;
} else if (tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
return LLM_CHAT_TEMPLATE_MINICPM;
@@ -203,20 +197,19 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
// Official mistral 'v7' template
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
for (auto message : chat) {
std::string role(message->role);
std::string content(message->content);
if (role == "system") {
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
} else if (role == "user") {
ss << "[INST]" << trailing_space << content << "[/INST]";
} else {
ss << trailing_space << content << "</s>";
ss << "[INST] " << content << "[/INST]";
}
else {
ss << " " << content << "</s>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
@@ -439,7 +432,7 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
@@ -449,14 +442,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>\n";
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
@@ -627,23 +620,7 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
// SmolVLM
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "<end_of_utterance>\n";
} else {
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
} else {
// template not supported
return -1;
}

View File

@@ -14,7 +14,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3,
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3,
@@ -30,8 +29,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
LLM_CHAT_TEMPLATE_COMMAND_R,
LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGLM_3,
LLM_CHAT_TEMPLATE_CHATGLM_4,
LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
@@ -42,7 +41,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View File

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@
#include "llama-kv-cache.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map>
#include <vector>
@@ -29,12 +28,7 @@ struct llama_context {
void synchronize();
const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
const llama_model & get_model() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
@@ -72,6 +66,7 @@ struct llama_context {
void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
void set_cross_attn(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
@@ -135,32 +130,6 @@ struct llama_context {
llama_perf_context_data perf_get_data() const;
void perf_reset();
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start);
private:
//
// output
@@ -170,30 +139,51 @@ private:
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
//
// graph
//
public:
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf,
const std::vector<struct llama_kv_defrag_move> & moves) const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
@@ -210,10 +200,14 @@ private:
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
std::unique_ptr<llama_kv_cache_unified> kv_self;
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
@@ -240,9 +234,6 @@ private:
ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;

View File

@@ -29,8 +29,8 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
bool cross_attn;
bool warmup;
bool op_offload;
enum llama_pooling_type pooling_type;

View File

@@ -55,21 +55,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && pos) {
const int64_t n_tokens = ubatch->n_tokens;
if (ubatch->token && n_pos_per_embd == 4) {
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
// the 3 first dims are the same, and 4th dim is all 0
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
// copy the first dimension
for (int i = 0; i < n_tokens; ++i) {
pos_data[ i] = ubatch->pos[i];
pos_data[ n_tokens + i] = ubatch->pos[i];
pos_data[2 * n_tokens + i] = ubatch->pos[i];
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
} else {
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
}
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
}
}
@@ -85,7 +71,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
) * f_attn_temp_scale + 1.0;
}
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
}
}
@@ -284,7 +270,24 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
data[i] = kv_self->s_copy(i);
const uint32_t cell_id = i + kv_self->head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}
@@ -300,7 +303,18 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
// clear unused states
for (int i = 0; i < n_kv; ++i) {
data[i] = kv_self->s_mask(i);
const uint32_t cell_id = i + kv_self->head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
data[i] = (float) (kv_cell.src >= 0);
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
}
@@ -532,6 +546,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
if (ubatch->embd) {
ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
}
}
//
// llm_graph_context
//
@@ -578,7 +598,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
res (std::make_unique<llm_graph_result>()) {
}
int64_t llm_graph_context::n_pos_per_embd() const {
int64_t llm_graph_context::n_pos_per_token() const {
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
}
@@ -782,17 +802,13 @@ ggml_tensor * llm_graph_context::build_ffn(
} break;
}
if (gate && type_gate == LLM_FFN_PAR) {
if (type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il);
}
if (down) {
cur = build_lora_mm(down, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (down_b) {
@@ -900,35 +916,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * experts = nullptr;
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
} else {
cur = up;
}
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
switch (type_op) {
case LLM_FFN_SILU:
{
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_moe_silu", il);
gate = ggml_silu(ctx0, gate);
cb(gate, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_moe_gelu", il);
gate = ggml_gelu(ctx0, gate);
cb(gate, "ffn_moe_gelu", il);
} break;
default:
GGML_ABORT("fatal error");
}
if (gate_exps) {
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate_par", il);
}
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
if (!weight_before_ffn) {
@@ -971,7 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
@@ -1012,11 +1020,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
}
ggml_tensor * llm_graph_context::build_inp_pos() const {
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
auto & cur = inp->pos;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
ggml_set_input(cur);
res->add_input(std::move(inp));
@@ -1025,12 +1033,11 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
}
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
auto & cur = inp->attn_scale;
// this need to be 1x1xN for broadcasting
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
ggml_set_input(cur);
res->add_input(std::move(inp));
@@ -1078,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
@@ -1095,7 +1102,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
}
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
@@ -1228,19 +1235,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
if (v_mla) {
#if 0
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
cur = ggml_mul_mat(ctx0, v_mla, cur);
#else
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
// The permutations are noops and only change how the tensor data is interpreted.
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_mul_mat(ctx0, v_mla, cur);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
#endif
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
@@ -1420,6 +1416,8 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self->size == n_ctx);
@@ -1514,6 +1512,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
const int64_t n_embd = hparams.n_embd;
auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
ggml_tensor * cur = nullptr;
inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
ggml_set_input(inp->cross_attn_state);
cur = inp->cross_attn_state;
cb(cur, "inp_cross_attn_state", -1);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
@@ -1569,7 +1586,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto n_kv = kv_self->n;
const auto kv_head = kv_self->head;
@@ -1601,7 +1618,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto token_shift_count = hparams.token_shift_count;
@@ -1622,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;

View File

@@ -19,7 +19,6 @@ struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -87,31 +86,34 @@ public:
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_embd = 1;
const int64_t n_pos_per_token = 1;
};
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
const int64_t n_pos_per_token = 1;
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
};
@@ -187,26 +189,26 @@ public:
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_recurrent * kv_self;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_recurrent * kv_self;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -284,6 +286,16 @@ public:
const llama_cross * cross = nullptr;
};
class llm_graph_input_cross_attn_state : public llm_graph_input_i {
public:
llm_graph_input_cross_attn_state() = default;
virtual ~llm_graph_input_cross_attn_state() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
};
//
// llm_graph_result
//
@@ -298,7 +310,6 @@ class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
@@ -313,7 +324,6 @@ class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -330,7 +340,6 @@ public:
}
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
@@ -354,8 +363,8 @@ struct llm_graph_params {
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
ggml_backend_sched * sched;
ggml_backend * backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
@@ -406,9 +415,9 @@ struct llm_graph_context {
ggml_context * ctx0 = nullptr;
ggml_backend_sched_t sched;
ggml_backend_sched * sched;
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
@@ -421,7 +430,7 @@ struct llm_graph_context {
llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_embd() const;
int64_t n_pos_per_token() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
@@ -495,6 +504,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_attn_state() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;

View File

@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}
bool llama_hparams::cross_attention_layers(uint32_t il) const {
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
}

View File

@@ -2,6 +2,8 @@
#include "llama.h"
#include <algorithm>
#include <array>
// bump if necessary
@@ -42,6 +44,7 @@ struct llama_hparams {
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0;
uint32_t n_vocab = 0;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla = 0;
@@ -56,6 +59,7 @@ struct llama_hparams {
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
@@ -68,7 +72,6 @@ struct llama_hparams {
float expert_weights_scale = 0.0;
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
float f_norm_eps;
float f_norm_rms_eps;
@@ -159,6 +162,9 @@ struct llama_hparams {
// Block skip connection
bool n_bskcn(uint32_t n, uint32_t il) const;
// cross attention layers
bool cross_attention_layers(uint32_t il) const;
bool is_swa(uint32_t il) const;
};

View File

File diff suppressed because it is too large Load Diff

View File

@@ -2,72 +2,32 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include <functional>
#include <set>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
using llama_memory_i::llama_memory_i;
// call if batch processing fails - restores the cache state
virtual void restore() = 0;
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
// call after successful batch processing - clears any pending state
virtual void commit() = 0;
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
// process any pending defrag/shift/etc. operations
// optionally call once before processing a new batch
virtual bool update(llama_context & lctx) = 0;
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual void set_full() = 0;
//
// batch processing
//
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;
// getters
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
//
// llama_kv_cache_guard
//
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
@@ -82,7 +42,7 @@ struct llama_kv_cache_guard {
private:
llama_kv_cache * kv;
};
// block of KV slots to move when defragging
struct llama_kv_defrag_move {
uint32_t src;
@@ -90,50 +50,65 @@ struct llama_kv_defrag_move {
uint32_t len;
};
//
// llama_kv_cache_unified
//
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const llama_kv_cell & other) const {
return seq_id == other.seq_id;
}
};
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
};
static uint32_t get_padding(const llama_cparams & cparams);
llama_kv_cache_unified(
const llama_model & model,
const llama_hparams & hparams,
callbacks cbs);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding);
bool offload);
~llama_kv_cache_unified() = default;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
//
// llama_memory_i
//
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
void clear() override;
void defrag() override;
virtual void restore() override;
virtual void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -143,76 +118,25 @@ public:
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;
bool find_slot(const llama_ubatch & batch);
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
// find how many cells are currently in use
uint32_t cell_max() const;
bool get_can_shift() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
const llama_model & model;
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// required padding
uint32_t padding = 1;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
// defrag
struct {
std::vector<llama_kv_defrag_move> moves;
} defrag_info;
@@ -221,6 +145,7 @@ private:
bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
@@ -231,125 +156,25 @@ private:
std::vector<slot_range> ranges;
} pending;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf,
const std::vector<llama_kv_defrag_move> & moves) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// llama_kv_cache_recurrent
//
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);
~llama_kv_cache_recurrent() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
int32_t s_copy(int i) const;
float s_mask(int i) const;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members
const llama_hparams & hparams;
callbacks cbs;
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
@@ -361,41 +186,18 @@ public:
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<llama_kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
//const llama_model & model;
const llama_hparams & hparams;
// commit/restore cache
// TODO: rework for recurrent cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -403,6 +205,11 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
//
// kv cache view

View File

@@ -2,22 +2,12 @@
#include "llama.h"
struct llama_memory_params {
// kv cache
ggml_type type_k;
ggml_type type_v;
// parameters for other types of memory
// ...
};
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;

Some files were not shown because too many files have changed in this diff Show More