mirror of
https://github.com/ollama/ollama.git
synced 2026-01-21 22:10:58 -05:00
Compare commits
23 Commits
v0.14.2-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5d0f72f16 | ||
|
|
148a1be0a3 | ||
|
|
d6dd430abd | ||
|
|
ae78112c50 | ||
|
|
01cf7445f3 | ||
|
|
31085d5e53 | ||
|
|
c42e9d244f | ||
|
|
e98b5e8b4e | ||
|
|
68e00c7c36 | ||
|
|
4f138a1749 | ||
|
|
03bf241c33 | ||
|
|
a887406c24 | ||
|
|
d51e95ba7e | ||
|
|
3d01f2aa34 | ||
|
|
634c416645 | ||
|
|
57de86cc61 | ||
|
|
12719b6e87 | ||
|
|
a077d996e3 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 |
18
Dockerfile
18
Dockerfile
@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
RUN yum install -y yum-utils epel-release \
|
||||
&& dnf install -y clang ccache \
|
||||
&& dnf install -y clang ccache git \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||
ENV CC=clang CXX=clang++
|
||||
|
||||
@@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
@@ -156,14 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -172,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -185,7 +180,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
1
MLX_VERSION
Normal file
1
MLX_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.4.1
|
||||
@@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
30
api/types.go
30
api/types.go
@@ -127,6 +127,20 @@ type GenerateRequest struct {
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Width is the width of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps for image generation.
|
||||
// Only used for image generation models.
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -735,7 +749,7 @@ type ShowResponse struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
@@ -860,6 +874,20 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Image contains a base64-encoded generated image.
|
||||
// Only present for image generation models.
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Completed is the number of completed steps in image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Total is the total number of steps for image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Total int64 `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
105
cmd/cmd.go
105
cmd/cmd.go
@@ -46,8 +46,9 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -93,15 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Validate model name early to fail fast
|
||||
modelName := args[0]
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
if create.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -134,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -527,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImage) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
@@ -826,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
// Unload the model if it's running before deletion
|
||||
if err := loadOrUnloadModel(cmd, &runOptions{
|
||||
Model: args[0],
|
||||
Model: arg,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1742,15 +1815,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1905,6 +1985,7 @@ func NewCLI() *cobra.Command {
|
||||
} {
|
||||
switch cmd {
|
||||
case runCmd:
|
||||
imagegen.AppendFlagsDocs(cmd)
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||
case serveCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
|
||||
@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
|
||||
@@ -311,6 +311,10 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
150
convert/convert_glm4moelite.go
Normal file
150
convert/convert_glm4moelite.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type glm4MoeLiteModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glm4moelite"
|
||||
kv["general.type"] = "model"
|
||||
kv["glm4moelite.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["glm4moelite.attention.head_count"] = numHeads
|
||||
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
|
||||
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["glm4moelite.attention.value_length"] = p.VHeadDim
|
||||
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["glm4moelite.embedding_length"] = p.HiddenSize
|
||||
kv["glm4moelite.expert_count"] = p.ExpertCount
|
||||
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
kv["glm4moelite.expert_gating_func"] = uint32(2)
|
||||
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
|
||||
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
100
convert/convert_lfm2.go
Normal file
100
convert/convert_lfm2.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type lfm2Model struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
ConvLCache uint32 `json:"conv_L_cache"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieEmbedding bool `json:"tie_embedding"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*lfm2Model)(nil)
|
||||
|
||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "lfm2"
|
||||
kv["lfm2.vocab_size"] = p.VocabSize
|
||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
// Build per-layer KV head count array based on layer_types
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
}
|
||||
|
||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
|
||||
// Squeeze conv weights: [D, 1, K] -> [D, K]
|
||||
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
|
||||
if len(shape) == 3 && shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.embedding_norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"operator_norm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_output",
|
||||
"self_attn.q_layernorm", "attn_q_norm",
|
||||
"self_attn.k_layernorm", "attn_k_norm",
|
||||
"conv.conv", "shortconv.conv",
|
||||
"conv.in_proj", "shortconv.in_proj",
|
||||
"conv.out_proj", "shortconv.out_proj",
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
"ffn_norm", "ffn_norm",
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ const (
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
62
docs/api.md
62
docs/api.md
@@ -16,6 +16,7 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -58,6 +59,15 @@ Advanced parameters (optional):
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
Experimental image generation parameters (for image generation models only):
|
||||
|
||||
> [!WARNING]
|
||||
> These parameters are experimental and may change in future versions.
|
||||
|
||||
- `width`: width of the generated image in pixels
|
||||
- `height`: height of the generated image in pixels
|
||||
- `steps`: number of diffusion steps
|
||||
|
||||
#### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
|
||||
@@ -1867,3 +1877,55 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Features
|
||||
|
||||
### Image Generation (Experimental)
|
||||
|
||||
> [!WARNING]
|
||||
> Image generation is experimental and may change in future versions.
|
||||
|
||||
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
|
||||
|
||||
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
|
||||
|
||||
#### Example
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
Progress updates during generation:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"image": "iVBORw0KGgoAAAANSUhEUg...",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -275,6 +275,73 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
- [x] `dimensions`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/images/generations` (experimental)
|
||||
|
||||
> Note: This endpoint is experimental and may change or be removed in future versions.
|
||||
|
||||
Generate images using image generation models.
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python images.py
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://localhost:11434/v1/',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model='x/z-image-turbo',
|
||||
prompt='A cute robot learning to paint',
|
||||
size='1024x1024',
|
||||
response_format='b64_json',
|
||||
)
|
||||
print(response.data[0].b64_json[:50] + '...')
|
||||
```
|
||||
|
||||
```javascript images.js
|
||||
import OpenAI from "openai";
|
||||
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:11434/v1/",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const response = await openai.images.generate({
|
||||
model: "x/z-image-turbo",
|
||||
prompt: "A cute robot learning to paint",
|
||||
size: "1024x1024",
|
||||
response_format: "b64_json",
|
||||
});
|
||||
|
||||
console.log(response.data[0].b64_json.slice(0, 50) + "...");
|
||||
```
|
||||
|
||||
```shell images.sh
|
||||
curl -X POST http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "A cute robot learning to paint",
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json"
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `prompt`
|
||||
- [x] `size` (e.g. "1024x1024")
|
||||
- [x] `response_format` (only `b64_json` supported)
|
||||
- [ ] `n`
|
||||
- [ ] `quality`
|
||||
- [ ] `style`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
@@ -2,6 +2,12 @@
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
|
||||

