mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 01:21:26 -05:00
Compare commits
25 Commits
rmdisplayl
...
bmizerany/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfd4152eb6 | ||
|
|
0fbb379373 | ||
|
|
7afb2e125a | ||
|
|
41a272de9f | ||
|
|
f335722275 | ||
|
|
6d53b67c2c | ||
|
|
969238b19e | ||
|
|
949d7832cf | ||
|
|
99d227c9db | ||
|
|
a27e419b47 | ||
|
|
e4d0db5a97 | ||
|
|
ba460802c2 | ||
|
|
e54a3c7fcd | ||
|
|
9f8691c6c8 | ||
|
|
a0b8a32eb4 | ||
|
|
7027f264fb | ||
|
|
9bee3b63b1 | ||
|
|
309aef7fee | ||
|
|
08655170aa | ||
|
|
2b341069a7 | ||
|
|
c00fee6936 | ||
|
|
c2d813bdc3 | ||
|
|
786f3a1c44 | ||
|
|
3397eff0cd | ||
|
|
0efb7931c7 |
@@ -42,7 +42,7 @@ ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN mkdir /tmp/scratch && \
|
||||
for dep in $(cat /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/x86_64/rocm*/lib/deps.txt) ; do \
|
||||
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
||||
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
||||
done && \
|
||||
(cd /opt/rocm/lib && tar cf - rocblas/library) | (cd /tmp/scratch/ && tar xf - ) && \
|
||||
|
||||
@@ -64,6 +64,7 @@ Here are some example models that can be downloaded:
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
||||
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
|
||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
@@ -316,7 +317,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Database
|
||||
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
|
||||
### Package managers
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -87,19 +86,29 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
// Re-wire context done behavior to attempt a graceful shutdown of the server
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
cmd.Process.Signal(os.Interrupt) //nolint:errcheck
|
||||
err := terminate(cmd)
|
||||
if err != nil {
|
||||
slog.Warn("error trying to gracefully terminate server", "err", err)
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
|
||||
tick := time.NewTicker(10 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
// OS agnostic "is it still running"
|
||||
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
return nil //nolint:nilerr
|
||||
exited, err := isProcessExited(cmd.Process.Pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exited {
|
||||
return nil
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
|
||||
cmd.Process.Kill() //nolint:errcheck
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,35 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, cmd string) *exec.Cmd {
|
||||
return exec.CommandContext(ctx, cmd, "serve")
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
return cmd.Process.Signal(os.Interrupt)
|
||||
}
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to find process: %v", err)
|
||||
}
|
||||
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("error signaling process: %v", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -2,12 +2,88 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, exePath string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, exePath, "serve")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000}
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
dll, err := windows.LoadDLL("kernel32.dll")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dll.Release() // nolint: errcheck
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
f, err := dll.FindProc("AttachConsole")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := f.Call(uintptr(pid))
|
||||
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("SetConsoleCtrlHandler")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(0, 1)
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const STILL_ACTIVE = 259
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open process: %v", err)
|
||||
}
|
||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
||||
|
||||
var exitCode uint32
|
||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get exit code: %v", err)
|
||||
}
|
||||
|
||||
if exitCode == STILL_ACTIVE {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
51
cmd/cmd.go
51
cmd/cmd.go
@@ -105,24 +105,48 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
zf := zip.NewWriter(tf)
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
||||
files := []string{}
|
||||
|
||||
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(tfiles) == 0 {
|
||||
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, tfiles...)
|
||||
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("no safetensors files were found in '%s'", path)
|
||||
return fmt.Errorf("no models were found in '%s'", path)
|
||||
}
|
||||
|
||||
// add the safetensor config file + tokenizer
|
||||
// add the safetensor/torch config file + tokenizer
|
||||
files = append(files, filepath.Join(path, "config.json"))
|
||||
files = append(files, filepath.Join(path, "params.json"))
|
||||
files = append(files, filepath.Join(path, "added_tokens.json"))
|
||||
files = append(files, filepath.Join(path, "tokenizer.model"))
|
||||
|
||||
for _, fn := range files {
|
||||
f, err := os.Open(fn)
|
||||
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
|
||||
continue
|
||||
|
||||
// just skip whatever files aren't there
|
||||
if os.IsNotExist(err) {
|
||||
if strings.HasSuffix(fn, "tokenizer.model") {
|
||||
// try the parent dir before giving up
|
||||
parentDir := filepath.Dir(path)
|
||||
newFn := filepath.Join(parentDir, "tokenizer.model")
|
||||
f, err = os.Open(newFn)
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -228,14 +252,6 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if os.Getenv("OLLAMA_MODELS") != "" {
|
||||
return errors.New("OLLAMA_MODELS must only be set for 'ollama serve'")
|
||||
}
|
||||
|
||||
if err := checkServerHeartbeat(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -962,10 +978,11 @@ func NewCLI() *cobra.Command {
|
||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunHandler,
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
@@ -45,157 +40,45 @@ type ByteOrder interface {
|
||||
binary.AppendByteOrder
|
||||
}
|
||||
|
||||
type MetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
}
|
||||
|
||||
type ModelFormat interface {
|
||||
GetLayerName(string) (string, error)
|
||||
GetTensors(string, *Params) ([]llm.Tensor, error)
|
||||
GetParams(string) (*Params, error)
|
||||
GetModelArch(string, string, *Params) (ModelArch, error)
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
Path string
|
||||
Name string
|
||||
Params *Params
|
||||
Vocab *Vocab
|
||||
Tensors []llm.Tensor
|
||||
Format ModelFormat
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data MetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := GetTensorName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("%v", t))
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
}
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
func GetModelFormat(dirname string) (ModelFormat, error) {
|
||||
files, err := filepath.Glob(filepath.Join(dirname, "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = ReadSafeTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
for _, fn := range files {
|
||||
slog.Debug(fmt.Sprintf("file = %s", fn))
|
||||
if strings.HasSuffix(fn, ".safetensors") {
|
||||
return &SafetensorFormat{}, nil
|
||||
} else if strings.HasSuffix(fn, ".bin") {
|
||||
slog.Debug("model is torch")
|
||||
return &TorchFormat{}, nil
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
func GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
return nil, fmt.Errorf("couldn't determine model format")
|
||||
}
|
||||
|
||||
// Details on gguf's tokenizer can be found at:
|
||||
@@ -206,7 +89,7 @@ type Vocab struct {
|
||||
Types []int32
|
||||
}
|
||||
|
||||
func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
|
||||
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
|
||||
if err != nil {
|
||||
@@ -286,8 +169,8 @@ func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
}
|
||||
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
|
||||
|
||||
if vocabSize > len(v.Tokens) {
|
||||
missingTokens := vocabSize - len(v.Tokens)
|
||||
if params.VocabSize > len(v.Tokens) {
|
||||
missingTokens := params.VocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
@@ -298,136 +181,3 @@ func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func GetTensorName(n string) (string, error) {
|
||||
tMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
v, ok := tMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
filename string
|
||||
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
|
||||
@@ -65,13 +65,14 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
}
|
||||
|
||||
func (m *GemmaModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
for _, l := range t {
|
||||
if strings.HasSuffix(l.Name, "norm.weight") {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
@@ -85,7 +86,7 @@ func (m *GemmaModel) GetTensors() error {
|
||||
}
|
||||
|
||||
func (m *GemmaModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
176
convert/llama.go
Normal file
176
convert/llama.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type LlamaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
||||
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
|
||||
var err error
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("heads = %d", heads))
|
||||
|
||||
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||
wt := l.WriterTo.(torchWriterTo)
|
||||
wt.handler = llamaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) LoadVocab() error {
|
||||
var v *Vocab
|
||||
var err error
|
||||
|
||||
slog.Debug("loading vocab")
|
||||
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("vocab loaded")
|
||||
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
}
|
||||
|
||||
func (m *MistralModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,7 +124,7 @@ func (m *MistralModel) GetTensors() error {
|
||||
}
|
||||
|
||||
func (m *MistralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
304
convert/safetensors.go
Normal file
304
convert/safetensors.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
filename string
|
||||
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
type tensorMetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
type SafetensorFormat struct{}
|
||||
|
||||
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting tensor data")
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = m.readTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data tensorMetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
slog.Error("couldn't decode properly")
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("metadata = %#v", data))
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := m.GetLayerName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
|
||||
slog.Debug(fmt.Sprintf("offset = %d", offset))
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
tMap := map[string]string{
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
286
convert/torch.go
Normal file
286
convert/torch.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/nlpodyssey/gopickle/types"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type torchWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
storage pytorch.StorageInterface
|
||||
handler func(w io.Writer, r torchWriterTo) error
|
||||
}
|
||||
|
||||
type TorchFormat struct{}
|
||||
|
||||
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting torch tensors")
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
|
||||
if err != nil {
|
||||
slog.Error("didn't find any torch files")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, fn := range files {
|
||||
m, err := pytorch.Load(fn)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error unpickling: %q", err))
|
||||
return []llm.Tensor{}, err
|
||||
}
|
||||
|
||||
for _, k := range m.(*types.Dict).Keys() {
|
||||
if strings.HasSuffix(k.(string), "self_attn.rotary_emb.inv_freq") {
|
||||
continue
|
||||
}
|
||||
|
||||
t, _ := m.(*types.Dict).Get(k)
|
||||
tshape := t.(*pytorch.Tensor).Size
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(tshape) {
|
||||
case 0:
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(tshape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(tshape[0] * tshape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := tf.GetLayerName(k.(string))
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range tshape {
|
||||
shape[i] = uint64(tshape[i])
|
||||
}
|
||||
|
||||
tensor := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset, // calculate the offset
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
tensor.WriterTo = torchWriterTo{
|
||||
t: &tensor,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
storage: t.(*pytorch.Tensor).Source,
|
||||
}
|
||||
|
||||
tensors = append(tensors, tensor)
|
||||
offset += size
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
|
||||
}
|
||||
|
||||
func getAltParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "params.json"))
|
||||
if err != nil {
|
||||
slog.Error("no params.json")
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
type TorchParams struct {
|
||||
HiddenSize int `json:"dim"`
|
||||
AttentionHeads int `json:"n_heads"`
|
||||
KeyValHeads int `json:"n_kv_heads"`
|
||||
HiddenLayers int `json:"n_layers"`
|
||||
RopeTheta int `json:"rope_theta"`
|
||||
NormEPS float64 `json:"norm_eps"`
|
||||
}
|
||||
|
||||
var tparams TorchParams
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(&tparams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := &Params{
|
||||
HiddenSize: tparams.HiddenSize,
|
||||
AttentionHeads: tparams.AttentionHeads,
|
||||
KeyValHeads: tparams.KeyValHeads,
|
||||
HiddenLayers: tparams.HiddenLayers,
|
||||
NormEPS: tparams.NormEPS,
|
||||
}
|
||||
|
||||
switch {
|
||||
case tparams.RopeTheta == 1000000:
|
||||
// Codellama
|
||||
params.ContextSize = 16384
|
||||
case tparams.NormEPS == 1e-06:
|
||||
// llama2
|
||||
slog.Debug("Found llama2 - setting context size to 4096")
|
||||
params.ContextSize = 4096
|
||||
default:
|
||||
params.ContextSize = 2048
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// try params.json instead
|
||||
return getAltParams(dirpath)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var params Params
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"tok_embeddings.weight": "token_embd.weight",
|
||||
"output.weight": "output.weight",
|
||||
"norm.weight": "output_norm.weight",
|
||||
"rope.freqs": "rope_freqs.weight",
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
lMap := map[string]string{
|
||||
"layers.(\\d+).attention_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).attention_output_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).feed_forward.w2.weight": "blk.$1.ffn_down.weight",
|
||||
"layers.(\\d+).feed_forward.w1.weight": "blk.$1.ffn_gate.weight",
|
||||
"layers.(\\d+).feed_forward.w3.weight": "blk.$1.ffn_up.weight",
|
||||
"layers.(\\d+).ffn_norm.weight": "blk.$1.ffn_norm.weight",
|
||||
"layers.(\\d+).attention.wk.weight": "blk.$1.attn_k.weight",
|
||||
"layers.(\\d+).attention.wo.weight": "blk.$1.attn_output.weight",
|
||||
"layers.(\\d+).attention.wq.weight": "blk.$1.attn_q.weight",
|
||||
"layers.(\\d+).attention.wv.weight": "blk.$1.attn_v.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range lMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r)
|
||||
}
|
||||
|
||||
switch r.storage.(type) {
|
||||
case *pytorch.FloatStorage:
|
||||
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name))
|
||||
return 0, nil
|
||||
case *pytorch.HalfStorage:
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
|
||||
if err := binary.Write(w, r.bo, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
|
||||
if err := binary.Write(w, r.bo, tData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
return &LlamaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
@@ -139,9 +139,6 @@ PARAMETER <parameter> <parametervalue>
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_gqa | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | int | num_gqa 1 |
|
||||
| num_gpu | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 50 |
|
||||
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -18,7 +18,7 @@ const ollama = new Ollama({
|
||||
model: "llama2",
|
||||
});
|
||||
|
||||
const answer = await ollama.call(`why is the sky blue?`);
|
||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
51
examples/go-chat/main.go
Normal file
51
examples/go-chat/main.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
api.Message{
|
||||
Role: "system",
|
||||
Content: "Provide very brief, concise responses",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "Name some unusual animals",
|
||||
},
|
||||
api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Monotreme, platypus, echidna",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "which of these is the most dangerous?",
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := &api.ChatRequest{
|
||||
Model: "llama2",
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
respFunc := func(resp api.ChatResponse) error {
|
||||
fmt.Print(resp.Message.Content)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Chat(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
47
examples/go-multimodal/main.go
Normal file
47
examples/go-multimodal/main.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) <= 1 {
|
||||
log.Fatal("usage: <image name>")
|
||||
}
|
||||
|
||||
imgData, err := os.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "llava",
|
||||
Prompt: "describe this image",
|
||||
Images: []api.ImageData{imgData},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
31
examples/go-pull-progress/main.go
Normal file
31
examples/go-pull-progress/main.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
req := &api.PullRequest{
|
||||
Model: "mistral",
|
||||
}
|
||||
progressFunc := func(resp api.ProgressResponse) error {
|
||||
fmt.Printf("Progress: status=%v, total=%v, completed=%v\n", resp.Status, resp.Total, resp.Completed)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Pull(ctx, req, progressFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
7
go.mod
7
go.mod
@@ -19,7 +19,10 @@ require (
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
|
||||
@@ -68,7 +71,7 @@ require (
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
6
go.sum
6
go.sum
@@ -122,6 +122,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 h1:DV4iXjNn6fGeDl1AkZ1I0QB/0DBjrc7kPpxHrmuDzW4=
|
||||
@@ -236,8 +238,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -32,6 +32,7 @@ func CheckVRAM() (uint64, error) {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
||||
}
|
||||
|
||||
@@ -52,7 +53,7 @@ func GetGPUInfo() GpuInfo {
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: 0,
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
FreeMemory: 0,
|
||||
DeviceCount: 0,
|
||||
}, nil
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#import <Metal/Metal.h>
|
||||
#include <stdint.h>
|
||||
uint64_t getRecommendedMaxVRAM();
|
||||
uint64_t getPhysicalMemory();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build darwin
|
||||
// go:build darwin
|
||||
#include "gpu_info_darwin.h"
|
||||
|
||||
uint64_t getRecommendedMaxVRAM()
|
||||
{
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
uint64_t getRecommendedMaxVRAM() {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t getPhysicalMemory() {
|
||||
return [[NSProcessInfo processInfo] physicalMemory];
|
||||
}
|
||||
|
||||
13
llm/ggml.go
13
llm/ggml.go
@@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
@@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
|
||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||
partialOffload = max(
|
||||
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||
)
|
||||
}
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
@@ -350,7 +361,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
|
||||
20
llm/gguf.go
20
llm/gguf.go
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
@@ -52,6 +54,7 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
|
||||
}
|
||||
|
||||
model := newGGUF(c)
|
||||
slog.Debug(fmt.Sprintf("model = %#v", model))
|
||||
if err := model.Decode(rs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -187,6 +190,8 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
llm.kv[k] = v
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
|
||||
|
||||
// decode tensors
|
||||
for i := 0; uint64(i) < llm.numTensor(); i++ {
|
||||
name, err := readGGUFString(llm, rs)
|
||||
@@ -243,7 +248,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
}
|
||||
|
||||
padding := llm.padding(offset, int64(alignment))
|
||||
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
||||
if _, err := rs.Seek(padding-offset, io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -451,6 +456,7 @@ var ggufKVOrder = map[string][]string{
|
||||
"llama": {
|
||||
"general.architecture",
|
||||
"general.name",
|
||||
"llama.vocab_size",
|
||||
"llama.context_length",
|
||||
"llama.embedding_length",
|
||||
"llama.block_count",
|
||||
@@ -509,11 +515,17 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
kvCheck := make(map[string]bool)
|
||||
for k := range kv {
|
||||
kvCheck[k] = false
|
||||
}
|
||||
|
||||
for _, k := range ggufKVOrder["llama"] {
|
||||
v, ok := kv[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
kvCheck[k] = true
|
||||
|
||||
if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
|
||||
return err
|
||||
@@ -567,6 +579,12 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range kvCheck {
|
||||
if !v {
|
||||
return fmt.Errorf("Didn't know how to write kv %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tensor := range tensors {
|
||||
if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
|
||||
return err
|
||||
|
||||
Submodule llm/llama.cpp updated: 1b67731e18...7593639ce3
@@ -17,7 +17,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -36,10 +35,6 @@ type LlamaServer struct {
|
||||
options api.Options
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
|
||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
@@ -91,7 +86,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
|
||||
if info.Library != "metal" {
|
||||
if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) {
|
||||
if memoryRequiredPartial > memoryAvailable {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
}
|
||||
@@ -113,7 +108,11 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
|
||||
memoryLayerOutput := layers["output"].size()
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
if memoryAvailable > memoryRequiredTotal {
|
||||
|
||||
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
|
||||
// disable partial offloading when model is greater than total system memory
|
||||
opts.NumGPU = 0
|
||||
} else if memoryAvailable > memoryRequiredTotal {
|
||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||
memoryRequiredPartial = memoryRequiredTotal
|
||||
}
|
||||
@@ -277,12 +276,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
if err = s.waitUntilRunning(); err != nil {
|
||||
slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
s.Close()
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -383,7 +376,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) waitUntilRunning() error {
|
||||
func (s *LlamaServer) WaitUntilRunning() error {
|
||||
start := time.Now()
|
||||
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
|
||||
@@ -322,7 +322,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
|
||||
pathName := realpath(modelFileDir, c.Args)
|
||||
|
||||
ggufName, err := convertSafetensors(name, pathName, fn)
|
||||
ggufName, err := convertModel(name, pathName, fn)
|
||||
if err != nil {
|
||||
var pathErr *fs.PathError
|
||||
switch {
|
||||
@@ -633,7 +633,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
r, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -668,17 +668,22 @@ func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (
|
||||
rc.Close()
|
||||
}
|
||||
|
||||
params, err := convert.GetParams(tempDir)
|
||||
mf, err := convert.GetModelFormat(tempDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
|
||||
params, err := mf.GetParams(tempDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "processing safetensors"})
|
||||
mArch, err := mf.GetModelArch(name, tempDir, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "processing tensors"})
|
||||
if err := mArch.GetTensors(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -68,6 +68,18 @@ var loaded struct {
|
||||
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
func unload() {
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(c, 10*time.Second)
|
||||
@@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
if needLoad {
|
||||
if loaded.llama != nil {
|
||||
slog.Info("changing loaded model")
|
||||
loaded.llama.Close()
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
unload()
|
||||
}
|
||||
|
||||
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
@@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
loaded.projectors = model.ProjectorPaths
|
||||
loaded.llama = llama
|
||||
loaded.Options = &opts
|
||||
|
||||
if err = llama.WaitUntilRunning(); err != nil {
|
||||
slog.Error("error loading llama server", "error", err)
|
||||
unload()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
unload()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1146,9 +1150,7 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
unload()
|
||||
gpu.Cleanup()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -51,8 +48,11 @@ var (
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if !ok {
|
||||
typ, digest, ok = strings.Cut(s, ":")
|
||||
}
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
return Digest{s: fmt.Sprintf("%s-%s", typ, digest)}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ package model
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -25,11 +27,17 @@ var (
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
// DefaultMask is the default mask used by [Name.DisplayShortest].
|
||||
DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
// MaskDefault is the default mask used by [Name.DisplayShortest].
|
||||
MaskDefault = "registry.ollama.ai/library/?:latest"
|
||||
|
||||
// MaskNothing is a mask that masks nothing.
|
||||
MaskNothing = "?/?/?:?"
|
||||
|
||||
// DefaultFill is the default fill used by [ParseName].
|
||||
DefaultFill = "registry.ollama.ai/library/_:latest"
|
||||
FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
|
||||
|
||||
// FillNothing is a fill that fills nothing.
|
||||
FillNothing = "?/?/?:?+?"
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
@@ -47,11 +55,11 @@ const (
|
||||
PartBuild
|
||||
PartDigest
|
||||
|
||||
// Invalid is a special part that is used to indicate that a part is
|
||||
// invalid. It is not a valid part of a Name.
|
||||
//
|
||||
// It should be kept as the last part in the list.
|
||||
PartInvalid
|
||||
// NumParts is the number of parts in a Name. In this list, it must
|
||||
// follow the final part.
|
||||
NumParts
|
||||
|
||||
PartExtraneous = -1
|
||||
)
|
||||
|
||||
var kindNames = map[PartKind]string{
|
||||
@@ -61,7 +69,6 @@ var kindNames = map[PartKind]string{
|
||||
PartTag: "Tag",
|
||||
PartBuild: "Build",
|
||||
PartDigest: "Digest",
|
||||
PartInvalid: "Invalid",
|
||||
}
|
||||
|
||||
func (k PartKind) String() string {
|
||||
@@ -96,11 +103,9 @@ func (k PartKind) String() string {
|
||||
// The parts can be obtained in their original form by calling [Name.Parts].
|
||||
//
|
||||
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
|
||||
//
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [6]string // host, namespace, model, tag, build, digest
|
||||
parts [NumParts]string // host, namespace, model, tag, build, digest
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offsets all into a single uint64 since the first
|
||||
@@ -109,7 +114,7 @@ type Name struct {
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseNameFill parses s into a Name, and returns the result of filling it with
|
||||
// ParseName parses s into a Name, and returns the result of filling it with
|
||||
// defaults. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
@@ -139,19 +144,19 @@ type Name struct {
|
||||
//
|
||||
// It returns the zero value if any part is invalid.
|
||||
//
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
// # Fills
|
||||
//
|
||||
// For more about filling in missing parts, see [Fill].
|
||||
func ParseNameFill(s, defaults string) Name {
|
||||
// For any valid s, the fill string is used to fill in missing parts of the
|
||||
// Name. The fill string must be a valid Name with the exception that any part
|
||||
// may be the string ("?"), which will not be considered for filling.
|
||||
func ParseName(s, fill string) Name {
|
||||
var r Name
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if kind == PartInvalid {
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
if kind == PartExtraneous || !isValidPart(kind, part) {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
@@ -159,34 +164,51 @@ func ParseNameFill(s, defaults string) Name {
|
||||
return true
|
||||
})
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
if defaults == "" {
|
||||
return r
|
||||
}
|
||||
return Fill(r, ParseNameFill(defaults, ""))
|
||||
return fillName(r, fill)
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
// ParseName is equal to ParseNameFill(s, DefaultFill).
|
||||
func ParseName(s string) Name {
|
||||
return ParseNameFill(s, DefaultFill)
|
||||
func parseMask(s string) Name {
|
||||
var r Name
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if part == "?" {
|
||||
// mask part; treat as empty but valid
|
||||
return true
|
||||
}
|
||||
if !isValidPart(kind, part) {
|
||||
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
|
||||
}
|
||||
r.parts[kind] = part
|
||||
return true
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func MustParseNameFill(s, defaults string) Name {
|
||||
r := ParseNameFill(s, "")
|
||||
func MustParseName(s, fill string) Name {
|
||||
r := ParseName(s, fill)
|
||||
if !r.IsValid() {
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
panic("invalid Name: " + s)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Fill fills in the missing parts of dst with the parts of src.
|
||||
// fillName fills in the missing parts of dst with the parts of src.
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
func Fill(dst, src Name) Name {
|
||||
var r Name
|
||||
//
|
||||
// It skipps fill parts that are "?".
|
||||
func fillName(r Name, fill string) Name {
|
||||
fill = cmp.Or(fill, FillDefault)
|
||||
f := parseMask(fill)
|
||||
if fill != FillNothing && f.IsZero() {
|
||||
panic("invalid fill")
|
||||
}
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
if f.parts[i] == "?" {
|
||||
continue
|
||||
}
|
||||
r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -212,7 +234,7 @@ func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.Parts() {
|
||||
for _, part := range r.parts {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
@@ -231,32 +253,59 @@ func (r Name) slice(from, to PartKind) Name {
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayShortest returns the shortest possible display string in form:
|
||||
// DisplayShortest returns the shortest possible, masked display string in form:
|
||||
//
|
||||
// [host/][<namespace>/]<model>[:<tag>]
|
||||
//
|
||||
// The host is omitted if it is the mask host is the same as r.
|
||||
// The namespace is omitted if the host and the namespace are the same as r.
|
||||
// The tag is omitted if it is the mask tag is the same as r.
|
||||
// # Masks
|
||||
//
|
||||
// The mask is a string that specifies which parts of the name to omit based
|
||||
// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
|
||||
// that are the same as the mask, moving from left to right until the first
|
||||
// unequal part is found. It then moves right to left until the first unequal
|
||||
// part is found. The result is the shortest possible display string.
|
||||
//
|
||||
// Unlike a [Name] the mask can contain "?" characters which are treated as
|
||||
// wildcards. A "?" will never match a part of the name, since a valid name
|
||||
// can never contain a "?" character.
|
||||
//
|
||||
// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
|
||||
// with ("registry.ollama.ai/library/?:latest") will produce the display string
|
||||
// ("mistral").
|
||||
//
|
||||
// If mask is the empty string, then [MaskDefault] is used.
|
||||
//
|
||||
// DisplayShortest panics if the mask is not the empty string, MaskNothing, and
|
||||
// invalid.
|
||||
//
|
||||
// # Builds
|
||||
//
|
||||
// For now, DisplayShortest does consider the build or return one in the
|
||||
// result. We can lift this restriction when needed.
|
||||
func (r Name) DisplayShortest(mask string) string {
|
||||
mask = cmp.Or(mask, DefaultMask)
|
||||
d := ParseName(mask)
|
||||
if !d.IsValid() {
|
||||
panic("mask is an invalid Name")
|
||||
mask = cmp.Or(mask, MaskDefault)
|
||||
d := parseMask(mask)
|
||||
if mask != MaskNothing && r.IsZero() {
|
||||
panic("invalid Name")
|
||||
}
|
||||
equalSlice := func(form, to PartKind) bool {
|
||||
return r.slice(form, to).EqualFold(d.slice(form, to))
|
||||
for i := range PartTag {
|
||||
if !strings.EqualFold(r.parts[i], d.parts[i]) {
|
||||
break
|
||||
}
|
||||
r.parts[i] = ""
|
||||
}
|
||||
if equalSlice(PartHost, PartNamespace) {
|
||||
r.parts[PartNamespace] = ""
|
||||
for i := PartTag; i >= 0; i-- {
|
||||
if !strings.EqualFold(r.parts[i], d.parts[i]) {
|
||||
break
|
||||
}
|
||||
r.parts[i] = ""
|
||||
}
|
||||
if equalSlice(PartHost, PartHost) {
|
||||
r.parts[PartHost] = ""
|
||||
}
|
||||
if equalSlice(PartTag, PartTag) {
|
||||
r.parts[PartTag] = ""
|
||||
}
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
return r.slice(PartHost, PartTag).DisplayLong()
|
||||
}
|
||||
|
||||
// DisplayLongest returns the result of r.DisplayShortest(MaskNothing).
|
||||
func (r Name) DisplayLongest() string {
|
||||
return r.DisplayShortest(MaskNothing)
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
@@ -303,15 +352,12 @@ var builderPool = sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
func (r Name) DisplayLong() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
@@ -321,14 +367,14 @@ func (r Name) String() string {
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer. It returns a string suitable for
|
||||
// debugging and logging. It is similar to [Name.String] but it always
|
||||
// debugging and logging. It is similar to [Name.DisplayLong] but it always
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return r.String()
|
||||
return r.DisplayLong()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
@@ -393,14 +439,11 @@ func downcase(r rune) rune {
|
||||
return r
|
||||
}
|
||||
|
||||
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
|
||||
|
||||
// Parts returns the parts of the Name in order of concreteness.
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
func (r Name) Host() string { return r.parts[PartHost] }
|
||||
func (r Name) Namespace() string { return r.parts[PartNamespace] }
|
||||
func (r Name) Model() string { return r.parts[PartModel] }
|
||||
func (r Name) Build() string { return r.parts[PartBuild] }
|
||||
func (r Name) Tag() string { return r.parts[PartTag] }
|
||||
|
||||
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
|
||||
// restrictions in the go1.22 iter package requiring the
|
||||
@@ -418,27 +461,16 @@ type iter_Seq2[A, B any] func(func(A, B) bool)
|
||||
// No other normalizations are performed.
|
||||
func parts(s string) iter_Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
s = strings.TrimPrefix(s, "http://")
|
||||
} else {
|
||||
s = strings.TrimPrefix(s, "https://")
|
||||
}
|
||||
|
||||
if len(s) > MaxNamePartLen || len(s) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
yieldValid := func(kind PartKind, part string) bool {
|
||||
if !isValidPart(kind, part) {
|
||||
yield(PartInvalid, "")
|
||||
return false
|
||||
}
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
numConsecutiveDots := 0
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
@@ -448,7 +480,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
// we don't keep spinning on it, waiting for
|
||||
// an isInValidPart check which would scan
|
||||
// over it again.
|
||||
yield(PartInvalid, "")
|
||||
yield(state, s[i+1:j])
|
||||
return
|
||||
}
|
||||
|
||||
@@ -456,7 +488,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
case '@':
|
||||
switch state {
|
||||
case PartDigest:
|
||||
if !yieldValid(PartDigest, s[i+1:j]) {
|
||||
if !yield(PartDigest, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
if i == 0 {
|
||||
@@ -468,73 +500,69 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
}
|
||||
state, j, partLen = PartBuild, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
return
|
||||
}
|
||||
case '+':
|
||||
switch state {
|
||||
case PartBuild, PartDigest:
|
||||
if !yieldValid(PartBuild, s[i+1:j]) {
|
||||
if !yield(PartBuild, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartTag, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
return
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartTag, s[i+1:j]) {
|
||||
if !yield(PartTag, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartModel, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
return
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case PartModel, PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartModel, s[i+1:j]) {
|
||||
if !yield(PartModel, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j = PartNamespace, i
|
||||
case PartNamespace:
|
||||
if !yieldValid(PartNamespace, s[i+1:j]) {
|
||||
if !yield(PartNamespace, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartHost, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
return
|
||||
}
|
||||
default:
|
||||
if s[i] == '.' {
|
||||
if numConsecutiveDots++; numConsecutiveDots > 1 {
|
||||
yield(PartInvalid, "")
|
||||
yield(state, "")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
numConsecutiveDots = 0
|
||||
}
|
||||
if !isValidByteFor(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state <= PartNamespace {
|
||||
yieldValid(state, s[:j])
|
||||
yield(state, s[:j])
|
||||
} else {
|
||||
yieldValid(PartModel, s[:j])
|
||||
yield(PartModel, s[:j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r Name) IsZero() bool {
|
||||
return r.parts == [6]string{}
|
||||
return r.parts == [NumParts]string{}
|
||||
}
|
||||
|
||||
// IsValid reports if a model has at minimum a valid model part.
|
||||
@@ -544,13 +572,101 @@ func (r Name) IsValid() bool {
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically,
|
||||
// it trims any leading "/" and then calls [ParseName] with fill.
|
||||
func ParseNameFromURLPath(s, fill string) Name {
|
||||
s = strings.TrimPrefix(s, "/")
|
||||
return ParseName(s, fill)
|
||||
}
|
||||
|
||||
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
|
||||
// complete Name.
|
||||
//
|
||||
// The parts maintain their original case.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
|
||||
func (r Name) URLPath() string {
|
||||
return r.DisplayShortest(MaskNothing)
|
||||
}
|
||||
|
||||
// ParseNameFromFilepath parses a file path into a Name. The input string must be a
|
||||
// valid file path representation of a model name in the form:
|
||||
//
|
||||
// host/namespace/model/tag/build
|
||||
//
|
||||
// The zero valid is returned if s does not contain all path elements
|
||||
// leading up to the model part, or if any path element is an invalid part
|
||||
// for the its corresponding part kind.
|
||||
//
|
||||
// The fill string is used to fill in missing parts of any constructed Name.
|
||||
// See [ParseName] for more information on the fill string.
|
||||
func ParseNameFromFilepath(s, fill string) Name {
|
||||
var r Name
|
||||
for i := range PartBuild + 1 {
|
||||
part, rest, _ := strings.Cut(s, string(filepath.Separator))
|
||||
if !isValidPart(i, part) {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[i] = part
|
||||
s = rest
|
||||
if s == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if s != "" {
|
||||
return Name{}
|
||||
}
|
||||
if !r.IsValid() {
|
||||
return Name{}
|
||||
}
|
||||
return fillName(r, fill)
|
||||
}
|
||||
|
||||
// Filepath returns a complete, canonicalized, relative file path using the
|
||||
// parts of a complete Name.
|
||||
//
|
||||
// Each parts is downcased, except for the build part which is upcased.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD"
|
||||
func (r Name) Filepath() string {
|
||||
for i := range r.parts {
|
||||
if PartKind(i) == PartBuild {
|
||||
r.parts[i] = strings.ToUpper(r.parts[i])
|
||||
} else {
|
||||
r.parts[i] = strings.ToLower(r.parts[i])
|
||||
}
|
||||
}
|
||||
return filepath.Join(r.parts[:]...)
|
||||
}
|
||||
|
||||
// FilepathNoBuild returns a complete, canonicalized, relative file path using
|
||||
// the parts of a complete Name, but without the build part.
|
||||
func (r Name) FilepathNoBuild() string {
|
||||
for i := range PartBuild {
|
||||
r.parts[i] = strings.ToLower(r.parts[i])
|
||||
}
|
||||
return filepath.Join(r.parts[:PartBuild]...)
|
||||
}
|
||||
|
||||
// isValidPart reports if s contains all valid characters for the given
|
||||
// part kind.
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
var consecutiveDots int
|
||||
for _, c := range []byte(s) {
|
||||
if c == '.' {
|
||||
if consecutiveDots++; consecutiveDots >= 2 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
consecutiveDots = 0
|
||||
}
|
||||
if !isValidByteFor(kind, c) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -111,11 +112,11 @@ func TestNameConsecutiveDots(t *testing.T) {
|
||||
for i := 1; i < 10; i++ {
|
||||
s := strings.Repeat(".", i)
|
||||
if i > 1 {
|
||||
if g := ParseNameFill(s, "").String(); g != "" {
|
||||
if g := ParseName(s, FillNothing).DisplayLong(); g != "" {
|
||||
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
|
||||
}
|
||||
} else {
|
||||
if g := ParseNameFill(s, "").String(); g != s {
|
||||
if g := ParseName(s, FillNothing).DisplayLong(); g != s {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
|
||||
}
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func TestNameConsecutiveDots(t *testing.T) {
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
if w, g := int(NumParts), len(p.parts); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
@@ -148,21 +149,71 @@ func TestParseName(t *testing.T) {
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
name := ParseNameFill(s, "")
|
||||
name := ParseName(s, FillNothing)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseNameFill(name.String(), "").EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
fill string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
|
||||
{"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
|
||||
{"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
|
||||
|
||||
// Invalid
|
||||
{"", "example.com/library/?:latest+Q4_0", ""},
|
||||
{"llama2:?", "example.com/library/?:latest+Q4_0", ""},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
name := ParseName(tt.in, tt.fill)
|
||||
if g := name.DisplayLong(); g != tt.want {
|
||||
t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("invalid fill", func(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic")
|
||||
}
|
||||
}()
|
||||
ParseName("x", "^")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
|
||||
cases := []string{
|
||||
"http://https://valid.com/valid/valid:latest",
|
||||
"https://http://valid.com/valid/valid:latest",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
name := ParseName(s, FillNothing)
|
||||
if name.IsValid() {
|
||||
t.Errorf("expected invalid path; got %#v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
@@ -179,7 +230,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseNameFill(tt.in, "")
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
@@ -194,7 +245,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
@@ -211,7 +262,7 @@ func TestNameLogValue(t *testing.T) {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseNameFill(s, "")
|
||||
name := ParseName(s, FillNothing)
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
@@ -258,7 +309,7 @@ func TestNameGoString(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseNameFill(tt.in, "")
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
@@ -267,6 +318,13 @@ func TestNameGoString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayLongest(t *testing.T) {
|
||||
g := ParseName("example.com/library/mistral:latest+Q4_0", FillNothing).DisplayLongest()
|
||||
if g != "example.com/library/mistral:latest" {
|
||||
t.Errorf("got = %q; want %q", g, "example.com/library/mistral:latest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayShortest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
@@ -286,11 +344,14 @@ func TestDisplayShortest(t *testing.T) {
|
||||
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
|
||||
// zero value
|
||||
{"", MaskDefault, "", true},
|
||||
|
||||
// invalid mask
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
|
||||
|
||||
// DefaultMask
|
||||
{"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
|
||||
{"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
|
||||
|
||||
// Auto-Fill
|
||||
{"x", "example.com/library/_:latest", "x", false},
|
||||
@@ -309,7 +370,7 @@ func TestDisplayShortest(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
p := ParseNameFill(tt.in, "")
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.DisplayShortest(tt.mask); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
@@ -320,7 +381,7 @@ func TestDisplayShortest(t *testing.T) {
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
@@ -331,10 +392,26 @@ func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseNameFromFilepath(f *testing.F) {
|
||||
f.Add("example.com/library/mistral/7b/Q4_0")
|
||||
f.Add("example.com/../mistral/7b/Q4_0")
|
||||
f.Add("example.com/x/../7b/Q4_0")
|
||||
f.Add("example.com/x/../7b")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
name := ParseNameFromFilepath(s, FillNothing)
|
||||
if strings.Contains(s, "..") && !name.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
}
|
||||
if name.IsValid() == name.IsZero() {
|
||||
t.Errorf("expected valid path to be non-zero value; got %#v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
@@ -346,7 +423,7 @@ func FuzzParseName(f *testing.F) {
|
||||
f.Add(":@!@")
|
||||
f.Add("...")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseNameFill(s, "")
|
||||
r0 := ParseName(s, FillNothing)
|
||||
|
||||
if strings.Contains(s, "..") && !r0.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
@@ -359,73 +436,214 @@ func FuzzParseName(f *testing.F) {
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.Parts() {
|
||||
for _, p := range r0.parts {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.String(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
if !strings.EqualFold(r0.DisplayLong(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s)
|
||||
}
|
||||
|
||||
r1 := ParseNameFill(r0.String(), "")
|
||||
r1 := ParseName(r0.DisplayLong(), FillNothing)
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
dst string
|
||||
src string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
keep(name.DisplayLong())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
|
||||
r := Fill(ParseNameFill("mistral", ""), defaults)
|
||||
fmt.Println(r)
|
||||
func TestNamePath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"},
|
||||
|
||||
// Output:
|
||||
// registry.ollama.com/library/mistral:latest+Q4_0
|
||||
// incomplete
|
||||
{"example.com/library/mistral:latest", "example.com/library/mistral:latest"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.URLPath(); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameFilepath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantNoBuild string
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "Example.Com/Library/Mistral:Latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "Example.Com/Library/Mistral:Latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral:latest",
|
||||
want: "example.com/library/mistral/latest",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "",
|
||||
want: "",
|
||||
wantNoBuild: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
g := p.Filepath()
|
||||
g = filepath.ToSlash(g)
|
||||
if g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
g = p.FilepathNoBuild()
|
||||
g = filepath.ToSlash(g)
|
||||
if g != tt.wantNoBuild {
|
||||
t.Errorf("got = %q; want %q", g, tt.wantNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFilepath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
fill string // default is FillNothing
|
||||
want string
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral/latest/Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/latest",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/^/mistral/latest/Q4_0",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/../Q4_0",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/latest/Q4_0/extra",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator))
|
||||
fill := cmp.Or(tt.fill, FillNothing)
|
||||
want := ParseName(tt.want, fill)
|
||||
if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) {
|
||||
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFromPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
fill string // default is FillNothing
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral",
|
||||
want: "example.com/library/mistral",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "/example.com/",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "/example.com/^/mistral/latest",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
fill := cmp.Or(tt.fill, FillNothing)
|
||||
if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want {
|
||||
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
|
||||
m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
|
||||
m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
|
||||
m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseNameFill("mistral:LATest", "").MapHash()] = true
|
||||
m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
@@ -434,15 +652,15 @@ func ExampleName_MapHash() {
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseNameFill("mistral:latest", ""),
|
||||
ParseNameFill("mistRal:7b+q4", ""),
|
||||
ParseNameFill("MIstral:7b", ""),
|
||||
ParseName("mistral:latest", FillNothing),
|
||||
ParseName("mistRal:7b+q4", FillNothing),
|
||||
ParseName("MIstral:7b", FillNothing),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
fmt.Println(n.DisplayLong())
|
||||
}
|
||||
|
||||
// Output:
|
||||
@@ -457,7 +675,7 @@ func ExampleName_completeAndResolved() {
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
name := ParseNameFill(s, "")
|
||||
name := ParseName(s, FillNothing)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
|
||||
}
|
||||
|
||||
@@ -468,7 +686,7 @@ func ExampleName_completeAndResolved() {
|
||||
}
|
||||
|
||||
func ExampleName_DisplayShortest() {
|
||||
name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
|
||||
name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
|
||||
|
||||
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
|
||||
@@ -476,7 +694,7 @@ func ExampleName_DisplayShortest() {
|
||||
fmt.Println(name.DisplayShortest("_/_/_:_"))
|
||||
|
||||
// Default
|
||||
name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
|
||||
name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
|
||||
fmt.Println(name.DisplayShortest(""))
|
||||
|
||||
// Output:
|
||||
|
||||
Reference in New Issue
Block a user