mirror of
https://github.com/ollama/ollama.git
synced 2025-12-24 08:10:54 -05:00
Compare commits
2 Commits
brucemacd/
...
parth/pyth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23e8ac9428 | ||
|
|
611d3a17ed |
16
.github/workflows/release.yaml
vendored
16
.github/workflows/release.yaml
vendored
@@ -432,22 +432,6 @@ jobs:
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Trigger downstream release process
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
|
||||
@@ -19,8 +19,8 @@ linters:
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- staticcheck
|
||||
- tenv
|
||||
- unconvert
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
disable:
|
||||
|
||||
@@ -51,8 +51,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
|
||||
add_compile_definitions(NDEBUG)
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
|
||||
FETCH_HEAD=2016f07bd106c73699ecbaace80f55db5ed95dac
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@@ -15,13 +15,11 @@ help:
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||
|
||||
.PHONY: sync
|
||||
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
|
||||
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml
|
||||
|
||||
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp
|
||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@
|
||||
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
|
||||
go generate ./$(@D)
|
||||
.PHONY: llama/build-info.cpp
|
||||
llama/build-info.cpp: llama/build-info.cpp.in
|
||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
||||
|
||||
.PHONY: llama/llama.cpp
|
||||
llama/llama.cpp: llama/vendor/
|
||||
@@ -32,13 +30,12 @@ ml/backend/ggml/ggml: llama/vendor/ggml/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
|
||||
PATCHES=$(wildcard llama/patches/*.patch)
|
||||
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
|
||||
|
||||
.PHONY: apply-patches
|
||||
.NOTPARALLEL:
|
||||
apply-patches: $(PATCHED)
|
||||
apply-patches: $(addsuffix ed, $(PATCHES))
|
||||
|
||||
llama/patches/.%.patched: llama/patches/%.patch
|
||||
%.patched: %.patch
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
||||
|
||||
.PHONY: checkout
|
||||
@@ -60,4 +57,4 @@ format-patches: llama/patches
|
||||
|
||||
.PHONE: clean
|
||||
clean: checkout
|
||||
$(RM) llama/patches/.*.patched
|
||||
$(RM) $(addsuffix ed, $(PATCHES))
|
||||
|
||||
50
README.md
50
README.md
@@ -61,8 +61,6 @@ Here are some example models that can be downloaded:
|
||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
|
||||
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
|
||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||
@@ -79,7 +77,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -287,7 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
@@ -314,8 +312,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
|
||||
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||
@@ -329,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
|
||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
|
||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||
@@ -345,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
||||
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
|
||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
||||
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
@@ -372,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
||||
@@ -390,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
@@ -398,15 +394,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -448,9 +440,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -477,7 +468,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Libraries
|
||||
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
||||
- [crewAI](https://github.com/crewAIInc/crewAI)
|
||||
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
||||
@@ -524,21 +515,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
@@ -562,7 +552,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
@@ -572,8 +562,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||
@@ -587,8 +577,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
|
||||
### Supported backends
|
||||
|
||||
|
||||
@@ -24,10 +24,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -79,14 +76,6 @@ func NewClient(base *url.URL, http *http.Client) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
|
||||
token, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||
var reqBody io.Reader
|
||||
var data []byte
|
||||
@@ -108,21 +97,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -132,10 +106,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
respObj, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -173,22 +143,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
var err error
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -198,10 +152,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
request.Header.Set("Accept", "application/x-ndjson")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
response, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -136,7 +137,7 @@ func TestClientStream(t *testing.T) {
|
||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||
|
||||
var receivedChunks []ChatResponse
|
||||
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||
@@ -222,7 +223,7 @@ func TestClientDo(t *testing.T) {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||
|
||||
if tc.wantErr != "" {
|
||||
if err == nil {
|
||||
|
||||
54
api/types.go
54
api/types.go
@@ -83,12 +83,6 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -114,10 +108,6 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -136,11 +126,8 @@ func (t Tool) String() string {
|
||||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
@@ -284,6 +271,9 @@ type Options struct {
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
@@ -293,7 +283,12 @@ type Runner struct {
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
@@ -457,13 +452,12 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
@@ -477,6 +471,13 @@ type ProcessModelResponse struct {
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
}
|
||||
|
||||
type RetrieveModelResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
@@ -492,10 +493,6 @@ type GenerateResponse struct {
|
||||
// Response is the textual response itself.
|
||||
Response string `json:"response"`
|
||||
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
@@ -663,6 +660,9 @@ func DefaultOptions() Options {
|
||||
RepeatPenalty: 1.1,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Mirostat: 0,
|
||||
MirostatTau: 5.0,
|
||||
MirostatEta: 0.1,
|
||||
Seed: -1,
|
||||
|
||||
Runner: Runner{
|
||||
@@ -671,6 +671,8 @@ func DefaultOptions() Options {
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
UseMLock: false,
|
||||
UseMMap: nil,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -372,50 +372,3 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
input: `{ }`,
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var req GenerateRequest
|
||||
err := json.Unmarshal([]byte(test.input), &req)
|
||||
if test.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,14 +4,20 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
func InitLogging() {
|
||||
level := slog.LevelInfo
|
||||
|
||||
if envconfig.Debug() {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
var logFile *os.File
|
||||
var err error
|
||||
// Detect if we're a GUI app on windows, and if not, send logs to console
|
||||
@@ -27,8 +33,20 @@ func InitLogging() {
|
||||
return
|
||||
}
|
||||
}
|
||||
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
return attr
|
||||
},
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
|
||||
slog.Info("ollama app started")
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func BenchmarkColdStart(b *testing.B) {
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
ctx := context.Background()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
@@ -113,7 +113,7 @@ func BenchmarkWarmStart(b *testing.B) {
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
@@ -140,7 +140,7 @@ func setup(b *testing.B) *api.Client {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
|
||||
287
cmd/cmd.go
287
cmd/cmd.go
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -39,31 +38,12 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
@@ -126,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
req.Name = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -137,54 +117,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
|
||||
files := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Files {
|
||||
g.Go(func() error {
|
||||
if len(req.Files) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
for f, digest := range req.Files {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: this is incorrect since the file might be in a subdirectory
|
||||
// instead this should take the path relative to the model directory
|
||||
// but the current implementation does not allow this
|
||||
files.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
}
|
||||
req.Files = fileMap
|
||||
}
|
||||
|
||||
adapters := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Adapters {
|
||||
g.Go(func() error {
|
||||
if len(req.Adapters) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
for f, digest := range req.Adapters {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: same here
|
||||
adapters.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
}
|
||||
req.Adapters = fileMap
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Files = files.Items()
|
||||
req.Adapters = adapters.Items()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -253,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
|
||||
}
|
||||
}()
|
||||
|
||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
@@ -283,9 +243,6 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
req := &api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
|
||||
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
@@ -320,22 +277,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
hidethinking, err := cmd.Flags().GetBool("hidethinking")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.HideThinking = hidethinking
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -399,11 +340,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
@@ -789,38 +725,11 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
case float64:
|
||||
v = fmt.Sprintf("%g", vData)
|
||||
case []any:
|
||||
targetWidth := 10 // Small width where we are displaying the data in a column
|
||||
|
||||
var itemsToShow int
|
||||
totalWidth := 1 // Start with 1 for opening bracket
|
||||
|
||||
// Find how many we can fit
|
||||
for i := range vData {
|
||||
itemStr := fmt.Sprintf("%v", vData[i])
|
||||
width := runewidth.StringWidth(itemStr)
|
||||
|
||||
// Add separator width (", ") for all items except the first
|
||||
if i > 0 {
|
||||
width += 2
|
||||
}
|
||||
|
||||
// Check if adding this item would exceed our width limit
|
||||
if totalWidth+width > targetWidth && i > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
totalWidth += width
|
||||
itemsToShow++
|
||||
}
|
||||
|
||||
// Format the output
|
||||
if itemsToShow < len(vData) {
|
||||
v = fmt.Sprintf("%v", vData[:itemsToShow])
|
||||
v = strings.TrimSuffix(v, "]")
|
||||
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
|
||||
} else {
|
||||
v = fmt.Sprintf("%v", vData)
|
||||
n := 3
|
||||
if len(vData) < n {
|
||||
n = len(vData)
|
||||
}
|
||||
v = fmt.Sprintf("%v", vData[:n])
|
||||
default:
|
||||
v = fmt.Sprintf("%T", vData)
|
||||
}
|
||||
@@ -841,19 +750,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
|
||||
head := func(s string, n int) (rows [][]string) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
count := 0
|
||||
for scanner.Scan() {
|
||||
text := strings.TrimSpace(scanner.Text())
|
||||
if text == "" {
|
||||
continue
|
||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
||||
if text := scanner.Text(); text != "" {
|
||||
rows = append(rows, []string{"", strings.TrimSpace(text)})
|
||||
}
|
||||
count++
|
||||
if n < 0 || count <= n {
|
||||
rows = append(rows, []string{"", text})
|
||||
}
|
||||
}
|
||||
if n >= 0 && count > n {
|
||||
rows = append(rows, []string{"", "..."})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -965,19 +865,17 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
type generateContextKey string
|
||||
|
||||
type runOptions struct {
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
HideThinking bool
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -1033,26 +931,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
}
|
||||
|
||||
func thinkingOutputOpeningText(plainText bool) string {
|
||||
text := "Thinking...\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||
}
|
||||
|
||||
func thinkingOutputClosingText(plainText bool) string {
|
||||
text := "...done thinking.\n\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||
}
|
||||
|
||||
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -1080,34 +958,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
// filtered out anyway since current models don't expect them unless you're
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
@@ -1124,7 +982,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
@@ -1186,32 +1043,13 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
content := response.Response
|
||||
|
||||
if response.Response != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
if response.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1237,7 +1075,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
@@ -1341,11 +1178,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
|
||||
if !strings.Contains(err.Error(), " refused") {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return errors.New("could not connect to ollama app, is it running?")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1423,7 +1260,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1453,8 +1290,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
@@ -1506,6 +1341,7 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
@@ -1571,6 +1407,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||
})
|
||||
default:
|
||||
appendEnvDocs(cmd, envs)
|
||||
@@ -1594,45 +1431,3 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
// If the user has explicitly set thinking options, either through the CLI or
|
||||
// through the `/set think` or `set nothink` interactive options, then we
|
||||
// respect them. Otherwise, we check model capabilities to see if the model
|
||||
// supports thinking. If the model does support thinking, we enable it.
|
||||
// Otherwise, we unset the thinking option (which is different than setting it
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
|
||||
if caps == nil {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := client.Show(context.Background(), &api.ShowRequest{
|
||||
Model: runOpts.Model,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caps = &ret.Capabilities
|
||||
}
|
||||
|
||||
thinkingSupported := false
|
||||
for _, cap := range *caps {
|
||||
if cap == model.CapabilityThinking {
|
||||
thinkingSupported = true
|
||||
}
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -225,7 +226,6 @@ Weigh anchor!
|
||||
System
|
||||
You are a pirate!
|
||||
Ahoy, matey!
|
||||
...
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
@@ -337,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.SetContext(context.TODO())
|
||||
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
|
||||
t.Fatalf("DeleteHandler failed: %v", err)
|
||||
}
|
||||
@@ -399,6 +399,11 @@ func TestGetModelfileName(t *testing.T) {
|
||||
var expectedFilename string
|
||||
|
||||
if tt.fileExists {
|
||||
tempDir, err := os.MkdirTemp("", "modelfiledir")
|
||||
defer os.RemoveAll(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile dir creation failed: %v", err)
|
||||
}
|
||||
var fn string
|
||||
if tt.modelfileName != "" {
|
||||
fn = tt.modelfileName
|
||||
@@ -406,7 +411,7 @@ func TestGetModelfileName(t *testing.T) {
|
||||
fn = "Modelfile"
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
||||
tempFile, err := os.CreateTemp(tempDir, fn)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||
}
|
||||
@@ -525,7 +530,7 @@ func TestPushHandler(t *testing.T) {
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
@@ -630,7 +635,7 @@ func TestListHandler(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
@@ -685,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
if req.Name != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
}
|
||||
|
||||
@@ -725,7 +730,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
}))
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
||||
tempFile, err := os.CreateTemp("", "modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -745,7 +750,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
|
||||
@@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -62,8 +62,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@@ -130,7 +128,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -198,19 +195,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
continue
|
||||
@@ -271,22 +260,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
@@ -475,11 +448,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
@@ -543,7 +511,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
@@ -563,8 +531,6 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
@@ -585,7 +551,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,17 +10,14 @@ func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
|
||||
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 7)
|
||||
assert.Len(t, res, 5)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[5], "six.webp")
|
||||
assert.Contains(t, res[6], "seven.WEBP")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbetween1")
|
||||
assert.NotContains(t, res, "./1.svg")
|
||||
@@ -33,12 +28,10 @@ func TestExtractFilenames(t *testing.T) {
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
|
||||
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
|
||||
d:\path with\spaces\thirteen.WEBP some ending
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 13)
|
||||
assert.Len(t, res, 10)
|
||||
assert.NotContains(t, res, "inbetween2")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
@@ -56,31 +49,4 @@ d:\path with\spaces\thirteen.WEBP some ending
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.PNG")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
assert.Contains(t, res[10], "eleven.webp")
|
||||
assert.Contains(t, res[10], "c:")
|
||||
assert.Contains(t, res[11], "twelve.WebP")
|
||||
assert.Contains(t, res[11], "c:")
|
||||
assert.Contains(t, res[12], "thirteen.WEBP")
|
||||
assert.Contains(t, res[12], "d:")
|
||||
}
|
||||
|
||||
// Ensure that file paths wrapped in single quotes are removed with the quotes.
|
||||
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "img.jpg")
|
||||
data := make([]byte, 600)
|
||||
copy(data, []byte{
|
||||
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
|
||||
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0xff, 0xd9,
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test image: %v", err)
|
||||
}
|
||||
|
||||
input := "before '" + fp + "' after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
|
||||
@@ -4,27 +4,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
Installer = "OllamaSetup.exe"
|
||||
)
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
if len(isProcRunning(Installer)) > 0 {
|
||||
return fmt.Errorf("upgrade in progress...")
|
||||
}
|
||||
// log.Printf("XXX Attempting to find and start ollama app")
|
||||
AppName := "ollama app.exe"
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
@@ -45,11 +35,14 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
// log.Printf("XXX attempting to start app %s", appExe)
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "hidden")
|
||||
cmd := exec.Command(cmd_path, "/c", appExe)
|
||||
// TODO - these hide flags aren't working - still pops up a command window for some reason
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
// TODO this didn't help either...
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -63,50 +56,3 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
}
|
||||
|
||||
func isProcRunning(procName string) []uint32 {
|
||||
pids := make([]uint32, 2048)
|
||||
var ret uint32
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
if ret > uint32(len(pids)) {
|
||||
pids = make([]uint32, ret+10)
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if ret < uint32(len(pids)) {
|
||||
pids = pids[:ret]
|
||||
}
|
||||
var matches []uint32
|
||||
for _, pid := range pids {
|
||||
if pid == 0 {
|
||||
continue
|
||||
}
|
||||
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer windows.CloseHandle(hProcess)
|
||||
var module windows.Handle
|
||||
var cbNeeded uint32
|
||||
cb := (uint32)(unsafe.Sizeof(module))
|
||||
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
|
||||
continue
|
||||
}
|
||||
var sz uint32 = 1024 * 8
|
||||
moduleName := make([]uint16, sz)
|
||||
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
|
||||
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||
continue
|
||||
}
|
||||
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
|
||||
if strings.EqualFold(exeFile, procName) {
|
||||
matches = append(matches, pid)
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
ensureThinkingSupport(t.Context(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -15,12 +14,13 @@ import (
|
||||
)
|
||||
|
||||
type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TextModel TextParameters `json:"text_config"`
|
||||
}
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
type TextParameters struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
}
|
||||
|
||||
type AdapterParameters struct {
|
||||
@@ -53,11 +53,8 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
for _, sv := range t.SpecialVocabulary {
|
||||
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
||||
if len(sv.IDs) > 0 {
|
||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
|
||||
}
|
||||
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
||||
}
|
||||
|
||||
return kv
|
||||
@@ -92,7 +89,7 @@ type ModelConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
@@ -109,13 +106,13 @@ type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -150,14 +147,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -176,8 +173,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
switch p.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
conv = &llamaModel{}
|
||||
case "MllamaForConditionalGeneration":
|
||||
conv = &mllamaModel{}
|
||||
case "Llama4ForConditionalGeneration":
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
@@ -194,8 +189,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
conv = &qwen2Model{}
|
||||
case "Qwen2_5_VLForConditionalGeneration":
|
||||
conv = &qwen25VLModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
@@ -219,22 +212,24 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
vocabSize := int(p.VocabSize)
|
||||
if vocabSize == 0 {
|
||||
tVocabSize := int(p.TextModel.VocabSize)
|
||||
vocabSize = tVocabSize
|
||||
}
|
||||
|
||||
switch {
|
||||
case vocabSize == 0:
|
||||
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
case vocabSize > len(t.Vocabulary.Tokens):
|
||||
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
||||
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
||||
}
|
||||
case vocabSize < len(t.Vocabulary.Tokens):
|
||||
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
|
||||
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
@@ -244,13 +239,13 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
return writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
}
|
||||
return ggml.WriteGGUF(f, kv, ts)
|
||||
return ggml.WriteGGUF(ws, kv, ts)
|
||||
}
|
||||
|
||||
@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
t.SetRepacker(p.addOne)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -126,11 +126,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
if p.RopeScaling.factors != nil {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||
@@ -139,14 +139,13 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") ||
|
||||
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
if !p.skipRepack {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
@@ -182,9 +181,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") {
|
||||
if strings.HasSuffix(name, "attn_q.weight") {
|
||||
heads = p.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") {
|
||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
||||
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
|
||||
@@ -88,13 +88,13 @@ func (p *llama4Model) Replacements() []string {
|
||||
}
|
||||
|
||||
// Tensors implements ModelConverter.
|
||||
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
var textTensors []Tensor
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
@@ -112,7 +112,7 @@ func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// clone tensor since we need separate repackers
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
|
||||
Kind: tt.Kind(),
|
||||
Shape: newShape,
|
||||
@@ -125,7 +125,7 @@ func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
t.SetRepacker(p.repack())
|
||||
newShape := slices.Clone(t.Shape())
|
||||
newShape[1], newShape[2] = newShape[2], newShape[1]
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
|
||||
@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
|
||||
@@ -89,8 +89,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
@@ -100,7 +100,7 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
@@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
return true
|
||||
})
|
||||
|
||||
var out []*ggml.Tensor
|
||||
var out []ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type mllamaModel struct {
|
||||
ModelParameters
|
||||
TextModel struct {
|
||||
llamaModel
|
||||
|
||||
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumGlobalLayers uint32 `json:"num_global_layers"`
|
||||
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
|
||||
AttentionHeads uint32 `json:"attention_heads"`
|
||||
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
MaxNumTiles uint32 `json:"max_num_tiles"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope.freq_base"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
for k, v := range m.TextModel.KV(t) {
|
||||
if strings.HasPrefix(k, "llama.") {
|
||||
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
|
||||
}
|
||||
}
|
||||
|
||||
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
|
||||
|
||||
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
|
||||
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
|
||||
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
|
||||
|
||||
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
|
||||
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
|
||||
|
||||
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
|
||||
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
|
||||
|
||||
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
|
||||
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
|
||||
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
|
||||
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *mllamaModel) Replacements() []string {
|
||||
return append(
|
||||
m.TextModel.Replacements(),
|
||||
"language_model.", "",
|
||||
"gate_attn", "attn_gate",
|
||||
"gate_ffn", "ffn_gate",
|
||||
"cross_attn.", "cross_attn_",
|
||||
"vision_model", "v",
|
||||
"class_embedding", "class_embd",
|
||||
"patch_embedding", "patch_embd",
|
||||
"gated_positional_embedding.tile_embedding", "tile_position_embd",
|
||||
"gated_positional_embedding.embedding", "position_embd.weight",
|
||||
"gated_positional_embedding", "position_embd",
|
||||
"embedding.weight", "weight",
|
||||
"pre_tile_positional_embedding", "pre_tile_position_embd",
|
||||
"post_tile_positional_embedding", "post_tile_position_embd",
|
||||
"layernorm_pre", "pre_ln",
|
||||
"layernorm_post", "post_ln",
|
||||
"global_transformer.layers", "global.blk",
|
||||
"transformer.layers", "blk",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
"multi_modal_projector", "mm.0",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
var text []Tensor
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
|
||||
text = append(text, t)
|
||||
} else if t.Name() == "v.position_embd.gate" {
|
||||
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(m.repack(name))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return append(out, m.TextModel.Tensors(text)...)
|
||||
}
|
||||
|
||||
func (m *mllamaModel) repack(name string) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
|
||||
heads := m.VisionModel.AttentionHeads
|
||||
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
t, err = tensor.Tanh(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name == "v.position_embd.gate" {
|
||||
t, err = tensor.Sub(float32(1), t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be return as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var addRopeFactors sync.Once
|
||||
|
||||
out := make([]*ggml.Tensor, 0, len(ts)+2)
|
||||
out := make([]ggml.Tensor, 0, len(ts)+2)
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||
addRopeFactors.Do(func() {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: "rope_factors_long.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||
WriterTo: p.RopeScaling.LongFactor,
|
||||
}, &ggml.Tensor{
|
||||
}, ggml.Tensor{
|
||||
Name: "rope_factors_short.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -15,7 +15,6 @@ type qwen2Model struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_scaling"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
}
|
||||
@@ -40,18 +39,16 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
case "yarn":
|
||||
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
case "mrope", "default":
|
||||
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
|
||||
default:
|
||||
panic("unknown rope scaling type")
|
||||
}
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen25VLModel struct {
|
||||
qwen2Model
|
||||
|
||||
VisionModel struct {
|
||||
Depth uint32 `json:"depth"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
InChannels uint32 `json:"in_chans"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
||||
WindowSize uint32 `json:"window_size"`
|
||||
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
for k, v := range q.qwen2Model.KV(t) {
|
||||
if strings.HasPrefix(k, "qwen2.") {
|
||||
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
|
||||
}
|
||||
}
|
||||
|
||||
if q.VisionModel.FullAttentionBlocks == nil {
|
||||
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
|
||||
}
|
||||
|
||||
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
|
||||
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
|
||||
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
|
||||
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
|
||||
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
|
||||
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
|
||||
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
|
||||
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
|
||||
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
|
||||
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
|
||||
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||
for t := range splitDim(t, 2,
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
|
||||
) {
|
||||
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||
out = append(out, t)
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
strings.NewReplacer("attn.qkv", "attn_q"),
|
||||
strings.NewReplacer("attn.qkv", "attn_k"),
|
||||
strings.NewReplacer("attn.qkv", "attn_v"),
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen25VLModel) Replacements() []string {
|
||||
return append(
|
||||
p.qwen2Model.Replacements(),
|
||||
"visual", "v",
|
||||
"blocks", "blk",
|
||||
"attn.proj", "attn_out",
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
}
|
||||
t.Cleanup(func() { r.Close() })
|
||||
|
||||
m, err := ggml.Decode(r, -1)
|
||||
m, _, err := ggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -130,7 +130,6 @@ func TestConvertModel(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer expectFile.Close()
|
||||
|
||||
var expect map[string]string
|
||||
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
|
||||
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
m, err := ggml.Decode(r, -1)
|
||||
m, _, err := ggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
58
convert/fs.go
Normal file
58
convert/fs.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type ZipReader struct {
|
||||
r *zip.Reader
|
||||
p string
|
||||
|
||||
// limit is the maximum size of a file that can be read directly
|
||||
// from the zip archive. Files larger than this size will be extracted
|
||||
limit int64
|
||||
}
|
||||
|
||||
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
|
||||
return &ZipReader{r, p, limit}
|
||||
}
|
||||
|
||||
func (z *ZipReader) Open(name string) (fs.File, error) {
|
||||
r, err := z.r.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if fi, err := r.Stat(); err != nil {
|
||||
return nil, err
|
||||
} else if fi.Size() < z.limit {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(name) {
|
||||
return nil, zip.ErrInsecurePath
|
||||
}
|
||||
|
||||
n := filepath.Join(z.p, name)
|
||||
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
|
||||
w, err := os.Create(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.Open(n)
|
||||
}
|
||||
@@ -38,10 +38,7 @@ const (
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
t.name == "v.positional_embedding_vlm" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
// is split evenly based on the number of replacers provided.
|
||||
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
|
||||
return func(yield func(*ggml.Tensor) bool) {
|
||||
for i, replacer := range replacers {
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = shape[dim] / uint64(len(replacers))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
|
||||
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
|
||||
if !yield(&ggml.Tensor{
|
||||
Name: replacer.Replace(t.Name()),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: tt,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -110,7 +110,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
}
|
||||
|
||||
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
@@ -172,34 +171,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
}
|
||||
}
|
||||
|
||||
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
defer f.Close()
|
||||
|
||||
var p map[string]json.RawMessage
|
||||
if err := json.NewDecoder(f).Decode(&p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, st := range specialTokenTypes {
|
||||
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
|
||||
var ids []int32
|
||||
if err := json.Unmarshal(bts, &ids); err != nil {
|
||||
// value is not a list so the existing ID is used
|
||||
continue
|
||||
}
|
||||
|
||||
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
|
||||
return sv.Type == st
|
||||
}); i >= 0 {
|
||||
t.SpecialVocabulary[i].IDs = ids
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -309,9 +280,6 @@ type SpecialVocabulary struct {
|
||||
ID int
|
||||
Content string
|
||||
AddToken bool
|
||||
|
||||
// IDs is populated by generation_config.json
|
||||
IDs []int32
|
||||
}
|
||||
|
||||
func (sv SpecialVocabulary) Key() string {
|
||||
|
||||
@@ -247,67 +247,6 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generation config eos token ids",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "<bos>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"content": "<eos>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"content": "<eot>",
|
||||
"special": true
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"content": "<eom>",
|
||||
"special": true
|
||||
}
|
||||
],
|
||||
"model": {
|
||||
"vocab": {
|
||||
"<bos>": 0,
|
||||
"<eos>": 1,
|
||||
"<eot>": 2,
|
||||
"<eom>": 3
|
||||
}
|
||||
}
|
||||
}`),
|
||||
"tokenizer_config.json": strings.NewReader(`{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"bos_token": "<bos>",
|
||||
"eos_token": "<eos>"
|
||||
}`),
|
||||
"generation_config.json": strings.NewReader(`{
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": [1, 2, 3]
|
||||
}`),
|
||||
}),
|
||||
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
|
||||
Scores: []float32{0, 1, 2, 3},
|
||||
Types: []int32{3, 3, 3, 3},
|
||||
},
|
||||
SpecialVocabulary: []*SpecialVocabulary{
|
||||
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
|
||||
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if envconfig.LogLevel() < slog.LevelInfo {
|
||||
if envconfig.Debug() {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
|
||||
@@ -27,14 +27,12 @@
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef LOG
|
||||
#define LOG(verbose, ...) \
|
||||
do { \
|
||||
if (verbose) { \
|
||||
fprintf(stderr, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_cudart.h"
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
@@ -59,7 +58,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
@@ -169,9 +168,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->free = memInfo.free;
|
||||
resp->used = memInfo.used;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
}
|
||||
|
||||
@@ -181,4 +180,4 @@ void cudart_release(cudart_handle_t h) {
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
@@ -1,7 +1,6 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_nvcuda.h"
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||
@@ -194,8 +193,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
|
||||
|
||||
@@ -248,4 +247,4 @@ void nvcuda_release(nvcuda_handle_t h) {
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
73
docs/api.md
73
docs/api.md
@@ -19,7 +19,7 @@
|
||||
|
||||
### Model names
|
||||
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
|
||||
### Durations
|
||||
|
||||
@@ -43,7 +43,6 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
- `prompt`: the prompt to generate a response for
|
||||
- `suffix`: the text after the model response
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
@@ -395,6 +394,9 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"repeat_penalty": 1.2,
|
||||
"presence_penalty": 1.5,
|
||||
"frequency_penalty": 1.0,
|
||||
"mirostat": 1,
|
||||
"mirostat_tau": 0.8,
|
||||
"mirostat_eta": 0.6,
|
||||
"penalize_newline": true,
|
||||
"stop": ["\n", "user:"],
|
||||
"numa": false,
|
||||
@@ -402,7 +404,10 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"num_batch": 2,
|
||||
"num_gpu": 1,
|
||||
"main_gpu": 0,
|
||||
"low_vram": false,
|
||||
"vocab_only": false,
|
||||
"use_mmap": true,
|
||||
"use_mlock": false,
|
||||
"num_thread": 8
|
||||
}
|
||||
}'
|
||||
@@ -491,13 +496,11 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: list of tools in JSON for the model to use if supported
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||
- `content`: the content of the message
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
|
||||
@@ -955,8 +958,19 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
||||
|
||||
| Type | Recommended |
|
||||
| --- | :-: |
|
||||
| q2_K | |
|
||||
| q3_K_L | |
|
||||
| q3_K_M | |
|
||||
| q3_K_S | |
|
||||
| q4_0 | |
|
||||
| q4_1 | |
|
||||
| q4_K_M | * |
|
||||
| q4_K_S | |
|
||||
| q5_0 | |
|
||||
| q5_1 | |
|
||||
| q5_K_M | |
|
||||
| q5_K_S | |
|
||||
| q6_K | |
|
||||
| q8_0 | * |
|
||||
|
||||
### Examples
|
||||
@@ -1001,8 +1015,8 @@ Quantize a non-quantized model.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/create -d '{
|
||||
"model": "llama3.2:quantized",
|
||||
"from": "llama3.2:3b-instruct-fp16",
|
||||
"model": "llama3.1:quantized",
|
||||
"from": "llama3.1:8b-instruct-fp16",
|
||||
"quantize": "q4_K_M"
|
||||
}'
|
||||
```
|
||||
@@ -1012,14 +1026,12 @@ curl http://localhost:11434/api/create -d '{
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302}
|
||||
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552}
|
||||
{"status":"verifying conversion"}
|
||||
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"}
|
||||
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
|
||||
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
|
||||
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
|
||||
{"status":"quantizing F16 model to Q4_K_M"}
|
||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
||||
{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
|
||||
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
|
||||
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
```
|
||||
@@ -1157,46 +1169,29 @@ A single JSON object will be returned.
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
|
||||
"model": "codellama:13b",
|
||||
"name": "codellama:13b",
|
||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
||||
"size": 7365960935,
|
||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
||||
"capabilities": [
|
||||
"completion"
|
||||
],
|
||||
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": [
|
||||
"qwen2"
|
||||
],
|
||||
"parameter_size": "7.6B",
|
||||
"quantization_level": "Q4_K_M"
|
||||
"family": "llama",
|
||||
"families": null,
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
"model": "llama4:latest",
|
||||
"name": "llama3:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": [
|
||||
"llama"
|
||||
],
|
||||
"parameter_size": "3.2B",
|
||||
"quantization_level": "Q4_K_M"
|
||||
"families": null,
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
@@ -31,7 +31,7 @@ OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
```shell
|
||||
/set parameter num_ctx 4096
|
||||
/set parameter num_ctx 8192
|
||||
```
|
||||
|
||||
When using the API, specify the `num_ctx` parameter:
|
||||
@@ -41,7 +41,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama3.2",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"options": {
|
||||
"num_ctx": 4096
|
||||
"num_ctx": 8192
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -132,12 +132,22 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
@@ -150,6 +150,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
|
||||
@@ -149,22 +149,9 @@ func Bool(k string) func() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// LogLevel returns the log level for the application.
|
||||
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
|
||||
func LogLevel() slog.Level {
|
||||
level := slog.LevelInfo
|
||||
if s := Var("OLLAMA_DEBUG"); s != "" {
|
||||
if b, _ := strconv.ParseBool(s); b {
|
||||
level = slog.LevelDebug
|
||||
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
|
||||
level = slog.Level(i * -4)
|
||||
}
|
||||
}
|
||||
|
||||
return level
|
||||
}
|
||||
|
||||
var (
|
||||
// Debug enabled additional debug information.
|
||||
Debug = Bool("OLLAMA_DEBUG")
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
@@ -182,9 +169,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1)
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -224,6 +209,8 @@ var (
|
||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
|
||||
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
|
||||
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
|
||||
)
|
||||
|
||||
func Uint64(key string, defaultValue uint64) func() uint64 {
|
||||
@@ -240,6 +227,20 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
|
||||
}
|
||||
}
|
||||
|
||||
func Int64(key string, defaultValue int64) func() int64 {
|
||||
return func() int64 {
|
||||
if s := Var(key); s != "" {
|
||||
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
|
||||
} else {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
// Set aside VRAM per GPU
|
||||
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
|
||||
|
||||
@@ -251,7 +252,7 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
@@ -268,7 +269,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
|
||||
// Informational
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
func TestHost(t *testing.T) {
|
||||
@@ -280,9 +278,9 @@ func TestVar(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 4096,
|
||||
"2048": 2048,
|
||||
cases := map[string]int64{
|
||||
"": -1,
|
||||
"4096": 4096,
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
@@ -294,34 +292,3 @@ func TestContextLength(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevel(t *testing.T) {
|
||||
cases := map[string]slog.Level{
|
||||
// Default to INFO
|
||||
"": slog.LevelInfo,
|
||||
"false": slog.LevelInfo,
|
||||
"f": slog.LevelInfo,
|
||||
"0": slog.LevelInfo,
|
||||
|
||||
// True values enable Debug
|
||||
"true": slog.LevelDebug,
|
||||
"t": slog.LevelDebug,
|
||||
|
||||
// Positive values increase verbosity
|
||||
"1": slog.LevelDebug,
|
||||
"2": logutil.LevelTrace,
|
||||
|
||||
// Negative values decrease verbosity
|
||||
"-1": slog.LevelWarn,
|
||||
"-2": slog.LevelError,
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", k)
|
||||
if i := LogLevel(); i != v {
|
||||
t.Errorf("%s: expected %d, got %d", k, v, i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
118
fs/ggml/ggml.go
118
fs/ggml/ggml.go
@@ -15,7 +15,6 @@ import (
|
||||
type GGML struct {
|
||||
container
|
||||
model
|
||||
Length int64
|
||||
}
|
||||
|
||||
type model interface {
|
||||
@@ -37,12 +36,12 @@ func (kv KV) ParameterCount() uint64 {
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
}
|
||||
|
||||
func (kv KV) FileType() FileType {
|
||||
func (kv KV) FileType() fileType {
|
||||
if t := kv.Uint("general.file_type"); t > 0 {
|
||||
return FileType(t)
|
||||
return fileType(t)
|
||||
}
|
||||
|
||||
return FileTypeUnknown
|
||||
return fileTypeUnknown
|
||||
}
|
||||
|
||||
func (kv KV) BlockCount() uint64 {
|
||||
@@ -126,8 +125,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -152,7 +149,7 @@ func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ..
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
@@ -229,11 +226,7 @@ func (t Tensor) block() (n int) {
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
return (TensorType)(t.Kind).BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
switch t.Kind {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
@@ -259,77 +252,73 @@ func (t TensorType) BlockSize() uint64 {
|
||||
}
|
||||
|
||||
func (t Tensor) typeSize() uint64 {
|
||||
return TensorType(t.Kind).TypeSize()
|
||||
}
|
||||
blockSize := t.blockSize()
|
||||
|
||||
func (t TensorType) TypeSize() uint64 {
|
||||
blockSize := t.BlockSize()
|
||||
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
switch t.Kind {
|
||||
case 0: // FP32
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
case 1: // FP16
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
case 2: // Q4_0
|
||||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
case 3: // Q4_1
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
case 6: // Q5_0
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
case 7: // Q5_1
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case TensorTypeQ8_0:
|
||||
case 8: // Q8_0
|
||||
return 2 + blockSize
|
||||
case TensorTypeQ8_1:
|
||||
case 9: // Q8_1
|
||||
return 2 + 2 + blockSize
|
||||
case TensorTypeQ2_K:
|
||||
case 10: // Q2_K
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
case 11: // Q3_K
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
case 12: // Q4_K
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case TensorTypeQ5_K:
|
||||
case 13: // Q5_K
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case TensorTypeQ6_K:
|
||||
case 14: // Q6_K
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
case 15: // Q8_K
|
||||
return 4 + blockSize + 2*blockSize/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
case 16: // IQ2_XXS
|
||||
return 2 + 2*blockSize/8
|
||||
case tensorTypeIQ2_XS:
|
||||
case 17: // IQ2_XS
|
||||
return 2 + 2*blockSize/8 + blockSize/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
case 18: // IQ3_XXS
|
||||
return 2 + blockSize/4 + blockSize/8
|
||||
case tensorTypeIQ1_S:
|
||||
case 19: // IQ1_S
|
||||
return 2 + blockSize/8 + blockSize/16
|
||||
case tensorTypeIQ4_NL:
|
||||
case 20: // IQ4_NL
|
||||
return 2 + blockSize/2
|
||||
case tensorTypeIQ3_S:
|
||||
case 21: // IQ3_S
|
||||
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
case 22: // IQ2_S
|
||||
return 2 + blockSize/4 + blockSize/16
|
||||
case tensorTypeIQ4_XS:
|
||||
case 23: // IQ4_XS
|
||||
return 2 + 2 + blockSize/2 + blockSize/64
|
||||
case TensorTypeI8:
|
||||
case 24: // I8
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
case 25: // I16
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
case 26: // I32
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
case 27: // I64
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
case 28: // F64
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
case 29: // IQ1_M
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
case 30: // BF16
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) Elements() uint64 {
|
||||
func (t Tensor) parameters() uint64 {
|
||||
var count uint64 = 1
|
||||
for _, n := range t.Shape {
|
||||
count *= n
|
||||
@@ -338,11 +327,11 @@ func (t Tensor) Elements() uint64 {
|
||||
}
|
||||
|
||||
func (t Tensor) Size() uint64 {
|
||||
return t.Elements() * t.typeSize() / t.blockSize()
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
func (t Tensor) Type() string {
|
||||
return TensorType(t.Kind).String()
|
||||
return fileType(t.Kind).String()
|
||||
}
|
||||
|
||||
type container interface {
|
||||
@@ -387,12 +376,12 @@ func DetectContentType(b []byte) string {
|
||||
//
|
||||
// It collects array values for arrays with a size less than or equal to
|
||||
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||
|
||||
var magic uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var c container
|
||||
@@ -402,25 +391,24 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
return nil, 0, errors.New("invalid file magic")
|
||||
}
|
||||
|
||||
model, err := c.Decode(rs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// final model type
|
||||
return &GGML{
|
||||
container: c,
|
||||
model: model,
|
||||
Length: offset,
|
||||
}, nil
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
@@ -492,7 +480,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
var ropeFreqsCount uint64
|
||||
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||
ropeFreqsCount = ropeFreqsWeights.Elements()
|
||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,20 +640,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
case "qwen25vl":
|
||||
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
|
||||
|
||||
numPatches := maxPixels / (patchSize * patchSize)
|
||||
|
||||
graphSize = 4 * (maxPixels*numChannels + // Original image storage
|
||||
// Normalized pixels
|
||||
maxPixels*numChannels +
|
||||
// Patches storage (numPatches * channels * patchSize^2)
|
||||
numPatches*numChannels*patchSize*patchSize +
|
||||
// Self-attention calculations
|
||||
numPatches*numPatches*headCount +
|
||||
// Additional buffer for processing
|
||||
embeddingLength*numPatches)
|
||||
case "llama4":
|
||||
// vision graph is computed independently in the same schedule
|
||||
// and is negligible compared to the worst case text graph
|
||||
|
||||
110
fs/ggml/gguf.go
110
fs/ggml/gguf.go
@@ -9,12 +9,8 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
@@ -229,7 +225,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
}
|
||||
|
||||
llm.tensors = append(llm.tensors, &tensor)
|
||||
llm.parameters += tensor.Elements()
|
||||
llm.parameters += tensor.parameters()
|
||||
}
|
||||
|
||||
// patch KV with parameter count
|
||||
@@ -492,38 +488,25 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if t == ggufTypeString {
|
||||
for _, e := range any(s).([]string) {
|
||||
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -531,12 +514,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
slices.SortStableFunc(ts, func(a, b Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
@@ -547,34 +530,21 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
})
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
ts[i].Offset = s
|
||||
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
|
||||
for _, t := range ts {
|
||||
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
|
||||
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
||||
return err
|
||||
}
|
||||
s += ts[i].Size()
|
||||
s += uint64(ggufPadding(int64(s), int64(alignment)))
|
||||
s += t.Size()
|
||||
}
|
||||
|
||||
offset, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offset += ggufPadding(offset, int64(alignment))
|
||||
|
||||
var g errgroup.Group
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||
for _, t := range ts {
|
||||
t := t
|
||||
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||
g.Go(func() error {
|
||||
_, err := t.WriteTo(w)
|
||||
if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
|
||||
return err
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
@@ -589,10 +559,8 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
|
||||
var err error
|
||||
switch v := v.(type) {
|
||||
case uint32, FileType:
|
||||
case uint32:
|
||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||
case uint64:
|
||||
err = writeGGUF(ws, ggufTypeUint64, v)
|
||||
case float32:
|
||||
err = writeGGUF(ws, ggufTypeFloat32, v)
|
||||
case bool:
|
||||
@@ -601,20 +569,32 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFString(ws, v)
|
||||
case []int32:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||
case *array[int32]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||
case []uint32:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||
case *array[uint32]:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
|
||||
case []float32:
|
||||
err = writeGGUFArray(ws, ggufTypeFloat32, v)
|
||||
case *array[float32]:
|
||||
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
|
||||
case []string:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, e := range v {
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
@@ -622,7 +602,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
|
||||
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
||||
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
|
||||
return err
|
||||
@@ -649,6 +629,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
|
||||
return binary.Write(ws, binary.LittleEndian, t.Offset)
|
||||
}
|
||||
|
||||
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
|
||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = t.WriteTo(ws)
|
||||
return err
|
||||
}
|
||||
|
||||
func ggufPadding(offset, align int64) int64 {
|
||||
return (align - offset%align) % align
|
||||
}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
343
fs/ggml/type.go
343
fs/ggml/type.go
@@ -1,31 +1,26 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
// FileType is the Go equivalent to llama_ftype used for gguf file typing
|
||||
type FileType uint32
|
||||
type fileType uint32
|
||||
|
||||
const (
|
||||
FileTypeF32 FileType = iota
|
||||
FileTypeF16
|
||||
fileTypeF32 fileType = iota
|
||||
fileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
FileTypeQ8_0
|
||||
fileTypeQ4_1_F16
|
||||
fileTypeQ4_2 // unused
|
||||
fileTypeQ4_3 // unused
|
||||
fileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
fileTypeQ2_K
|
||||
fileTypeQ3_K_S
|
||||
fileTypeQ3_K_M
|
||||
fileTypeQ3_K_L
|
||||
FileTypeQ4_K_S
|
||||
FileTypeQ4_K_M
|
||||
fileTypeQ4_K_S
|
||||
fileTypeQ4_K_M
|
||||
fileTypeQ5_K_S
|
||||
fileTypeQ5_K_M
|
||||
fileTypeQ6_K
|
||||
@@ -42,62 +37,93 @@ const (
|
||||
fileTypeIQ2_M
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ1_M
|
||||
FileTypeBF16
|
||||
fileTypeQ4_0_4_4 // unused by GGML
|
||||
fileTypeQ4_0_4_8 // unused by GGML
|
||||
fileTypeQ4_0_8_8 // unused by GGML
|
||||
fileTypeTQ1_0
|
||||
fileTypeTQ2_0
|
||||
fileTypeBF16
|
||||
|
||||
FileTypeUnknown = 1024
|
||||
fileTypeUnknown
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
// Only Ollama supported types are considered valid
|
||||
func ParseFileType(s string) (FileType, error) {
|
||||
func ParseFileType(s string) (fileType, error) {
|
||||
switch s {
|
||||
case "F32":
|
||||
return FileTypeF32, nil
|
||||
return fileTypeF32, nil
|
||||
case "F16":
|
||||
return FileTypeF16, nil
|
||||
return fileTypeF16, nil
|
||||
case "Q4_0":
|
||||
return fileTypeQ4_0, nil
|
||||
case "Q4_1":
|
||||
return fileTypeQ4_1, nil
|
||||
case "Q4_1_F16":
|
||||
return fileTypeQ4_1_F16, nil
|
||||
case "Q8_0":
|
||||
return FileTypeQ8_0, nil
|
||||
return fileTypeQ8_0, nil
|
||||
case "Q5_0":
|
||||
return fileTypeQ5_0, nil
|
||||
case "Q5_1":
|
||||
return fileTypeQ5_1, nil
|
||||
case "Q2_K":
|
||||
return fileTypeQ2_K, nil
|
||||
case "Q3_K_S":
|
||||
return fileTypeQ3_K_S, nil
|
||||
case "Q3_K_M":
|
||||
return fileTypeQ3_K_M, nil
|
||||
case "Q3_K_L":
|
||||
return fileTypeQ3_K_L, nil
|
||||
case "Q4_K_S":
|
||||
return FileTypeQ4_K_S, nil
|
||||
case "Q4_K_M", "Q4_K":
|
||||
return FileTypeQ4_K_M, nil
|
||||
return fileTypeQ4_K_S, nil
|
||||
case "Q4_K_M":
|
||||
return fileTypeQ4_K_M, nil
|
||||
case "Q5_K_S":
|
||||
return fileTypeQ5_K_S, nil
|
||||
case "Q5_K_M":
|
||||
return fileTypeQ5_K_M, nil
|
||||
case "Q6_K":
|
||||
return fileTypeQ6_K, nil
|
||||
case "IQ2_XXS":
|
||||
return fileTypeIQ2_XXS, nil
|
||||
case "IQ2_XS":
|
||||
return fileTypeIQ2_XS, nil
|
||||
case "Q2_K_S":
|
||||
return fileTypeQ2_K_S, nil
|
||||
case "IQ3_XS":
|
||||
return fileTypeIQ3_XS, nil
|
||||
case "IQ3_XXS":
|
||||
return fileTypeIQ3_XXS, nil
|
||||
case "IQ1_S":
|
||||
return fileTypeIQ1_S, nil
|
||||
case "IQ4_NL":
|
||||
return fileTypeIQ4_NL, nil
|
||||
case "IQ3_S":
|
||||
return fileTypeIQ3_S, nil
|
||||
case "IQ3_M":
|
||||
return fileTypeIQ3_M, nil
|
||||
case "IQ2_S":
|
||||
return fileTypeIQ2_S, nil
|
||||
case "IQ2_M":
|
||||
return fileTypeIQ2_M, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ1_M":
|
||||
return fileTypeIQ1_M, nil
|
||||
case "BF16":
|
||||
return FileTypeBF16, nil
|
||||
return fileTypeBF16, nil
|
||||
default:
|
||||
supportedFileTypes := []FileType{
|
||||
FileTypeF32,
|
||||
FileTypeF16,
|
||||
FileTypeQ4_K_S,
|
||||
FileTypeQ4_K_M,
|
||||
FileTypeQ8_0,
|
||||
// fsggml.FileTypeBF16, // TODO
|
||||
}
|
||||
strs := make([]string, len(supportedFileTypes))
|
||||
for i := range supportedFileTypes {
|
||||
strs[i] = supportedFileTypes[i].String()
|
||||
}
|
||||
|
||||
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
|
||||
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (t FileType) String() string {
|
||||
// Note: this routine will return a broader set of file types for existing models
|
||||
func (t fileType) String() string {
|
||||
switch t {
|
||||
case FileTypeF32:
|
||||
case fileTypeF32:
|
||||
return "F32"
|
||||
case FileTypeF16:
|
||||
case fileTypeF16:
|
||||
return "F16"
|
||||
case fileTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case FileTypeQ8_0:
|
||||
case fileTypeQ4_1_F16:
|
||||
return "Q4_1_F16"
|
||||
case fileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
return "Q5_0"
|
||||
@@ -111,9 +137,9 @@ func (t FileType) String() string {
|
||||
return "Q3_K_M"
|
||||
case fileTypeQ3_K_L:
|
||||
return "Q3_K_L"
|
||||
case FileTypeQ4_K_S:
|
||||
case fileTypeQ4_K_S:
|
||||
return "Q4_K_S"
|
||||
case FileTypeQ4_K_M:
|
||||
case fileTypeQ4_K_M:
|
||||
return "Q4_K_M"
|
||||
case fileTypeQ5_K_S:
|
||||
return "Q5_K_S"
|
||||
@@ -121,198 +147,39 @@ func (t FileType) String() string {
|
||||
return "Q5_K_M"
|
||||
case fileTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case fileTypeIQ2_XXS:
|
||||
return "IQ2_XXS"
|
||||
case fileTypeIQ2_XS:
|
||||
return "IQ2_XS"
|
||||
case fileTypeQ2_K_S:
|
||||
return "Q2_K_S"
|
||||
case FileTypeBF16:
|
||||
case fileTypeIQ3_XS:
|
||||
return "IQ3_XS"
|
||||
case fileTypeIQ3_XXS:
|
||||
return "IQ3_XXS"
|
||||
case fileTypeIQ1_S:
|
||||
return "IQ1_S"
|
||||
case fileTypeIQ4_NL:
|
||||
return "IQ4_NL"
|
||||
case fileTypeIQ3_S:
|
||||
return "IQ3_S"
|
||||
case fileTypeIQ3_M:
|
||||
return "IQ3_M"
|
||||
case fileTypeIQ2_S:
|
||||
return "IQ2_S"
|
||||
case fileTypeIQ4_XS:
|
||||
return "IQ4_XS"
|
||||
case fileTypeIQ2_M:
|
||||
return "IQ2_M"
|
||||
case fileTypeIQ1_M:
|
||||
return "IQ1_M"
|
||||
case fileTypeBF16:
|
||||
return "BF16"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t FileType) Value() uint32 {
|
||||
func (t fileType) Value() uint32 {
|
||||
return uint32(t)
|
||||
}
|
||||
|
||||
func (ftype FileType) ToTensorType() TensorType {
|
||||
switch ftype {
|
||||
case FileTypeF32:
|
||||
return TensorTypeF32
|
||||
case FileTypeF16:
|
||||
return TensorTypeF16
|
||||
case fileTypeQ4_0:
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
return TensorTypeQ5_0
|
||||
case fileTypeQ5_1:
|
||||
return TensorTypeQ5_1
|
||||
case fileTypeQ2_K:
|
||||
return TensorTypeQ2_K
|
||||
case fileTypeQ3_K_S:
|
||||
return TensorTypeQ3_K
|
||||
case fileTypeQ3_K_M:
|
||||
return TensorTypeQ3_K
|
||||
case fileTypeQ3_K_L:
|
||||
return TensorTypeQ3_K
|
||||
case FileTypeQ4_K_S:
|
||||
return TensorTypeQ4_K
|
||||
case FileTypeQ4_K_M:
|
||||
return TensorTypeQ4_K
|
||||
case fileTypeQ5_K_S:
|
||||
return TensorTypeQ5_K
|
||||
case fileTypeQ5_K_M:
|
||||
return TensorTypeQ5_K
|
||||
case fileTypeQ6_K:
|
||||
return TensorTypeQ6_K
|
||||
case fileTypeQ2_K_S:
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
}
|
||||
}
|
||||
|
||||
// TensorType is equivalent to ggml_type for individual tensor types
|
||||
// Note: these are not the same as FileType
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2 // unused by GGML
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
tensorTypeIQ2_XXS // not supported by ollama
|
||||
tensorTypeIQ2_XS // not supported by ollama
|
||||
tensorTypeIQ3_XXS // not supported by ollama
|
||||
tensorTypeIQ1_S // not supported by ollama
|
||||
tensorTypeIQ4_NL // not supported by ollama
|
||||
tensorTypeIQ3_S // not supported by ollama
|
||||
tensorTypeIQ2_S // not supported by ollama
|
||||
tensorTypeIQ4_XS // not supported by ollama
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
tensorTypeIQ1_M // not supported by ollama
|
||||
TensorTypeBF16
|
||||
tensorTypeQ4_0_4_4 // unused by GGML
|
||||
tensorTypeQ4_0_4_8 // unused by GGML
|
||||
tensorTypeQ4_0_8_8 // unused by GGML
|
||||
tensorTypeTQ1_0 // not supported by ollama
|
||||
tensorTypeTQ2_0 // not supported by ollama
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
// Only Ollama supported types are considered valid
|
||||
func ParseTensorType(s string) (TensorType, error) {
|
||||
switch s {
|
||||
case "F32":
|
||||
return TensorTypeF32, nil
|
||||
case "F16":
|
||||
return TensorTypeF16, nil
|
||||
case "Q4_0":
|
||||
return TensorTypeQ4_0, nil
|
||||
case "Q4_1":
|
||||
return TensorTypeQ4_1, nil
|
||||
case "Q5_0":
|
||||
return TensorTypeQ5_0, nil
|
||||
case "Q5_1":
|
||||
return TensorTypeQ5_1, nil
|
||||
case "Q8_0":
|
||||
return TensorTypeQ8_0, nil
|
||||
case "Q8_1":
|
||||
return TensorTypeQ8_1, nil
|
||||
case "Q2_K":
|
||||
return TensorTypeQ2_K, nil
|
||||
case "Q3_K":
|
||||
return TensorTypeQ3_K, nil
|
||||
case "Q4_K":
|
||||
return TensorTypeQ4_K, nil
|
||||
case "Q5_K":
|
||||
return TensorTypeQ5_K, nil
|
||||
case "Q6_K":
|
||||
return TensorTypeQ6_K, nil
|
||||
case "Q8_K":
|
||||
return TensorTypeQ8_K, nil
|
||||
case "F64":
|
||||
return TensorTypeF64, nil
|
||||
case "BF16":
|
||||
return TensorTypeBF16, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) IsQuantized() bool {
|
||||
switch t {
|
||||
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) RowSize(ne uint64) uint64 {
|
||||
return t.TypeSize() * ne / t.BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) String() string {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return "F32"
|
||||
case TensorTypeF16:
|
||||
return "F16"
|
||||
case TensorTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case TensorTypeQ5_0:
|
||||
return "Q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "Q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "Q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "Q2_K"
|
||||
case TensorTypeQ3_K:
|
||||
return "Q3_K"
|
||||
case TensorTypeQ4_K:
|
||||
return "Q4_K"
|
||||
case TensorTypeQ5_K:
|
||||
return "Q5_K"
|
||||
case TensorTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case TensorTypeQ8_K:
|
||||
return "Q8_K"
|
||||
case TensorTypeF64:
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
350
fs/gguf/gguf.go
350
fs/gguf/gguf.go
@@ -1,350 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
typeUint8 uint32 = iota
|
||||
typeInt8
|
||||
typeUint16
|
||||
typeInt16
|
||||
typeUint32
|
||||
typeInt32
|
||||
typeFloat32
|
||||
typeBool
|
||||
typeString
|
||||
typeArray
|
||||
typeUint64
|
||||
typeInt64
|
||||
typeFloat64
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("unsupported")
|
||||
|
||||
type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *readSeeker
|
||||
bts []byte
|
||||
}
|
||||
|
||||
func Open(path string) (f *File, err error) {
|
||||
f = &File{bts: make([]byte, 4096)}
|
||||
f.file, err = os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.reader = newReadSeeker(f.file, 32<<10)
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||
}
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version != 3 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensors.doneFunc = func() error {
|
||||
offset, err := f.reader.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||
return nil
|
||||
}
|
||||
|
||||
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) readTensor() (TensorInfo, error) {
|
||||
name, err := readString(f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
dims, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
shape := make([]uint64, dims)
|
||||
for i := range dims {
|
||||
shape[i], err = read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
type_, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
offset, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
key, err := readString(f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
value, err := func() (any, error) {
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return read[uint8](f)
|
||||
case typeInt8:
|
||||
return read[int8](f)
|
||||
case typeUint16:
|
||||
return read[uint16](f)
|
||||
case typeInt16:
|
||||
return read[int16](f)
|
||||
case typeUint32:
|
||||
return read[uint32](f)
|
||||
case typeInt32:
|
||||
return read[int32](f)
|
||||
case typeUint64:
|
||||
return read[uint64](f)
|
||||
case typeInt64:
|
||||
return read[int64](f)
|
||||
case typeFloat32:
|
||||
return read[float32](f)
|
||||
case typeFloat64:
|
||||
return read[float64](f)
|
||||
case typeBool:
|
||||
return read[bool](f)
|
||||
case typeString:
|
||||
return readString(f)
|
||||
case typeArray:
|
||||
return readArray(f)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
return KeyValue{
|
||||
Key: key,
|
||||
Value: Value{value},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func read[T any](f *File) (t T, err error) {
|
||||
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func readArray(f *File) (any, error) {
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return readArrayData[uint8](f, n)
|
||||
case typeInt8:
|
||||
return readArrayData[int8](f, n)
|
||||
case typeUint16:
|
||||
return readArrayData[uint16](f, n)
|
||||
case typeInt16:
|
||||
return readArrayData[int16](f, n)
|
||||
case typeUint32:
|
||||
return readArrayData[uint32](f, n)
|
||||
case typeInt32:
|
||||
return readArrayData[int32](f, n)
|
||||
case typeUint64:
|
||||
return readArrayData[uint64](f, n)
|
||||
case typeInt64:
|
||||
return readArrayData[int64](f, n)
|
||||
case typeFloat32:
|
||||
return readArrayData[float32](f, n)
|
||||
case typeFloat64:
|
||||
return readArrayData[float64](f, n)
|
||||
case typeBool:
|
||||
return readArrayData[bool](f, n)
|
||||
case typeString:
|
||||
return readArrayString(f, n)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
func (f *File) KeyValue(key string) KeyValue {
|
||||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||
key = f.KeyValue("general.architecture").String() + "." + key
|
||||
}
|
||||
|
||||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||
return kv.Key == key
|
||||
}); index >= 0 {
|
||||
return f.keyValues.values[index]
|
||||
}
|
||||
|
||||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||
if keyValue.Key == key {
|
||||
return keyValue
|
||||
}
|
||||
}
|
||||
|
||||
return KeyValue{}
|
||||
}
|
||||
|
||||
func (f *File) NumKeyValues() int {
|
||||
return int(f.keyValues.count)
|
||||
}
|
||||
|
||||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
return f.keyValues.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
|
||||
return TensorInfo{}
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
t := f.TensorInfo(name)
|
||||
if t.NumBytes() == 0 {
|
||||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
// Setup
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.gguf")
|
||||
|
||||
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"general.alignment": int64(32),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := Open(tempFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Test
|
||||
if got := f.NumKeyValues(); got != 2 {
|
||||
t.Errorf("NumKeyValues() = %d, want %d", got, 2)
|
||||
}
|
||||
if got := f.NumTensors(); got != 2 {
|
||||
t.Errorf("NumTensors() = %d, want %d", got, 2)
|
||||
}
|
||||
archKV := f.KeyValue("general.architecture")
|
||||
if archKV.Key == "" {
|
||||
t.Error("KeyValue(\"general.architecture\") not found")
|
||||
}
|
||||
if got := archKV.String(); got != "llama" {
|
||||
t.Errorf("KeyValue(\"general.architecture\").String() = %q, want %q", got, "llama")
|
||||
}
|
||||
alignKV := f.KeyValue("general.alignment")
|
||||
if alignKV.Key == "" {
|
||||
t.Error("KeyValue(\"general.alignment\") not found")
|
||||
}
|
||||
if got := alignKV.Int(); got != 32 {
|
||||
t.Errorf("KeyValue(\"general.alignment\").Int() = %d, want %d", got, 32)
|
||||
}
|
||||
expectedTensorNames := []string{"token_embd.weight", "output.weight"}
|
||||
var gotTensorNames []string
|
||||
for _, tensor := range f.TensorInfos() {
|
||||
gotTensorNames = append(gotTensorNames, tensor.Name)
|
||||
}
|
||||
if !slices.Equal(gotTensorNames, expectedTensorNames) {
|
||||
t.Errorf("tensor names = %v, want %v", gotTensorNames, expectedTensorNames)
|
||||
}
|
||||
tokenTensor := f.TensorInfo("token_embd.weight")
|
||||
if tokenTensor.Name != "token_embd.weight" {
|
||||
t.Error("TensorInfo(\"token_embd.weight\") not found")
|
||||
}
|
||||
if len(tokenTensor.Shape) == 0 {
|
||||
t.Error("TensorInfo(\"token_embd.weight\") has empty shape")
|
||||
}
|
||||
outputTensor := f.TensorInfo("output.weight")
|
||||
if outputTensor.Name != "output.weight" {
|
||||
t.Error("TensorInfo(\"output.weight\") not found")
|
||||
}
|
||||
if len(outputTensor.Shape) == 0 {
|
||||
t.Error("TensorInfo(\"output.weight\") has empty shape")
|
||||
}
|
||||
var gotKeyCount int
|
||||
for _, kv := range f.KeyValues() {
|
||||
gotKeyCount++
|
||||
if kv.Key == "" {
|
||||
t.Error("found key value with empty key")
|
||||
}
|
||||
}
|
||||
if gotKeyCount != 2 {
|
||||
t.Errorf("iterated key count = %d, want %d", gotKeyCount, 2)
|
||||
}
|
||||
tensorInfo, reader, err := f.TensorReader("token_embd.weight")
|
||||
if err != nil {
|
||||
t.Errorf("TensorReader(\"token_embd.weight\") error: %v", err)
|
||||
}
|
||||
if tensorInfo.Name != "token_embd.weight" {
|
||||
t.Errorf("TensorReader returned wrong tensor: %q", tensorInfo.Name)
|
||||
}
|
||||
if reader == nil {
|
||||
t.Error("TensorReader returned nil reader")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
// Create benchmark test file
|
||||
tempDir := b.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "benchmark.gguf")
|
||||
|
||||
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||
"general.architecture": "llama",
|
||||
"general.alignment": int64(32),
|
||||
}, []testTensorInfo{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||
}); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get file info for reporting
|
||||
info, err := os.Stat(tempFile)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Logf("Benchmark file size: %d bytes", info.Size())
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
f, err := Open(tempFile)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Access some data to ensure it's actually being read
|
||||
_ = f.KeyValue("general.architecture").String()
|
||||
_ = f.KeyValue("general.alignment").Int()
|
||||
_ = f.NumTensors()
|
||||
_ = f.NumKeyValues()
|
||||
|
||||
// Iterate through some tensors
|
||||
count := 0
|
||||
for _, tensor := range f.TensorInfos() {
|
||||
_ = tensor.Name
|
||||
count++
|
||||
if count >= 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create test GGUF files
|
||||
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write GGUF magic
|
||||
if _, err := file.Write([]byte("GGUF")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write version
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write tensor count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata count
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
for key, value := range keyValues {
|
||||
if err := writeKeyValue(file, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write tensor info
|
||||
for _, tensor := range tensors {
|
||||
if err := writeTensorInfo(file, tensor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write some dummy tensor data
|
||||
dummyData := make([]byte, 1024)
|
||||
file.Write(dummyData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type testTensorInfo struct {
|
||||
Name string
|
||||
Shape []uint64
|
||||
Type uint32
|
||||
}
|
||||
|
||||
func writeKeyValue(file *os.File, key string, value any) error {
|
||||
// Write key length and key
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value based on type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := file.Write([]byte(v))
|
||||
return err
|
||||
case int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case bool:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(file, binary.LittleEndian, v)
|
||||
case []string:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, s := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []int64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, i := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []float64:
|
||||
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, f := range v {
|
||||
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
|
||||
// Write tensor name
|
||||
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.Write([]byte(tensor.Name)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write dimensions
|
||||
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, dim := range tensor.Shape {
|
||||
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write type
|
||||
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write offset (dummy value)
|
||||
return binary.Write(file, binary.LittleEndian, uint64(0))
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type KeyValue struct {
|
||||
Key string
|
||||
Value
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
func (v Value) Int() int64 {
|
||||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||
func (v Value) Ints() (i64s []int64) {
|
||||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||
func (v Value) Uint() uint64 {
|
||||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||
func (v Value) Uints() (u64s []uint64) {
|
||||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||
func (v Value) Float() float64 {
|
||||
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||
func (v Value) Floats() (f64s []float64) {
|
||||
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||
func (v Value) Bool() bool {
|
||||
return value[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||
func (v Value) Bools() (bools []bool) {
|
||||
return values[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||
func (v Value) String() string {
|
||||
return value[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// IsNil checks if the Value is nil. It returns true if the value is nil or if it is a nil pointer, interface, slice, map, channel, or function.
|
||||
func (v Value) IsNil() bool {
|
||||
if v.value == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for nil pointers, interfaces, slices, maps, channels, and functions
|
||||
rv := reflect.ValueOf(v.value)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
|
||||
return rv.IsNil()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||
for key, value := range values {
|
||||
if key == name {
|
||||
matched = value
|
||||
} else {
|
||||
unmatched = append(unmatched, value...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||
"float64": {float32(42), float64(42)},
|
||||
"string": {"42", "hello"},
|
||||
"bool": {true, false},
|
||||
}
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
matched, unmatched := split("int64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 42 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 0 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 42 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 0 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
matched, unmatched := split("float64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 42 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 0 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
matched, unmatched := split("string", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != v {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != "" {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
matched, unmatched := split("bool", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != v {
|
||||
t.Errorf("expected true, got %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != false {
|
||||
t.Errorf("expected false, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||
"float64s": {[]float32{42}, []float64{42}},
|
||||
"strings": {[]string{"42"}, []string{"hello"}},
|
||||
"bools": {[]bool{true}, []bool{false}},
|
||||
}
|
||||
|
||||
t.Run("int64s", func(t *testing.T) {
|
||||
matched, unmatched := split("int64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64s := kv.Ints(); i64s != nil {
|
||||
t.Errorf("expected nil, got %v", i64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64s", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64s := kv.Uints(); u64s != nil {
|
||||
t.Errorf("expected nil, got %v", u64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64s", func(t *testing.T) {
|
||||
matched, unmatched := split("float64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64s := kv.Floats(); f64s != nil {
|
||||
t.Errorf("expected nil, got %v", f64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
matched, unmatched := split("strings", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.Strings(); s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bools", func(t *testing.T) {
|
||||
matched, unmatched := split("bools", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bools(); b != nil {
|
||||
t.Errorf("expected nil, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type lazy[T any] struct {
|
||||
count uint64
|
||||
next func() (T, bool)
|
||||
stop func()
|
||||
values []T
|
||||
|
||||
doneFunc func() error
|
||||
}
|
||||
|
||||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||
it := lazy[T]{}
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.values = make([]T, 0)
|
||||
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||
for i := range it.count {
|
||||
t, err := fn()
|
||||
if err != nil {
|
||||
slog.Error("error reading tensor", "index", i, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
it.values = append(it.values, t)
|
||||
if !yield(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if it.doneFunc != nil {
|
||||
it.doneFunc()
|
||||
}
|
||||
})
|
||||
|
||||
return &it, nil
|
||||
}
|
||||
|
||||
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for _, v := range g.All() {
|
||||
if !yield(v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := range int(g.count) {
|
||||
if i < len(g.values) {
|
||||
if !yield(i, g.values[i]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t, ok := g.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if !yield(i, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) rest() (collected bool) {
|
||||
for {
|
||||
_, ok := g.next()
|
||||
collected = collected || ok
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return collected
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type readSeeker struct {
|
||||
rs io.ReadSeeker
|
||||
br *bufio.Reader
|
||||
}
|
||||
|
||||
func newReadSeeker(rs io.ReadSeeker, size int) *readSeeker {
|
||||
return &readSeeker{
|
||||
rs: rs,
|
||||
br: bufio.NewReaderSize(rs, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *readSeeker) Read(p []byte) (int, error) {
|
||||
return b.br.Read(p)
|
||||
}
|
||||
|
||||
func (b *readSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||
if whence == io.SeekCurrent {
|
||||
offset -= int64(b.br.Buffered())
|
||||
}
|
||||
n, err := b.rs.Seek(offset, whence)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b.br.Reset(b.rs)
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Offset uint64
|
||||
Shape []uint64
|
||||
Type TensorType
|
||||
}
|
||||
|
||||
func (t TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range t.Shape {
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (t TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(t.NumValues()) * t.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (t TensorInfo) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", t.Name),
|
||||
slog.Int64("offset", int64(t.Offset)),
|
||||
slog.Any("shape", t.Shape),
|
||||
slog.Int64("num_values", t.NumValues()),
|
||||
slog.Int64("num_bytes", t.NumBytes()),
|
||||
slog.Any("type", t.Type),
|
||||
)
|
||||
}
|
||||
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3
|
||||
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ2_XXS
|
||||
tensorTypeIQ2_XS
|
||||
tensorTypeIQ3_XXS
|
||||
tensorTypeIQ1_S
|
||||
tensorTypeIQ4_NL
|
||||
tensorTypeIQ3_S
|
||||
tensorTypeIQ2_S
|
||||
tensorTypeIQ4_XS
|
||||
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ1_M
|
||||
|
||||
TensorTypeBF16
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_0_4_4
|
||||
tensorTypeQ4_0_4_8
|
||||
tensorTypeQ4_0_8_8
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeTQ1_0
|
||||
tensorTypeTQ2_0
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
)
|
||||
|
||||
func (t TensorType) NumBytes() float64 {
|
||||
return float64(t.typeSize()) / float64(t.blockSize())
|
||||
}
|
||||
|
||||
func (t TensorType) typeSize() int64 {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + t.blockSize()/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + t.blockSize()/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + t.blockSize()/2
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + t.blockSize()/2
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + t.blockSize()
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + t.blockSize()
|
||||
case TensorTypeQ2_K:
|
||||
return t.blockSize()/16 + t.blockSize()/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
return t.blockSize()/8 + t.blockSize()/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + t.blockSize()/2
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + t.blockSize()/8 + t.blockSize()/2
|
||||
case TensorTypeQ6_K:
|
||||
return t.blockSize()/2 + t.blockSize()/4 + t.blockSize()/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + t.blockSize() + 2*t.blockSize()/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*t.blockSize()/8
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*t.blockSize()/8 + t.blockSize()/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/8
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + t.blockSize()/8 + t.blockSize()/16
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + t.blockSize()/2
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/8 + t.blockSize()/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + t.blockSize()/4 + t.blockSize()/16
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + t.blockSize()/2 + t.blockSize()/64
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
return t.blockSize()/8 + t.blockSize()/16 + t.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) blockSize() int64 {
|
||||
switch t {
|
||||
case TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) String() string {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Uint64("value", uint64(t)),
|
||||
slog.String("name", strings.ToUpper(t.String())),
|
||||
slog.Int64("size", t.typeSize()),
|
||||
slog.Int64("block_size", t.blockSize()),
|
||||
slog.Float64("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
12
go.mod
12
go.mod
@@ -11,7 +11,7 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sync v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -70,12 +70,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/term v0.29.0
|
||||
golang.org/x/text v0.22.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -34,15 +34,13 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -64,15 +62,13 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -102,15 +98,13 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -150,8 +144,6 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
@@ -190,7 +182,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@@ -206,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
@@ -218,7 +210,9 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
@@ -232,7 +226,9 @@ func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T,
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
model: "qwen2.5vl",
|
||||
model: "llava:7b",
|
||||
},
|
||||
{
|
||||
model: "llama3.2-vision",
|
||||
@@ -60,7 +60,6 @@ func TestVisionModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
|
||||
@@ -48,6 +48,17 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
return 8 * time.Minute, 10 * time.Minute
|
||||
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
|
||||
t.Skip("too little time")
|
||||
return time.Duration(0), time.Duration(0)
|
||||
}
|
||||
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
|
||||
}
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
//go:build integration && models
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQuantization(t *testing.T) {
|
||||
sourceModels := []string{
|
||||
"qwen2.5:0.5b-instruct-fp16",
|
||||
}
|
||||
quantizations := []string{
|
||||
"Q8_0",
|
||||
"Q4_K_S",
|
||||
"Q4_K_M",
|
||||
"Q4_K",
|
||||
}
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
started := time.Now()
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, base := range sourceModels {
|
||||
if err := PullIfMissing(ctx, client, base); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
for _, quant := range quantizations {
|
||||
newName := fmt.Sprintf("%s__%s", base, quant)
|
||||
t.Run(newName, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: newName,
|
||||
Quantization: quant,
|
||||
From: base,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
return nil
|
||||
}
|
||||
t.Logf("quantizing: %s -> %s", base, quant)
|
||||
if err := client.Create(ctx, req, fn); err != nil {
|
||||
t.Fatalf("create failed %s", err)
|
||||
}
|
||||
defer func() {
|
||||
req := &api.DeleteRequest{
|
||||
Model: newName,
|
||||
}
|
||||
t.Logf("deleting: %s -> %s", base, quant)
|
||||
if err := client.Delete(ctx, req); err != nil {
|
||||
t.Logf("failed to clean up %s: %s", req.Model, err)
|
||||
}
|
||||
}()
|
||||
// Check metadata on the model
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
if !strings.Contains(resp.Details.QuantizationLevel, quant) {
|
||||
t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Stream: &stream,
|
||||
}
|
||||
t.Logf("verifying: %s -> %s", base, quant)
|
||||
|
||||
// Some smaller quantizations can cause models to have poor quality
|
||||
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
||||
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
atLeastOne = true
|
||||
t.Log(fullResp)
|
||||
reqCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(reqCtx, &genReq, genfn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if genErr != nil && !atLeastOne {
|
||||
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
|
||||
t.Logf("passed")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
1
integration/testdata/embed.json
vendored
1
integration/testdata/embed.json
vendored
File diff suppressed because one or more lines are too long
@@ -217,7 +217,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
@@ -359,14 +358,3 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
return 8 * time.Minute, 10 * time.Minute
|
||||
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
|
||||
t.Skip("too little time")
|
||||
return time.Duration(0), time.Duration(0)
|
||||
}
|
||||
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
|
||||
}
|
||||
|
||||
@@ -30,11 +30,6 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
@@ -164,13 +159,12 @@ func (c *Causal) Close() {
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !c.curReserve {
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
@@ -217,9 +211,10 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
c.curMask = c.buildMask(ctx)
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func newRange() cellRange {
|
||||
@@ -244,7 +239,7 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
@@ -302,7 +297,7 @@ func roundUp(length, pad int) int {
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
@@ -310,11 +305,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
@@ -335,7 +325,10 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
@@ -343,7 +336,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
@@ -498,7 +491,12 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
c.opts = opts
|
||||
if ctx != nil {
|
||||
c.curMask = c.buildMask(ctx)
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
if err != nil {
|
||||
// This error should never occur because we have previously built a mask with the same shape
|
||||
panic(fmt.Errorf("SetCausal: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -654,7 +652,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
}
|
||||
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
|
||||
@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice(test.in, test.inShape...)
|
||||
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.Empty(dtype, shape...)
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
return t
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
f := make([]float32, len(s))
|
||||
for i := range f {
|
||||
f[i] = float32(s[i])
|
||||
}
|
||||
|
||||
out := c.FromFloatSlice(f, shape...)
|
||||
out, _ := c.FromFloatSlice(f, shape...)
|
||||
out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
return out
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
|
||||
s = append(s, i)
|
||||
}
|
||||
|
||||
out := c.FromFloatSlice(s, len(s))
|
||||
out, _ := c.FromFloatSlice(s, len(s))
|
||||
out.(*testTensor).dtype = dtype
|
||||
return out
|
||||
}
|
||||
@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() {}
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
|
||||
char const *LLAMA_COMMIT = "2016f07bd106c73699ecbaace80f55db5ed95dac";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
@@ -10,11 +10,11 @@ include common/stb_image.*
|
||||
include include/
|
||||
include include/llama.*
|
||||
include include/llama-*.*
|
||||
include tools/
|
||||
include tools/mtmd/
|
||||
include tools/mtmd/clip.*
|
||||
include tools/mtmd/clip-impl.*
|
||||
include tools/mtmd/llava.*
|
||||
include examples/
|
||||
include examples/llava/
|
||||
include examples/llava/clip.*
|
||||
include examples/llava/clip-impl.*
|
||||
include examples/llava/llava.*
|
||||
include src/
|
||||
include src/llama.*
|
||||
include src/llama-*.*
|
||||
|
||||
19
llama/llama.cpp/common/common.cpp
vendored
19
llama/llama.cpp/common/common.cpp
vendored
@@ -1096,6 +1096,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.n_threads = params.cpuparams.n_threads;
|
||||
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embeddings = params.embedding;
|
||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||
cparams.rope_freq_base = params.rope_freq_base;
|
||||
@@ -1113,7 +1114,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
@@ -1565,20 +1565,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
|
||||
const int64_t ne_datapoint = llama_n_ctx(ctx);
|
||||
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
|
||||
ggml_opt_dataset_t result = ggml_opt_dataset_init(
|
||||
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
|
||||
|
||||
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
|
||||
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
|
||||
|
||||
for (int64_t idata = 0; idata < ndata; ++idata) {
|
||||
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
|
||||
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
20
llama/llama.cpp/common/common.h
vendored
20
llama/llama.cpp/common/common.h
vendored
@@ -66,6 +66,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_MAIN,
|
||||
LLAMA_EXAMPLE_INFILL,
|
||||
LLAMA_EXAMPLE_EMBEDDING,
|
||||
LLAMA_EXAMPLE_PERPLEXITY,
|
||||
LLAMA_EXAMPLE_RETRIEVAL,
|
||||
@@ -95,7 +96,6 @@ enum common_sampler_type {
|
||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
@@ -161,7 +161,6 @@ struct common_params_sampling {
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
@@ -324,6 +323,7 @@ struct common_params {
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool logits_all = false; // return logits for all tokens in the batch
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
@@ -332,7 +332,6 @@ struct common_params {
|
||||
bool no_kv_offload = false; // disable KV offloading
|
||||
bool warmup = true; // warmup run
|
||||
bool check_tensors = false; // validate tensor data
|
||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
@@ -341,10 +340,8 @@ struct common_params {
|
||||
|
||||
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
||||
|
||||
// multimodal models (see tools/mtmd)
|
||||
// multimodal models (see examples/llava)
|
||||
struct common_params_model mmproj;
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
|
||||
// embedding
|
||||
@@ -410,14 +407,13 @@ struct common_params {
|
||||
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
|
||||
// cvector-generator params
|
||||
int n_pca_batch = 100;
|
||||
int n_pca_iterations = 1000;
|
||||
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
||||
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
|
||||
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
|
||||
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
|
||||
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
|
||||
|
||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||
|
||||
@@ -666,9 +662,3 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// training utils
|
||||
//
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||
|
||||
@@ -16,9 +16,6 @@ using json = nlohmann::ordered_json;
|
||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||
|
||||
if (max_items == 0) {
|
||||
return "";
|
||||
}
|
||||
if (min_items == 0 && max_items == 1) {
|
||||
return item_rule + "?";
|
||||
}
|
||||
|
||||
107
llama/llama.cpp/common/sampling.cpp
vendored
107
llama/llama.cpp/common/sampling.cpp
vendored
@@ -1,7 +1,6 @@
|
||||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
@@ -230,48 +229,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
params.logit_bias.data()));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
if (params.top_n_sigma >= 0) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
} else {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
@@ -473,7 +475,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
@@ -489,7 +490,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
@@ -504,7 +504,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
@@ -518,7 +517,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
||||
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
@@ -535,16 +533,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
auto sampler = sampler_canonical_name_map.find(name);
|
||||
if (sampler != sampler_canonical_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
continue;
|
||||
}
|
||||
if (allow_alt_names) {
|
||||
sampler = sampler_alt_name_map.find(name);
|
||||
if (sampler != sampler_alt_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
continue;
|
||||
} else {
|
||||
if (allow_alt_names) {
|
||||
sampler = sampler_alt_name_map.find(name);
|
||||
if (sampler != sampler_alt_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
|
||||
}
|
||||
|
||||
return samplers;
|
||||
@@ -556,7 +552,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
@@ -571,8 +566,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
const auto sampler = sampler_name_map.find(c);
|
||||
if (sampler != sampler_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
} else {
|
||||
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "gguf.h"
|
||||
#include "clip.h"
|
||||
|
||||
#include "clip.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdarg>
|
||||
#include <string>
|
||||
@@ -15,29 +17,33 @@
|
||||
#define KEY_FTYPE "general.file_type"
|
||||
#define KEY_NAME "general.name"
|
||||
#define KEY_DESCRIPTION "general.description"
|
||||
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_USE_SILU "clip.use_silu"
|
||||
#define KEY_N_EMBD "clip.vision.embedding_length"
|
||||
#define KEY_N_FF "clip.vision.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.vision.block_count"
|
||||
#define KEY_N_HEAD "clip.vision.attention.head_count"
|
||||
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.vision.projection_dim"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||
|
||||
|
||||
//
|
||||
@@ -53,16 +59,10 @@
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
|
||||
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
|
||||
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
|
||||
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
|
||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
|
||||
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
|
||||
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
@@ -70,14 +70,8 @@
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
#define TN_MM_INP_NORM "mm.input_norm.weight"
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
|
||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
|
||||
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
|
||||
|
||||
// mimicpmv
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
@@ -93,23 +87,18 @@
|
||||
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
|
||||
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
||||
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
||||
|
||||
// align x to upper multiple of n
|
||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||
#define TN_GLM_BOI_W "adapter.boi"
|
||||
#define TN_GLM_EOI_W "adapter.eoi"
|
||||
|
||||
enum projector_type {
|
||||
PROJECTOR_TYPE_MLP,
|
||||
PROJECTOR_TYPE_MLP_NORM,
|
||||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_MINICPMV,
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_QWEN2VL,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
PROJECTOR_TYPE_QWEN25VL,
|
||||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -117,14 +106,10 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_MLP, "mlp" },
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
@@ -239,15 +224,6 @@ struct clip_image_u8_batch {
|
||||
|
||||
struct clip_image_f32_batch {
|
||||
std::vector<clip_image_f32_ptr> entries;
|
||||
|
||||
clip_image_f32_batch clone() const {
|
||||
clip_image_f32_batch new_batch;
|
||||
new_batch.entries.reserve(entries.size());
|
||||
for (const auto & entry : entries) {
|
||||
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
||||
}
|
||||
return new_batch;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
2927
llama/llama.cpp/examples/llava/clip.cpp
vendored
Normal file
2927
llama/llama.cpp/examples/llava/clip.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -47,7 +47,7 @@ CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_par
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
|
||||
|
||||
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
||||
@@ -59,29 +59,18 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
||||
|
||||
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
|
||||
"use clip_n_output_tokens instead");
|
||||
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
|
||||
"use clip_n_output_tokens instead");
|
||||
|
||||
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
|
||||
// for M-RoPE, this will be the number of token positions in X and Y directions
|
||||
// for other models, X will be the total number of tokens and Y will be 1
|
||||
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
|
||||
// this should be equal to the embedding dimension of the text model
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
|
||||
CLIP_API struct clip_image_size * clip_image_size_init(void);
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
|
||||
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
|
||||
CLIP_API struct clip_image_size * clip_image_size_init();
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
||||
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
|
||||
|
||||
// nx, ny are the output image dimensions
|
||||
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||
@@ -125,6 +114,8 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
#include "llava.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cerrno>
|
||||
@@ -113,7 +112,7 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<
|
||||
}
|
||||
|
||||
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
|
||||
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
|
||||
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
|
||||
struct {
|
||||
struct ggml_context * ctx;
|
||||
} model;
|
||||
@@ -176,7 +175,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
|
||||
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
|
||||
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
|
||||
// fill it with the image embeddings, ignoring the base
|
||||
for (size_t i = 1; i < num_images; i++) {
|
||||
@@ -210,17 +209,13 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
||||
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
|
||||
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
|
||||
ggml_build_forward_expand(gf, flatten);
|
||||
|
||||
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
||||
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
|
||||
ggml_backend_graph_compute(backend.get(), gf);
|
||||
|
||||
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
|
||||
struct ggml_tensor* result = ggml_graph_node(gf, -1);
|
||||
|
||||
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
|
||||
// append without newline tokens (default behavior in llava_arch when not using unpad ):
|
||||
memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
|
||||
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));
|
||||
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
|
||||
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
|
||||
|
||||
// Debug: Test single segments
|
||||
// Current findings: sending base image, sending a segment embedding all works similar to python
|
||||
@@ -318,7 +313,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
|
||||
image_embd_v[i],
|
||||
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
|
||||
n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
|
||||
n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
|
||||
}
|
||||
*n_img_pos = n_img_pos_out;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
@@ -347,8 +342,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
}
|
||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
|
||||
*n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image\n");
|
||||
@@ -386,8 +381,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
|
||||
|
||||
int n_img_pos_out;
|
||||
clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
|
||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
|
||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
|
||||
*n_img_pos = n_img_pos_out;
|
||||
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
@@ -462,7 +456,7 @@ struct llava_embd_batch {
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
@@ -474,6 +468,7 @@ struct llava_embd_batch {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*n_embd =*/ n_embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
@@ -497,7 +492,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
||||
n_eval = n_batch;
|
||||
}
|
||||
float * embd = image_embed->embed+i*n_embd;
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
|
||||
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
@@ -1,4 +1,4 @@
|
||||
package mtmd
|
||||
package llava
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++11
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common
|
||||
69
llama/llama.cpp/include/llama.h
vendored
69
llama/llama.cpp/include/llama.h
vendored
@@ -4,7 +4,6 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@@ -112,8 +111,6 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@@ -258,6 +255,7 @@ extern "C" {
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
int32_t n_embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
@@ -353,18 +351,20 @@ extern "C" {
|
||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
// TODO: move at the end of the struct
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool cross_attn; // whether to use cross attention
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
// currently works only with CPU execution
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool op_offload; // whether to offload host tensor operations to device
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -446,10 +446,6 @@ extern "C" {
|
||||
size_t n_paths,
|
||||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_model_save_to_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_model);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
||||
"use llama_model_free instead");
|
||||
|
||||
@@ -464,6 +460,10 @@ extern "C" {
|
||||
struct llama_context_params params),
|
||||
"use llama_init_from_model instead");
|
||||
|
||||
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||
// and not set on the context for all batches.
|
||||
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
@@ -929,19 +929,14 @@ extern "C" {
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// In contrast to llama_decode() - this call does not use KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
@@ -1242,7 +1237,6 @@ extern "C" {
|
||||
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
@@ -1438,37 +1432,6 @@ extern "C" {
|
||||
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
||||
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
// function that returns whether or not a given tensor contains trainable parameters
|
||||
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// always returns true
|
||||
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
struct llama_opt_params {
|
||||
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
|
||||
|
||||
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
|
||||
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
};
|
||||
|
||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
LLAMA_API void llama_opt_epoch(
|
||||
struct llama_context * lctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
6
llama/llama.cpp/src/llama-adapter.cpp
vendored
6
llama/llama.cpp/src/llama-adapter.cpp
vendored
@@ -253,9 +253,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
@@ -294,9 +291,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
|
||||
81
llama/llama.cpp/src/llama-arch.cpp
vendored
81
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -6,6 +6,7 @@
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_MLLAMA, "mllama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
{ LLM_ARCH_FALCON, "falcon" },
|
||||
@@ -19,7 +20,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
@@ -73,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -108,7 +109,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||
@@ -144,6 +144,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
|
||||
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
@@ -273,6 +274,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MLLAMA,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DECI,
|
||||
{
|
||||
@@ -476,24 +511,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
{
|
||||
@@ -1570,6 +1587,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MISTRAL3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@@ -1701,6 +1734,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
// this tensor is loaded for T5, but never used
|
||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
||||
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
||||
13
llama/llama.cpp/src/llama-arch.h
vendored
13
llama/llama.cpp/src/llama-arch.h
vendored
@@ -11,6 +11,7 @@
|
||||
enum llm_arch {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_MLLAMA,
|
||||
LLM_ARCH_DECI,
|
||||
LLM_ARCH_FALCON,
|
||||
LLM_ARCH_BAICHUAN,
|
||||
@@ -23,7 +24,6 @@ enum llm_arch {
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
@@ -75,6 +75,7 @@ enum llm_arch {
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_SOLAR,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
@@ -112,7 +113,6 @@ enum llm_kv {
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
@@ -148,6 +148,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
||||
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
@@ -349,6 +350,14 @@ enum llm_tensor {
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_BSKCN_TV,
|
||||
LLM_TENSOR_CROSS_ATTN_K_NORM,
|
||||
LLM_TENSOR_CROSS_ATTN_K_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_O_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_Q_NORM,
|
||||
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_V_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
|
||||
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
|
||||
LLM_TENSOR_CONV1D,
|
||||
LLM_TENSOR_CONVNEXT_DW,
|
||||
LLM_TENSOR_CONVNEXT_NORM,
|
||||
|
||||
9
llama/llama.cpp/src/llama-batch.cpp
vendored
9
llama/llama.cpp/src/llama-batch.cpp
vendored
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@@ -203,7 +203,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
@@ -213,7 +212,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
@@ -241,7 +239,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
@@ -265,7 +262,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
@@ -320,6 +316,7 @@ struct llama_batch llama_batch_get_one(
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*n_embd =*/ 0,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
@@ -332,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*n_embd =*/ 0,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
@@ -340,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
|
||||
if (embd) {
|
||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
batch.n_embd = embd;
|
||||
} else {
|
||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
|
||||
3
llama/llama.cpp/src/llama-batch.h
vendored
3
llama/llama.cpp/src/llama-batch.h
vendored
@@ -70,8 +70,7 @@ struct llama_sbatch {
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
llama_sbatch() = default;
|
||||
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
|
||||
55
llama/llama.cpp/src/llama-chat.cpp
vendored
55
llama/llama.cpp/src/llama-chat.cpp
vendored
@@ -35,7 +35,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
@@ -51,8 +50,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
||||
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
||||
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
|
||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
|
||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
|
||||
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
||||
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
||||
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
||||
@@ -63,7 +62,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -83,9 +81,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
if (tmpl_contains("<|im_start|>")) {
|
||||
return tmpl_contains("<|im_sep|>")
|
||||
? LLM_CHAT_TEMPLATE_PHI_4
|
||||
: tmpl_contains("<end_of_utterance>")
|
||||
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
||||
@@ -123,12 +119,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
}
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_PHI_3;
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGLM_4;
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
||||
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
||||
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
||||
} else if (tmpl_contains("bos_token + message['role']")) {
|
||||
@@ -157,7 +149,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
||||
} else if (tmpl_contains("[gMASK]sop")) {
|
||||
// chatglm3-6b
|
||||
return LLM_CHAT_TEMPLATE_CHATGLM_3;
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_3;
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_4;
|
||||
} else if (tmpl_contains(LU8("<用户>"))) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||
@@ -203,20 +197,19 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
|
||||
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST]" << trailing_space << content << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << content << "</s>";
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
@@ -439,7 +432,7 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
|
||||
// chatglm3-6b
|
||||
ss << "[gMASK]" << "sop";
|
||||
for (auto message : chat) {
|
||||
@@ -449,14 +442,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
for (auto message : chat) {
|
||||
@@ -627,23 +620,7 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
|
||||
// SmolVLM
|
||||
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << "User: " << message->content << "<end_of_utterance>\n";
|
||||
} else {
|
||||
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else {
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
}
|
||||
|
||||
6
llama/llama.cpp/src/llama-chat.h
vendored
6
llama/llama.cpp/src/llama-chat.h
vendored
@@ -14,7 +14,6 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
@@ -30,8 +29,8 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
||||
LLM_CHAT_TEMPLATE_COMMAND_R,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGLM_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGLM_4,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_4,
|
||||
LLM_CHAT_TEMPLATE_GLMEDGE,
|
||||
LLM_CHAT_TEMPLATE_MINICPM,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_3,
|
||||
@@ -42,7 +41,6 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
929
llama/llama.cpp/src/llama-context.cpp
vendored
929
llama/llama.cpp/src/llama-context.cpp
vendored
File diff suppressed because it is too large
Load Diff
79
llama/llama.cpp/src/llama-context.h
vendored
79
llama/llama.cpp/src/llama-context.h
vendored
@@ -8,7 +8,6 @@
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@@ -29,12 +28,7 @@ struct llama_context {
|
||||
|
||||
void synchronize();
|
||||
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
ggml_context * get_ctx_compute() const;
|
||||
const llama_model & get_model() const;
|
||||
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
@@ -72,6 +66,7 @@ struct llama_context {
|
||||
void set_embeddings (bool value);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
void set_cross_attn(bool value);
|
||||
|
||||
void set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
@@ -135,32 +130,6 @@ struct llama_context {
|
||||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
@@ -170,30 +139,51 @@ private:
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
int32_t output_reserve(int32_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
void output_reorder();
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
|
||||
public:
|
||||
int32_t graph_max_nodes() const;
|
||||
|
||||
// zero-out inputs and create the ctx_compute for the compute graph
|
||||
ggml_cgraph * graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
// used by kv_self_update()
|
||||
ggml_tensor * build_rope_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
ggml_backend_buffer * bbuf) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
const std::vector<struct llama_kv_defrag_move> & moves) const;
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
@@ -210,10 +200,14 @@ private:
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
llama_sbatch sbatch;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||
|
||||
// TODO: remove
|
||||
bool logits_all = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
@@ -240,9 +234,6 @@ private:
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
|
||||
2
llama/llama.cpp/src/llama-cparams.h
vendored
2
llama/llama.cpp/src/llama-cparams.h
vendored
@@ -29,8 +29,8 @@ struct llama_cparams {
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
bool cross_attn;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
||||
141
llama/llama.cpp/src/llama-graph.cpp
vendored
141
llama/llama.cpp/src/llama-graph.cpp
vendored
@@ -55,21 +55,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->pos && pos) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
if (ubatch->token && n_pos_per_embd == 4) {
|
||||
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
|
||||
// the 3 first dims are the same, and 4th dim is all 0
|
||||
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
|
||||
// copy the first dimension
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
pos_data[ i] = ubatch->pos[i];
|
||||
pos_data[ n_tokens + i] = ubatch->pos[i];
|
||||
pos_data[2 * n_tokens + i] = ubatch->pos[i];
|
||||
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
|
||||
}
|
||||
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
|
||||
} else {
|
||||
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
|
||||
}
|
||||
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +71,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
||||
) * f_attn_temp_scale + 1.0;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
|
||||
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,7 +270,24 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_copy(i);
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -300,7 +303,18 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_mask(i);
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -532,6 +546,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->embd) {
|
||||
ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// llm_graph_context
|
||||
//
|
||||
@@ -578,7 +598,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
res (std::make_unique<llm_graph_result>()) {
|
||||
}
|
||||
|
||||
int64_t llm_graph_context::n_pos_per_embd() const {
|
||||
int64_t llm_graph_context::n_pos_per_token() const {
|
||||
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
||||
}
|
||||
|
||||
@@ -782,17 +802,13 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
} break;
|
||||
}
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
if (type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
}
|
||||
|
||||
if (down) {
|
||||
cur = build_lora_mm(down, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (down_b) {
|
||||
@@ -900,35 +916,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(up, "ffn_moe_up", il);
|
||||
|
||||
ggml_tensor * experts = nullptr;
|
||||
if (gate_exps) {
|
||||
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(cur, "ffn_moe_gate", il);
|
||||
} else {
|
||||
cur = up;
|
||||
}
|
||||
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(gate, "ffn_moe_gate", il);
|
||||
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
{
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_silu", il);
|
||||
gate = ggml_silu(ctx0, gate);
|
||||
cb(gate, "ffn_moe_silu", il);
|
||||
} break;
|
||||
case LLM_FFN_GELU:
|
||||
{
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_gelu", il);
|
||||
gate = ggml_gelu(ctx0, gate);
|
||||
cb(gate, "ffn_moe_gelu", il);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
if (gate_exps) {
|
||||
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(cur, "ffn_moe_gate_par", il);
|
||||
}
|
||||
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(par, "ffn_moe_gate_par", il);
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
if (!weight_before_ffn) {
|
||||
@@ -971,7 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
@@ -1012,11 +1020,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
|
||||
|
||||
auto & cur = inp->pos;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
@@ -1025,12 +1033,11 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
||||
|
||||
auto & cur = inp->attn_scale;
|
||||
|
||||
// this need to be 1x1xN for broadcasting
|
||||
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
|
||||
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
@@ -1078,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
|
||||
@@ -1095,7 +1102,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
|
||||
@@ -1228,19 +1235,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
if (v_mla) {
|
||||
#if 0
|
||||
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
||||
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
||||
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
#else
|
||||
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
||||
// The permutations are noops and only change how the tensor data is interpreted.
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
@@ -1420,6 +1416,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
GGML_ASSERT(!kv_self->recurrent);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
GGML_ASSERT(kv_self->size == n_ctx);
|
||||
@@ -1514,6 +1512,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
||||
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
|
||||
|
||||
ggml_tensor * cur = nullptr;
|
||||
|
||||
inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
|
||||
ggml_set_input(inp->cross_attn_state);
|
||||
|
||||
cur = inp->cross_attn_state;
|
||||
|
||||
cb(cur, "inp_cross_attn_state", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_cross * inp,
|
||||
ggml_cgraph * gf,
|
||||
@@ -1569,7 +1586,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
@@ -1601,7 +1618,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
@@ -1622,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
44
llama/llama.cpp/src/llama-graph.h
vendored
44
llama/llama.cpp/src/llama-graph.h
vendored
@@ -19,7 +19,6 @@ struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_recurrent;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@@ -87,31 +86,34 @@ public:
|
||||
|
||||
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
|
||||
};
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||
llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
|
||||
virtual ~llm_graph_input_pos() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
||||
|
||||
const int64_t n_pos_per_embd = 1;
|
||||
const int64_t n_pos_per_token = 1;
|
||||
};
|
||||
|
||||
// temperature tuning, used by llama4
|
||||
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
||||
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
||||
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
||||
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
||||
virtual ~llm_graph_input_attn_temp() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
|
||||
|
||||
const int64_t n_pos_per_token = 1;
|
||||
|
||||
const uint32_t n_attn_temp_floor_scale;
|
||||
const float f_attn_temp_scale;
|
||||
};
|
||||
@@ -187,26 +189,26 @@ public:
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@@ -284,6 +286,16 @@ public:
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_attn_state : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_cross_attn_state() = default;
|
||||
virtual ~llm_graph_input_cross_attn_state() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@@ -298,7 +310,6 @@ class llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_tokens() = 0;
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
@@ -313,7 +324,6 @@ class llm_graph_result : public llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() override { return t_tokens; }
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
@@ -330,7 +340,6 @@ public:
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
@@ -354,8 +363,8 @@ struct llm_graph_params {
|
||||
const llama_cparams & cparams;
|
||||
const llama_ubatch & ubatch;
|
||||
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
ggml_backend_sched * sched;
|
||||
ggml_backend * backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
@@ -406,9 +415,9 @@ struct llm_graph_context {
|
||||
|
||||
ggml_context * ctx0 = nullptr;
|
||||
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_sched * sched;
|
||||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
@@ -421,7 +430,7 @@ struct llm_graph_context {
|
||||
|
||||
llm_graph_context(const llm_graph_params & params);
|
||||
|
||||
int64_t n_pos_per_embd() const;
|
||||
int64_t n_pos_per_token() const;
|
||||
|
||||
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||
|
||||
@@ -495,6 +504,7 @@ struct llm_graph_context {
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
ggml_tensor * build_inp_s_mask() const;
|
||||
ggml_tensor * build_inp_cross_attn_state() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
|
||||
4
llama/llama.cpp/src/llama-hparams.cpp
vendored
4
llama/llama.cpp/src/llama-hparams.cpp
vendored
@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bool llama_hparams::cross_attention_layers(uint32_t il) const {
|
||||
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
|
||||
}
|
||||
|
||||
8
llama/llama.cpp/src/llama-hparams.h
vendored
8
llama/llama.cpp/src/llama-hparams.h
vendored
@@ -2,6 +2,8 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <array>
|
||||
|
||||
// bump if necessary
|
||||
@@ -42,6 +44,7 @@ struct llama_hparams {
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
uint32_t n_vocab = 0;
|
||||
|
||||
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||
uint32_t n_embd_head_k_mla = 0;
|
||||
@@ -56,6 +59,7 @@ struct llama_hparams {
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
|
||||
|
||||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
@@ -68,7 +72,6 @@ struct llama_hparams {
|
||||
float expert_weights_scale = 0.0;
|
||||
bool expert_weights_norm = false;
|
||||
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
||||
uint32_t moe_every_n_layers = 0;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
@@ -159,6 +162,9 @@ struct llama_hparams {
|
||||
// Block skip connection
|
||||
bool n_bskcn(uint32_t n, uint32_t il) const;
|
||||
|
||||
// cross attention layers
|
||||
bool cross_attention_layers(uint32_t il) const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
};
|
||||
|
||||
|
||||
1826
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
1826
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
File diff suppressed because it is too large
Load Diff
367
llama/llama.cpp/src/llama-kv-cache.h
vendored
367
llama/llama.cpp/src/llama-kv-cache.h
vendored
@@ -2,72 +2,32 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
struct llama_sbatch;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
using llama_memory_i::llama_memory_i;
|
||||
|
||||
// call if batch processing fails - restores the cache state
|
||||
virtual void restore() = 0;
|
||||
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||
|
||||
// call after successful batch processing - clears any pending state
|
||||
virtual void commit() = 0;
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual void set_full() = 0;
|
||||
|
||||
//
|
||||
// batch processing
|
||||
//
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
// different KV caches require different batch splitting strategies
|
||||
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
||||
|
||||
// getters
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
virtual llama_pos get_pos_max() const = 0;
|
||||
virtual bool get_can_shift() const = 0;
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
|
||||
//
|
||||
// state write/read
|
||||
//
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_guard
|
||||
//
|
||||
|
||||
struct llama_kv_cache_guard {
|
||||
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||
|
||||
@@ -82,7 +42,7 @@ struct llama_kv_cache_guard {
|
||||
private:
|
||||
llama_kv_cache * kv;
|
||||
};
|
||||
|
||||
|
||||
// block of KV slots to move when defragging
|
||||
struct llama_kv_defrag_move {
|
||||
uint32_t src;
|
||||
@@ -90,50 +50,65 @@ struct llama_kv_defrag_move {
|
||||
uint32_t len;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
int32_t src = -1; // used by recurrent state models to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const llama_kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
// can be used to query data from the model if needed
|
||||
struct callbacks {
|
||||
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||
};
|
||||
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t padding);
|
||||
bool offload);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos pos_max() const;
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
virtual void restore() override;
|
||||
virtual void commit() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@@ -143,76 +118,25 @@ public:
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
bool find_slot(const llama_ubatch & batch);
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
// TODO: maybe not needed
|
||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// required padding
|
||||
uint32_t padding = 1;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
// defrag
|
||||
|
||||
struct {
|
||||
std::vector<llama_kv_defrag_move> moves;
|
||||
} defrag_info;
|
||||
@@ -221,6 +145,7 @@ private:
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// commit/restore cache
|
||||
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
@@ -231,125 +156,25 @@ private:
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const std::vector<llama_kv_defrag_move> & moves) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||
|
||||
// members
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
callbacks cbs;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
@@ -361,41 +186,18 @@ public:
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
// commit/restore cache
|
||||
// TODO: rework for recurrent cache
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
@@ -403,6 +205,11 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||
//public:
|
||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||
//};
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
|
||||
12
llama/llama.cpp/src/llama-memory.h
vendored
12
llama/llama.cpp/src/llama-memory.h
vendored
@@ -2,22 +2,12 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
struct llama_memory_params {
|
||||
// kv cache
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
|
||||
// parameters for other types of memory
|
||||
// ...
|
||||
};
|
||||
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
class llama_memory_i {
|
||||
public:
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
virtual void clear() = 0;
|
||||
virtual void defrag() = 0;
|
||||
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user