|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
@@ -27,21 +33,22 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
@@ -68,3 +75,4 @@ claude --model glm-4.7:cloud
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
@@ -269,6 +269,8 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -856,7 +858,9 @@ func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
|
||||
148
integration/imagegen_test.go
Normal file
148
integration/imagegen_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestImageGeneration(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 8)
|
||||
|
||||
type testCase struct {
|
||||
imageGenModel string
|
||||
visionModel string
|
||||
prompt string
|
||||
expectedWords []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
imageGenModel: "jmorgan/z-image-turbo",
|
||||
visionModel: "llama3.2-vision",
|
||||
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
|
||||
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Pull both models
|
||||
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
|
||||
t.Fatalf("failed to pull image gen model: %v", err)
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
|
||||
t.Fatalf("failed to pull vision model: %v", err)
|
||||
}
|
||||
|
||||
// Generate the image
|
||||
t.Logf("Generating image with prompt: %s", tc.prompt)
|
||||
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "image generation not available") {
|
||||
t.Skip("Target system does not support image generation")
|
||||
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
|
||||
t.Skip("Windows does not support image generation yet")
|
||||
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
|
||||
t.Skip("Driver is too old")
|
||||
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
|
||||
t.Skip("insufficient memory for image generation")
|
||||
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
|
||||
t.Skip("CUDA GPU is not available")
|
||||
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
|
||||
// most likely linux arm - not supported yet
|
||||
t.Skip("unsupported architecture")
|
||||
}
|
||||
t.Fatalf("failed to generate image: %v", err)
|
||||
}
|
||||
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode image: %v", err)
|
||||
}
|
||||
t.Logf("Generated image: %d bytes", len(imageData))
|
||||
|
||||
// Preload vision model and check GPU loading
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load vision model: %v", err)
|
||||
}
|
||||
|
||||
// Use vision model to describe the image
|
||||
chatReq := api.ChatRequest{
|
||||
Model: tc.visionModel,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
|
||||
Images: []api.ImageData{imageData},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Verify the vision model's response contains expected keywords
|
||||
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
|
||||
if response != nil {
|
||||
t.Logf("Vision model response: %s", response.Content)
|
||||
|
||||
// Additional detailed check for keywords
|
||||
content := strings.ToLower(response.Content)
|
||||
foundWords := []string{}
|
||||
missingWords := []string{}
|
||||
for _, word := range tc.expectedWords {
|
||||
if strings.Contains(content, word) {
|
||||
foundWords = append(foundWords, word)
|
||||
} else {
|
||||
missingWords = append(missingWords, word)
|
||||
}
|
||||
}
|
||||
t.Logf("Found keywords: %v", foundWords)
|
||||
if len(missingWords) > 0 {
|
||||
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateImage calls the Ollama API to generate an image and returns the base64 image data
|
||||
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
|
||||
var imageBase64 string
|
||||
|
||||
err := client.Generate(ctx, &api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
if resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate image: %w", err)
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return "", fmt.Errorf("no image data in response")
|
||||
}
|
||||
|
||||
return imageBase64, nil
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
@@ -143,6 +144,7 @@ var (
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"lfm2.5-thinking",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
@@ -263,6 +265,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
|
||||
@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -14,7 +14,7 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
Status string `json:"-"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
blobs, err := BlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
status: fmt.Sprintf("%s %s", status, digest),
|
||||
Status: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
status: fmt.Sprintf("using existing layer %s", digest),
|
||||
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
blobPath, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
95
manifest/paths.go
Normal file
95
manifest/paths.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
func Path() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathForName returns the path to the manifest file for a specific model name.
|
||||
func PathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func BlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PruneDirectory removes empty directories recursively.
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
@@ -541,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
if err := json.Unmarshal(data, &generateResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Only write response when done with image
|
||||
if generateResponse.Done && generateResponse.Image != "" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing prompt",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing model",
|
||||
body: `{
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Test that ImageWriter transforms GenerateResponse to OpenAI format
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: time.Unix(1234567890, 0).UTC(),
|
||||
Done: true,
|
||||
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
body := `{"model": "test-model", "prompt": "test"}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var imageResp openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if imageResp.Created != 1234567890 {
|
||||
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
|
||||
}
|
||||
|
||||
if len(imageResp.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
|
||||
}
|
||||
|
||||
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +162,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
304
model/models/glm4moelite/model.go
Normal file
304
model/models/glm4moelite/model.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
|
||||
kvLoraRank,
|
||||
qkNopeHeadDim,
|
||||
qkRopeHeadDim,
|
||||
kqNopeHeadDim,
|
||||
qkHeadDim int
|
||||
qLoraRank int
|
||||
vHeadDim int
|
||||
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads int
|
||||
|
||||
eps,
|
||||
ropeBase float32
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Q *nn.Linear `gguf:"attn_q"`
|
||||
|
||||
QA *nn.Linear `gguf:"attn_q_a"`
|
||||
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||
QB *nn.Linear `gguf:"attn_q_b"`
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
compressedKV.Stride(1), 1,
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
experts = experts.Mul(ctx, topKWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residuals := hiddenStates
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
scores := routerLogits.Sigmoid(ctx)
|
||||
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||
|
||||
if opts.normTopKProb {
|
||||
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
if i < firstDenseLayerIndex {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{}
|
||||
}
|
||||
}
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "glm4":
|
||||
pre = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glm4moelite", New)
|
||||
}
|
||||
410
model/models/lfm2/cache.go
Normal file
410
model/models/lfm2/cache.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
444
model/models/lfm2/cache_test.go
Normal file
444
model/models/lfm2/cache_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
}
|
||||
}
|
||||
253
model/models/lfm2/model.go
Normal file
253
model/models/lfm2/model.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
headDim, ropeDim int
|
||||
|
||||
eps, ropeBase, ropeScale float32
|
||||
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
|
||||
// per-layer head counts (LFM2 alternates attention and recurrent layers)
|
||||
numHeadsByLayer []int
|
||||
numKVHeadsByLayer []int
|
||||
}
|
||||
|
||||
func (o Options) headDimValue() int {
|
||||
// Head dim is shared across layers; fall back to first attention layer head count.
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
return cmp.Or(o.headDim, o.hiddenSize/h)
|
||||
}
|
||||
}
|
||||
return cmp.Or(o.headDim, o.hiddenSize)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
|
||||
headCount := 1
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
headCount = h
|
||||
break
|
||||
}
|
||||
}
|
||||
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "default":
|
||||
// use default BPE pretokenizer
|
||||
default:
|
||||
// llama-bpe style (default for LFM2)
|
||||
pretokenizers = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
},
|
||||
}
|
||||
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
hc, ok := c.(headCounts)
|
||||
if !ok {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
headCount := hc.HeadCount()
|
||||
headCountKV := hc.HeadCountKV()
|
||||
|
||||
m.numHeadsByLayer = make([]int, len(m.Layers))
|
||||
m.numKVHeadsByLayer = make([]int, len(m.Layers))
|
||||
for i := range m.Layers {
|
||||
m.numHeadsByLayer[i] = int(headCount[i])
|
||||
m.numKVHeadsByLayer[i] = int(headCountKV[i])
|
||||
|
||||
if m.numKVHeadsByLayer[i] == 0 {
|
||||
m.Layers[i].Operator = &ShortConv{}
|
||||
} else {
|
||||
m.Layers[i].Operator = &Attention{}
|
||||
}
|
||||
}
|
||||
|
||||
lCache := int(c.Uint("shortconv.l_cache"))
|
||||
dConv := max(0, lCache-1)
|
||||
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
|
||||
}
|
||||
|
||||
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
headDim := opts.headDimValue()
|
||||
numHeads := opts.numHeadsByLayer[layer]
|
||||
numKVHeads := opts.numKVHeadsByLayer[layer]
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("lfm2", New)
|
||||
}
|
||||
50
model/models/lfm2/shortconv.go
Normal file
50
model/models/lfm2/shortconv.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type shortConvKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
|
||||
// state stored in the HybridCache.
|
||||
type ShortConv struct {
|
||||
Conv *shortConvKernel `gguf:"shortconv.conv"`
|
||||
InProj *nn.Linear `gguf:"shortconv.in_proj"`
|
||||
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
|
||||
}
|
||||
|
||||
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
nSeqs := cache.numSeqs()
|
||||
seqTokens := cache.seqTokens()
|
||||
hiddenSize := hiddenStates.Dim(0)
|
||||
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
|
||||
panic("lfm2: unsupported batch layout for shortconv")
|
||||
}
|
||||
|
||||
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
|
||||
|
||||
elementSize := bcx.Stride(0)
|
||||
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
|
||||
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
state, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
panic("lfm2: failed to get conv state: " + err.Error())
|
||||
}
|
||||
sx := state.Concat(ctx, bx, 0)
|
||||
|
||||
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
|
||||
y := c.Mul(ctx, convOut)
|
||||
|
||||
dConv := sx.Dim(0) - seqTokens
|
||||
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
|
||||
|
||||
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
|
||||
}
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
410
model/parsers/glm46.go
Normal file
410
model/parsers/glm46.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type glm46ParserState int
|
||||
|
||||
const (
|
||||
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
|
||||
glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
glm46ParserState_CollectingThinking
|
||||
glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
glm46ParserState_CollectingContent
|
||||
glm46ParserState_ToolStartedEatingWhitespace
|
||||
glm46ParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
glm46ThinkingOpenTag = "<think>"
|
||||
glm46ThinkingCloseTag = "</think>"
|
||||
glm46ToolOpenTag = "<tool_call>"
|
||||
glm46ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46Event interface {
|
||||
isGLM46Event()
|
||||
}
|
||||
|
||||
type glm46EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventContent) isGLM46Event() {}
|
||||
|
||||
type glm46EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (glm46EventRawToolCall) isGLM46Event() {}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseGLM46ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46ParserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
|
||||
|
||||
case glm46ParserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ThinkingCloseTag) {
|
||||
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
|
||||
|
||||
case glm46ParserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
|
||||
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
|
||||
|
||||
case glm46ParserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ToolCloseTag) {
|
||||
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm46 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, glm46EventRawToolCall{raw: toolContent})
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
|
||||
type GLMToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeGLM46Content(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
// We need to escape text between tags, but not the tags themselves
|
||||
escaped := escapeGLM46Content(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
862
model/parsers/glm46_test.go
Normal file
862
model/parsers/glm46_test.go
Normal file
@@ -0,0 +1,862 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46ParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []glm46Event
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "leading whitespace before think tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n\t ",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "<think>thinking</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag with whitespace inside",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think> \n thinking content \n </think>regular content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
glm46EventContent{content: "regular content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with leading whitespace after opening tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call> \n test \n </tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "simple thinking then content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>I am thinking</think>Now I respond",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "I am thinking"},
|
||||
glm46EventContent{content: "Now I respond"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streamed thinking content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>hello",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
|
||||
},
|
||||
{
|
||||
input: " world",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>Let me call a tool</think>here is text<tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "Let me call a tool"},
|
||||
glm46EventContent{content: "here is text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with content after",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
|
||||
},
|
||||
{
|
||||
input: "ink>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nk>content</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split tool open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think>content<tool",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "_call>inside",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "inside"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "ought more",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nking is fun",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " <thinking is fun"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content\n<tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "\n<tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>content</tool",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "content</tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "first"},
|
||||
glm46EventContent{content: "between"},
|
||||
glm46EventRawToolCall{raw: "second"},
|
||||
glm46EventContent{content: "end"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - direct to content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "just content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "just content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - skip to content then tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "Here's the answer:<tool_call>test</tool_call>done",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "Here's the answer:"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "done"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - whitespace preserved when no tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n content with leading whitespace",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " \n content with leading whitespace"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after think close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after tool_call close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>test</tool_call> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace (single chunk)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace with newlines",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking\n\n ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content trailing whitespace emitted when more content arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "more thinking",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: " more thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace before partial close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking </th",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ink>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := GLM46Parser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
|
||||
// document order when collecting multiple elements with the same tag name into slices.
|
||||
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
|
||||
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
xml string
|
||||
wantKeys []string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "alternating keys and values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>A</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>B</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>C</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"first", "second", "third"},
|
||||
wantValues: []string{"A", "B", "C"},
|
||||
},
|
||||
{
|
||||
name: "all keys then all values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_key>key2</arg_key>
|
||||
<arg_key>key3</arg_key>
|
||||
<arg_value>val1</arg_value>
|
||||
<arg_value>val2</arg_value>
|
||||
<arg_value>val3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"key1", "key2", "key3"},
|
||||
wantValues: []string{"val1", "val2", "val3"},
|
||||
},
|
||||
{
|
||||
name: "mixed grouping",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>a</arg_key>
|
||||
<arg_value>1</arg_value>
|
||||
<arg_key>b</arg_key>
|
||||
<arg_key>c</arg_key>
|
||||
<arg_value>2</arg_value>
|
||||
<arg_value>3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"a", "b", "c"},
|
||||
wantValues: []string{"1", "2", "3"},
|
||||
},
|
||||
{
|
||||
name: "reverse order - all values then all keys",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_value>X</arg_value>
|
||||
<arg_value>Y</arg_value>
|
||||
<arg_value>Z</arg_value>
|
||||
<arg_key>x</arg_key>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_key>z</arg_key>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"x", "y", "z"},
|
||||
wantValues: []string{"X", "Y", "Z"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var parsed GLMToolCallXML
|
||||
err := xml.Unmarshal([]byte(tc.xml), &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal XML: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
|
||||
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
|
||||
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM46ToolCallParsing(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-current-weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>New York, NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `calculate
|
||||
<arg_key>x</arg_key>
|
||||
<arg_value>3.14</arg_value>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>["a", "b", "c"]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "function name with whitespace",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: ` get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "values with special characters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `execute-command
|
||||
<arg_key>command</arg_key>
|
||||
<arg_value>ls && echo "done"</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>a < b and c > d</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute-command",
|
||||
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode in function names and values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `获取天气
|
||||
<arg_key>城市</arg_key>
|
||||
<arg_value>北京</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>Hello! 你好! 🌟</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param1</arg_key>
|
||||
<arg_value></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param1": ""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special chars in arg_key names",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param<1></arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>a&b</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>x>y</arg_key>
|
||||
<arg_value>value3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive ampersands",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>test &&&& more</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "test &&&& more"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed special chars together",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value><>&<>&</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "<>&<>&"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "newlines and tabs in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>multiline</arg_key>
|
||||
<arg_value>line1
|
||||
indented line2
|
||||
line3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single and double quotes in values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>quotes</arg_key>
|
||||
<arg_value>She said "Hello's there!"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CDATA-like content that should be treated as text",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>cdata</arg_key>
|
||||
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all special XML entities",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>entities</arg_key>
|
||||
<arg_value><>&'"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"entities": "<>&'""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with multiple parameters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>value3</arg_value>
|
||||
<arg_key>fourth</arg_key>
|
||||
<arg_value>value4</arg_value>
|
||||
<arg_key>fifth</arg_key>
|
||||
<arg_value>value5</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with identical key names but different positions",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>first occurrence</arg_value>
|
||||
<arg_key>other</arg_key>
|
||||
<arg_value>middle</arg_value>
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>second occurrence</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
// Later occurrence should overwrite earlier one
|
||||
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array with mixed types",
|
||||
tools: []api.Tool{
|
||||
tool("process", map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `process
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>[1, "hello", true, null]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: args(`{"items": [1, "hello", true, null]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
tools: []api.Tool{
|
||||
tool("test", map[string]api.ToolProperty{
|
||||
"tags": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `test
|
||||
<arg_key>tags</arg_key>
|
||||
<arg_value>[]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: args(`{"tags": []}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with array of objects",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
|
||||
}),
|
||||
},
|
||||
// <tool_call>TodoWrite
|
||||
// <arg_key>todos</arg_key>
|
||||
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
|
||||
// </tool_call>
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with plain string",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {Type: api.PropertyType{"array", "string"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>Error: could not load todos</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": "Error: could not load todos"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
|
||||
if err != nil {
|
||||
t.Errorf("case %d (%s): %v", i, tc.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
|
||||
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
20
model/parsers/glm47.go
Normal file
20
model/parsers/glm47.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
|
||||
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type GLM47Parser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
99
model/parsers/glm47_test.go
Normal file
99
model/parsers/glm47_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47ParserAdd(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init([]api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"count": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
}),
|
||||
}, nil, nil)
|
||||
|
||||
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
|
||||
// so the model output does NOT include the opening <think> tag.
|
||||
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "plan" {
|
||||
t.Fatalf("expected thinking 'plan', got %q", thinking)
|
||||
}
|
||||
if content != "Answer" {
|
||||
t.Fatalf("expected content 'Answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
expectedArgs := args(`{"count": 3, "enabled": true}`)
|
||||
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
|
||||
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserNoThinkingContent(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
// When thinking is enabled but model has no thinking to output,
|
||||
// it should output </think> immediately followed by content.
|
||||
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserThinkingDisabled(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
// When thinking is disabled, parser stays in LookingForThinkingOpen state
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
// Model outputs plain content (prompt ended with </think>)
|
||||
content, thinking, calls, err := parser.Add("Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
|
||||
<arg_key>expr</arg_key>
|
||||
<arg_value>a < b && c > d</arg_value>`}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
expected := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: args(`{"expr": "a < b && c > d"}`),
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(toolCall, expected) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
498
model/parsers/lfm2.go
Normal file
498
model/parsers/lfm2.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2ParserState int
|
||||
|
||||
const (
|
||||
LFM2CollectingThinking LFM2ParserState = iota
|
||||
LFM2CollectingContent
|
||||
LFM2CollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
lfm2ThinkingOpenTag = "<think>"
|
||||
lfm2ThinkingCloseTag = "</think>"
|
||||
lfm2ToolCallStartTag = "<|tool_call_start|>"
|
||||
lfm2ToolCallEndTag = "<|tool_call_end|>"
|
||||
)
|
||||
|
||||
type LFM2Parser struct {
|
||||
state LFM2ParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
|
||||
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// Check both model capability AND request preference
|
||||
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
if !thinkingEnabled {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = LFM2CollectingThinking
|
||||
p.needsThinkingLeadingTrim = true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
type lfm2Event interface {
|
||||
isLFM2Event()
|
||||
}
|
||||
|
||||
type lfm2EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (lfm2EventThinkingContent) isLFM2Event() {}
|
||||
func (lfm2EventContent) isLFM2Event() {}
|
||||
func (lfm2EventToolCall) isLFM2Event() {}
|
||||
|
||||
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case lfm2EventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case lfm2EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case lfm2EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) parseEvents() []lfm2Event {
|
||||
var all []lfm2Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []lfm2Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
|
||||
var events []lfm2Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case LFM2CollectingThinking:
|
||||
// Strip opening <think> tag if present
|
||||
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
|
||||
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
|
||||
p.needsThinkingLeadingTrim = true
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
|
||||
// Trim leading whitespace after <think> tag (may span multiple chunks)
|
||||
if p.needsThinkingLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content or buffer is empty
|
||||
if len(bufStr) > 0 {
|
||||
p.needsThinkingLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
|
||||
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
|
||||
thinking := split[0]
|
||||
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||
|
||||
remaining := split[1]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingContent
|
||||
p.needsThinkingLeadingTrim = false
|
||||
// Set flag to trim any additional whitespace that may arrive in later chunks
|
||||
p.needsContentLeadingTrim = len(remaining) == 0
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else { // otherwise its thinking content
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingContent:
|
||||
// Trim leading whitespace after </think> tag (may span multiple chunks)
|
||||
if p.needsContentLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content
|
||||
if len(bufStr) > 0 {
|
||||
p.needsContentLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
|
||||
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingToolCalls
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, lfm2EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
} else { // otherwise its content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, lfm2EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingToolCalls:
|
||||
// Look for complete tool call JSON between tags
|
||||
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
|
||||
toolCallContent := bufStr[:idx]
|
||||
|
||||
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
|
||||
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
|
||||
|
||||
// Check if there's another tool call
|
||||
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
|
||||
remaining = remaining[len(lfm2ToolCallStartTag):]
|
||||
} else {
|
||||
// No more tool calls, go back to content
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
p.state = LFM2CollectingContent
|
||||
}
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
events = append(events, lfm2EventToolCall{toolCall: tc})
|
||||
}
|
||||
return events, true
|
||||
} else if err != nil {
|
||||
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
|
||||
}
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseToolCallsContent parses one or more tool calls from content
|
||||
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
|
||||
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Try JSON format first: {"name": "func", "arguments": {...}}
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if len(parsed.Arguments) > 0 {
|
||||
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
args = api.NewToolCallFunctionArguments()
|
||||
}
|
||||
|
||||
return []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
return p.parsePythonStyleToolCalls(content)
|
||||
}
|
||||
|
||||
// parsePythonStyleToolCalls parses one or more Python-style tool calls
|
||||
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
|
||||
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Strip outer brackets if present: [func(...)] -> func(...)
|
||||
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
|
||||
content = content[1 : len(content)-1]
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
// Parse multiple function calls separated by commas at the top level
|
||||
for len(content) > 0 {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip leading comma from previous iteration
|
||||
if strings.HasPrefix(content, ",") {
|
||||
content = strings.TrimSpace(content[1:])
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find function name
|
||||
parenIdx := strings.Index(content, "(")
|
||||
if parenIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no opening parenthesis")
|
||||
}
|
||||
|
||||
funcName := strings.TrimSpace(content[:parenIdx])
|
||||
if funcName == "" {
|
||||
return nil, errors.New("invalid tool call: empty function name")
|
||||
}
|
||||
|
||||
// Find matching closing parenthesis
|
||||
closeIdx := findMatchingParen(content, parenIdx)
|
||||
if closeIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no matching closing parenthesis")
|
||||
}
|
||||
|
||||
argsStr := content[parenIdx+1 : closeIdx]
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
|
||||
if argsStr != "" {
|
||||
if err := parsePythonArgs(argsStr, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
|
||||
// Move past this function call
|
||||
content = content[closeIdx+1:]
|
||||
}
|
||||
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, errors.New("no tool calls found")
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
|
||||
// Returns -1 if not found. Handles nested parentheses and quoted strings.
|
||||
func findMatchingParen(s string, openIdx int) int {
|
||||
depth := 1
|
||||
i := openIdx + 1
|
||||
for i < len(s) && depth > 0 {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
case '\'', '"':
|
||||
// Skip quoted string
|
||||
quote := s[i]
|
||||
i++
|
||||
for i < len(s) && s[i] != quote {
|
||||
if s[i] == '\\' && i+1 < len(s) {
|
||||
i++ // skip escaped char
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
|
||||
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||
calls, err := p.parseToolCallsContent(content)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
if len(calls) == 0 {
|
||||
return api.ToolCall{}, errors.New("no tool call found")
|
||||
}
|
||||
return calls[0], nil
|
||||
}
|
||||
|
||||
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
|
||||
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
|
||||
// Simple state machine to parse key='value' pairs
|
||||
// Handles: command='ls', flag="-la", count=42, enabled=true
|
||||
var key string
|
||||
i := 0
|
||||
|
||||
for i < len(argsStr) {
|
||||
// Skip whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse key
|
||||
keyStart := i
|
||||
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) || argsStr[i] != '=' {
|
||||
return errors.New("invalid argument: expected '='")
|
||||
}
|
||||
key = strings.TrimSpace(argsStr[keyStart:i])
|
||||
i++ // skip '='
|
||||
|
||||
// Skip whitespace after =
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
|
||||
// Parse value
|
||||
var value string
|
||||
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
|
||||
// Quoted string
|
||||
quote := argsStr[i]
|
||||
i++
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != quote {
|
||||
if argsStr[i] == '\\' && i+1 < len(argsStr) {
|
||||
i += 2 // skip escaped char
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
value = argsStr[valueStart:i]
|
||||
if i < len(argsStr) {
|
||||
i++ // skip closing quote
|
||||
}
|
||||
args.Set(key, value)
|
||||
} else {
|
||||
// Unquoted value (number, bool, etc)
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
value = strings.TrimSpace(argsStr[valueStart:i])
|
||||
|
||||
// Try to parse as number or bool
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if value == "true" {
|
||||
args.Set(key, true)
|
||||
} else if value == "false" {
|
||||
args.Set(key, false)
|
||||
} else {
|
||||
args.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip comma and whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1088
model/parsers/lfm2_test.go
Normal file
1088
model/parsers/lfm2_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,6 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@@ -14,243 +13,114 @@ const (
|
||||
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
|
||||
Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
Nemotron3NanoCollectingContent
|
||||
Nemotron3NanoCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronToolCallClose = "</tool_call>"
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
)
|
||||
|
||||
type Nemotron3NanoParser struct {
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||
|
||||
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.toolParser = &Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
|
||||
// thinking is enabled if user requests it
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
if !thinkingEnabled {
|
||||
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
type nemotronEvent interface {
|
||||
isNemotronEvent()
|
||||
}
|
||||
|
||||
type nemotronEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (nemotronEventThinkingContent) isNemotronEvent() {}
|
||||
func (nemotronEventContent) isNemotronEvent() {}
|
||||
func (nemotronEventToolCall) isNemotronEvent() {}
|
||||
|
||||
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case nemotronEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case nemotronEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case nemotronEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
if p.state == Nemotron3NanoCollectingContent {
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
|
||||
var all []nemotronEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []nemotronEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
|
||||
}
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Nemotron3NanoCollectingThinking:
|
||||
if strings.Contains(bufStr, nemotronThinkClose) {
|
||||
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.WriteString(remainder)
|
||||
// Transition to whitespace-skipping state if buffer is empty,
|
||||
// otherwise go directly to content collection
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
if thinking != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
// We only want to skip whitespace between thinking and content
|
||||
case Nemotron3NanoSkipWhitespaceAfterThinking:
|
||||
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
|
||||
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
||||
if s == "" {
|
||||
return "", "", nil, nil
|
||||
}
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return nil, true
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
|
||||
case Nemotron3NanoCollectingContent:
|
||||
if strings.Contains(bufStr, nemotronToolCallOpen) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
|
||||
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = Nemotron3NanoCollectingToolCalls
|
||||
if content != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
|
||||
// Nemotron3NanoCollectingThinking - buffer and look for end markers
|
||||
p.buffer.WriteString(s)
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
|
||||
thinkIdx := strings.Index(bufStr, nemotronThinkClose)
|
||||
toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
|
||||
|
||||
var endIdx int = -1
|
||||
var remainder string
|
||||
|
||||
if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
|
||||
endIdx = thinkIdx
|
||||
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
|
||||
} else if toolIdx != -1 {
|
||||
endIdx = toolIdx
|
||||
remainder = bufStr[toolIdx:] // Include <tool_call> tag
|
||||
}
|
||||
|
||||
if endIdx != -1 {
|
||||
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
|
||||
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
content, _, calls, err = p.toolParser.Add(remainder, done)
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case Nemotron3NanoCollectingToolCalls:
|
||||
if strings.Contains(bufStr, nemotronToolCallClose) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []nemotronEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, nemotronEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, nemotronToolCallOpen) {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
return content, thinking, calls, err
|
||||
}
|
||||
|
||||
return nil, false
|
||||
// No end marker - emit unambiguous thinking
|
||||
thinking = p.emitThinking(bufStr)
|
||||
return "", thinking, nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
|
||||
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
|
||||
)
|
||||
// emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
|
||||
func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
|
||||
// Check for partial </think> or <tool_call> at end
|
||||
thinkOverlap := overlap(bufStr, nemotronThinkClose)
|
||||
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
|
||||
maxOverlap := max(thinkOverlap, toolOverlap)
|
||||
|
||||
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name
|
||||
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
|
||||
if len(fnMatch) < 2 {
|
||||
return toolCall, nil
|
||||
}
|
||||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
paramName := match[1]
|
||||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
}
|
||||
if maxOverlap > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-maxOverlap]
|
||||
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
|
||||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parseValue(raw, paramType)
|
||||
// No partial tags - emit all but trailing whitespace
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
if wsLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-wsLen]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
// Nothing to hold back
|
||||
p.buffer.Reset()
|
||||
return bufStr
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
|
||||
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
|
||||
func TestNemotron3NanoParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -17,18 +19,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content - no thinking",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple content - thinking disabled",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "thinking then content",
|
||||
input: "Let me think about this...</think>\nHere is my answer.",
|
||||
@@ -43,69 +33,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
|
||||
expectedContent: "The answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters",
|
||||
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then tool call",
|
||||
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
@@ -135,19 +62,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter value",
|
||||
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block - immediate close",
|
||||
input: "</think>\nHere is my answer.",
|
||||
@@ -161,18 +75,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "</think>\nSome content after spurious tag.",
|
||||
},
|
||||
{
|
||||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
input: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "thinking with only whitespace after close tag",
|
||||
input: "My thoughts...</think> \n\t\n Content here.",
|
||||
@@ -180,25 +82,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "My thoughts...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
input: "Hello 世界! 🌍 Ñoño",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello 世界! 🌍 Ñoño",
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric parameter",
|
||||
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -233,6 +116,8 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
|
||||
// Tool call streaming is tested in qwen3coder_test.go.
|
||||
func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -242,18 +127,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content character by character",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "streaming content small tokens",
|
||||
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you today?",
|
||||
},
|
||||
{
|
||||
name: "streaming thinking then content - granular",
|
||||
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
|
||||
@@ -268,45 +141,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process",
|
||||
expectedContent: "The answer.",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call - highly granular",
|
||||
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call - granular",
|
||||
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split character by character",
|
||||
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking close tag split character by character",
|
||||
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
|
||||
@@ -321,22 +155,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Thinking...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters - streaming",
|
||||
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then content then tool call - streaming",
|
||||
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
|
||||
@@ -352,45 +170,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls - streaming",
|
||||
chunks: []string{
|
||||
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
|
||||
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
|
||||
"</function>", "\n", "</tool_call>", "\n",
|
||||
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
|
||||
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
|
||||
"</function>\n", "</tool_call>",
|
||||
},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter - streaming",
|
||||
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block",
|
||||
chunks: []string{"</think>", "\n", "Just content."},
|
||||
@@ -398,12 +177,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "",
|
||||
expectedContent: "Just content.",
|
||||
},
|
||||
{
|
||||
name: "empty input chunks interspersed",
|
||||
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello world!",
|
||||
},
|
||||
{
|
||||
name: "tool call immediately after think close - no content",
|
||||
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
|
||||
@@ -418,25 +191,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with empty parameter value",
|
||||
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tool call tag at end - buffered",
|
||||
chunks: []string{"Here's some content", "<tool"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Here's some content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -572,3 +326,65 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
|
||||
// but the model outputs content + tool call WITHOUT the </think> tag.
|
||||
// The parser should still parse the tool call (content before is treated as thinking).
|
||||
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
|
||||
chunks := []string{
|
||||
"Let", " me", " analyze", " this", ".", "\n",
|
||||
"<tool_call>", "\n",
|
||||
"<function=get_weather>", "\n",
|
||||
"<parameter=city>", "Paris", "</parameter>", "\n",
|
||||
"</function>", "\n",
|
||||
"</tool_call>",
|
||||
}
|
||||
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
|
||||
|
||||
var allContent string
|
||||
var allThinking string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range chunks {
|
||||
content, thinking, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, thinking, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
|
||||
expectedThinking := "Let me analyze this."
|
||||
|
||||
expectedCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if allContent != "" {
|
||||
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
|
||||
}
|
||||
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,12 @@ func ParserForName(name string) Parser {
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "lfm2":
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Parser{hasThinkingSupport: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -91,6 +91,37 @@ func TestQwenParserStreaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tags split character by character",
|
||||
steps: []step{
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "b", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "/", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
|
||||
@@ -96,3 +96,11 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
110
model/renderers/glm46.go
Normal file
110
model/renderers/glm46.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GLM46Renderer struct{}
|
||||
|
||||
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
223
model/renderers/glm46_test.go
Normal file
223
model/renderers/glm46_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
skip string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip != "" {
|
||||
t.Skip(tt.skip)
|
||||
}
|
||||
renderer := &GLM46Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
170
model/renderers/glm47.go
Normal file
170
model/renderers/glm47.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// GLM47Renderer renders messages for GLM-4.7 models.
|
||||
//
|
||||
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type GLM47Renderer struct{}
|
||||
|
||||
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatGLM47ToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
191
model/renderers/glm47_test.go
Normal file
191
model/renderers/glm47_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
|
||||
},
|
||||
{
|
||||
name: "system and user",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello there"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with reasoning_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Answer with reasoning."},
|
||||
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "assistant", Content: "It is 22C."},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls and responses",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Compare weather"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "tool", Content: `{"temperature":18}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "preserved thinking in multi-turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think step by step"},
|
||||
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
|
||||
{Role: "user", Content: "Continue"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &GLM47Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
144
model/renderers/lfm2.go
Normal file
144
model/renderers/lfm2.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2Renderer struct {
|
||||
IsThinking bool
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
|
||||
|
||||
// Extract first system message if present (to combine with tools)
|
||||
var firstSystemContent string
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
firstSystemContent = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
// Append tools to first system content
|
||||
if len(tools) > 0 {
|
||||
if firstSystemContent != "" {
|
||||
firstSystemContent += "\n"
|
||||
}
|
||||
firstSystemContent += "List of tools: ["
|
||||
for i, tool := range tools {
|
||||
toolJSON, err := json.Marshal(tool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
firstSystemContent += string(toolJSON)
|
||||
if i < len(tools)-1 {
|
||||
firstSystemContent += ", "
|
||||
}
|
||||
}
|
||||
firstSystemContent += "]"
|
||||
}
|
||||
|
||||
// Output first system block if it has content
|
||||
if firstSystemContent != "" {
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(firstSystemContent)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
// Find the index of the last assistant message for thinking stripping
|
||||
lastAssistantIndex := -1
|
||||
for i := len(messages) - 1; i >= startIdx; i-- {
|
||||
if messages[i].Role == "assistant" {
|
||||
lastAssistantIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Track whether we need to add generation prompt
|
||||
needsGenerationPrompt := len(messages) > 0
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
switch message.Role {
|
||||
case "system":
|
||||
// Additional system messages (after the first) are rendered normally
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
// Check if this is the last assistant message
|
||||
isLastAssistant := i == lastAssistantIndex
|
||||
|
||||
// Process content (may need thinking stripped)
|
||||
content := message.Content
|
||||
|
||||
// Handle thinking tags in assistant content
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
if strings.Contains(content, "</think>") {
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 1 {
|
||||
if !isLastAssistant && !keepPastThinking {
|
||||
// Strip thinking entirely for past assistant messages
|
||||
content = strings.TrimSpace(parts[1])
|
||||
} else {
|
||||
// Preserve thinking but trim whitespace after </think>
|
||||
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
// Assistant with tool calls - write content first (if any after stripping)
|
||||
if content != "" {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool_call_start|>")
|
||||
toolCallJSON := map[string]any{
|
||||
"name": toolCall.Function.Name,
|
||||
"arguments": toolCall.Function.Arguments,
|
||||
}
|
||||
callJSON, _ := json.Marshal(toolCallJSON)
|
||||
sb.WriteString(string(callJSON))
|
||||
sb.WriteString("<|tool_call_end|>")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered as plain messages per the chat template
|
||||
sb.WriteString("<|im_start|>tool\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
// Note: Model is a "thinking-only" model - it will output <think> itself
|
||||
// We don't add <think> tag to the prompt
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
427
model/renderers/lfm2_test.go
Normal file
427
model/renderers/lfm2_test.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestLFM2Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple system messages rendered separately",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction."},
|
||||
{Role: "system", Content: "Second instruction."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "only system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
|
||||
name: "user-assistant-user: last assistant preserves thinking",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reasoning</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With two assistants, first is stripped (not last), second preserved (is last)
|
||||
name: "multi-turn thinking: first stripped, second preserved",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With thinking enabled (keep_past_thinking=true), both preserved
|
||||
name: "multi-turn thinking: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "assistant", Content: "Let me check."},
|
||||
{Role: "tool", Content: "22C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tools without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get time",
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "user-tool sequence",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "full tool call cycle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "assistant", Content: "Let me check"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "It's 22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "你好世界! مرحبا 🌍"},
|
||||
{Role: "assistant", Content: "Hello! 👋"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "newlines in content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
|
||||
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "empty assistant content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "OK"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Generation prompt does NOT include <think> - model outputs it
|
||||
name: "generation prompt has no think tag",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think hard"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Interleaved: thinking before tool call - last assistant preserves thinking
|
||||
name: "thinking before tool call (last assistant)",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>I need to check the weather</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tool calls - first has thinking stripped
|
||||
name: "two assistants with tools: first thinking stripped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tools - both preserved when thinking enabled
|
||||
name: "two assistants with tools: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Content before thinking before tool call
|
||||
name: "content then thinking then tool call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.<think>Using weather API</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LFM2Renderer{IsThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,6 +80,12 @@ func rendererForName(name string) Renderer {
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
package renderers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func propsMap(s string) *api.ToolPropertiesMap {
|
||||
var result api.ToolPropertiesMap
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in propsMap(): " + err.Error())
|
||||
}
|
||||
return &result
|
||||
}
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
@@ -733,3 +737,60 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
var data []ImageURLOrData
|
||||
if resp.Image != "" {
|
||||
data = []ImageURLOrData{{B64JSON: resp.Image}}
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -60,7 +60,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
@@ -71,10 +71,12 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,12 +84,10 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,7 +154,6 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
@@ -166,13 +165,12 @@ _build_macapp() {
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -180,7 +178,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := GetBlobsPath(files[fn])
|
||||
blobPath, err := manifest.BlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(t, mediaType)
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
var layers []manifest.Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := GetBlobsPath(layer.Digest)
|
||||
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
n model.Name
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
@@ -465,10 +467,10 @@ type downloadOpts struct {
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
fp, err := manifest.BlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
|
||||
215
server/images.go
215
server/images.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -41,6 +41,7 @@ var (
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
@@ -272,44 +274,22 @@ func (m *Model) String() string {
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
n := model.ParseName(name)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
m := &Model{
|
||||
Name: n.String(),
|
||||
ShortName: n.DisplayShortest(),
|
||||
Digest: mf.Digest(),
|
||||
Template: template.DefaultTemplate,
|
||||
}
|
||||
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if mf.Config.Digest != "" {
|
||||
filename, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -320,29 +300,29 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
for _, layer := range mf.Layers {
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -350,7 +330,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -360,7 +340,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.System = string(bts)
|
||||
m.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -369,7 +349,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -379,7 +359,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -387,11 +367,11 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.License = append(model.License, string(bts))
|
||||
m.License = append(m.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
@@ -406,7 +386,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -435,7 +415,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||
manifests, err := Manifests(true)
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -450,7 +430,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k := range deleteMap {
|
||||
fp, err := GetBlobsPath(k)
|
||||
fp, err := manifest.BlobsPath(k)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
@@ -466,7 +446,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
func PruneLayers() error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
p, err := GetBlobsPath("")
|
||||
p, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -481,9 +461,9 @@ func PruneLayers() error {
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := GetBlobsPath(name)
|
||||
_, err := manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||
@@ -508,63 +488,30 @@ func PruneLayers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
manifestPath, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -572,7 +519,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -580,17 +527,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -609,44 +556,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
manifest, _, err := GetManifest(mp)
|
||||
existingMf, err := manifest.ParseNamedManifest(n)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||
} else {
|
||||
for _, l := range manifest.Layers {
|
||||
for _, l := range existingMf.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
if manifest.Config.Digest != "" {
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -656,7 +603,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
n: n,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
@@ -675,7 +622,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
fp, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -690,16 +637,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
delete(deleteMap, mf.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -726,9 +673,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -736,7 +683,7 @@ func hasTensorLayers(layers []Layer) bool {
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -745,12 +692,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
destDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -775,14 +722,14 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
@@ -793,12 +740,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -810,7 +757,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -820,12 +767,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
srcDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -850,7 +797,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -862,13 +809,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
ManifestRef: n.Tag,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
@@ -878,7 +825,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var m Manifest
|
||||
var m manifest.Manifest
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -916,7 +863,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1040,7 +987,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
fp, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
{
|
||||
name: "model missing image generation capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedErrMsg: "does not support image generation",
|
||||
},
|
||||
{
|
||||
name: "model with image generation capability",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -20,19 +21,19 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := ParseNamedManifest(name)
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = ParseNamedManifest(name)
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultRegistry = "registry.ollama.ai"
|
||||
DefaultNamespace = "library"
|
||||
DefaultTag = "latest"
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
|
||||
// don't quantize vision encoder tensors (named with "v." prefix)
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
@@ -219,6 +219,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
|
||||
|
||||
// do not quantize LFM2's shortconv kernel weights
|
||||
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
|
||||
|
||||
208
server/routes.go
208
server/routes.go
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
@@ -51,7 +52,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +165,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +192,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -344,7 +316,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
// expire the runner if unload is requested (empty prompt, keep alive is 0)
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
@@ -358,6 +330,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
@@ -997,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
existing, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
@@ -1041,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
m, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
@@ -1103,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, ErrModelPathInvalid
|
||||
return nil, model.Unqualified(name)
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -1125,7 +1103,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
@@ -1133,6 +1111,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1142,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1154,8 +1148,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
ModifiedAt: mf.FileInfo().ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||
// default we return an empty map.
|
||||
ModelInfo: make(map[string]any),
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
@@ -1215,7 +1212,27 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -1269,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
ms, err := Manifests(true)
|
||||
ms, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1300,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Digest: m.Digest(),
|
||||
ModifiedAt: m.FileInfo().ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1360,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1376,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||
p, err := GetBlobsPath(ib)
|
||||
p, err := manifest.BlobsPath(ib)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1394,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1412,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1587,13 +1604,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -1613,7 +1629,7 @@ func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
||||
blobsDir, err := GetBlobsPath("")
|
||||
blobsDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1622,7 +1638,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() {
|
||||
if _, err := Manifests(false); err != nil {
|
||||
if _, err := manifest.Manifests(false); err != nil {
|
||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||
} else {
|
||||
// clean up unused layers and manifests
|
||||
@@ -1630,12 +1646,12 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := GetManifestPath()
|
||||
manifestsPath, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -2460,3 +2476,91 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// handleImageGenerate handles image generation requests within GenerateHandler.
|
||||
// This is called when the model has the Image capability.
|
||||
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
|
||||
// Validate image dimensions
|
||||
const maxDimension int32 = 4096
|
||||
if req.Width > maxDimension || req.Height > maxDimension {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner for image generation
|
||||
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request (empty prompt)
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
if s, ok := req.Options["seed"]; ok {
|
||||
switch v := s.(type) {
|
||||
case int:
|
||||
seed = int64(v)
|
||||
case int64:
|
||||
seed = v
|
||||
case float64:
|
||||
seed = int64(v)
|
||||
}
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if cr.Image != "" {
|
||||
res.Image = cr.Image
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.Metrics.TotalDuration = time.Since(checkpointStart)
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
}); err != nil {
|
||||
// Only send JSON error if streaming hasn't started yet
|
||||
// (once streaming starts, headers are committed and we can't change status code)
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if manifest.Config.Digest == "" {
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2101,3 +2101,95 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateUnload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
loadFnCalled = true
|
||||
req.successCh <- &runnerRef{llama: &mockRunner{}}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
|
||||
loadFnCalled = false
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "",
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.DoneReason != "unload" {
|
||||
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Error("expected done to be true")
|
||||
}
|
||||
|
||||
if loadFnCalled {
|
||||
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -574,7 +574,8 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -21,12 +21,14 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -51,7 +53,7 @@ const (
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
@@ -279,7 +281,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
CapabilityImage = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
|
||||
const MissingPart = "!MISSING!"
|
||||
|
||||
const (
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
// DefaultName returns a name with the default values for the host, namespace,
|
||||
// and tag parts. The model and digest parts are empty.
|
||||
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||
//
|
||||
// - The default host is ("registry.ollama.ai")
|
||||
// - The default namespace is ("library")
|
||||
// - The default tag is ("latest")
|
||||
// - The default protocol scheme is ("https")
|
||||
func DefaultName() Name {
|
||||
return Name{
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
ProtocolScheme: defaultProtocolScheme,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +91,11 @@ func (k partKind) String() string {
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
ProtocolScheme string
|
||||
}
|
||||
|
||||
// ParseName parses and assembles a Name from a name string. The
|
||||
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
|
||||
}
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
if ok {
|
||||
n.ProtocolScheme = scheme
|
||||
} else {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
// BaseURL returns the base URL for the registry.
|
||||
func (n Name) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: n.ProtocolScheme,
|
||||
Host: n.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||
func (n Name) DisplayNamespaceModel() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
b.WriteString(n.Model)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
||||
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
ProtocolScheme: "scheme",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
|
||||
282
x/create/client/create.go
Normal file
282
x/create/client/create.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := manifest.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if quantize != "" {
|
||||
return createQuantizedLayers(r, name, dtype, shape, quantize)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to manifest.Layer
|
||||
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
146
x/create/client/create_test.go
Normal file
146
x/create/client/create_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
@@ -11,10 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
@@ -50,9 +51,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize with affine mode: group_size=32, bits=8
|
||||
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
399
x/create/create.go
Normal file
399
x/create/create.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
752
x/create/create_test.go
Normal file
752
x/create/create_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,40 +12,24 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
@@ -77,8 +61,8 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize == "fp8" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to fp8"
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
@@ -103,11 +87,14 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
|
||||
// Determine if this tensor should be quantized
|
||||
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
@@ -119,6 +106,19 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
@@ -164,11 +164,11 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info
|
||||
if quantize == "fp8" {
|
||||
cfg["quantization"] = "FP8"
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
} else {
|
||||
cfg["quantization"] = "BF16"
|
||||
cfg["quantization"] = torchDtype
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
@@ -211,3 +211,12 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
// URL format not supported when using base64 transfer
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
} else {
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
6
x/imagegen/cache/step.go
vendored
6
x/imagegen/cache/step.go
vendored
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// Supports both single-stream and dual-stream architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,79 +38,20 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
// Hide from main flags section - shown in separate section via AppendFlagsDocs
|
||||
cmd.Flags().MarkHidden("width")
|
||||
cmd.Flags().MarkHidden("height")
|
||||
cmd.Flags().MarkHidden("steps")
|
||||
@@ -119,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().MarkHidden("negative")
|
||||
}
|
||||
|
||||
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
|
||||
func AppendFlagsDocs(cmd *cobra.Command) {
|
||||
usage := `
|
||||
Image Generation Flags (experimental):
|
||||
--width int Image width
|
||||
--height int Image height
|
||||
--steps int Denoising steps
|
||||
--seed int Random seed
|
||||
--negative str Negative prompt
|
||||
`
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
|
||||
}
|
||||
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
@@ -158,17 +111,15 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
req.Options = map[string]any{"seed": opts.Seed}
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -182,32 +133,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
p.StopAndClear()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -245,6 +189,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
return err
|
||||
}
|
||||
|
||||
// Preload the model with the specified keepalive
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
preloadReq := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
KeepAlive: keepAlive,
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
|
||||
return nil
|
||||
}); err != nil {
|
||||
p.StopAndClear()
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
p.StopAndClear()
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Describe an image to generate (/help for commands)",
|
||||
@@ -282,7 +243,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
case strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
|
||||
printInteractiveHelp(opts)
|
||||
printInteractiveHelp()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set "):
|
||||
if err := handleSetCommand(line[5:], &opts); err != nil {
|
||||
@@ -301,12 +262,12 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
req.Options = map[string]any{"seed": opts.Seed}
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -321,32 +282,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
p.StopAndClear()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
@@ -397,12 +351,13 @@ func sanitizeFilename(s string) string {
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
// TODO: reconcile /set commands with /set parameter in text gen REPL (cmd/cmd.go)
|
||||
func printInteractiveHelp() {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed")
|
||||
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
|
||||
fmt.Fprintln(os.Stderr, " /show Show current settings")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
// When quantize is true, returns multiple layers (weight + scales)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
// Check if quantization is supported
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales (use _scale suffix convention)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name, // Keep original name for weight
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale", // Add _scale suffix
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, imagegen.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias", // Add _qbias suffix
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// Non-quantized path: just create a single layer
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
|
||||
@@ -7,17 +7,20 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -46,9 +49,9 @@ func main() {
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
@@ -56,13 +59,11 @@ func main() {
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
@@ -78,6 +79,11 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
@@ -117,60 +123,44 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
@@ -271,6 +261,8 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
@@ -291,3 +283,12 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -108,3 +109,160 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
|
||||
return 1 // Not JPEG
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data[2:])
|
||||
for {
|
||||
var marker [2]byte
|
||||
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
|
||||
return 1
|
||||
}
|
||||
|
||||
if marker[1] == 0xE1 { // APP1 (EXIF)
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen < 14 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
continue
|
||||
}
|
||||
seg := make([]byte, segLen)
|
||||
if _, err := r.Read(seg); err != nil {
|
||||
return 1
|
||||
}
|
||||
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
|
||||
return parseTIFFOrientation(seg[6:])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if marker[1] == 0xD9 || marker[1] == 0xDA {
|
||||
return 1 // EOI or SOS
|
||||
}
|
||||
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
|
||||
continue // RST markers
|
||||
}
|
||||
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen > 0 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseTIFFOrientation(tiff []byte) int {
|
||||
if len(tiff) < 8 {
|
||||
return 1
|
||||
}
|
||||
|
||||
var big bool
|
||||
switch string(tiff[:2]) {
|
||||
case "MM":
|
||||
big = true
|
||||
case "II":
|
||||
big = false
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
|
||||
u16 := func(b []byte) uint16 {
|
||||
if big {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
return uint16(b[1])<<8 | uint16(b[0])
|
||||
}
|
||||
u32 := func(b []byte) uint32 {
|
||||
if big {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
|
||||
}
|
||||
|
||||
if u16(tiff[2:4]) != 42 {
|
||||
return 1
|
||||
}
|
||||
|
||||
ifdOffset := u32(tiff[4:8])
|
||||
if int(ifdOffset)+2 > len(tiff) {
|
||||
return 1
|
||||
}
|
||||
|
||||
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
|
||||
for i := range int(numEntries) {
|
||||
offset := ifdOffset + 2 + uint32(i)*12
|
||||
if int(offset)+12 > len(tiff) {
|
||||
break
|
||||
}
|
||||
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
|
||||
o := int(u16(tiff[offset+8 : offset+10]))
|
||||
if o >= 1 && o <= 8 {
|
||||
return o
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func applyOrientation(img image.Image, orientation int) image.Image {
|
||||
if orientation <= 1 || orientation > 8 {
|
||||
return img
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
outW, outH := w, h
|
||||
if orientation >= 5 {
|
||||
outW, outH = h, w
|
||||
}
|
||||
|
||||
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
var dx, dy int
|
||||
switch orientation {
|
||||
case 2:
|
||||
dx, dy = w-1-x, y
|
||||
case 3:
|
||||
dx, dy = w-1-x, h-1-y
|
||||
case 4:
|
||||
dx, dy = x, h-1-y
|
||||
case 5:
|
||||
dx, dy = y, x
|
||||
case 6:
|
||||
dx, dy = h-1-y, x
|
||||
case 7:
|
||||
dx, dy = h-1-y, w-1-x
|
||||
case 8:
|
||||
dx, dy = y, w-1-x
|
||||
}
|
||||
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
@@ -32,31 +33,15 @@ type ModelManifest struct {
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
// DefaultManifestDir returns the manifest storage directory.
|
||||
// Respects OLLAMA_MODELS.
|
||||
|
||||
func DefaultManifestDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
@@ -175,3 +160,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
26
x/imagegen/manifest_test.go
Normal file
26
x/imagegen/manifest_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
// Simulate packaged/systemd environment
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
t.Setenv("HOME", "/usr/share/ollama")
|
||||
|
||||
// Manifest dir must respect OLLAMA_MODELS
|
||||
wantManifest := filepath.Join(modelsDir, "manifests")
|
||||
if got := DefaultManifestDir(); got != wantManifest {
|
||||
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
|
||||
}
|
||||
|
||||
// Blob dir must respect OLLAMA_MODELS
|
||||
wantBlobs := filepath.Join(modelsDir, "blobs")
|
||||
if got := DefaultBlobDir(); got != wantBlobs {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
@@ -24,9 +24,8 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
@@ -72,31 +71,38 @@ func ResolveModelName(modelName string) string {
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
|
||||
if index.Architecture != "" {
|
||||
return index.Architecture
|
||||
}
|
||||
return index.ClassName
|
||||
}
|
||||
|
||||
@@ -72,9 +72,8 @@ func TestCheckMemoryRequirements(t *testing.T) {
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
@@ -94,13 +93,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
#include "mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
|
||||
6
x/imagegen/mlx/doc.go
Normal file
6
x/imagegen/mlx/doc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
|
||||
//
|
||||
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
|
||||
package mlx
|
||||
439
x/imagegen/mlx/generate_wrappers.go
Normal file
439
x/imagegen/mlx/generate_wrappers.go
Normal file
@@ -0,0 +1,439 @@
|
||||
//go:build ignore
|
||||
|
||||
// This tool generates MLX-C dynamic loading wrappers.
|
||||
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
func findHeaders(directory string) ([]string, error) {
|
||||
var headers []string
|
||||
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return headers, err
|
||||
}
|
||||
|
||||
func cleanContent(content string) string {
|
||||
// Remove single-line comments
|
||||
re := regexp.MustCompile(`//.*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Remove multi-line comments
|
||||
re = regexp.MustCompile(`/\*.*?\*/`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove preprocessor directives (lines starting with #) - use multiline mode
|
||||
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove extern "C" { and } blocks more conservatively
|
||||
// Only remove the extern "C" { line, not the content inside
|
||||
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
// Remove standalone closing braces that are not part of function declarations
|
||||
re = regexp.MustCompile(`\n\s*\}\s*\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Collapse whitespace and newlines
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
content = re.ReplaceAllString(content, " ")
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func extractParamNames(params string) []string {
|
||||
if params == "" || strings.TrimSpace(params) == "void" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var names []string
|
||||
|
||||
// Split by comma, but respect parentheses (for function pointers)
|
||||
parts := splitParams(params)
|
||||
|
||||
// Remove array brackets
|
||||
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
|
||||
|
||||
// Function pointer pattern
|
||||
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
|
||||
|
||||
// Type keywords to skip
|
||||
typeKeywords := map[string]bool{
|
||||
"const": true,
|
||||
"struct": true,
|
||||
"unsigned": true,
|
||||
"signed": true,
|
||||
"long": true,
|
||||
"short": true,
|
||||
"int": true,
|
||||
"char": true,
|
||||
"float": true,
|
||||
"double": true,
|
||||
"void": true,
|
||||
"size_t": true,
|
||||
"uint8_t": true,
|
||||
"uint16_t": true,
|
||||
"uint32_t": true,
|
||||
"uint64_t": true,
|
||||
"int8_t": true,
|
||||
"int16_t": true,
|
||||
"int32_t": true,
|
||||
"int64_t": true,
|
||||
"intptr_t": true,
|
||||
"uintptr_t": true,
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove array brackets
|
||||
part = arrayBrackets.ReplaceAllString(part, "")
|
||||
|
||||
// For function pointers like "void (*callback)(int)"
|
||||
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
|
||||
names = append(names, matches[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular parameter: last identifier
|
||||
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
|
||||
if len(tokens) > 0 {
|
||||
// The last token is usually the parameter name
|
||||
// Skip type keywords
|
||||
for i := len(tokens) - 1; i >= 0; i-- {
|
||||
if !typeKeywords[tokens[i]] {
|
||||
names = append(names, tokens[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
func splitParams(params string) []string {
|
||||
var parts []string
|
||||
var current bytes.Buffer
|
||||
depth := 0
|
||||
|
||||
for _, char := range params + "," {
|
||||
switch char {
|
||||
case '(':
|
||||
depth++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
depth--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func parseFunctions(content string) []Function {
|
||||
var functions []Function
|
||||
|
||||
// Match function declarations: return_type function_name(params);
|
||||
// Matches both mlx_* and _mlx_* functions
|
||||
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
|
||||
|
||||
matches := pattern.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range matches {
|
||||
returnType := strings.TrimSpace(match[1])
|
||||
funcName := strings.TrimSpace(match[2])
|
||||
params := strings.TrimSpace(match[3])
|
||||
|
||||
// Skip if this looks like a variable declaration
|
||||
if params == "" || strings.Contains(params, "{") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up return type
|
||||
returnType = strings.Join(strings.Fields(returnType), " ")
|
||||
|
||||
// Extract parameter names
|
||||
paramNames := extractParamNames(params)
|
||||
|
||||
// Check if ARM64 guard is needed
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
return functions
|
||||
}
|
||||
|
||||
func needsARM64Guard(name, retType, params string) bool {
|
||||
return strings.Contains(name, "float16") ||
|
||||
strings.Contains(name, "bfloat16") ||
|
||||
strings.Contains(retType, "float16_t") ||
|
||||
strings.Contains(retType, "bfloat16_t") ||
|
||||
strings.Contains(params, "float16_t") ||
|
||||
strings.Contains(params, "bfloat16_t")
|
||||
}
|
||||
|
||||
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
|
||||
// Generate header file
|
||||
var headerBuf bytes.Buffer
|
||||
|
||||
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
|
||||
headerBuf.WriteString("//\n")
|
||||
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
|
||||
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
|
||||
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
|
||||
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
|
||||
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
|
||||
|
||||
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
|
||||
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
headerBuf.WriteString("#include <stdio.h>\n\n")
|
||||
|
||||
// Undef all MLX functions to avoid conflicts
|
||||
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
|
||||
for _, fn := range functions {
|
||||
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Function pointer extern declarations
|
||||
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Initialization function declaration
|
||||
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
|
||||
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
|
||||
|
||||
// Wrapper function declarations
|
||||
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
|
||||
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
|
||||
|
||||
// Write header file
|
||||
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write header file: %w", err)
|
||||
}
|
||||
|
||||
// Generate implementation file
|
||||
var implBuf bytes.Buffer
|
||||
|
||||
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
|
||||
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
|
||||
|
||||
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n\n")
|
||||
|
||||
// Function pointer definitions
|
||||
implBuf.WriteString("// Function pointer definitions\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
|
||||
// Initialization function
|
||||
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
|
||||
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
|
||||
implBuf.WriteString(" if (handle == NULL) {\n")
|
||||
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n\n")
|
||||
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
|
||||
implBuf.WriteString(" return 0;\n")
|
||||
implBuf.WriteString("}\n\n")
|
||||
|
||||
// Wrapper function implementations
|
||||
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
|
||||
// Call through function pointer
|
||||
if fn.ReturnType != "void" {
|
||||
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
|
||||
} else {
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
|
||||
}
|
||||
|
||||
// Pass parameters
|
||||
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
|
||||
implBuf.WriteString(");\n")
|
||||
implBuf.WriteString("}\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write implementation file
|
||||
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write implementation file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
headerDir := args[0]
|
||||
outputHeader := args[1]
|
||||
// Default implementation file is same name with .c extension
|
||||
outputImpl := outputHeader
|
||||
if len(args) > 2 {
|
||||
outputImpl = args[2]
|
||||
} else if strings.HasSuffix(outputHeader, ".h") {
|
||||
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
|
||||
}
|
||||
|
||||
// Check if header directory exists
|
||||
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
|
||||
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
|
||||
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
|
||||
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
|
||||
|
||||
// Find all headers
|
||||
headers, err := findHeaders(headerDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
|
||||
|
||||
// Parse all headers
|
||||
var allFunctions []Function
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, header := range headers {
|
||||
content, err := os.ReadFile(header)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
|
||||
continue
|
||||
}
|
||||
|
||||
cleaned := cleanContent(string(content))
|
||||
functions := parseFunctions(cleaned)
|
||||
|
||||
// Deduplicate
|
||||
for _, fn := range functions {
|
||||
if !seen[fn.Name] {
|
||||
seen[fn.Name] = true
|
||||
allFunctions = append(allFunctions, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
|
||||
|
||||
// Generate wrapper files
|
||||
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
|
||||
}
|
||||
5786
x/imagegen/mlx/mlx.c
Normal file
5786
x/imagegen/mlx/mlx.c
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user