mirror of
https://github.com/ollama/ollama.git
synced 2026-02-27 04:27:01 -05:00
Compare commits
4 Commits
pdevine/sa
...
jessegross
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6586e61525 | ||
|
|
f6229d2464 | ||
|
|
fda19c1282 | ||
|
|
93ca624a7c |
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
||||
TotalSize() uint64
|
||||
MemorySize() (total, vram uint64)
|
||||
VRAMByGPU(id ml.DeviceID) uint64
|
||||
Pid() int
|
||||
GetPort() int
|
||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
totalSize, _ := s.MemorySize()
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMSize() uint64 {
|
||||
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
vram += g.Size()
|
||||
}
|
||||
|
||||
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||
|
||||
// Some elements are always on CPU. However, if we have allocated all layers
|
||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||
noCPULayers := true
|
||||
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
||||
}
|
||||
}
|
||||
if noCPULayers {
|
||||
mem += s.mem.InputWeights
|
||||
mem += s.mem.CPU.Graph
|
||||
vram += s.mem.InputWeights
|
||||
vram += s.mem.CPU.Graph
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *llmServer) TotalSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
mem := s.mem.InputWeights
|
||||
mem += s.mem.CPU.Size()
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
}
|
||||
|
||||
return mem
|
||||
return total, vram
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
|
||||
@@ -71,6 +71,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) IsMLX() bool {
|
||||
return m.Config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1951,6 +1952,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
}
|
||||
if v.llama != nil {
|
||||
mr.ContextLength = v.llama.ContextLength()
|
||||
total, vram := v.llama.MemorySize()
|
||||
mr.Size = int64(total)
|
||||
mr.SizeVRAM = int64(vram)
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
@@ -2213,6 +2217,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
if m.IsMLX() {
|
||||
truncate = false
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
|
||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if pending.model.IsMLX() {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
@@ -536,6 +536,7 @@ iGPUScan:
|
||||
}
|
||||
}
|
||||
|
||||
totalSize, vramSize := llama.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -545,8 +546,8 @@ iGPUScan:
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
totalSize, vramSize := server.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
||||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) GetPort() int { return -1 }
|
||||
|
||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
// MemorySize returns the total and VRAM memory usage.
|
||||
func (s *Server) MemorySize() (total, vram uint64) {
|
||||
return s.vramSize, s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -21,23 +20,24 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
contextLength int
|
||||
memory uint
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
@@ -98,18 +98,9 @@ func NewClient(modelName string) (*Client, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
cmd: cmd,
|
||||
@@ -201,6 +192,19 @@ type completionOpts struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
@@ -260,28 +264,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
var raw CompletionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
EvalDuration: raw.EvalDuration,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -294,7 +294,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
return c.contextLength
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
@@ -347,9 +347,16 @@ func (c *Client) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status int
|
||||
Progress int
|
||||
ContextLength int
|
||||
Memory uint
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -362,6 +369,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var status statusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.contextLength = status.ContextLength
|
||||
c.memory = status.Memory
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -388,19 +404,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
return c.vramSize
|
||||
func (c *Client) currentMemory() uint64 {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := c.Ping(ctx); err != nil {
|
||||
slog.Warn("failed to get current memory", "error", err)
|
||||
}
|
||||
return uint64(c.memory)
|
||||
}
|
||||
|
||||
// MemorySize implements llm.LlamaServer.
|
||||
func (c *Client) MemorySize() (total, vram uint64) {
|
||||
mem := c.currentMemory()
|
||||
return mem, mem
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
return c.vramSize
|
||||
return c.currentMemory()
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
|
||||
@@ -20,6 +20,7 @@ type Model interface {
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -46,12 +49,28 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
if len(inputs) >= r.contextLength {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.MaxTokens <= 0 {
|
||||
request.Options.MaxTokens = maxGenerate
|
||||
} else {
|
||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||
}
|
||||
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
@@ -93,8 +112,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
@@ -105,7 +123,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
@@ -113,18 +131,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
final.EvalCount = i
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
}:
|
||||
}
|
||||
|
||||
@@ -137,7 +153,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
final.EvalDuration = time.Since(now)
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
|
||||
@@ -4,14 +4,15 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
contextLength int
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -50,9 +50,11 @@ func Execute(args []string) error {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -78,7 +80,7 @@ func Execute(args []string) error {
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
@@ -87,9 +89,6 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
|
||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user