mirror of
https://github.com/ollama/ollama.git
synced 2025-12-30 19:19:41 -05:00
Compare commits
12 Commits
jmorganca/
...
parth/agen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96d69ee2b2 | ||
|
|
89f74a8b05 | ||
|
|
ca43de117f | ||
|
|
7ff2b373f4 | ||
|
|
805177c054 | ||
|
|
6f9fc4e1bf | ||
|
|
fc62078ba4 | ||
|
|
d08c33faa0 | ||
|
|
253b035b4a | ||
|
|
d4f9bd5fe5 | ||
|
|
18fdcc94e5 | ||
|
|
7ad036992f |
22
api/types.go
22
api/types.go
@@ -17,6 +17,12 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillRef is an alias for model.SkillRef representing a skill reference.
|
||||
type SkillRef = model.SkillRef
|
||||
|
||||
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
|
||||
type MCPRef = model.MCPRef
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
@@ -557,6 +563,18 @@ type CreateRequest struct {
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Skills is a list of skill references for the agent (local paths or registry refs)
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
|
||||
// MCPs is a list of MCP server references for the agent
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
|
||||
// AgentType defines the type of agent (e.g., "conversational", "task-based")
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
|
||||
// Entrypoint specifies an external command to run instead of the built-in chat loop
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -608,6 +626,10 @@ type ShowResponse struct {
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
444
cmd/cmd.go
444
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -494,6 +495,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
// Check if this is an agent
|
||||
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
|
||||
if isAgent {
|
||||
opts.IsAgent = true
|
||||
opts.AgentType = info.AgentType
|
||||
opts.Skills = info.Skills
|
||||
opts.MCPs = info.MCPs
|
||||
opts.Entrypoint = info.Entrypoint
|
||||
}
|
||||
|
||||
// Check if this is an embedding model
|
||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||
|
||||
@@ -517,6 +528,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// If agent has entrypoint, run it instead of chat loop
|
||||
if opts.Entrypoint != "" {
|
||||
return runEntrypoint(cmd, opts)
|
||||
}
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -545,9 +561,62 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
// For agents, use chat API even in non-interactive mode to support tools
|
||||
if opts.IsAgent {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
|
||||
_, err := chat(cmd, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
|
||||
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
|
||||
entrypoint := opts.Entrypoint
|
||||
|
||||
// Check if entrypoint contains $PROMPT placeholder
|
||||
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
|
||||
|
||||
if hasPlaceholder && opts.Prompt != "" {
|
||||
// Replace $PROMPT with the actual prompt
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
|
||||
} else if hasPlaceholder {
|
||||
// No prompt provided but placeholder exists - remove placeholder
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
|
||||
}
|
||||
|
||||
// Parse entrypoint into command and args
|
||||
parts := strings.Fields(entrypoint)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("empty entrypoint")
|
||||
}
|
||||
|
||||
command := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
// If user provided a prompt and no placeholder was used, append it as argument
|
||||
if opts.Prompt != "" && !hasPlaceholder {
|
||||
args = append(args, opts.Prompt)
|
||||
}
|
||||
|
||||
// Look up command in PATH
|
||||
execPath, err := exec.LookPath(command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("entrypoint command not found: %s", command)
|
||||
}
|
||||
|
||||
// Create subprocess
|
||||
proc := exec.Command(execPath, args...)
|
||||
proc.Stdin = os.Stdin
|
||||
proc.Stdout = os.Stdout
|
||||
proc.Stderr = os.Stderr
|
||||
|
||||
// Run and wait
|
||||
return proc.Run()
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -907,47 +976,96 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
// Only show Model section if there's actual model info (not for entrypoint-only agents)
|
||||
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
|
||||
if hasModelInfo {
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// Display agent information if this is an agent
|
||||
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
|
||||
tableRender("Agent", func() (rows [][]string) {
|
||||
if resp.AgentType != "" {
|
||||
rows = append(rows, []string{"", "type", resp.AgentType})
|
||||
}
|
||||
if resp.Entrypoint != "" {
|
||||
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
|
||||
}
|
||||
if len(resp.Skills) > 0 {
|
||||
for i, skill := range resp.Skills {
|
||||
label := "skill"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show skill name or digest
|
||||
skillDisplay := skill.Name
|
||||
if skillDisplay == "" && skill.Digest != "" {
|
||||
skillDisplay = skill.Digest[:12] + "..."
|
||||
}
|
||||
rows = append(rows, []string{"", label, skillDisplay})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
if len(resp.MCPs) > 0 {
|
||||
for i, mcp := range resp.MCPs {
|
||||
label := "mcp"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show MCP name and command
|
||||
mcpDisplay := mcp.Name
|
||||
if mcp.Command != "" {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
mcpDisplay += " (" + cmdLine + ")"
|
||||
}
|
||||
rows = append(rows, []string{"", label, mcpDisplay})
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
@@ -1189,6 +1307,11 @@ type runOptions struct {
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
IsAgent bool
|
||||
AgentType string
|
||||
Skills []api.SkillRef
|
||||
MCPs []api.MCPRef
|
||||
Entrypoint string
|
||||
}
|
||||
|
||||
func (r runOptions) Copy() runOptions {
|
||||
@@ -1218,6 +1341,12 @@ func (r runOptions) Copy() runOptions {
|
||||
think = &cThink
|
||||
}
|
||||
|
||||
var skills []api.SkillRef
|
||||
if r.Skills != nil {
|
||||
skills = make([]api.SkillRef, len(r.Skills))
|
||||
copy(skills, r.Skills)
|
||||
}
|
||||
|
||||
return runOptions{
|
||||
Model: r.Model,
|
||||
ParentModel: r.ParentModel,
|
||||
@@ -1233,6 +1362,9 @@ func (r runOptions) Copy() runOptions {
|
||||
Think: think,
|
||||
HideThinking: r.HideThinking,
|
||||
ShowConnect: r.ShowConnect,
|
||||
IsAgent: r.IsAgent,
|
||||
AgentType: r.AgentType,
|
||||
Skills: skills,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1316,6 +1448,65 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load skills for agents
|
||||
var skillsCatalog *skillCatalog
|
||||
if opts.IsAgent && len(opts.Skills) > 0 {
|
||||
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
|
||||
var skillNames []string
|
||||
for _, s := range skillsCatalog.Skills {
|
||||
skillNames = append(skillNames, s.Name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Load MCP servers for agents (from opts and global config)
|
||||
var mcpMgr *mcpManager
|
||||
allMCPs := opts.MCPs
|
||||
|
||||
// Load global MCPs from ~/.ollama/mcp.json
|
||||
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Skip disabled MCPs
|
||||
if srv.Disabled {
|
||||
continue
|
||||
}
|
||||
// Check if already in opts.MCPs (model takes precedence)
|
||||
found := false
|
||||
for _, m := range opts.MCPs {
|
||||
if m.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) > 0 {
|
||||
mcpMgr = newMCPManager()
|
||||
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
|
||||
}
|
||||
if mcpMgr.ToolCount() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
|
||||
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
|
||||
}
|
||||
defer mcpMgr.Shutdown()
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
@@ -1339,6 +1530,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
|
||||
role := "assistant"
|
||||
|
||||
@@ -1379,7 +1571,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
if skillsCatalog != nil || mcpMgr != nil {
|
||||
// Store tool calls for execution after response is complete
|
||||
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||
} else {
|
||||
// No skills catalog or MCP, just display tool calls
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1392,31 +1590,159 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
opts.Format = `"` + opts.Format + `"`
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
// Prepare messages with agent-specific system prompt
|
||||
messages := opts.Messages
|
||||
if skillsCatalog != nil {
|
||||
// Add skills system prompt as the first system message
|
||||
skillsPrompt := skillsCatalog.SystemPrompt()
|
||||
if skillsPrompt != "" {
|
||||
// Insert skills prompt at the beginning, or append to existing system message
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
// Append to existing system message
|
||||
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
|
||||
} else {
|
||||
// Insert new system message at the beginning
|
||||
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
|
||||
messages = append([]api.Message{systemMsg}, messages...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
// Add tools for agents (combine skills and MCP tools)
|
||||
var allTools api.Tools
|
||||
if skillsCatalog != nil {
|
||||
allTools = append(allTools, skillsCatalog.Tools()...)
|
||||
}
|
||||
return nil, err
|
||||
if mcpMgr != nil {
|
||||
allTools = append(allTools, mcpMgr.Tools()...)
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
req.Tools = allTools
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute tool calls and continue the conversation
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Add assistant's tool call message to history
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: fullResponse.String(),
|
||||
ToolCalls: pendingToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Execute each tool call and collect results
|
||||
var toolResults []api.Message
|
||||
for _, call := range pendingToolCalls {
|
||||
// Show what's being executed
|
||||
switch call.Function.Name {
|
||||
case "run_skill_script":
|
||||
skill, _ := call.Function.Arguments["skill"].(string)
|
||||
command, _ := call.Function.Arguments["command"].(string)
|
||||
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
|
||||
case "read_skill_file":
|
||||
skill, _ := call.Function.Arguments["skill"].(string)
|
||||
path, _ := call.Function.Arguments["path"].(string)
|
||||
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
|
||||
}
|
||||
|
||||
var result api.Message
|
||||
var handled bool
|
||||
var err error
|
||||
|
||||
// Try skill catalog first
|
||||
if skillsCatalog != nil {
|
||||
result, handled, err = skillsCatalog.RunToolCall(call)
|
||||
}
|
||||
|
||||
// If not handled by skills, try MCP
|
||||
if !handled && mcpMgr != nil {
|
||||
result, handled, err = mcpMgr.RunToolCall(call)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
// Add error result
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if !handled {
|
||||
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Display tool output
|
||||
if result.Content != "" {
|
||||
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
|
||||
}
|
||||
|
||||
// Add tool result to messages
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: result.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool results to message history
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Reset state for next iteration
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
|
||||
// Start new progress spinner for next API call
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
@@ -1908,6 +2234,8 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
NewSkillCommand(),
|
||||
NewMCPCommand(),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -34,6 +34,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /skills Show available skills")
|
||||
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
@@ -443,6 +446,411 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/skill "):
|
||||
args := strings.Fields(line)
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current skills")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill add <path>")
|
||||
continue
|
||||
}
|
||||
skillPath := args[2]
|
||||
|
||||
// Expand ~ to home directory
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Printf("Error expanding path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
// Make absolute
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error resolving path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify SKILL.md exists
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract skill name from SKILL.md
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading SKILL.md: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillName, _ := extractSkillMetadata(string(content))
|
||||
if skillName == "" {
|
||||
skillName = filepath.Base(absPath)
|
||||
}
|
||||
|
||||
// Check if already added
|
||||
for _, s := range opts.Skills {
|
||||
if s.Name == skillName {
|
||||
fmt.Printf("Skill '%s' is already loaded\n", skillName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Add to skills (using path as Name, no digest for local skills)
|
||||
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
|
||||
opts.IsAgent = true // Enable agent mode if not already
|
||||
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill remove <name>")
|
||||
continue
|
||||
}
|
||||
skillName := args[2]
|
||||
|
||||
found := false
|
||||
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
|
||||
for _, s := range opts.Skills {
|
||||
// Match by name or by path basename
|
||||
name := s.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name)
|
||||
}
|
||||
if name == skillName || s.Name == skillName {
|
||||
found = true
|
||||
fmt.Printf("Removed skill '%s'\n", skillName)
|
||||
} else {
|
||||
newSkills = append(newSkills, s)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fmt.Printf("Skill '%s' not found\n", skillName)
|
||||
} else {
|
||||
opts.Skills = newSkills
|
||||
}
|
||||
|
||||
case "list", "ls":
|
||||
if len(opts.Skills) == 0 {
|
||||
fmt.Println("No skills loaded in this session.")
|
||||
} else {
|
||||
fmt.Println("Skills loaded in this session:")
|
||||
for _, skill := range opts.Skills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
// For local paths, show basename
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (local: " + skill.Name + ")"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/skills"):
|
||||
// Show skills from model (bundled) + session skills
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model skills with session skills
|
||||
allSkills := make([]api.SkillRef, 0)
|
||||
allSkills = append(allSkills, resp.Skills...)
|
||||
|
||||
// Add session skills that aren't already in model skills
|
||||
for _, sessionSkill := range opts.Skills {
|
||||
found := false
|
||||
for _, modelSkill := range resp.Skills {
|
||||
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allSkills = append(allSkills, sessionSkill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allSkills) == 0 {
|
||||
fmt.Println("No skills available.")
|
||||
} else {
|
||||
fmt.Println("Available Skills:")
|
||||
for _, skill := range allSkills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (session)"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/mcp"):
|
||||
args := strings.Fields(line)
|
||||
|
||||
// If just "/mcp" with no args, show all MCP servers
|
||||
if len(args) == 1 {
|
||||
// Show MCPs from model (bundled) + global config
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model MCPs with global config MCPs
|
||||
allMCPs := make([]api.MCPRef, 0)
|
||||
allMCPs = append(allMCPs, resp.MCPs...)
|
||||
|
||||
// Load global config
|
||||
globalConfig, _ := loadMCPConfig()
|
||||
globalMCPNames := make(map[string]bool)
|
||||
|
||||
if globalConfig != nil {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Check if already in model MCPs
|
||||
found := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
globalMCPNames[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) == 0 {
|
||||
fmt.Println("No MCP servers available.")
|
||||
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
|
||||
} else {
|
||||
fmt.Println("Available MCP Servers:")
|
||||
for _, mcp := range allMCPs {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
source := ""
|
||||
disabled := ""
|
||||
// Check if it's from model or global config
|
||||
isFromModel := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == mcp.Name {
|
||||
isFromModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFromModel {
|
||||
source = " (model)"
|
||||
} else if globalMCPNames[mcp.Name] {
|
||||
source = " (global)"
|
||||
// Check if disabled
|
||||
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
|
||||
disabled = " [disabled]"
|
||||
}
|
||||
}
|
||||
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /mcp add <name> <command> [args...]")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
mcpCommand := args[3]
|
||||
mcpArgs := args[4:]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := config.MCPServers[mcpName]; exists {
|
||||
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
|
||||
}
|
||||
|
||||
// Add to global config
|
||||
config.MCPServers[mcpName] = MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: mcpCommand,
|
||||
Args: mcpArgs,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmdLine := mcpCommand
|
||||
if len(mcpArgs) > 0 {
|
||||
cmdLine += " " + strings.Join(mcpArgs, " ")
|
||||
}
|
||||
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
|
||||
fmt.Println("Note: MCP server will be started on next message.")
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp remove <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := config.MCPServers[mcpName]; !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
delete(config.MCPServers, mcpName)
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "disable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp disable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = true
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "enable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp enable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if !srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = false
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
@@ -451,6 +859,20 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "skill", "/skill":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "mcp", "/mcp":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
|
||||
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
|
||||
545
cmd/mcp.go
Normal file
545
cmd/mcp.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
mcpInitTimeout = 30 * time.Second
|
||||
mcpCallTimeout = 60 * time.Second
|
||||
mcpShutdownTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// JSON-RPC types
|
||||
type jsonrpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type jsonrpcResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *jsonrpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type jsonrpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// MCP protocol types
|
||||
type mcpInitializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
ClientInfo mcpClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
type mcpClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpInitializeResult struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities mcpCapabilities `json:"capabilities"`
|
||||
ServerInfo mcpServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
type mcpCapabilities struct {
|
||||
Tools *mcpToolsCapability `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolsCapability struct {
|
||||
ListChanged bool `json:"listChanged,omitempty"`
|
||||
}
|
||||
|
||||
type mcpServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema mcpToolInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
type mcpToolInputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolsListResult struct {
|
||||
Tools []mcpTool `json:"tools"`
|
||||
}
|
||||
|
||||
type mcpToolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolCallResult struct {
|
||||
Content []mcpContent `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type mcpContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// mcpServer represents a running MCP server process
|
||||
type mcpServer struct {
|
||||
ref api.MCPRef
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
tools []mcpTool
|
||||
mu sync.Mutex
|
||||
nextID int
|
||||
started bool
|
||||
}
|
||||
|
||||
// mcpManager manages multiple MCP servers for an agent session
|
||||
type mcpManager struct {
|
||||
servers map[string]*mcpServer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// newMCPManager creates a new MCP manager
|
||||
func newMCPManager() *mcpManager {
|
||||
return &mcpManager{
|
||||
servers: make(map[string]*mcpServer),
|
||||
}
|
||||
}
|
||||
|
||||
// loadMCPsFromRefs initializes MCP servers from refs
|
||||
func (m *mcpManager) loadMCPsFromRefs(refs []api.MCPRef) error {
|
||||
if len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ref := range refs {
|
||||
if err := m.addServer(ref); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to initialize MCP server %q: %v\n", ref.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addServer adds and starts an MCP server
|
||||
func (m *mcpManager) addServer(ref api.MCPRef) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.servers[ref.Name]; exists {
|
||||
return fmt.Errorf("MCP server %q already exists", ref.Name)
|
||||
}
|
||||
|
||||
srv := &mcpServer{
|
||||
ref: ref,
|
||||
nextID: 1,
|
||||
}
|
||||
|
||||
if err := srv.start(); err != nil {
|
||||
return fmt.Errorf("starting MCP server: %w", err)
|
||||
}
|
||||
|
||||
m.servers[ref.Name] = srv
|
||||
return nil
|
||||
}
|
||||
|
||||
// start starts the MCP server process
|
||||
func (s *mcpServer) start() error {
|
||||
s.mu.Lock()
|
||||
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
s.cmd = exec.Command(s.ref.Command, s.ref.Args...)
|
||||
|
||||
// Set environment
|
||||
s.cmd.Env = os.Environ()
|
||||
for k, v := range s.ref.Env {
|
||||
s.cmd.Env = append(s.cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
|
||||
var err error
|
||||
s.stdin, err = s.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := s.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stdout pipe: %w", err)
|
||||
}
|
||||
s.stdout = bufio.NewReader(stdout)
|
||||
|
||||
s.stderr, err = s.cmd.StderrPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start stderr reader goroutine (discard stderr for now)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(s.stderr)
|
||||
for scanner.Scan() {
|
||||
_ = scanner.Text()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := s.cmd.Start(); err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("starting process: %w", err)
|
||||
}
|
||||
|
||||
s.started = true
|
||||
s.mu.Unlock() // Release lock before calling initialize/listTools which use the mutex
|
||||
|
||||
// Initialize the server
|
||||
if err := s.initialize(); err != nil {
|
||||
s.stop()
|
||||
return fmt.Errorf("initializing MCP server: %w", err)
|
||||
}
|
||||
|
||||
// Get available tools
|
||||
if err := s.listTools(); err != nil {
|
||||
s.stop()
|
||||
return fmt.Errorf("listing tools: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initialize sends the MCP initialize request
|
||||
func (s *mcpServer) initialize() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
||||
defer cancel()
|
||||
|
||||
params := mcpInitializeParams{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]any{},
|
||||
ClientInfo: mcpClientInfo{
|
||||
Name: "ollama",
|
||||
Version: "0.1.0",
|
||||
},
|
||||
}
|
||||
|
||||
var result mcpInitializeResult
|
||||
if err := s.call(ctx, "initialize", params, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send initialized notification
|
||||
return s.notify("notifications/initialized", nil)
|
||||
}
|
||||
|
||||
// listTools fetches the available tools from the MCP server
|
||||
func (s *mcpServer) listTools() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
||||
defer cancel()
|
||||
|
||||
var result mcpToolsListResult
|
||||
if err := s.call(ctx, "tools/list", nil, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.tools = result.Tools
|
||||
return nil
|
||||
}
|
||||
|
||||
// call sends a JSON-RPC request and waits for the response
|
||||
func (s *mcpServer) call(ctx context.Context, method string, params any, result any) error {
|
||||
s.mu.Lock()
|
||||
id := s.nextID
|
||||
s.nextID++
|
||||
s.mu.Unlock()
|
||||
|
||||
req := jsonrpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Send request
|
||||
s.mu.Lock()
|
||||
_, err = s.stdin.Write(append(reqBytes, '\n'))
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing request: %w", err)
|
||||
}
|
||||
|
||||
// Read response with timeout
|
||||
respCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
s.mu.Lock()
|
||||
line, err := s.stdout.ReadBytes('\n')
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
respCh <- line
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
return fmt.Errorf("reading response: %w", err)
|
||||
case line := <-respCh:
|
||||
var resp jsonrpcResponse
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return fmt.Errorf("unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("MCP error %d: %s", resp.Error.Code, resp.Error.Message)
|
||||
}
|
||||
|
||||
if result != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, result); err != nil {
|
||||
return fmt.Errorf("unmarshaling result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends a JSON-RPC notification (no response expected)
|
||||
func (s *mcpServer) notify(method string, params any) error {
|
||||
req := jsonrpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling notification: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, err := s.stdin.Write(append(reqBytes, '\n')); err != nil {
|
||||
return fmt.Errorf("writing notification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callTool executes a tool call on the MCP server
|
||||
func (s *mcpServer) callTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
|
||||
params := mcpToolCallParams{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
var result mcpToolCallResult
|
||||
if err := s.call(ctx, "tools/call", params, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Concatenate text content
|
||||
var sb strings.Builder
|
||||
for _, content := range result.Content {
|
||||
if content.Type == "text" {
|
||||
sb.WriteString(content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return sb.String(), errors.New(sb.String())
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// stop shuts down the MCP server
|
||||
func (s *mcpServer) stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stdin to signal shutdown
|
||||
if s.stdin != nil {
|
||||
s.stdin.Close()
|
||||
}
|
||||
|
||||
// Wait for process with timeout
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- s.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(mcpShutdownTimeout):
|
||||
s.cmd.Process.Kill()
|
||||
case <-done:
|
||||
}
|
||||
|
||||
s.started = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tools returns all tools from all MCP servers as api.Tools
|
||||
func (m *mcpManager) Tools() api.Tools {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var tools api.Tools
|
||||
|
||||
for serverName, srv := range m.servers {
|
||||
for _, t := range srv.tools {
|
||||
// Namespace tool names: mcp_{servername}_{toolname}
|
||||
namespacedName := fmt.Sprintf("mcp_%s_%s", serverName, t.Name)
|
||||
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: namespacedName,
|
||||
Description: t.Description,
|
||||
Parameters: convertMCPSchema(t.InputSchema),
|
||||
},
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertMCPSchema converts MCP input schema to api.ToolFunctionParameters
|
||||
func convertMCPSchema(schema mcpToolInputSchema) api.ToolFunctionParameters {
|
||||
params := api.ToolFunctionParameters{
|
||||
Type: schema.Type,
|
||||
Required: schema.Required,
|
||||
Properties: make(map[string]api.ToolProperty),
|
||||
}
|
||||
|
||||
for name, prop := range schema.Properties {
|
||||
if propMap, ok := prop.(map[string]any); ok {
|
||||
tp := api.ToolProperty{}
|
||||
if t, ok := propMap["type"].(string); ok {
|
||||
tp.Type = api.PropertyType{t}
|
||||
}
|
||||
if d, ok := propMap["description"].(string); ok {
|
||||
tp.Description = d
|
||||
}
|
||||
params.Properties[name] = tp
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// RunToolCall routes a tool call to the appropriate MCP server
|
||||
func (m *mcpManager) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
||||
name := call.Function.Name
|
||||
|
||||
// Check if this is an MCP tool (mcp_servername_toolname)
|
||||
if !strings.HasPrefix(name, "mcp_") {
|
||||
return api.Message{}, false, nil
|
||||
}
|
||||
|
||||
// Parse server name and tool name
|
||||
rest := strings.TrimPrefix(name, "mcp_")
|
||||
idx := strings.Index(rest, "_")
|
||||
if idx == -1 {
|
||||
return toolMessage(call, fmt.Sprintf("invalid MCP tool name: %s", name)), true, nil
|
||||
}
|
||||
|
||||
serverName := rest[:idx]
|
||||
toolName := rest[idx+1:]
|
||||
|
||||
m.mu.RLock()
|
||||
srv, ok := m.servers[serverName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("MCP server %q not found", serverName)), true, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := srv.callTool(ctx, toolName, call.Function.Arguments)
|
||||
if err != nil {
|
||||
return toolMessage(call, fmt.Sprintf("error: %v", err)), true, nil
|
||||
}
|
||||
|
||||
return toolMessage(call, result), true, nil
|
||||
}
|
||||
|
||||
// Shutdown stops all MCP servers
|
||||
func (m *mcpManager) Shutdown() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, srv := range m.servers {
|
||||
srv.stop()
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*mcpServer)
|
||||
}
|
||||
|
||||
// ServerNames returns the names of all running MCP servers
|
||||
func (m *mcpManager) ServerNames() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(m.servers))
|
||||
for name := range m.servers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ToolCount returns the total number of tools across all servers
|
||||
func (m *mcpManager) ToolCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, srv := range m.servers {
|
||||
count += len(srv.tools)
|
||||
}
|
||||
return count
|
||||
}
|
||||
898
cmd/mcp_cmd.go
Normal file
898
cmd/mcp_cmd.go
Normal file
@@ -0,0 +1,898 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MCPConfigFile represents the global MCP configuration file structure.
|
||||
type MCPConfigFile struct {
|
||||
MCPServers map[string]MCPServerConfig `json:"mcpServers"`
|
||||
}
|
||||
|
||||
// MCPServerConfig represents a single MCP server configuration.
|
||||
type MCPServerConfig struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
// getMCPConfigPath returns the path to the global MCP config file.
|
||||
func getMCPConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "mcp.json")
|
||||
}
|
||||
|
||||
// loadMCPConfig loads the global MCP configuration file.
|
||||
func loadMCPConfig() (*MCPConfigFile, error) {
|
||||
configPath := getMCPConfigPath()
|
||||
if configPath == "" {
|
||||
return nil, fmt.Errorf("could not determine home directory")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Return empty config if file doesn't exist
|
||||
return &MCPConfigFile{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("reading config: %w", err)
|
||||
}
|
||||
|
||||
var config MCPConfigFile
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if config.MCPServers == nil {
|
||||
config.MCPServers = make(map[string]MCPServerConfig)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// saveMCPConfig saves the global MCP configuration file.
|
||||
func saveMCPConfig(config *MCPConfigFile) error {
|
||||
configPath := getMCPConfigPath()
|
||||
if configPath == "" {
|
||||
return fmt.Errorf("could not determine home directory")
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating config directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
return fmt.Errorf("writing config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPAddHandler handles the mcp add command.
|
||||
func MCPAddHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: ollama mcp add NAME COMMAND [ARGS...]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
command := args[1]
|
||||
cmdArgs := args[2:]
|
||||
|
||||
// Load existing config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := config.MCPServers[name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: overwriting existing MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
// Add the new server
|
||||
config.MCPServers[name] = MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: command,
|
||||
Args: cmdArgs,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
configPath := getMCPConfigPath()
|
||||
fmt.Fprintf(os.Stderr, "Added MCP server '%s' to %s\n", name, configPath)
|
||||
fmt.Fprintf(os.Stderr, " Command: %s %s\n", command, strings.Join(cmdArgs, " "))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPRemoveGlobalHandler handles removing an MCP from global config.
|
||||
func MCPRemoveGlobalHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp remove-global NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
if _, exists := config.MCPServers[name]; !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
delete(config.MCPServers, name)
|
||||
fmt.Fprintf(os.Stderr, "Removed MCP server '%s' from global config\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPListGlobalHandler handles listing global MCP servers.
|
||||
func MCPListGlobalHandler(cmd *cobra.Command, args []string) error {
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
if len(config.MCPServers) == 0 {
|
||||
fmt.Println("No global MCP servers configured")
|
||||
fmt.Printf("Add one with: ollama mcp add NAME COMMAND [ARGS...]\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Global MCP servers (%s):\n\n", getMCPConfigPath())
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tCOMMAND\tSTATUS")
|
||||
|
||||
for name, srv := range config.MCPServers {
|
||||
cmdLine := srv.Command
|
||||
if len(srv.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(srv.Args, " ")
|
||||
}
|
||||
status := "enabled"
|
||||
if srv.Disabled {
|
||||
status = "disabled"
|
||||
}
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\n", name, cmdLine, status)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// MCPDisableHandler handles disabling an MCP server in global config.
|
||||
func MCPDisableHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp disable NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
srv, exists := config.MCPServers[name]
|
||||
if !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
if srv.Disabled {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' is already disabled\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = true
|
||||
config.MCPServers[name] = srv
|
||||
fmt.Fprintf(os.Stderr, "Disabled MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPEnableHandler handles enabling an MCP server in global config.
|
||||
func MCPEnableHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp enable NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
srv, exists := config.MCPServers[name]
|
||||
if !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
if !srv.Disabled {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' is already enabled\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = false
|
||||
config.MCPServers[name] = srv
|
||||
fmt.Fprintf(os.Stderr, "Enabled MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPPushHandler handles the mcp push command.
|
||||
func MCPPushHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("usage: ollama mcp push NAME[:TAG] PATH")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
path := args[1]
|
||||
|
||||
// Expand path
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving path: %w", err)
|
||||
}
|
||||
|
||||
// Validate MCP directory - check for mcp.json, package.json, or any config file
|
||||
validFiles := []string{"mcp.json", "package.json", "server.py", "server.js", "main.py", "index.js"}
|
||||
found := false
|
||||
for _, vf := range validFiles {
|
||||
if _, err := os.Stat(filepath.Join(absPath, vf)); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("MCP directory should contain one of: %s", strings.Join(validFiles, ", "))
|
||||
}
|
||||
|
||||
// Parse MCP name (will set Kind="mcp")
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Create MCP layer
|
||||
displayName := n.DisplayShortest()
|
||||
status := fmt.Sprintf("Creating MCP layer for %s", displayName)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
layer, err := server.CreateMCPLayer(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating MCP layer: %w", err)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
|
||||
// Create MCP manifest
|
||||
manifest, configLayer, err := createMCPManifest(absPath, layer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating MCP manifest: %w", err)
|
||||
}
|
||||
|
||||
// Write manifest locally
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "MCP %s created locally\n", displayName)
|
||||
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
|
||||
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
|
||||
|
||||
// Push to registry
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.PushRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Push(context.Background(), req, fn); err != nil {
|
||||
// If push fails, still show success for local creation
|
||||
fmt.Fprintf(os.Stderr, "\nNote: Local MCP created but push failed: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama mcp push %s\n", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPPullHandler handles the mcp pull command.
|
||||
func MCPPullHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama mcp pull NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
req := &api.PullRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Pull(context.Background(), req, fn); err != nil {
|
||||
return fmt.Errorf("pulling MCP: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPListHandler handles the mcp list command.
|
||||
func MCPListHandler(cmd *cobra.Command, args []string) error {
|
||||
mcps, err := listLocalMCPs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing MCPs: %w", err)
|
||||
}
|
||||
|
||||
if len(mcps) == 0 {
|
||||
fmt.Println("No MCPs installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
|
||||
|
||||
for _, mcp := range mcps {
|
||||
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
|
||||
mcp.Namespace,
|
||||
mcp.Name,
|
||||
mcp.Tag,
|
||||
format.HumanBytes(mcp.Size),
|
||||
format.HumanTime(mcp.ModifiedAt, "Never"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// MCPRemoveHandler handles the mcp rm command.
|
||||
func MCPRemoveHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp rm NAME[:TAG] [NAME[:TAG]...]")
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
fmt.Fprintf(os.Stderr, "Invalid MCP name: %s\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "MCP not found: %s\n", displayName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(manifestPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up empty parent directories
|
||||
dir := filepath.Dir(manifestPath)
|
||||
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
if len(entries) == 0 {
|
||||
os.Remove(dir)
|
||||
dir = filepath.Dir(dir)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPShowHandler handles the mcp show command.
|
||||
func MCPShowHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama mcp show NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("MCP not found: %s", displayName)
|
||||
}
|
||||
return fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("MCP: %s\n\n", displayName)
|
||||
|
||||
fmt.Println("Layers:")
|
||||
for _, layer := range manifest.Layers {
|
||||
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
|
||||
}
|
||||
|
||||
// Try to read and display mcp.json or package.json content
|
||||
if len(manifest.Layers) > 0 {
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == server.MediaTypeMCP {
|
||||
mcpPath, err := server.GetMCPsPath(layer.Digest)
|
||||
if err == nil {
|
||||
// Try mcp.json first
|
||||
mcpJSONPath := filepath.Join(mcpPath, "mcp.json")
|
||||
if content, err := os.ReadFile(mcpJSONPath); err == nil {
|
||||
fmt.Println("\nConfig (mcp.json):")
|
||||
fmt.Println(string(content))
|
||||
} else {
|
||||
// Try package.json
|
||||
pkgJSONPath := filepath.Join(mcpPath, "package.json")
|
||||
if content, err := os.ReadFile(pkgJSONPath); err == nil {
|
||||
fmt.Println("\nConfig (package.json):")
|
||||
fmt.Println(string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// List files in the MCP
|
||||
fmt.Println("\nFiles:")
|
||||
filepath.Walk(mcpPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
relPath, _ := filepath.Rel(mcpPath, path)
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
fmt.Printf(" %s/\n", relPath)
|
||||
} else {
|
||||
fmt.Printf(" %s (%s)\n", relPath, format.HumanBytes(info.Size()))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPInfo represents information about an installed MCP.
|
||||
type MCPInfo struct {
|
||||
Namespace string
|
||||
Name string
|
||||
Tag string
|
||||
Size int64
|
||||
ModifiedAt time.Time
|
||||
}
|
||||
|
||||
// listLocalMCPs returns a list of locally installed MCPs.
|
||||
// MCPs are stored with 5-part paths: host/namespace/kind/model/tag
|
||||
// where kind is "mcp".
|
||||
func listLocalMCPs() ([]MCPInfo, error) {
|
||||
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
|
||||
|
||||
var mcps []MCPInfo
|
||||
|
||||
// Walk through all registries
|
||||
registries, err := os.ReadDir(manifestsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcps, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, registry := range registries {
|
||||
if !registry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk namespaces
|
||||
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
if !namespace.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk kinds looking for "mcp"
|
||||
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kind := range kinds {
|
||||
if !kind.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process mcp kind
|
||||
if kind.Name() != server.MCPNamespace {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk MCP names (model names)
|
||||
mcpNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, mcpName := range mcpNames {
|
||||
if !mcpName.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk tags
|
||||
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name(), tag.Name())
|
||||
fi, err := os.Stat(manifestPath)
|
||||
if err != nil || fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read manifest to get size
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Layers {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
|
||||
// Build display name using model.Name
|
||||
n := model.Name{
|
||||
Host: registry.Name(),
|
||||
Namespace: namespace.Name(),
|
||||
Kind: kind.Name(),
|
||||
Model: mcpName.Name(),
|
||||
Tag: tag.Name(),
|
||||
}
|
||||
|
||||
mcps = append(mcps, MCPInfo{
|
||||
Namespace: n.Namespace + "/" + n.Kind,
|
||||
Name: n.Model,
|
||||
Tag: n.Tag,
|
||||
Size: totalSize,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mcps, nil
|
||||
}
|
||||
|
||||
// createMCPManifest creates a manifest for a standalone MCP.
|
||||
func createMCPManifest(mcpDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
|
||||
// Try to read mcp.json or package.json to extract metadata
|
||||
name, description := extractMCPMetadata(mcpDir)
|
||||
if name == "" {
|
||||
// Use directory name as fallback
|
||||
name = filepath.Base(mcpDir)
|
||||
}
|
||||
|
||||
// Create config
|
||||
config := map[string]any{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"architecture": "amd64",
|
||||
"os": "linux",
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer
|
||||
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating config layer: %w", err)
|
||||
}
|
||||
|
||||
manifest := &server.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: configLayer,
|
||||
Layers: []server.Layer{layer},
|
||||
}
|
||||
|
||||
return manifest, &configLayer, nil
|
||||
}
|
||||
|
||||
// extractMCPMetadata extracts name and description from mcp.json or package.json.
|
||||
func extractMCPMetadata(mcpDir string) (name, description string) {
|
||||
// Try mcp.json first
|
||||
mcpJSONPath := filepath.Join(mcpDir, "mcp.json")
|
||||
if data, err := os.ReadFile(mcpJSONPath); err == nil {
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err == nil {
|
||||
if n, ok := config["name"].(string); ok {
|
||||
name = n
|
||||
}
|
||||
if d, ok := config["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
return name, description
|
||||
}
|
||||
}
|
||||
|
||||
// Try package.json
|
||||
pkgJSONPath := filepath.Join(mcpDir, "package.json")
|
||||
if data, err := os.ReadFile(pkgJSONPath); err == nil {
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err == nil {
|
||||
if n, ok := config["name"].(string); ok {
|
||||
name = n
|
||||
}
|
||||
if d, ok := config["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
return name, description
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// NewMCPCommand creates the mcp parent command with subcommands.
|
||||
func NewMCPCommand() *cobra.Command {
|
||||
mcpCmd := &cobra.Command{
|
||||
Use: "mcp",
|
||||
Short: "Manage MCP servers",
|
||||
Long: "Commands for managing MCP (Model Context Protocol) servers (add, push, pull, list, rm, show)",
|
||||
}
|
||||
|
||||
// Global config commands
|
||||
addCmd := &cobra.Command{
|
||||
Use: "add NAME COMMAND [ARGS...]",
|
||||
Short: "Add an MCP server to global config",
|
||||
Long: `Add an MCP server to the global config (~/.ollama/mcp.json).
|
||||
Global MCP servers are available to all agents.
|
||||
|
||||
Examples:
|
||||
ollama mcp add web-search uv run ./mcp-server.py
|
||||
ollama mcp add calculator python3 /path/to/calc.py`,
|
||||
Args: cobra.MinimumNArgs(2),
|
||||
RunE: MCPAddHandler,
|
||||
DisableFlagParsing: true, // Allow args with dashes
|
||||
}
|
||||
|
||||
removeGlobalCmd := &cobra.Command{
|
||||
Use: "remove-global NAME [NAME...]",
|
||||
Aliases: []string{"rm-global"},
|
||||
Short: "Remove an MCP server from global config",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPRemoveGlobalHandler,
|
||||
}
|
||||
|
||||
listGlobalCmd := &cobra.Command{
|
||||
Use: "list-global",
|
||||
Short: "List global MCP servers",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: MCPListGlobalHandler,
|
||||
}
|
||||
|
||||
// Registry commands
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push NAME[:TAG] PATH",
|
||||
Short: "Push an MCP server to a registry",
|
||||
Long: "Package a local MCP server directory and push it to a registry",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: MCPPushHandler,
|
||||
}
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull NAME[:TAG]",
|
||||
Short: "Pull an MCP server from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: MCPPullHandler,
|
||||
}
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List installed MCP servers (from registry)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: MCPListHandler,
|
||||
}
|
||||
|
||||
rmCmd := &cobra.Command{
|
||||
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
|
||||
Aliases: []string{"remove", "delete"},
|
||||
Short: "Remove an MCP server (from registry)",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPRemoveHandler,
|
||||
}
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show NAME[:TAG]",
|
||||
Short: "Show MCP server details",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: MCPShowHandler,
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable NAME [NAME...]",
|
||||
Short: "Disable an MCP server (keep in config)",
|
||||
Long: `Disable an MCP server without removing it from config.
|
||||
Disabled servers will not be started when running agents.
|
||||
Use 'ollama mcp enable' to re-enable.`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPDisableHandler,
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable NAME [NAME...]",
|
||||
Short: "Enable a disabled MCP server",
|
||||
Long: `Re-enable a previously disabled MCP server.`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPEnableHandler,
|
||||
}
|
||||
|
||||
mcpCmd.AddCommand(addCmd, removeGlobalCmd, listGlobalCmd, disableCmd, enableCmd, pushCmd, pullCmd, listCmd, rmCmd, showCmd)
|
||||
|
||||
return mcpCmd
|
||||
}
|
||||
570
cmd/skill_cmd.go
Normal file
570
cmd/skill_cmd.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillPushHandler handles the skill push command.
|
||||
func SkillPushHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
path := args[1]
|
||||
|
||||
// Expand path
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving path: %w", err)
|
||||
}
|
||||
|
||||
// Validate skill directory
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Parse skill name (will set Kind="skill")
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Create skill layer
|
||||
displayName := n.DisplayShortest()
|
||||
status := fmt.Sprintf("Creating skill layer for %s", displayName)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
layer, err := server.CreateSkillLayer(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill layer: %w", err)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
|
||||
// Create skill manifest
|
||||
manifest, configLayer, err := createSkillManifest(absPath, layer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill manifest: %w", err)
|
||||
}
|
||||
|
||||
// Write manifest locally
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
|
||||
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
|
||||
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
|
||||
|
||||
// Push to registry
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
// For now, we'll use the existing push mechanism
|
||||
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.PushRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Push(context.Background(), req, fn); err != nil {
|
||||
// If push fails, still show success for local creation
|
||||
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillPullHandler handles the skill pull command.
|
||||
func SkillPullHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
req := &api.PullRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Pull(context.Background(), req, fn); err != nil {
|
||||
return fmt.Errorf("pulling skill: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillListHandler handles the skill list command.
|
||||
func SkillListHandler(cmd *cobra.Command, args []string) error {
|
||||
skills, err := listLocalSkills()
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing skills: %w", err)
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
fmt.Println("No skills installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
|
||||
|
||||
for _, skill := range skills {
|
||||
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
|
||||
skill.Namespace,
|
||||
skill.Name,
|
||||
skill.Tag,
|
||||
format.HumanBytes(skill.Size),
|
||||
format.HumanTime(skill.ModifiedAt, "Never"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// SkillRemoveHandler handles the skill rm command.
|
||||
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(manifestPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up empty parent directories
|
||||
dir := filepath.Dir(manifestPath)
|
||||
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
if len(entries) == 0 {
|
||||
os.Remove(dir)
|
||||
dir = filepath.Dir(dir)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillShowHandler handles the skill show command.
|
||||
func SkillShowHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("skill not found: %s", displayName)
|
||||
}
|
||||
return fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Skill: %s\n\n", displayName)
|
||||
|
||||
fmt.Println("Layers:")
|
||||
for _, layer := range manifest.Layers {
|
||||
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
|
||||
}
|
||||
|
||||
// Try to read and display SKILL.md content
|
||||
if len(manifest.Layers) > 0 {
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == server.MediaTypeSkill {
|
||||
skillPath, err := server.GetSkillsPath(layer.Digest)
|
||||
if err == nil {
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
if content, err := os.ReadFile(skillMdPath); err == nil {
|
||||
fmt.Println("\nContent:")
|
||||
fmt.Println(string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillInfo represents information about an installed skill.
|
||||
type SkillInfo struct {
|
||||
Namespace string
|
||||
Name string
|
||||
Tag string
|
||||
Size int64
|
||||
ModifiedAt time.Time
|
||||
}
|
||||
|
||||
// listLocalSkills returns a list of locally installed skills.
|
||||
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
|
||||
// where kind is "skill".
|
||||
func listLocalSkills() ([]SkillInfo, error) {
|
||||
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
|
||||
|
||||
var skills []SkillInfo
|
||||
|
||||
// Walk through all registries
|
||||
registries, err := os.ReadDir(manifestsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return skills, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, registry := range registries {
|
||||
if !registry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk namespaces
|
||||
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
if !namespace.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk kinds looking for "skill"
|
||||
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kind := range kinds {
|
||||
if !kind.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process skill kind
|
||||
if kind.Name() != server.SkillNamespace {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk skill names (model names)
|
||||
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, skillName := range skillNames {
|
||||
if !skillName.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk tags
|
||||
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
|
||||
fi, err := os.Stat(manifestPath)
|
||||
if err != nil || fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read manifest to get size
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Layers {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
|
||||
// Build display name using model.Name
|
||||
n := model.Name{
|
||||
Host: registry.Name(),
|
||||
Namespace: namespace.Name(),
|
||||
Kind: kind.Name(),
|
||||
Model: skillName.Name(),
|
||||
Tag: tag.Name(),
|
||||
}
|
||||
|
||||
skills = append(skills, SkillInfo{
|
||||
Namespace: n.Namespace + "/" + n.Kind,
|
||||
Name: n.Model,
|
||||
Tag: n.Tag,
|
||||
Size: totalSize,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
// createSkillManifest creates a manifest for a standalone skill.
|
||||
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
|
||||
// Read SKILL.md to extract metadata
|
||||
skillMdPath := filepath.Join(skillDir, "SKILL.md")
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Extract name and description from frontmatter
|
||||
name, description := extractSkillMetadata(string(content))
|
||||
if name == "" {
|
||||
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
|
||||
}
|
||||
|
||||
// Create config
|
||||
config := map[string]any{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"architecture": "amd64",
|
||||
"os": "linux",
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer
|
||||
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating config layer: %w", err)
|
||||
}
|
||||
|
||||
manifest := &server.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: configLayer,
|
||||
Layers: []server.Layer{layer},
|
||||
}
|
||||
|
||||
return manifest, &configLayer, nil
|
||||
}
|
||||
|
||||
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
|
||||
func extractSkillMetadata(content string) (name, description string) {
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
inFrontmatter := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if trimmed == "---" {
|
||||
if !inFrontmatter {
|
||||
inFrontmatter = true
|
||||
continue
|
||||
} else {
|
||||
break // End of frontmatter
|
||||
}
|
||||
}
|
||||
|
||||
if inFrontmatter {
|
||||
if strings.HasPrefix(trimmed, "name:") {
|
||||
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
|
||||
} else if strings.HasPrefix(trimmed, "description:") {
|
||||
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return name, description
|
||||
}
|
||||
|
||||
// NewSkillCommand creates the skill parent command with subcommands.
|
||||
func NewSkillCommand() *cobra.Command {
|
||||
skillCmd := &cobra.Command{
|
||||
Use: "skill",
|
||||
Short: "Manage skills",
|
||||
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
|
||||
}
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push NAME[:TAG] PATH",
|
||||
Short: "Push a skill to a registry",
|
||||
Long: "Package a local skill directory and push it to a registry",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPushHandler,
|
||||
}
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull NAME[:TAG]",
|
||||
Short: "Pull a skill from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPullHandler,
|
||||
}
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List installed skills",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: SkillListHandler,
|
||||
}
|
||||
|
||||
rmCmd := &cobra.Command{
|
||||
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
|
||||
Aliases: []string{"remove", "delete"},
|
||||
Short: "Remove a skill",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: SkillRemoveHandler,
|
||||
}
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show NAME[:TAG]",
|
||||
Short: "Show skill details",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: SkillShowHandler,
|
||||
}
|
||||
|
||||
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
|
||||
|
||||
return skillCmd
|
||||
}
|
||||
589
cmd/skills.go
Normal file
589
cmd/skills.go
Normal file
@@ -0,0 +1,589 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/server"
|
||||
)
|
||||
|
||||
const (
|
||||
skillFileName = "SKILL.md"
|
||||
maxSkillDescription = 1024
|
||||
maxSkillNameLength = 64
|
||||
)
|
||||
|
||||
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
|
||||
type skillMetadata struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type skillDefinition struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string // Full SKILL.md content (without frontmatter)
|
||||
Dir string
|
||||
SkillPath string
|
||||
}
|
||||
|
||||
type skillCatalog struct {
|
||||
Skills []skillDefinition
|
||||
byName map[string]skillDefinition
|
||||
}
|
||||
|
||||
func loadSkills(paths []string) (*skillCatalog, error) {
|
||||
if len(paths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
for _, root := range paths {
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skills directory %q: %w", root, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("skills path %q is not a directory", root)
|
||||
}
|
||||
|
||||
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
|
||||
// Skills can be referenced by:
|
||||
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
|
||||
// - Name (local path): loaded from the filesystem (for development)
|
||||
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
|
||||
if len(refs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
|
||||
for _, ref := range refs {
|
||||
var skillDir string
|
||||
|
||||
if ref.Digest != "" {
|
||||
// Load from extracted skill cache
|
||||
path, err := server.GetSkillsPath(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
|
||||
}
|
||||
|
||||
// Check if skill is already extracted
|
||||
skillMdPath := filepath.Join(path, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
|
||||
// Try to extract the skill blob
|
||||
path, err = server.ExtractSkillBlob(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
|
||||
}
|
||||
}
|
||||
|
||||
skillDir = path
|
||||
} else if ref.Name != "" {
|
||||
// Check if this is a local path or a registry reference
|
||||
if !server.IsLocalSkillPath(ref.Name) {
|
||||
// Registry reference without a digest - skill needs to be pulled first
|
||||
// This happens when an agent references a skill that hasn't been bundled
|
||||
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
|
||||
}
|
||||
|
||||
// Local path - resolve it
|
||||
skillPath := ref.Name
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
// Check if this is a directory containing skills or a single skill
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
|
||||
skillMdPath := filepath.Join(absPath, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); err == nil {
|
||||
// Direct skill directory
|
||||
skillDir = absPath
|
||||
} else {
|
||||
// Parent directory - walk to find skill subdirectories
|
||||
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillSubDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillSubDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
|
||||
}
|
||||
} else {
|
||||
// Both empty - skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the skill from skillDir if set
|
||||
if skillDir != "" {
|
||||
skillMdPath := filepath.Join(skillDir, skillFileName)
|
||||
skill, err := parseSkillFile(skillMdPath, skillDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
|
||||
rawContent, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
var meta skillMetadata
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
|
||||
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
|
||||
}
|
||||
|
||||
if err := validateSkillMetadata(meta, skillDir); err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
absDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
return skillDefinition{
|
||||
Name: meta.Name,
|
||||
Description: meta.Description,
|
||||
Content: bodyContent,
|
||||
Dir: absDir,
|
||||
SkillPath: absPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
if !scanner.Scan() {
|
||||
return "", "", errors.New("empty SKILL.md")
|
||||
}
|
||||
if strings.TrimSpace(scanner.Text()) != "---" {
|
||||
return "", "", errors.New("missing YAML frontmatter")
|
||||
}
|
||||
|
||||
var fmLines []string
|
||||
foundEnd := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "---" {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
fmLines = append(fmLines, line)
|
||||
}
|
||||
if !foundEnd {
|
||||
return "", "", errors.New("frontmatter not terminated")
|
||||
}
|
||||
|
||||
// Collect remaining content as body
|
||||
var bodyLines []string
|
||||
for scanner.Scan() {
|
||||
bodyLines = append(bodyLines, scanner.Text())
|
||||
}
|
||||
|
||||
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
|
||||
}
|
||||
|
||||
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
|
||||
name := strings.TrimSpace(meta.Name)
|
||||
description := strings.TrimSpace(meta.Description)
|
||||
|
||||
switch {
|
||||
case name == "":
|
||||
return errors.New("missing skill name")
|
||||
case len(name) > maxSkillNameLength:
|
||||
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
|
||||
case !skillNamePattern.MatchString(name):
|
||||
return fmt.Errorf("invalid skill name %q", name)
|
||||
}
|
||||
|
||||
if description == "" {
|
||||
return errors.New("missing skill description")
|
||||
}
|
||||
if len(description) > maxSkillDescription {
|
||||
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
|
||||
}
|
||||
|
||||
// Skip directory name check for digest-based paths (extracted from blobs)
|
||||
dirName := filepath.Base(skillDir)
|
||||
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
|
||||
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *skillCatalog) SystemPrompt() string {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("# Skills\n\n")
|
||||
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
|
||||
b.WriteString("## Available Tools\n\n")
|
||||
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
|
||||
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
|
||||
|
||||
for _, skill := range c.Skills {
|
||||
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
|
||||
fmt.Fprintf(&b, "%s\n\n", skill.Content)
|
||||
b.WriteString("---\n\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (c *skillCatalog) Tools() api.Tools {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "run_skill_script",
|
||||
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "command"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"skill": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the script",
|
||||
},
|
||||
"command": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "read_skill_file",
|
||||
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "path"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"skill": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the file",
|
||||
},
|
||||
"path": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The relative path to the file within the skill directory",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
||||
switch call.Function.Name {
|
||||
case "read_skill_file":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
relPath, err := requireStringArg(call.Function.Arguments, "path")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
content, err := readSkillFile(skill.Dir, relPath)
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
return toolMessage(call, content), true, nil
|
||||
|
||||
case "run_skill_script":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
command, err := requireStringArg(call.Function.Arguments, "command")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
output, err := runSkillScript(skill.Dir, command)
|
||||
if err != nil {
|
||||
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
|
||||
}
|
||||
return toolMessage(call, output), true, nil
|
||||
|
||||
default:
|
||||
return api.Message{}, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// runSkillScript executes a shell command within a skill's directory.
|
||||
//
|
||||
// SECURITY LIMITATIONS (TODO):
|
||||
// - No sandboxing: commands run with full user permissions
|
||||
// - No path validation: model can run any command, not just scripts in skill dir
|
||||
// - Shell injection risk: sh -c is used, malicious input could be crafted
|
||||
// - No executable allowlist: any program can be called (curl, rm, etc.)
|
||||
// - No environment isolation: scripts inherit full environment variables
|
||||
//
|
||||
// POTENTIAL IMPROVEMENTS:
|
||||
// - Restrict commands to only reference files within skill directory
|
||||
// - Allowlist specific executables (python3, node, bash)
|
||||
// - Use sandboxing (Docker, nsjail, seccomp)
|
||||
// - Require explicit script registration in SKILL.md frontmatter
|
||||
// - Add per-skill configurable timeouts
|
||||
func runSkillScript(skillDir, command string) (string, error) {
|
||||
// Validate the skill directory exists
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(absSkillDir); err != nil {
|
||||
return "", fmt.Errorf("skill directory not found: %w", err)
|
||||
}
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", command)
|
||||
cmd.Dir = absSkillDir
|
||||
|
||||
// Inject the current working directory (where ollama run was called from)
|
||||
// as an environment variable so scripts can reference files in that directory
|
||||
workingDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
|
||||
|
||||
// Capture both stdout and stderr
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
// Combine output
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
if output != "" {
|
||||
output += "\n"
|
||||
}
|
||||
output += stderr.String()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return output, fmt.Errorf("command timed out after 30 seconds")
|
||||
}
|
||||
return output, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func readSkillFile(skillDir, relPath string) (string, error) {
|
||||
relPath = filepath.Clean(strings.TrimSpace(relPath))
|
||||
if relPath == "" {
|
||||
return "", errors.New("path is required")
|
||||
}
|
||||
if filepath.IsAbs(relPath) {
|
||||
return "", errors.New("path must be relative to the skill directory")
|
||||
}
|
||||
|
||||
target := filepath.Join(skillDir, relPath)
|
||||
absTarget, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel, err := filepath.Rel(absSkillDir, absTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") {
|
||||
return "", errors.New("path escapes the skill directory")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absTarget)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
|
||||
value, ok := args[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required argument %q", name)
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("argument %q must be a string", name)
|
||||
}
|
||||
if strings.TrimSpace(str) == "" {
|
||||
return "", fmt.Errorf("argument %q cannot be empty", name)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func toolMessage(call api.ToolCall, content string) api.Message {
|
||||
msg := api.Message{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolName: call.Function.Name,
|
||||
}
|
||||
if call.ID != "" {
|
||||
msg.ToolCallID = call.ID
|
||||
}
|
||||
return msg
|
||||
}
|
||||
211
docs/ENTRYPOINT_FEATURE.md
Normal file
211
docs/ENTRYPOINT_FEATURE.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# ENTRYPOINT Feature for Ollama Agents
|
||||
|
||||
## Overview
|
||||
|
||||
The ENTRYPOINT command allows agents to specify an external program to run instead of the built-in Ollama chat loop. This makes Ollama a packaging/distribution mechanism for agents with custom runtimes.
|
||||
|
||||
## Status: Implemented ✓
|
||||
|
||||
## What Was Done
|
||||
|
||||
### 1. Types & API
|
||||
|
||||
**`types/model/config.go`**
|
||||
- Added `Entrypoint string` field to `ConfigV2` struct
|
||||
|
||||
**`api/types.go`**
|
||||
- Added `Entrypoint string` to `CreateRequest` (line ~576)
|
||||
- Added `Entrypoint string` to `ShowResponse` (line ~632)
|
||||
|
||||
### 2. Parser
|
||||
|
||||
**`parser/parser.go`**
|
||||
- Added "entrypoint" to `isValidCommand()` switch
|
||||
- Added case in `CreateRequest()` to set `req.Entrypoint = c.Args`
|
||||
- Updated `ParseFile()` to allow ENTRYPOINT without FROM (entrypoint-only agents)
|
||||
- Added entrypoint serialization in `Command.String()`
|
||||
|
||||
### 3. Server
|
||||
|
||||
**`server/create.go`**
|
||||
- Added `config.Entrypoint = r.Entrypoint` to store entrypoint in config
|
||||
- Made FROM optional when ENTRYPOINT is specified:
|
||||
```go
|
||||
} else if r.Entrypoint != "" {
|
||||
// Entrypoint-only agent: no base model needed
|
||||
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
|
||||
}
|
||||
```
|
||||
|
||||
**`server/routes.go`**
|
||||
- Added `Entrypoint: m.Config.Entrypoint` to ShowResponse in `GetModelInfo()`
|
||||
|
||||
**`server/images.go`**
|
||||
- Added entrypoint serialization in `Model.String()`:
|
||||
```go
|
||||
if m.Config.Entrypoint != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "entrypoint",
|
||||
Args: m.Config.Entrypoint,
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### 4. CLI
|
||||
|
||||
**`cmd/cmd.go`**
|
||||
- Added `Entrypoint string` to `runOptions` struct
|
||||
- Updated agent detection to include Entrypoint check
|
||||
- Added entrypoint check before interactive mode:
|
||||
```go
|
||||
if opts.Entrypoint != "" {
|
||||
return runEntrypoint(cmd, opts)
|
||||
}
|
||||
```
|
||||
- Implemented `runEntrypoint()` function:
|
||||
- Parses entrypoint into command and args
|
||||
- Appends user prompt as additional argument if provided
|
||||
- Looks up command in PATH
|
||||
- Creates subprocess with stdin/stdout/stderr connected
|
||||
- Runs and waits for completion
|
||||
- Updated `showInfo()` to display entrypoint in Agent section
|
||||
- Updated `showInfo()` to hide Model section for entrypoint-only agents (no blank fields)
|
||||
- Added `$PROMPT` placeholder support in `runEntrypoint()`:
|
||||
- If entrypoint contains `$PROMPT`, it's replaced with the user's prompt
|
||||
- If no placeholder, prompt is appended as positional argument (backwards compatible)
|
||||
- If no prompt provided, `$PROMPT` is removed from the command
|
||||
|
||||
## Usage
|
||||
|
||||
### Agentfile
|
||||
```dockerfile
|
||||
# Minimal entrypoint agent (no model required)
|
||||
ENTRYPOINT ducky
|
||||
|
||||
# Or with full path
|
||||
ENTRYPOINT /usr/local/bin/ducky
|
||||
|
||||
# Or with arguments
|
||||
ENTRYPOINT ducky --verbose
|
||||
|
||||
# Use $PROMPT placeholder to control where prompt is inserted
|
||||
ENTRYPOINT ducky -p $PROMPT
|
||||
|
||||
# Without placeholder, prompt is appended as positional argument
|
||||
ENTRYPOINT echo "Hello" # becomes: echo "Hello" <prompt>
|
||||
|
||||
# Can still bundle skills/MCPs with entrypoint agents
|
||||
SKILL ./my-skill
|
||||
MCP calculator python3 ./calc.py
|
||||
ENTRYPOINT my-custom-runtime
|
||||
```
|
||||
|
||||
### CLI
|
||||
```bash
|
||||
# Create the agent
|
||||
ollama create ducky -f ducky.Agentfile
|
||||
|
||||
# Run it - starts the entrypoint (e.g., REPL)
|
||||
ollama run ducky
|
||||
|
||||
# With prompt (passed as argument to entrypoint)
|
||||
ollama run ducky "hello"
|
||||
|
||||
# Show agent info
|
||||
ollama show ducky
|
||||
# Agent
|
||||
# entrypoint ducky
|
||||
```
|
||||
|
||||
## Testing Done
|
||||
|
||||
1. **Basic entrypoint execution**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Hello from entrypoint"
|
||||
ollama run test-entry # Output: "Hello from entrypoint"
|
||||
```
|
||||
|
||||
2. **Prompt passing (positional)**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Args:"
|
||||
ollama run echo-test "hello world" # Output: "Args:" hello world
|
||||
```
|
||||
|
||||
3. **Prompt passing ($PROMPT placeholder)**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Prompt was:" $PROMPT "end"
|
||||
ollama run echo-placeholder "hello world" # Output: "Prompt was:" hello world "end"
|
||||
ollama run echo-placeholder # Output: "Prompt was:" "end"
|
||||
```
|
||||
|
||||
4. **Show command**: ✓
|
||||
```bash
|
||||
ollama show ducky
|
||||
# Agent
|
||||
# entrypoint ducky
|
||||
# (Model section hidden for entrypoint-only agents)
|
||||
```
|
||||
|
||||
5. **List command**: ✓
|
||||
- Entrypoint-only agents show with small sizes (~200 bytes)
|
||||
|
||||
## Left Over / Future Enhancements
|
||||
|
||||
### 1. Context Passing via Environment Variables
|
||||
Pass agent context to entrypoint via env vars:
|
||||
- `OLLAMA_AGENT_NAME` - Name of the agent
|
||||
- `OLLAMA_SKILLS_PATH` - Path to bundled skills
|
||||
- `OLLAMA_MCPS` - JSON of MCP configurations
|
||||
|
||||
### ~~2. Arguments Placeholder~~ ✓ DONE
|
||||
~~Support placeholder syntax for more control:~~
|
||||
```dockerfile
|
||||
# Now supported!
|
||||
ENTRYPOINT ducky -p $PROMPT
|
||||
```
|
||||
|
||||
### 3. Working Directory
|
||||
Set working directory for entrypoint:
|
||||
```dockerfile
|
||||
WORKDIR /app
|
||||
ENTRYPOINT ./run.sh
|
||||
```
|
||||
|
||||
### 4. Interactive Mode Detection
|
||||
Different behavior for REPL vs single-shot:
|
||||
- Detect if stdin is a TTY
|
||||
- Pass different flags based on mode
|
||||
|
||||
### 5. Signal Handling
|
||||
Improved signal forwarding to subprocess:
|
||||
- Forward SIGINT, SIGTERM gracefully
|
||||
- Handle cleanup on parent exit
|
||||
|
||||
### 6. Entrypoint with Model
|
||||
Allow both model and entrypoint:
|
||||
```dockerfile
|
||||
FROM llama3.2
|
||||
ENTRYPOINT my-custom-ui
|
||||
```
|
||||
The entrypoint could then use the model via Ollama API.
|
||||
|
||||
### 7. Pull/Push for Entrypoint Agents
|
||||
- Currently entrypoint agents can be created locally
|
||||
- Need to test/verify push/pull to registry works correctly
|
||||
- May need to handle entrypoint binaries (or just reference system commands)
|
||||
|
||||
### 8. Error Handling
|
||||
- Better error messages when entrypoint command not found
|
||||
- Validation of entrypoint during create (optional, warn if not found)
|
||||
|
||||
## Design Decisions
|
||||
|
||||
1. **Subprocess mode (not exec)**: Ollama stays as parent process to handle signals and cleanup
|
||||
|
||||
2. **No context passing initially**: Keep it simple, entrypoint handles its own config
|
||||
|
||||
3. **Skills/MCPs allowed**: Enables packaging assets with the agent even if entrypoint manages execution
|
||||
|
||||
4. **FROM optional**: Entrypoint agents don't need a model, just the runtime
|
||||
|
||||
5. **Prompt as argument**: User prompt is appended as argument to entrypoint command (simplest approach)
|
||||
332
docs/agent-skills-changes.md
Normal file
332
docs/agent-skills-changes.md
Normal file
@@ -0,0 +1,332 @@
|
||||
# Agent Skills Feature - Implementation Summary
|
||||
|
||||
This document summarizes all changes made to implement agent skills in Ollama, enabling `ollama run <agent>` with skill-based capabilities.
|
||||
|
||||
## Overview
|
||||
|
||||
Agents are models with attached skills. Skills are directories containing a `SKILL.md` file with instructions and optional executable scripts. When an agent runs, skills are loaded and injected into the system prompt, and the model can execute scripts via tool calls.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### 1. `cmd/skills.go` (NEW FILE)
|
||||
|
||||
Core skills implementation:
|
||||
|
||||
```go
|
||||
// Key types
|
||||
type skillMetadata struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type skillDefinition struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string // SKILL.md body content
|
||||
Dir string // Absolute path to skill directory
|
||||
SkillPath string // Absolute path to SKILL.md
|
||||
}
|
||||
|
||||
type skillCatalog struct {
|
||||
Skills []skillDefinition
|
||||
byName map[string]skillDefinition
|
||||
}
|
||||
```
|
||||
|
||||
**Key functions:**
|
||||
- `loadSkills(paths []string)` - Walks skill directories, parses SKILL.md files
|
||||
- `parseSkillFile(path, skillDir)` - Extracts YAML frontmatter and body content
|
||||
- `SystemPrompt()` - Generates system prompt with skill instructions
|
||||
- `Tools()` - Returns `run_skill_script` and `read_skill_file` tools
|
||||
- `RunToolCall(call)` - Executes tool calls from the model
|
||||
- `runSkillScript(skillDir, command)` - Executes shell commands in skill directory
|
||||
|
||||
**Tools provided to model:**
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `run_skill_script` | Execute a script in a skill's directory |
|
||||
| `read_skill_file` | Read a file from a skill's directory |
|
||||
|
||||
**Security note:** `runSkillScript` has documented limitations (no sandboxing, no path validation). See the function's doc comment for details.
|
||||
|
||||
---
|
||||
|
||||
### 2. `cmd/cmd.go`
|
||||
|
||||
**Changes to `runOptions` struct:**
|
||||
```go
|
||||
type runOptions struct {
|
||||
// ... existing fields ...
|
||||
IsAgent bool
|
||||
AgentType string
|
||||
Skills []string
|
||||
}
|
||||
```
|
||||
|
||||
**Agent detection in `RunHandler`** (~line 497-503):
|
||||
```go
|
||||
// Check if this is an agent
|
||||
isAgent := info.AgentType != "" || len(info.Skills) > 0
|
||||
if isAgent {
|
||||
opts.IsAgent = true
|
||||
opts.AgentType = info.AgentType
|
||||
opts.Skills = info.Skills
|
||||
}
|
||||
```
|
||||
|
||||
**Route agents to chat API** (~line 557-562):
|
||||
```go
|
||||
// For agents, use chat API even in non-interactive mode to support tools
|
||||
if opts.IsAgent {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
|
||||
_, err := chat(cmd, opts)
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
**Skills loading in `chat` function** (~line 1347-1361):
|
||||
```go
|
||||
var skillsCatalog *skillCatalog
|
||||
if opts.IsAgent && len(opts.Skills) > 0 {
|
||||
skillsCatalog, err = loadSkills(opts.Skills)
|
||||
// ... error handling ...
|
||||
// Print loaded skills
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
|
||||
}
|
||||
```
|
||||
|
||||
**System prompt injection** (~line 1448-1455):
|
||||
- Skills system prompt is prepended to messages
|
||||
|
||||
**Tool execution** (~line 1497-1533):
|
||||
- Executes pending tool calls via `skillsCatalog.RunToolCall()`
|
||||
- Displays script execution and output to terminal
|
||||
|
||||
---
|
||||
|
||||
### 3. `parser/parser.go`
|
||||
|
||||
**New valid commands** in `isValidCommand()`:
|
||||
```go
|
||||
case "from", "license", "template", "system", "adapter", "renderer",
|
||||
"parser", "parameter", "message", "requires", "skill", "agent_type":
|
||||
```
|
||||
|
||||
**Command handling in `CreateRequest()`**:
|
||||
```go
|
||||
case "skill":
|
||||
skills = append(skills, c.Args)
|
||||
case "agent_type":
|
||||
req.AgentType = c.Args
|
||||
```
|
||||
|
||||
**Underscore support in command names** (~line 545):
|
||||
```go
|
||||
case isAlpha(r), r == '_':
|
||||
return stateName, r, nil
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. `api/types.go`
|
||||
|
||||
**CreateRequest additions** (~line 560-564):
|
||||
```go
|
||||
// Skills is a list of skill directories for the agent
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
|
||||
// AgentType defines the type of agent (e.g., "conversational", "task-based")
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
```
|
||||
|
||||
**ShowResponse additions** (~line 633-637):
|
||||
```go
|
||||
// Skills loaded for this agent
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
|
||||
// AgentType for this agent
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. `types/model/config.go`
|
||||
|
||||
**ConfigV2 additions**:
|
||||
```go
|
||||
type ConfigV2 struct {
|
||||
// ... existing fields ...
|
||||
|
||||
// Agent-specific fields
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 6. `server/create.go`
|
||||
|
||||
**Store agent fields** (~line 65-66):
|
||||
```go
|
||||
config.Skills = r.Skills
|
||||
config.AgentType = r.AgentType
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 7. `server/routes.go`
|
||||
|
||||
**Return agent fields in ShowResponse** (~line 1107):
|
||||
```go
|
||||
resp := &api.ShowResponse{
|
||||
// ... existing fields ...
|
||||
Skills: m.Config.Skills,
|
||||
AgentType: m.Config.AgentType,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 8. `envconfig/config.go`
|
||||
|
||||
**Environment variable support**:
|
||||
```go
|
||||
func Skills() []string {
|
||||
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Agentfile Format
|
||||
|
||||
Agentfiles use the same syntax as Modelfiles with additional commands:
|
||||
|
||||
```dockerfile
|
||||
FROM gpt-oss:20b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
SKILL /path/to/skills/directory
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
|
||||
PARAMETER temperature 0.3
|
||||
PARAMETER top_p 0.9
|
||||
```
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `SKILL` | Path to a directory containing skill subdirectories |
|
||||
| `AGENT_TYPE` | Type of agent (e.g., "conversational") |
|
||||
|
||||
---
|
||||
|
||||
## SKILL.md Format
|
||||
|
||||
Each skill is a directory with a `SKILL.md` file:
|
||||
|
||||
```
|
||||
calculator-skill/
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
└── calculate.py
|
||||
```
|
||||
|
||||
**SKILL.md structure:**
|
||||
```markdown
|
||||
---
|
||||
name: calculator-skill
|
||||
description: A skill for performing calculations.
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Use `run_skill_script` to execute calculations
|
||||
2. Call: `python3 scripts/calculate.py '<expression>'`
|
||||
|
||||
## Examples
|
||||
|
||||
For "What is 25 * 4?":
|
||||
- Call: run_skill_script with skill="calculator-skill" and command="python3 scripts/calculate.py '25 * 4'"
|
||||
```
|
||||
|
||||
**Requirements:**
|
||||
- `name` must match directory name
|
||||
- `name` must be lowercase alphanumeric with hyphens only
|
||||
- `name` max 64 characters
|
||||
- `description` required, max 1024 characters
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Create an agent
|
||||
ollama create math-agent -f math-agent.Agentfile
|
||||
|
||||
# Run the agent
|
||||
ollama run math-agent "What is 25 * 4?"
|
||||
|
||||
# Output:
|
||||
# Loaded skills: calculator-skill
|
||||
# Running script in calculator-skill: python3 scripts/calculate.py '25 * 4'
|
||||
# Output:
|
||||
# 25 * 4 = 100
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Flow Diagram
|
||||
|
||||
```
|
||||
1. ollama run math-agent "query"
|
||||
│
|
||||
▼
|
||||
2. RunHandler detects agent (AgentType or Skills present)
|
||||
│
|
||||
▼
|
||||
3. Routes to chat() instead of generate()
|
||||
│
|
||||
▼
|
||||
4. loadSkills() parses SKILL.md files
|
||||
│
|
||||
▼
|
||||
5. SystemPrompt() injects skill instructions
|
||||
│
|
||||
▼
|
||||
6. Tools() provides run_skill_script, read_skill_file
|
||||
│
|
||||
▼
|
||||
7. Model generates response (may include tool calls)
|
||||
│
|
||||
▼
|
||||
8. RunToolCall() executes scripts, returns output
|
||||
│
|
||||
▼
|
||||
9. Display results to user
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
The `runSkillScript` function has known limitations documented in the code:
|
||||
|
||||
- No sandboxing (commands run with user permissions)
|
||||
- No path validation (model can run any command)
|
||||
- Shell injection risk (`sh -c` is used)
|
||||
- No executable allowlist
|
||||
- No environment isolation
|
||||
|
||||
**Potential improvements** (documented as TODOs):
|
||||
- Restrict to skill directory paths only
|
||||
- Allowlist executables (python3, node, bash)
|
||||
- Use sandboxing (Docker, nsjail, seccomp)
|
||||
- Require explicit script registration in SKILL.md
|
||||
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu.md).
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
|
||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||
|
||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
||||
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||
|
||||
### GPU Selection
|
||||
|
||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
265
docs/mcp-integration.md
Normal file
265
docs/mcp-integration.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# MCP (Model Context Protocol) Integration
|
||||
|
||||
This document describes the MCP integration for Ollama agents, enabling agents to use external tools via the Model Context Protocol.
|
||||
|
||||
## Overview
|
||||
|
||||
MCP allows Ollama agents to communicate with external tool servers over JSON-RPC 2.0 via stdio. This enables agents to access capabilities like web search, file operations, databases, and more through standardized tool interfaces.
|
||||
|
||||
## Status
|
||||
|
||||
| Phase | Description | Status |
|
||||
|-------|-------------|--------|
|
||||
| Phase 1 | Types & Parser | ✅ Complete |
|
||||
| Phase 2 | Layer Handling | ✅ Complete |
|
||||
| Phase 3 | Runtime Manager | ✅ Complete |
|
||||
| Phase 4 | CLI Commands | ✅ Complete |
|
||||
|
||||
## Agentfile Syntax
|
||||
|
||||
### Simple Command Format
|
||||
```dockerfile
|
||||
MCP <name> <command> [args...]
|
||||
```
|
||||
|
||||
Example:
|
||||
```dockerfile
|
||||
FROM llama3.2
|
||||
AGENT TYPE conversational
|
||||
SYSTEM You are a helpful assistant with MCP tools.
|
||||
MCP calculator python3 ./mcp-server.py
|
||||
MCP websearch node ./search-server.js
|
||||
```
|
||||
|
||||
### JSON Format
|
||||
```dockerfile
|
||||
MCP {"name": "custom", "command": "uv", "args": ["run", "server.py"], "env": {"API_KEY": "xxx"}}
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Type Definitions
|
||||
|
||||
**MCPRef** (`types/model/config.go`):
|
||||
```go
|
||||
type MCPRef struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "stdio"
|
||||
}
|
||||
```
|
||||
|
||||
### Tool Namespacing
|
||||
|
||||
MCP tools are namespaced to avoid conflicts:
|
||||
- Format: `mcp_{serverName}_{toolName}`
|
||||
- Example: Server "calculator" with tool "add" → `mcp_calculator_add`
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
1. Agent starts → MCP servers spawn as subprocesses
|
||||
2. Initialize via JSON-RPC: `initialize` → `notifications/initialized`
|
||||
3. Discover tools: `tools/list`
|
||||
4. During chat, model calls tools → routed via `tools/call`
|
||||
5. On shutdown, MCP servers are gracefully terminated
|
||||
|
||||
## Files
|
||||
|
||||
### Created
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `cmd/mcp.go` | Runtime MCP manager with JSON-RPC protocol |
|
||||
| `cmd/mcp_cmd.go` | CLI commands for managing MCPs (push, pull, list, etc.) |
|
||||
| `server/mcp.go` | MCP layer utilities (extraction, creation) |
|
||||
|
||||
### Modified
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `types/model/config.go` | Added `MCPRef` type, `MCPs` field to `ConfigV2` |
|
||||
| `types/model/name.go` | Added `"mcp"` to `ValidKinds` for 5-part name parsing |
|
||||
| `api/types.go` | Added `MCPRef` alias, `MCPs` to `CreateRequest`/`ShowResponse` |
|
||||
| `parser/parser.go` | Added `MCP` command parsing with JSON and simple formats |
|
||||
| `server/create.go` | Added `setMCPLayers()` for MCP config handling |
|
||||
| `server/routes.go` | Added `MCPs` to show response |
|
||||
| `cmd/cmd.go` | MCP integration in `chat()` function |
|
||||
| `cmd/interactive.go` | Added `/mcp` and `/mcps` REPL commands |
|
||||
|
||||
## Usage Example
|
||||
|
||||
### 1. Create an MCP Server
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
# mcp-server.py
|
||||
import json
|
||||
import sys
|
||||
|
||||
def handle_request(req):
|
||||
method = req.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "example", "version": "1.0"}
|
||||
}
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"tools": [{
|
||||
"name": "add",
|
||||
"description": "Adds two numbers",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
args = req["params"]["arguments"]
|
||||
return {"content": [{"type": "text", "text": f"{args['a'] + args['b']}"}]}
|
||||
return {}
|
||||
|
||||
for line in sys.stdin:
|
||||
req = json.loads(line)
|
||||
if "id" in req:
|
||||
result = handle_request(req)
|
||||
print(json.dumps({"jsonrpc": "2.0", "id": req["id"], "result": result}), flush=True)
|
||||
```
|
||||
|
||||
### 2. Create an Agent
|
||||
|
||||
```dockerfile
|
||||
# my-agent.Agentfile
|
||||
FROM gpt-oss:20b
|
||||
AGENT TYPE conversational
|
||||
SYSTEM You have access to a calculator. Use the add tool when asked to add numbers.
|
||||
MCP calculator python3 ./mcp-server.py
|
||||
```
|
||||
|
||||
### 3. Build and Run
|
||||
|
||||
```bash
|
||||
ollama create my-agent -f my-agent.Agentfile
|
||||
ollama run my-agent "What is 15 + 27?"
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
Loaded MCP servers: calculator (1 tools)
|
||||
Executing: mcp_calculator_add
|
||||
Output: 42
|
||||
The result is 42.
|
||||
```
|
||||
|
||||
## CLI Commands
|
||||
|
||||
The `ollama mcp` command provides utilities for managing MCP servers:
|
||||
|
||||
### Global Config Commands
|
||||
|
||||
Add an MCP server to the global config (`~/.ollama/mcp.json`):
|
||||
```bash
|
||||
# Add MCP to global config (available to all agents)
|
||||
ollama mcp add web-search uv run ./mcp-server.py
|
||||
ollama mcp add calculator python3 /path/to/calc.py
|
||||
|
||||
# List global MCP servers (shows enabled/disabled status)
|
||||
ollama mcp list-global
|
||||
|
||||
# Disable an MCP server (keeps in config but won't be loaded)
|
||||
ollama mcp disable web-search
|
||||
|
||||
# Re-enable a disabled MCP server
|
||||
ollama mcp enable web-search
|
||||
|
||||
# Remove from global config
|
||||
ollama mcp remove-global web-search
|
||||
```
|
||||
|
||||
### Registry Commands
|
||||
|
||||
Package and push MCPs to a registry:
|
||||
```bash
|
||||
# Push MCP to registry (creates locally first)
|
||||
ollama mcp push mcp/websearch:1.0 ./my-mcp-server/
|
||||
|
||||
# Pull MCP from registry
|
||||
ollama mcp pull mcp/websearch:1.0
|
||||
|
||||
# List installed MCPs (from registry)
|
||||
ollama mcp list
|
||||
|
||||
# Show MCP details
|
||||
ollama mcp show mcp/websearch:1.0
|
||||
|
||||
# Remove MCP
|
||||
ollama mcp rm mcp/websearch:1.0
|
||||
```
|
||||
|
||||
## REPL Commands
|
||||
|
||||
Inside `ollama run`, you can manage MCP servers dynamically:
|
||||
|
||||
```
|
||||
>>> /mcp # Show all MCP servers (model + global)
|
||||
>>> /mcp add calc python3 ./calc-server.py # Add MCP server to global config
|
||||
>>> /mcp remove calc # Remove MCP server from global config
|
||||
>>> /mcp disable calc # Disable an MCP server (keep in config)
|
||||
>>> /mcp enable calc # Re-enable a disabled MCP server
|
||||
>>> /? mcp # Get help for MCP commands
|
||||
```
|
||||
|
||||
The `/mcp` command shows all available MCP servers (both bundled with the model and from global config). Disabled servers are shown with a `[disabled]` marker. Use `/mcp add` and `/mcp remove` to manage MCPs in `~/.ollama/mcp.json`. Changes take effect on the next message.
|
||||
|
||||
## Global Config
|
||||
|
||||
MCPs can be configured globally in `~/.ollama/mcp.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"web-search": {
|
||||
"type": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "./mcp-server.py"]
|
||||
},
|
||||
"calculator": {
|
||||
"type": "stdio",
|
||||
"command": "python3",
|
||||
"args": ["/path/to/calc.py"],
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `disabled` field is optional. When set to `true`, the MCP server will not be loaded when running agents.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Remote Registry Push/Pull**: Full support for pushing/pulling MCPs to/from remote registries
|
||||
2. **Use go-sdk**: Consider using `github.com/modelcontextprotocol/go-sdk` for protocol handling
|
||||
3. **Resource Support**: Add MCP resources (not just tools)
|
||||
4. **Prompt Support**: Add MCP prompts
|
||||
|
||||
## Protocol Reference
|
||||
|
||||
MCP uses JSON-RPC 2.0 over stdio with these key methods:
|
||||
|
||||
| Method | Direction | Purpose |
|
||||
|--------|-----------|---------|
|
||||
| `initialize` | Client→Server | Handshake with capabilities |
|
||||
| `notifications/initialized` | Client→Server | Confirm initialization |
|
||||
| `tools/list` | Client→Server | Discover available tools |
|
||||
| `tools/call` | Client→Server | Execute a tool |
|
||||
|
||||
See [MCP Specification](https://modelcontextprotocol.io/docs) for full details.
|
||||
362
docs/skill-registry-design.md
Normal file
362
docs/skill-registry-design.md
Normal file
@@ -0,0 +1,362 @@
|
||||
# Skill Registry Design
|
||||
|
||||
## Overview
|
||||
|
||||
Skills are distributable capability packages for Ollama agents. They can be:
|
||||
- Bundled with agents at creation time (local paths)
|
||||
- Pulled from the registry (skill references)
|
||||
- Pushed to the registry for sharing
|
||||
|
||||
## User Experience
|
||||
|
||||
### Push a Skill
|
||||
|
||||
```bash
|
||||
# Push a local skill directory to the registry
|
||||
ollama skill push myname/calculator:1.0.0 ./skills/calculator-skill
|
||||
|
||||
# Output:
|
||||
# Creating skill layer for skill/myname/calculator:1.0.0
|
||||
# pushing sha256:abc123... 1.2KB
|
||||
# pushing sha256:def456... 220B
|
||||
# pushing manifest
|
||||
# Successfully pushed skill/myname/calculator:1.0.0
|
||||
```
|
||||
|
||||
### Pull a Skill
|
||||
|
||||
```bash
|
||||
# Pull a skill from the registry
|
||||
ollama skill pull calculator:1.0.0
|
||||
|
||||
# Output:
|
||||
# pulling manifest
|
||||
# pulling sha256:abc123... 1.2KB
|
||||
# extracting skill...
|
||||
# Successfully pulled skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
### List Installed Skills
|
||||
|
||||
```bash
|
||||
ollama skill list
|
||||
|
||||
# Output:
|
||||
# NAME TAG SIZE MODIFIED
|
||||
# skill/calculator 1.0.0 1.2 KB 2 hours ago
|
||||
# skill/myname/hello latest 0.8 KB 1 day ago
|
||||
```
|
||||
|
||||
### Remove a Skill
|
||||
|
||||
```bash
|
||||
ollama skill rm calculator:1.0.0
|
||||
# Deleted 'skill/calculator:1.0.0'
|
||||
```
|
||||
|
||||
### Use Skills in Agentfile
|
||||
|
||||
```dockerfile
|
||||
FROM llama3.2:3b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
SKILL skill/calculator:1.0.0 # Registry reference
|
||||
SKILL ./local-skill # Local path (for development)
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
```
|
||||
|
||||
## Technical Implementation
|
||||
|
||||
### Skill Manifest Format
|
||||
|
||||
```json
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {
|
||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
||||
"digest": "sha256:config...",
|
||||
"size": 220
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.skill",
|
||||
"digest": "sha256:skill...",
|
||||
"size": 1234
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Skill Config Format
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "calculator",
|
||||
"description": "A skill for performing calculations",
|
||||
"architecture": "amd64",
|
||||
"os": "linux"
|
||||
}
|
||||
```
|
||||
|
||||
### Storage Layout
|
||||
|
||||
Skills use a 5-part manifest structure: `host/namespace/kind/model/tag`
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
├── blobs/
|
||||
│ └── sha256-<skill-digest> # Skill tar.gz blob
|
||||
├── manifests/
|
||||
│ └── registry.ollama.ai/
|
||||
│ └── library/
|
||||
│ └── skill/ # Kind = skill
|
||||
│ └── calculator/
|
||||
│ └── 1.0.0
|
||||
│ └── myname/
|
||||
│ └── skill/ # User skills
|
||||
│ └── my-skill/
|
||||
│ └── latest
|
||||
└── skills/
|
||||
└── sha256-<digest>/ # Extracted skill cache
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
```
|
||||
|
||||
### Name Structure
|
||||
|
||||
Skills use a 5-part name structure with `kind` to distinguish from models:
|
||||
|
||||
| Skill Reference | Namespace | Kind | Model | Tag |
|
||||
|-----------------|-----------|------|-------|-----|
|
||||
| `skill/calculator:1.0.0` | library | skill | calculator | 1.0.0 |
|
||||
| `myname/skill/calc:latest` | myname | skill | calc | latest |
|
||||
|
||||
### Media Type
|
||||
|
||||
```go
|
||||
const MediaTypeSkill = "application/vnd.ollama.image.skill"
|
||||
```
|
||||
|
||||
### Key Types
|
||||
|
||||
```go
|
||||
// SkillRef represents a skill reference in agent config
|
||||
type SkillRef struct {
|
||||
Name string `json:"name,omitempty"` // "calculator-skill" or "myname/skill/calc:1.0.0"
|
||||
Digest string `json:"digest,omitempty"` // "sha256:abc..." (set when bundled)
|
||||
}
|
||||
|
||||
// model.Name represents a parsed 5-part name
|
||||
type Name struct {
|
||||
Host string // "registry.ollama.ai"
|
||||
Namespace string // "library" or "myname"
|
||||
Kind string // "skill" or "agent" or "" for models
|
||||
Model string // "calculator"
|
||||
Tag string // "1.0.0"
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Files
|
||||
|
||||
### Client (ollama)
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `server/skill.go` | Skill blob handling, path parsing, extraction |
|
||||
| `cmd/skill_cmd.go` | CLI commands (push, pull, list, rm, show) |
|
||||
| `cmd/skills.go` | Skill loading and catalog management |
|
||||
| `server/create.go` | Skill layer creation during agent create |
|
||||
| `server/images.go` | Skill extraction during pull |
|
||||
| `types/model/config.go` | SkillRef type definition |
|
||||
|
||||
### Registry (ollama.com)
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `ollamadotcom/registry/store.go` | MediaTypeSkill constant |
|
||||
| `ollamadotcom/store/store.go` | RecordPush handles skill layers |
|
||||
|
||||
## Registry Integration
|
||||
|
||||
### What Works
|
||||
|
||||
- Blob uploads (content-addressable, no auth required)
|
||||
- Layer indexing (skill layers stored with mediatype)
|
||||
- Manifest structure (4-part path compatible)
|
||||
|
||||
### What's Needed
|
||||
|
||||
1. **Namespace Configuration**: The `skill` namespace needs to be configured with:
|
||||
- Public read access
|
||||
- Authenticated write access
|
||||
|
||||
2. **Permission Model**: Decide who can push to `skill/` namespace:
|
||||
- Only Ollama team (curated library)
|
||||
- Verified publishers
|
||||
- Anyone (open registry)
|
||||
|
||||
## Pull Flow
|
||||
|
||||
### Agent with Bundled Skills
|
||||
|
||||
```
|
||||
ollama pull my-agent
|
||||
→ GET manifest (includes skill layers)
|
||||
→ Download all blobs (model + skills)
|
||||
→ Extract skill blobs to ~/.ollama/models/skills/
|
||||
→ Ready to run
|
||||
```
|
||||
|
||||
### Standalone Skill
|
||||
|
||||
```
|
||||
ollama skill pull calculator:1.0.0
|
||||
→ Parse as skill/calculator:1.0.0
|
||||
→ Convert to model.Name{Namespace: "skill", Model: "calculator", Tag: "1.0.0"}
|
||||
→ GET manifest from registry
|
||||
→ Download skill blob
|
||||
→ Extract to ~/.ollama/models/skills/sha256-<digest>/
|
||||
→ Available for agents to reference
|
||||
```
|
||||
|
||||
## Push Flow
|
||||
|
||||
```
|
||||
ollama skill push myname/calculator:1.0.0 ./my-skill
|
||||
→ Validate SKILL.md exists
|
||||
→ Create tar.gz of skill directory
|
||||
→ Compute SHA256 digest
|
||||
→ Store blob locally
|
||||
→ Create skill manifest with config layer
|
||||
→ Store manifest locally
|
||||
→ Push blobs to registry
|
||||
→ Push manifest to registry
|
||||
```
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
- Old agents with `Skills: []string` (paths) continue to work
|
||||
- New agents use `Skills: []SkillRef` with name and digest
|
||||
- Parser detects format and handles both
|
||||
|
||||
## Local Registry Testing
|
||||
|
||||
To test push/pull locally, you need MinIO and the Docker registry running:
|
||||
|
||||
```bash
|
||||
# 1. Start MinIO (for blob storage)
|
||||
minio server ~/.minio-data --console-address ':9001' &
|
||||
|
||||
# 2. Create the ollama-dev bucket (first time only)
|
||||
mc config host add local http://localhost:9000 minioadmin minioadmin
|
||||
mc mb local/ollama-dev
|
||||
|
||||
# 3. Start the registry (from ollama.com repo)
|
||||
cd /path/to/ollama.com/registry
|
||||
go run cmd/registry/main.go serve config-dev.yml &
|
||||
|
||||
# 4. Verify registry is running
|
||||
curl http://localhost:6000/v2/
|
||||
```
|
||||
|
||||
**Important:** The `config-dev.yml` must have matching ports:
|
||||
```yaml
|
||||
http:
|
||||
addr: :6000
|
||||
host: http://localhost:6000 # Must match addr!
|
||||
```
|
||||
|
||||
### Test Commands
|
||||
|
||||
```bash
|
||||
# Push skill from local folder
|
||||
ollama skill push localhost:6000/testuser/skill/calculator:1.0.0 ./skills/calculator-skill --insecure
|
||||
|
||||
# Pull skill from registry
|
||||
ollama skill pull localhost:6000/testuser/skill/calculator:1.0.0 --insecure
|
||||
|
||||
# List skills
|
||||
ollama skill list
|
||||
|
||||
# Show skill
|
||||
ollama skill show localhost:6000/testuser/skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Skill Naming Structure"
|
||||
A["skill/calculator:1.0.0"] --> B["host: registry.ollama.ai"]
|
||||
A --> C["namespace: library"]
|
||||
A --> D["kind: skill"]
|
||||
A --> E["model: calculator"]
|
||||
A --> F["tag: 1.0.0"]
|
||||
end
|
||||
|
||||
subgraph "Storage Layout"
|
||||
G["~/.ollama/models/"]
|
||||
G --> H["blobs/"]
|
||||
H --> I["sha256-<skill-digest>"]
|
||||
G --> J["manifests/"]
|
||||
J --> K["registry.ollama.ai/"]
|
||||
K --> L["library/skill/calculator/1.0.0"]
|
||||
K --> M["myname/skill/my-skill/latest"]
|
||||
G --> N["skills/"]
|
||||
N --> O["sha256-<digest>/"]
|
||||
O --> P["SKILL.md"]
|
||||
O --> Q["scripts/"]
|
||||
end
|
||||
|
||||
subgraph "Push Flow"
|
||||
R["User Command: ollama skill push"]
|
||||
R --> S["Validate SKILL.md"]
|
||||
S --> T["Create tar.gz of skill dir"]
|
||||
T --> U["Compute SHA256 digest"]
|
||||
U --> V["Store blob locally"]
|
||||
V --> W["Create skill manifest"]
|
||||
W --> X["Store manifest locally"]
|
||||
X --> Y["Push blobs to registry"]
|
||||
Y --> Z["Push manifest to registry"]
|
||||
end
|
||||
|
||||
subgraph "Pull Flow - Standalone Skill"
|
||||
AA["User Command: ollama skill pull"]
|
||||
AA --> AB["Parse name structure"]
|
||||
AB --> AC["GET manifest from registry"]
|
||||
AC --> AD["Download skill blob"]
|
||||
AD --> AE["Extract to skills/ directory"]
|
||||
AE --> AF["Available for agents"]
|
||||
end
|
||||
|
||||
subgraph "Pull Flow - Agent with Skills"
|
||||
AG["Pull Agent: ollama pull my-agent"]
|
||||
AG --> AH["GET manifest (includes skill layers)"]
|
||||
AH --> AI["Download all blobs (model + skills)"]
|
||||
AI --> AJ["Extract skill blobs"]
|
||||
AJ --> AK["Ready to run"]
|
||||
end
|
||||
|
||||
subgraph "Agentfile Integration"
|
||||
AL["Agentfile"]
|
||||
AL --> AM["FROM llama3.2:3b"]
|
||||
AL --> AN["SKILL skill/calculator:1.0.0"]
|
||||
AL --> AO["SKILL ./local-skill"]
|
||||
AO --> AP["Local path (development)"]
|
||||
AN --> AQ["Registry reference"]
|
||||
end
|
||||
|
||||
subgraph "Registry Components"
|
||||
AR["Registry Server"]
|
||||
AR --> AS["Blob Storage (MinIO)"]
|
||||
AR --> AT["Layer Indexing"]
|
||||
AR --> AU["Manifest Storage"]
|
||||
AR --> AV["Namespace Config"]
|
||||
end
|
||||
|
||||
Z --> AR
|
||||
AC --> AR
|
||||
AH --> AR
|
||||
```
|
||||
548
docs/skills.md
Normal file
548
docs/skills.md
Normal file
@@ -0,0 +1,548 @@
|
||||
# Ollama Skills
|
||||
|
||||
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Creating a Skill
|
||||
|
||||
Create a directory with a `SKILL.md` file:
|
||||
|
||||
```
|
||||
my-skill/
|
||||
├── SKILL.md # Required: Instructions for the agent
|
||||
└── scripts/ # Optional: Executable scripts
|
||||
└── run.py
|
||||
```
|
||||
|
||||
The `SKILL.md` file must have YAML frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: A brief description of what this skill does
|
||||
---
|
||||
|
||||
# My Skill
|
||||
|
||||
## Purpose
|
||||
Explain what this skill does and when to use it.
|
||||
|
||||
## Instructions
|
||||
Step-by-step instructions for the agent on how to use this skill.
|
||||
|
||||
## Examples
|
||||
Show example inputs and expected outputs.
|
||||
```
|
||||
|
||||
### Using Skills in an Agent
|
||||
|
||||
Reference skills in your Agentfile:
|
||||
|
||||
```dockerfile
|
||||
FROM llama3.2:3b
|
||||
AGENT_TYPE conversational
|
||||
|
||||
# Local skill (bundled with agent)
|
||||
SKILL ./path/to/my-skill
|
||||
|
||||
# Registry skill (pulled from ollama.com)
|
||||
SKILL library/skill/calculator:1.0.0
|
||||
|
||||
# User skill from registry
|
||||
SKILL myname/skill/calculator:1.0.0
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
```
|
||||
|
||||
### Managing Skills
|
||||
|
||||
```bash
|
||||
# Push a skill to the registry (uses your namespace)
|
||||
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
|
||||
|
||||
# Pull a skill from the official library
|
||||
ollama skill pull skill/calculator:1.0.0
|
||||
|
||||
# Pull a skill from a user's namespace
|
||||
ollama skill pull myname/skill/calculator:1.0.0
|
||||
|
||||
# List installed skills
|
||||
ollama skill list
|
||||
|
||||
# Show skill details
|
||||
ollama skill show skill/calculator:1.0.0
|
||||
|
||||
# Remove a skill
|
||||
ollama skill rm skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
### Dynamic Skills in Chat
|
||||
|
||||
You can add and remove skills dynamically during an interactive chat session:
|
||||
|
||||
```
|
||||
>>> /skills
|
||||
Available Skills:
|
||||
calculator (sha256:abc123def456...)
|
||||
|
||||
>>> /skill add ./my-local-skill
|
||||
Added skill 'my-skill' from ./my-local-skill
|
||||
|
||||
>>> /skill list
|
||||
Skills loaded in this session:
|
||||
my-skill (local: /path/to/my-local-skill)
|
||||
|
||||
>>> /skill remove my-skill
|
||||
Removed skill 'my-skill'
|
||||
```
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/skills` | Show all available skills (model + session) |
|
||||
| `/skill add <path>` | Add a skill from a local path |
|
||||
| `/skill remove <name>` | Remove a skill by name |
|
||||
| `/skill list` | List skills loaded in this session |
|
||||
|
||||
Dynamic skills take effect on the next message. This is useful for:
|
||||
- Testing skills during development
|
||||
- Temporarily adding capabilities to a model
|
||||
- Experimenting with skill combinations
|
||||
|
||||
## Skill Reference Formats
|
||||
|
||||
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
|
||||
|
||||
| Format | Example | Description |
|
||||
|--------|---------|-------------|
|
||||
| Local path | `./skills/calc` | Bundled with agent at create time |
|
||||
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
|
||||
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
|
||||
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
|
||||
|
||||
The `kind` field distinguishes skills from models:
|
||||
- `skill` - Skill packages
|
||||
- `agent` - Agent packages (future)
|
||||
- (empty) - Regular models
|
||||
|
||||
## SKILL.md Structure
|
||||
|
||||
### Required Frontmatter
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: skill-name # Must match directory name
|
||||
description: Brief description of the skill
|
||||
---
|
||||
```
|
||||
|
||||
### Recommended Sections
|
||||
|
||||
1. **Purpose**: What the skill does and when to use it
|
||||
2. **When to use**: Trigger conditions for the agent
|
||||
3. **Instructions**: Step-by-step usage guide
|
||||
4. **Examples**: Input/output examples
|
||||
5. **Scripts**: Documentation for any bundled scripts
|
||||
|
||||
### Example: Calculator Skill
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: calculator
|
||||
description: Performs mathematical calculations using Python
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Purpose
|
||||
This skill performs mathematical calculations using a bundled Python script.
|
||||
|
||||
## When to use
|
||||
- User asks to calculate something
|
||||
- User wants to do math operations
|
||||
- Any arithmetic is needed
|
||||
|
||||
## Instructions
|
||||
1. When calculation is needed, use the `run_skill_script` tool
|
||||
2. Call: `python3 scripts/calculate.py "<expression>"`
|
||||
3. Return the result to the user
|
||||
|
||||
## Examples
|
||||
|
||||
**Input**: "What is 25 * 4?"
|
||||
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
|
||||
**Output**: "25 * 4 = 100"
|
||||
```
|
||||
|
||||
## Storage Layout
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
├── blobs/
|
||||
│ └── sha256-<digest> # Skill tar.gz blob
|
||||
├── manifests/
|
||||
│ └── registry.ollama.ai/
|
||||
│ └── skill/ # Library skills
|
||||
│ └── calculator/
|
||||
│ └── 1.0.0
|
||||
│ └── skill-username/ # User skills
|
||||
│ └── my-skill/
|
||||
│ └── latest
|
||||
└── skills/
|
||||
└── sha256-<digest>/ # Extracted skill cache
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Security Considerations
|
||||
|
||||
## Current State (Development)
|
||||
|
||||
The current implementation has several security considerations that need to be addressed before production use.
|
||||
|
||||
### 1. Script Execution
|
||||
|
||||
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts run with the same permissions as the Ollama process
|
||||
- No sandboxing or isolation
|
||||
- Full filesystem access
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Sandbox script execution (containers, seccomp, etc.)
|
||||
- [ ] Resource limits (CPU, memory, time)
|
||||
- [ ] Filesystem isolation (read-only mounts, restricted paths)
|
||||
- [ ] Network policy controls
|
||||
- [ ] Capability dropping
|
||||
|
||||
### 2. Skill Provenance
|
||||
|
||||
**Risk**: Malicious skills could be pushed to the registry.
|
||||
|
||||
**Current behavior**:
|
||||
- No code signing or verification
|
||||
- No malware scanning
|
||||
- Trust based on namespace ownership
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Skill signing with author keys
|
||||
- [ ] Registry-side malware scanning
|
||||
- [ ] Content policy enforcement
|
||||
- [ ] Reputation system for skill authors
|
||||
|
||||
### 3. Namespace Squatting
|
||||
|
||||
**Risk**: Malicious actors could register skill names that impersonate official tools.
|
||||
|
||||
**Current behavior**:
|
||||
- First-come-first-served namespace registration
|
||||
- No verification of skill names
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Reserved namespace list (official tools, common names)
|
||||
- [ ] Trademark/name verification for popular skills
|
||||
- [ ] Clear namespacing conventions
|
||||
|
||||
### 4. Supply Chain Attacks
|
||||
|
||||
**Risk**: Compromised skills could inject malicious code into agents.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills pulled without integrity verification beyond digest
|
||||
- No dependency tracking
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] SBOM (Software Bill of Materials) for skills
|
||||
- [ ] Dependency vulnerability scanning
|
||||
- [ ] Pinned versions in Agentfiles
|
||||
- [ ] Audit logging of skill usage
|
||||
|
||||
### 5. Data Exfiltration
|
||||
|
||||
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills have access to conversation context
|
||||
- Scripts can make network requests
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Network egress controls
|
||||
- [ ] Sensitive data detection/masking
|
||||
- [ ] Audit logging of script network activity
|
||||
- [ ] User consent for data access
|
||||
|
||||
### 6. Privilege Escalation
|
||||
|
||||
**Risk**: Skills could escalate privileges through script execution.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts inherit Ollama process privileges
|
||||
- No capability restrictions
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Run scripts as unprivileged user
|
||||
- [ ] Drop all capabilities
|
||||
- [ ] Mandatory access controls (SELinux/AppArmor)
|
||||
|
||||
## Recommended Security Model
|
||||
|
||||
### Skill Trust Levels
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Level 0: Untrusted (default) │
|
||||
│ - No script execution │
|
||||
│ - Instructions only │
|
||||
│ - Safe for any skill │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 1: Sandboxed │
|
||||
│ - Scripts run in isolated container │
|
||||
│ - No network access │
|
||||
│ - Read-only filesystem │
|
||||
│ - Resource limits enforced │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 2: Trusted │
|
||||
│ - Scripts run with network access │
|
||||
│ - Can write to designated directories │
|
||||
│ - Requires explicit user approval │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 3: Privileged (admin only) │
|
||||
│ - Full host access │
|
||||
│ - System administration skills │
|
||||
│ - Requires admin approval │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Skill Manifest Security Fields (Future)
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: A skill description
|
||||
security:
|
||||
trust_level: sandboxed
|
||||
permissions:
|
||||
- network:read # Can make HTTP GET requests
|
||||
- filesystem:read:/data # Can read from /data
|
||||
resource_limits:
|
||||
max_memory: 256MB
|
||||
max_cpu_time: 30s
|
||||
max_disk: 100MB
|
||||
signature: sha256:abc... # Author signature
|
||||
---
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Future Considerations
|
||||
|
||||
## Feature Roadmap
|
||||
|
||||
### Phase 1: Foundation (Current)
|
||||
- [x] Skill bundling with agents
|
||||
- [x] Local skill development
|
||||
- [x] Basic CLI commands (push, pull, list, rm, show)
|
||||
- [x] Registry blob storage
|
||||
- [ ] Registry namespace configuration
|
||||
|
||||
### Phase 2: Security
|
||||
- [ ] Script sandboxing
|
||||
- [ ] Permission model
|
||||
- [ ] Skill signing
|
||||
- [ ] Audit logging
|
||||
|
||||
### Phase 3: Discovery
|
||||
- [ ] Skill search on ollama.com
|
||||
- [ ] Skill ratings and reviews
|
||||
- [ ] Usage analytics
|
||||
- [ ] Featured/trending skills
|
||||
|
||||
### Phase 4: Advanced Features
|
||||
- [ ] Skill dependencies
|
||||
- [ ] Skill versioning constraints
|
||||
- [ ] Skill composition (skills using skills)
|
||||
- [ ] Skill testing framework
|
||||
|
||||
## Open Questions
|
||||
|
||||
### 1. Skill Execution Model
|
||||
|
||||
**Question**: How should skills execute scripts?
|
||||
|
||||
Options:
|
||||
- **A) In-process**: Fast but unsafe
|
||||
- **B) Subprocess**: Current approach, moderate isolation
|
||||
- **C) Container**: Good isolation, requires container runtime
|
||||
- **D) WASM**: Portable and safe, limited capabilities
|
||||
- **E) Remote execution**: Offload to secure service
|
||||
|
||||
### 2. Skill Versioning
|
||||
|
||||
**Question**: How strict should version pinning be?
|
||||
|
||||
Options:
|
||||
- **A) Always latest**: Simple but risky
|
||||
- **B) Semantic versioning**: `^1.0.0` allows minor updates
|
||||
- **C) Exact pinning**: `=1.0.0` requires explicit updates
|
||||
- **D) Digest pinning**: `@sha256:abc` immutable reference
|
||||
|
||||
### 3. Skill Permissions
|
||||
|
||||
**Question**: How should users grant permissions to skills?
|
||||
|
||||
Options:
|
||||
- **A) All or nothing**: Accept all permissions or don't use
|
||||
- **B) Granular consent**: Approve each permission individually
|
||||
- **C) Trust levels**: Pre-defined permission bundles
|
||||
- **D) Runtime prompts**: Ask when permission is first used
|
||||
|
||||
### 4. Skill Discovery
|
||||
|
||||
**Question**: How should users find skills?
|
||||
|
||||
Options:
|
||||
- **A) Central registry only**: ollama.com/skills
|
||||
- **B) Federated registries**: Multiple skill sources
|
||||
- **C) Git repositories**: Pull from GitHub, etc.
|
||||
- **D) All of the above**: Multiple discovery mechanisms
|
||||
|
||||
### 5. Skill Monetization
|
||||
|
||||
**Question**: Should skill authors be able to monetize?
|
||||
|
||||
Options:
|
||||
- **A) Free only**: All skills are free and open
|
||||
- **B) Paid skills**: Authors can charge for skills
|
||||
- **C) Freemium**: Free tier with paid features
|
||||
- **D) Donations**: Voluntary support for authors
|
||||
|
||||
### 6. Skill Updates
|
||||
|
||||
**Question**: How should skill updates be handled?
|
||||
|
||||
Options:
|
||||
- **A) Manual**: User explicitly updates
|
||||
- **B) Auto-update**: Always use latest
|
||||
- **C) Notify**: Alert user to available updates
|
||||
- **D) Policy-based**: Organization controls update policy
|
||||
|
||||
## API Considerations
|
||||
|
||||
### Skill Metadata API
|
||||
|
||||
```
|
||||
GET /api/skills
|
||||
GET /api/skills/:namespace/:name
|
||||
GET /api/skills/:namespace/:name/versions
|
||||
GET /api/skills/:namespace/:name/readme
|
||||
```
|
||||
|
||||
### Skill Execution API
|
||||
|
||||
```
|
||||
POST /api/skills/:namespace/:name/execute
|
||||
{
|
||||
"command": "python3 scripts/run.py",
|
||||
"args": ["--input", "data"],
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
### Skill Permissions API
|
||||
|
||||
```
|
||||
GET /api/skills/:namespace/:name/permissions
|
||||
POST /api/skills/:namespace/:name/permissions/grant
|
||||
DELETE /api/skills/:namespace/:name/permissions/revoke
|
||||
```
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
### Skill Testing Framework
|
||||
|
||||
```bash
|
||||
# Run skill tests
|
||||
ollama skill test ./my-skill
|
||||
|
||||
# Test with specific model
|
||||
ollama skill test ./my-skill --model llama3.2:3b
|
||||
|
||||
# Generate test report
|
||||
ollama skill test ./my-skill --report
|
||||
```
|
||||
|
||||
### Test File Format
|
||||
|
||||
```yaml
|
||||
# my-skill/tests/test.yaml
|
||||
tests:
|
||||
- name: "basic calculation"
|
||||
input: "What is 2 + 2?"
|
||||
expect:
|
||||
contains: "4"
|
||||
tool_called: "run_skill_script"
|
||||
|
||||
- name: "complex expression"
|
||||
input: "Calculate 15% of 200"
|
||||
expect:
|
||||
contains: "30"
|
||||
```
|
||||
|
||||
## Compatibility Considerations
|
||||
|
||||
### Minimum Ollama Version
|
||||
|
||||
Skills should declare minimum Ollama version:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
requires:
|
||||
ollama: ">=0.4.0"
|
||||
---
|
||||
```
|
||||
|
||||
### Model Compatibility
|
||||
|
||||
Skills may require specific model capabilities:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: vision-skill
|
||||
requires:
|
||||
capabilities:
|
||||
- vision
|
||||
- tools
|
||||
---
|
||||
```
|
||||
|
||||
## Migration Path
|
||||
|
||||
### From Local to Registry
|
||||
|
||||
```bash
|
||||
# Develop locally
|
||||
SKILL ./my-skill
|
||||
|
||||
# Push when ready
|
||||
ollama skill push myname/my-skill:1.0.0 ./my-skill
|
||||
|
||||
# Update Agentfile
|
||||
SKILL skill/myname/my-skill:1.0.0
|
||||
```
|
||||
|
||||
### Version Upgrades
|
||||
|
||||
```bash
|
||||
# Check for updates
|
||||
ollama skill outdated
|
||||
|
||||
# Update specific skill
|
||||
ollama skill update calculator:1.0.0
|
||||
|
||||
# Update all skills
|
||||
ollama skill update --all
|
||||
```
|
||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
||||
|
||||
### Linux NVIDIA Troubleshooting
|
||||
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||
|
||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||
|
||||
|
||||
3
ducky.Agentfile
Normal file
3
ducky.Agentfile
Normal file
@@ -0,0 +1,3 @@
|
||||
SKILL ./skills/calculator-skill
|
||||
ENTRYPOINT ducky
|
||||
|
||||
@@ -148,6 +148,16 @@ func Remotes() []string {
|
||||
return r
|
||||
}
|
||||
|
||||
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
|
||||
// Returns empty slice if not configured.
|
||||
func Skills() []string {
|
||||
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
|
||||
func BoolWithDefault(k string) func(defaultValue bool) bool {
|
||||
return func(defaultValue bool) bool {
|
||||
if s := Var(k); s != "" {
|
||||
@@ -317,6 +327,9 @@ func AsMap() map[string]EnvVar {
|
||||
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
|
||||
}
|
||||
|
||||
// Skills configuration would go here when added
|
||||
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -83,5 +83,5 @@ require (
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 3 +
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
|
||||
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 +++++++++++
|
||||
9 files changed, 976 insertions(+), 17 deletions(-)
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
|
||||
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 ++++++++++
|
||||
9 files changed, 1005 insertions(+), 17 deletions(-)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
@@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
|
||||
|
||||
set_target_properties(ggml-base PROPERTIES
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 6852d2e20..48cdb1dcf 100644
|
||||
index 6852d2e20..334a30135 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -109,7 +109,7 @@ index 6852d2e20..48cdb1dcf 100644
|
||||
+
|
||||
+#if defined(GGML_USE_HIP)
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
@@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index fe57d4c58..1c07e767a 100644
|
||||
index fe57d4c58..dba8f4695 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
@@ -216,7 +216,7 @@ index fe57d4c58..1c07e767a 100644
|
||||
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
||||
+GGML_API void ggml_nvml_release();
|
||||
+GGML_API int ggml_hip_mgmt_init();
|
||||
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
+GGML_API void ggml_hip_mgmt_release();
|
||||
+
|
||||
#ifdef __cplusplus
|
||||
@@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 5349bce24..d43d46d1d 100644
|
||||
index 5349bce24..0103fd03a 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -236,6 +236,7 @@ class vk_memory_logger;
|
||||
@@ -334,7 +334,7 @@ index 5349bce24..d43d46d1d 100644
|
||||
+ switch (props2.properties.vendorID) {
|
||||
+ case VK_VENDOR_ID_AMD:
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
@@ -505,10 +505,10 @@ index 5349bce24..d43d46d1d 100644
|
||||
}
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 000000000..c1949b899
|
||||
index 000000000..23c765806
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,529 @@
|
||||
@@ -0,0 +1,558 @@
|
||||
+#include "ggml.h"
|
||||
+#include "ggml-impl.h"
|
||||
+
|
||||
@@ -842,7 +842,7 @@ index 000000000..c1949b899
|
||||
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
|
||||
+ if (gpu != NULL) gpu->pVtbl->Release(gpu)
|
||||
+
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock);
|
||||
+ if (adlx.handle == NULL) {
|
||||
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
|
||||
@@ -966,13 +966,16 @@ index 000000000..c1949b899
|
||||
+ return 0;
|
||||
+}
|
||||
+void ggml_hip_mgmt_release() {}
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
|
||||
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
|
||||
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
+
|
||||
+
|
||||
+ glob_t glob_result;
|
||||
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
+
|
||||
@@ -1006,7 +1009,6 @@ index 000000000..c1949b899
|
||||
+
|
||||
+ uint64_t memory;
|
||||
+ totalFileStream >> memory;
|
||||
+ *total = memory;
|
||||
+
|
||||
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
+ std::ifstream usedFileStream(usedFile.c_str());
|
||||
@@ -1019,6 +1021,33 @@ index 000000000..c1949b899
|
||||
+
|
||||
+ uint64_t memoryUsed;
|
||||
+ usedFileStream >> memoryUsed;
|
||||
+
|
||||
+ if (is_integrated_gpu) {
|
||||
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
|
||||
+ std::ifstream totalFileStream(totalFile.c_str());
|
||||
+ if (!totalFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ uint64_t gtt;
|
||||
+ totalFileStream >> gtt;
|
||||
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
|
||||
+ std::ifstream usedFileStream(usedFile.c_str());
|
||||
+ if (!usedFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ uint64_t gttUsed;
|
||||
+ usedFileStream >> gttUsed;
|
||||
+ memory += gtt;
|
||||
+ memoryUsed += gttUsed;
|
||||
+ }
|
||||
+
|
||||
+ *total = memory;
|
||||
+ *free = memory - memoryUsed;
|
||||
+
|
||||
+ file.close();
|
||||
|
||||
@@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
|
||||
|
||||
set_target_properties(ggml-base PROPERTIES
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index 1c07e767a..0da3e065b 100644
|
||||
index dba8f4695..7e17032c7 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
|
||||
GGML_API int ggml_hip_mgmt_init();
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
GGML_API void ggml_hip_mgmt_release();
|
||||
+GGML_API int ggml_dxgi_pdh_init();
|
||||
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
@@ -38,7 +38,7 @@ index 1c07e767a..0da3e065b 100644
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index d43d46d1d..df79f9f79 100644
|
||||
index 0103fd03a..9cc4ebdef 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
|
||||
@@ -10,7 +10,7 @@ fallback to cpu
|
||||
1 file changed, 3 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 48cdb1dcf..3102d7ea7 100644
|
||||
index 334a30135..5c9dfd032 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
|
||||
@@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
2
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
@@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
|
||||
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
||||
GGML_API void ggml_nvml_release();
|
||||
GGML_API int ggml_hip_mgmt_init();
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
GGML_API void ggml_hip_mgmt_release();
|
||||
GGML_API int ggml_dxgi_pdh_init();
|
||||
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
|
||||
@@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
||||
switch (props2.properties.vendorID) {
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
|
||||
35
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
35
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
@@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
|
||||
if (gpus != NULL) gpus->pVtbl->Release(gpus); \
|
||||
if (gpu != NULL) gpu->pVtbl->Release(gpu)
|
||||
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
std::lock_guard<std::mutex> lock(ggml_adlx_lock);
|
||||
if (adlx.handle == NULL) {
|
||||
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
|
||||
@@ -455,13 +455,16 @@ int ggml_hip_mgmt_init() {
|
||||
return 0;
|
||||
}
|
||||
void ggml_hip_mgmt_release() {}
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
|
||||
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
|
||||
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
|
||||
|
||||
glob_t glob_result;
|
||||
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
|
||||
@@ -495,7 +498,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
|
||||
uint64_t memory;
|
||||
totalFileStream >> memory;
|
||||
*total = memory;
|
||||
|
||||
std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
std::ifstream usedFileStream(usedFile.c_str());
|
||||
@@ -508,6 +510,33 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
|
||||
uint64_t memoryUsed;
|
||||
usedFileStream >> memoryUsed;
|
||||
|
||||
if (is_integrated_gpu) {
|
||||
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
|
||||
std::ifstream totalFileStream(totalFile.c_str());
|
||||
if (!totalFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
uint64_t gtt;
|
||||
totalFileStream >> gtt;
|
||||
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
|
||||
std::ifstream usedFileStream(usedFile.c_str());
|
||||
if (!usedFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
uint64_t gttUsed;
|
||||
usedFileStream >> gttUsed;
|
||||
memory += gtt;
|
||||
memoryUsed += gttUsed;
|
||||
}
|
||||
|
||||
*total = memory;
|
||||
*free = memory - memoryUsed;
|
||||
|
||||
file.close();
|
||||
|
||||
117
parser/parser.go
117
parser/parser.go
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -58,6 +59,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
|
||||
var messages []api.Message
|
||||
var licenses []string
|
||||
var skills []api.SkillRef
|
||||
var mcps []api.MCPRef
|
||||
params := make(map[string]any)
|
||||
|
||||
for _, c := range f.Commands {
|
||||
@@ -118,6 +121,23 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
case "skill":
|
||||
skills = append(skills, api.SkillRef{Name: c.Args})
|
||||
case "mcp":
|
||||
mcpRef, err := parseMCPArg(c.Args, relativeDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid MCP: %w", err)
|
||||
}
|
||||
mcps = append(mcps, mcpRef)
|
||||
case "agent_type":
|
||||
// Handle "AGENT TYPE conversational" -> strip "TYPE " prefix
|
||||
args := c.Args
|
||||
if strings.HasPrefix(strings.ToLower(args), "type ") {
|
||||
args = strings.TrimSpace(args[5:])
|
||||
}
|
||||
req.AgentType = args
|
||||
case "entrypoint":
|
||||
req.Entrypoint = c.Args
|
||||
default:
|
||||
if slices.Contains(deprecatedParameters, c.Name) {
|
||||
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
|
||||
@@ -150,6 +170,12 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
if len(licenses) > 0 {
|
||||
req.License = licenses
|
||||
}
|
||||
if len(skills) > 0 {
|
||||
req.Skills = skills
|
||||
}
|
||||
if len(mcps) > 0 {
|
||||
req.MCPs = mcps
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
@@ -333,7 +359,7 @@ func (c Command) String() string {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type", "entrypoint":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
@@ -359,7 +385,7 @@ const (
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", \"agent_type\", \"mcp\", or \"entrypoint\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
@@ -423,6 +449,9 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
switch s := strings.ToLower(b.String()); s {
|
||||
case "from":
|
||||
cmd.Name = "model"
|
||||
case "agent":
|
||||
// "AGENT TYPE" -> "agent_type", consume next word
|
||||
cmd.Name = "agent_type"
|
||||
case "parameter":
|
||||
// transition to stateParameter which sets command name
|
||||
next = stateParameter
|
||||
@@ -500,6 +529,10 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
if cmd.Name == "model" {
|
||||
return &f, nil
|
||||
}
|
||||
// Allow entrypoint-only agents without FROM
|
||||
if cmd.Name == "entrypoint" {
|
||||
return &f, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMissingFrom
|
||||
@@ -518,7 +551,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||
}
|
||||
case stateName:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
case isAlpha(r), r == '_':
|
||||
return stateName, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
@@ -619,7 +652,7 @@ func isValidMessageRole(role string) bool {
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type", "agent", "mcp", "entrypoint":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -666,3 +699,79 @@ func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User
|
||||
func expandPath(path, relativeDir string) (string, error) {
|
||||
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
|
||||
}
|
||||
|
||||
// parseMCPArg parses MCP command arguments.
|
||||
// Supports two formats:
|
||||
//
|
||||
// JSON: {"name": "web-search", "command": "uv", "args": ["run", "./script.py"]}
|
||||
// Simple: web-search uv run ./script.py (name, command, args...)
|
||||
func parseMCPArg(args string, relativeDir string) (api.MCPRef, error) {
|
||||
args = strings.TrimSpace(args)
|
||||
if args == "" {
|
||||
return api.MCPRef{}, errors.New("MCP requires arguments")
|
||||
}
|
||||
|
||||
// Try JSON format first
|
||||
if strings.HasPrefix(args, "{") {
|
||||
var ref api.MCPRef
|
||||
if err := json.Unmarshal([]byte(args), &ref); err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
if ref.Name == "" {
|
||||
return api.MCPRef{}, errors.New("MCP name is required")
|
||||
}
|
||||
if ref.Command == "" {
|
||||
return api.MCPRef{}, errors.New("MCP command is required")
|
||||
}
|
||||
if ref.Type == "" {
|
||||
ref.Type = "stdio"
|
||||
}
|
||||
// Expand relative paths in args
|
||||
for i, arg := range ref.Args {
|
||||
if isLocalPath(arg) {
|
||||
expanded, err := expandPath(arg, relativeDir)
|
||||
if err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
|
||||
}
|
||||
ref.Args[i] = expanded
|
||||
}
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Simple format: name command args...
|
||||
parts := strings.Fields(args)
|
||||
if len(parts) < 2 {
|
||||
return api.MCPRef{}, errors.New("MCP requires at least name and command")
|
||||
}
|
||||
|
||||
ref := api.MCPRef{
|
||||
Name: parts[0],
|
||||
Command: parts[1],
|
||||
Type: "stdio",
|
||||
}
|
||||
if len(parts) > 2 {
|
||||
ref.Args = parts[2:]
|
||||
}
|
||||
|
||||
// Expand relative paths in args
|
||||
for i, arg := range ref.Args {
|
||||
if isLocalPath(arg) {
|
||||
expanded, err := expandPath(arg, relativeDir)
|
||||
if err != nil {
|
||||
return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err)
|
||||
}
|
||||
ref.Args[i] = expanded
|
||||
}
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// isLocalPath checks if a string looks like a local filesystem path.
|
||||
func isLocalPath(s string) bool {
|
||||
return strings.HasPrefix(s, "/") ||
|
||||
strings.HasPrefix(s, "./") ||
|
||||
strings.HasPrefix(s, "../") ||
|
||||
strings.HasPrefix(s, "~")
|
||||
}
|
||||
|
||||
148
server/create.go
148
server/create.go
@@ -62,6 +62,10 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
config.Skills = r.Skills
|
||||
config.MCPs = r.MCPs
|
||||
config.AgentType = r.AgentType
|
||||
config.Entrypoint = r.Entrypoint
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
@@ -157,6 +161,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
} else if r.Entrypoint != "" {
|
||||
// Entrypoint-only agent: no base model needed
|
||||
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
|
||||
} else {
|
||||
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -543,6 +550,18 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle skill layers for agents
|
||||
layers, config.Skills, err = setSkillLayers(layers, config.Skills, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle MCP layers for agents
|
||||
layers, config.MCPs, err = setMCPLayers(layers, config.MCPs, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(layers, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -793,6 +812,135 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// setSkillLayers creates skill layers for local skill paths and updates the skill refs.
|
||||
// Local paths are converted to bundled skill layers with digests.
|
||||
// Registry references are kept as-is for later resolution during pull.
|
||||
func setSkillLayers(layers []Layer, skills []model.SkillRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.SkillRef, error) {
|
||||
if len(skills) == 0 {
|
||||
return layers, skills, nil
|
||||
}
|
||||
|
||||
// Remove any existing skill layers
|
||||
layers = removeLayer(layers, MediaTypeSkill)
|
||||
|
||||
var updatedSkills []model.SkillRef
|
||||
|
||||
for _, skill := range skills {
|
||||
// Check if this is a local path
|
||||
if IsLocalSkillPath(skill.Name) {
|
||||
// Expand home directory if needed
|
||||
skillPath := skill.Name
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
// Make absolute
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving skill path %q: %w", skill.Name, err)
|
||||
}
|
||||
|
||||
// Check if this is a direct skill directory or a parent containing skills
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err == nil {
|
||||
// Direct skill directory
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", filepath.Base(absPath))})
|
||||
|
||||
layer, err := CreateSkillLayer(absPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating skill layer for %q: %w", skill.Name, err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
updatedSkills = append(updatedSkills, model.SkillRef{
|
||||
Name: filepath.Base(absPath),
|
||||
Digest: layer.Digest,
|
||||
})
|
||||
} else {
|
||||
// Parent directory - walk to find skill subdirectories
|
||||
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != "SKILL.md" {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillDir := filepath.Dir(path)
|
||||
skillName := filepath.Base(skillDir)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("packaging skill: %s", skillName)})
|
||||
|
||||
layer, err := CreateSkillLayer(skillDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill layer for %q: %w", skillDir, err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
updatedSkills = append(updatedSkills, model.SkillRef{
|
||||
Name: skillName,
|
||||
Digest: layer.Digest,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("walking skill directory %q: %w", skill.Name, err)
|
||||
}
|
||||
}
|
||||
} else if skill.Digest != "" {
|
||||
// Already has a digest (from a pulled agent), keep as-is
|
||||
updatedSkills = append(updatedSkills, skill)
|
||||
} else {
|
||||
// Registry reference - keep as-is for later resolution
|
||||
updatedSkills = append(updatedSkills, skill)
|
||||
}
|
||||
}
|
||||
|
||||
return layers, updatedSkills, nil
|
||||
}
|
||||
|
||||
// setMCPLayers handles MCP server references.
|
||||
// Currently, MCPs are stored as config data (command/args).
|
||||
// Future: support bundling MCP server directories as layers.
|
||||
func setMCPLayers(layers []Layer, mcps []model.MCPRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.MCPRef, error) {
|
||||
if len(mcps) == 0 {
|
||||
return layers, mcps, nil
|
||||
}
|
||||
|
||||
// Remove any existing MCP layers
|
||||
layers = removeLayer(layers, MediaTypeMCP)
|
||||
|
||||
var updatedMCPs []model.MCPRef
|
||||
|
||||
for _, mcp := range mcps {
|
||||
// Validate MCP has required fields
|
||||
if mcp.Name == "" {
|
||||
return nil, nil, fmt.Errorf("MCP server requires a name")
|
||||
}
|
||||
if mcp.Command == "" {
|
||||
return nil, nil, fmt.Errorf("MCP server %q requires a command", mcp.Name)
|
||||
}
|
||||
|
||||
// Set default type if not specified
|
||||
if mcp.Type == "" {
|
||||
mcp.Type = "stdio"
|
||||
}
|
||||
|
||||
// For now, just keep MCPs as config data
|
||||
// Future: detect local paths in args and bundle them
|
||||
updatedMCPs = append(updatedMCPs, mcp)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("configuring MCP: %s", mcp.Name)})
|
||||
}
|
||||
|
||||
return layers, updatedMCPs, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
|
||||
@@ -2,11 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -33,45 +31,9 @@ const maxRetries = 6
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errPartSlow = errors.New("part slow, racing")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
// speedTracker tracks download speeds and computes rolling median.
|
||||
type speedTracker struct {
|
||||
mu sync.Mutex
|
||||
speeds []float64 // bytes per second
|
||||
}
|
||||
|
||||
func (s *speedTracker) Record(bytesPerSec float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.speeds = append(s.speeds, bytesPerSec)
|
||||
// Keep last 100 samples
|
||||
if len(s.speeds) > 100 {
|
||||
s.speeds = s.speeds[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *speedTracker) Median() float64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.speeds) < 3 {
|
||||
return 0 // not enough data
|
||||
}
|
||||
// Simple median: sort a copy and take middle
|
||||
sorted := make([]float64, len(s.speeds))
|
||||
copy(sorted, s.speeds)
|
||||
for i := range sorted {
|
||||
for j := i + 1; j < len(sorted); j++ {
|
||||
if sorted[j] < sorted[i] {
|
||||
sorted[i], sorted[j] = sorted[j], sorted[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return sorted[len(sorted)/2]
|
||||
}
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
@@ -132,127 +94,26 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
|
||||
downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48)
|
||||
const (
|
||||
numDownloadParts = 16
|
||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
func envInt(key string, defaultVal int) int {
|
||||
if s := os.Getenv(key); s != "" {
|
||||
if v, err := strconv.Atoi(s); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// streamHasher reads a file sequentially and hashes it as chunks complete.
|
||||
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
|
||||
// Works by reading from OS page cache - data just written is still in RAM.
|
||||
type streamHasher struct {
|
||||
file *os.File
|
||||
hasher hash.Hash
|
||||
parts []*blobDownloadPart
|
||||
total int64 // total bytes to hash
|
||||
hashed atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
completed []bool
|
||||
done bool
|
||||
err error
|
||||
}
|
||||
|
||||
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
|
||||
h := &streamHasher{
|
||||
file: file,
|
||||
hasher: sha256.New(),
|
||||
parts: parts,
|
||||
total: total,
|
||||
completed: make([]bool, len(parts)),
|
||||
}
|
||||
h.cond = sync.NewCond(&h.mu)
|
||||
return h
|
||||
}
|
||||
|
||||
// MarkComplete signals that a part has been written to disk.
|
||||
func (h *streamHasher) MarkComplete(partIndex int) {
|
||||
h.mu.Lock()
|
||||
h.completed[partIndex] = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Run reads and hashes the file sequentially. Call in a goroutine.
|
||||
func (h *streamHasher) Run() {
|
||||
buf := make([]byte, 64*1024) // 64KB read buffer
|
||||
var offset int64
|
||||
|
||||
for i, part := range h.parts {
|
||||
// Wait for this part to be written
|
||||
h.mu.Lock()
|
||||
for !h.completed[i] && !h.done {
|
||||
h.cond.Wait()
|
||||
}
|
||||
if h.done {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Read and hash this part (from page cache)
|
||||
remaining := part.Size
|
||||
for remaining > 0 {
|
||||
n := int64(len(buf))
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
nr, err := h.file.ReadAt(buf[:n], offset)
|
||||
if err != nil && err != io.EOF {
|
||||
h.mu.Lock()
|
||||
h.err = err
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.hasher.Write(buf[:nr])
|
||||
offset += int64(nr)
|
||||
remaining -= int64(nr)
|
||||
h.hashed.Store(offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the hasher to exit early.
|
||||
func (h *streamHasher) Stop() {
|
||||
h.mu.Lock()
|
||||
h.done = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Hashed returns bytes hashed so far.
|
||||
func (h *streamHasher) Hashed() int64 {
|
||||
return h.hashed.Load()
|
||||
}
|
||||
|
||||
// Digest returns the computed hash.
|
||||
func (h *streamHasher) Digest() string {
|
||||
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// Err returns any error from hashing.
|
||||
func (h *streamHasher) Err() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.err
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
return strings.Join([]string{
|
||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||
}, "-")
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
@@ -290,7 +151,14 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
|
||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
|
||||
size := downloadPartSize
|
||||
size := b.Total / numDownloadParts
|
||||
switch {
|
||||
case size < minDownloadPartSize:
|
||||
size = minDownloadPartSize
|
||||
case size > maxDownloadPartSize:
|
||||
size = maxDownloadPartSize
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < b.Total {
|
||||
if offset+size > b.Total {
|
||||
@@ -352,6 +220,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
setSparse(file)
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
directURL, err := func() (*url.URL, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
@@ -399,106 +270,44 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
|
||||
// Download chunks to disk, hash by reading from page cache.
|
||||
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
|
||||
// The hasher follows behind the downloaders, reading recently-written
|
||||
// data from OS page cache (RAM) rather than disk.
|
||||
sh := newStreamHasher(file, b.Parts, b.Total)
|
||||
tracker := &speedTracker{}
|
||||
|
||||
// Start hasher goroutine
|
||||
hashDone := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(hashDone)
|
||||
}()
|
||||
|
||||
// Log progress periodically
|
||||
// Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM
|
||||
const pageCacheWarningBytes = 4 << 30 // 4GB
|
||||
progressDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
downloaded := b.Completed.Load()
|
||||
hashed := sh.Hashed()
|
||||
dlPct := int(downloaded * 100 / b.Total)
|
||||
hPct := int(hashed * 100 / b.Total)
|
||||
spread := dlPct - hPct
|
||||
spreadBytes := downloaded - hashed
|
||||
|
||||
slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread))
|
||||
if spreadBytes > pageCacheWarningBytes {
|
||||
slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30)))
|
||||
}
|
||||
case <-progressDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(downloadConcurrency)
|
||||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
sh.MarkComplete(part.N)
|
||||
continue
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
var slowRetries int
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
// After 3 slow retries, stop checking slowness and let it complete
|
||||
skipSlowCheck := slowRetries >= 3
|
||||
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case errors.Is(err, errPartSlow):
|
||||
// Kill slow request, retry immediately (stays within concurrency limit)
|
||||
slowRetries++
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
sh.MarkComplete(part.N)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
close(progressDone)
|
||||
sh.Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for hasher to finish
|
||||
<-hashDone
|
||||
close(progressDone)
|
||||
if err := sh.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify hash
|
||||
if computed := sh.Digest(); computed != b.Digest {
|
||||
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
|
||||
}
|
||||
|
||||
// explicitly close the file so we can rename it
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
@@ -517,69 +326,38 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadChunkToDisk streams a part directly to disk at its offset.
|
||||
// Memory: ~32KB (read buffer only).
|
||||
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
|
||||
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
startTime := time.Now()
|
||||
var bytesAtLastCheck atomic.Int64
|
||||
|
||||
g.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
w := io.NewOffsetWriter(file, part.Offset)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
var written int64
|
||||
for written < part.Size {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
written += int64(n)
|
||||
b.Completed.Add(int64(n))
|
||||
bytesAtLastCheck.Store(written)
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Now()
|
||||
part.lastUpdatedMu.Unlock()
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
b.Completed.Add(-written)
|
||||
return err
|
||||
}
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
// Record speed for this part
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
if elapsed > 0 {
|
||||
tracker.Record(float64(part.Size) / elapsed)
|
||||
part.Completed.Add(n)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed.Store(part.Size)
|
||||
return b.writePart(part.Name(), part)
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
var lastBytes int64
|
||||
checksWithoutProgress := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
@@ -587,47 +365,19 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
|
||||
return nil
|
||||
}
|
||||
|
||||
currentBytes := bytesAtLastCheck.Load()
|
||||
|
||||
// Check for complete stall (30 seconds no progress)
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
return errPartStalled
|
||||
}
|
||||
|
||||
// Check for slow speed after 5+ seconds (only for multi-part downloads)
|
||||
// Skip if we've already retried for slowness too many times
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
|
||||
currentSpeed := float64(currentBytes) / elapsed
|
||||
median := tracker.Median()
|
||||
|
||||
// If we're below 10% of median speed, flag as slow
|
||||
if median > 0 && currentSpeed < median*0.1 {
|
||||
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
|
||||
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
|
||||
return errPartSlow
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if speed dropped significantly mid-download
|
||||
if currentBytes == lastBytes {
|
||||
checksWithoutProgress++
|
||||
if checksWithoutProgress >= 10 {
|
||||
slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N))
|
||||
return errPartStalled
|
||||
}
|
||||
} else {
|
||||
checksWithoutProgress = 0
|
||||
}
|
||||
lastBytes = currentBytes
|
||||
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSpeedTracker_Median(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
// Less than 3 samples returns 0
|
||||
s.Record(100)
|
||||
s.Record(200)
|
||||
if got := s.Median(); got != 0 {
|
||||
t.Errorf("expected 0 with < 3 samples, got %f", got)
|
||||
}
|
||||
|
||||
// With 3+ samples, returns median
|
||||
s.Record(300)
|
||||
// Samples: [100, 200, 300] -> median = 200
|
||||
if got := s.Median(); got != 200 {
|
||||
t.Errorf("expected median 200, got %f", got)
|
||||
}
|
||||
|
||||
// Add more samples
|
||||
s.Record(50)
|
||||
s.Record(250)
|
||||
// Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200
|
||||
if got := s.Median(); got != 200 {
|
||||
t.Errorf("expected median 200, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeedTracker_RollingWindow(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
// Add 105 samples (should keep only last 100)
|
||||
for i := 0; i < 105; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if len(s.speeds) != 100 {
|
||||
t.Errorf("expected 100 samples, got %d", len(s.speeds))
|
||||
}
|
||||
// First sample should be 5 (0-4 were dropped)
|
||||
if s.speeds[0] != 5 {
|
||||
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestSpeedTracker_Concurrent(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(v int) {
|
||||
defer wg.Done()
|
||||
s.Record(float64(v))
|
||||
s.Median() // concurrent read
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic, and should have reasonable state
|
||||
s.mu.Lock()
|
||||
if len(s.speeds) == 0 || len(s.speeds) > 100 {
|
||||
t.Errorf("unexpected speeds length: %d", len(s.speeds))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStreamHasher_Sequential(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
// Write test data
|
||||
data := []byte("hello world, this is a test of the stream hasher")
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create parts
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: int64(len(data))},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, int64(len(data)))
|
||||
|
||||
// Mark complete and run
|
||||
sh.MarkComplete(0)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
<-done
|
||||
|
||||
// Verify digest
|
||||
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
|
||||
if got := sh.Digest(); got != expected {
|
||||
t.Errorf("digest mismatch: got %s, want %s", got, expected)
|
||||
}
|
||||
|
||||
if err := sh.Err(); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
// Write test data (3 parts of 10 bytes each)
|
||||
data := []byte("0123456789ABCDEFGHIJabcdefghij")
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create 3 parts
|
||||
parts := []*blobDownloadPart{
|
||||
{N: 0, Offset: 0, Size: 10},
|
||||
{N: 1, Offset: 10, Size: 10},
|
||||
{N: 2, Offset: 20, Size: 10},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, int64(len(data)))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Mark parts complete out of order: 2, 0, 1
|
||||
sh.MarkComplete(2)
|
||||
sh.MarkComplete(0) // This should trigger hashing of part 0
|
||||
sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2
|
||||
|
||||
<-done
|
||||
|
||||
// Verify digest
|
||||
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
|
||||
if got := sh.Digest(); got != expected {
|
||||
t.Errorf("digest mismatch: got %s, want %s", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_Stop(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: 100},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, 100)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Stop without completing any parts
|
||||
sh.Stop()
|
||||
<-done
|
||||
|
||||
// Should exit cleanly without error
|
||||
if err := sh.Err(); err != nil {
|
||||
t.Errorf("unexpected error after Stop: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_HashedProgress(t *testing.T) {
|
||||
// Create temp file with known data
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
data := make([]byte, 1000)
|
||||
rand.Read(data)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{N: 0, Offset: 0, Size: 500},
|
||||
{N: 1, Offset: 500, Size: 500},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, 1000)
|
||||
|
||||
// Initially no progress
|
||||
if got := sh.Hashed(); got != 0 {
|
||||
t.Errorf("expected 0 hashed initially, got %d", got)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Complete part 0
|
||||
sh.MarkComplete(0)
|
||||
|
||||
// Give hasher time to process
|
||||
for i := 0; i < 100; i++ {
|
||||
if sh.Hashed() >= 500 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Complete part 1
|
||||
sh.MarkComplete(1)
|
||||
<-done
|
||||
|
||||
if got := sh.Hashed(); got != 1000 {
|
||||
t.Errorf("expected 1000 hashed, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSpeedTracker_Record(b *testing.B) {
|
||||
s := &speedTracker{}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSpeedTracker_Median(b *testing.B) {
|
||||
s := &speedTracker{}
|
||||
// Pre-populate with 100 samples
|
||||
for i := 0; i < 100; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Median()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStreamHasher(b *testing.B) {
|
||||
// Create temp file with test data
|
||||
f, err := os.CreateTemp("", "streamhasher_bench")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
size := 64 * 1024 * 1024 // 64MB
|
||||
data := make([]byte, size)
|
||||
rand.Read(data)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: int64(size)},
|
||||
}
|
||||
|
||||
b.SetBytes(int64(size))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sh := newStreamHasher(f, parts, int64(size))
|
||||
sh.MarkComplete(0)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashThroughput(b *testing.B) {
|
||||
// Baseline: raw SHA256 throughput on this machine
|
||||
size := 256 * 1024 * 1024 // 256MB
|
||||
data := make([]byte, size)
|
||||
rand.Read(data)
|
||||
|
||||
b.SetBytes(int64(size))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
h.Sum(nil)
|
||||
}
|
||||
}
|
||||
@@ -232,6 +232,13 @@ func (m *Model) String() string {
|
||||
})
|
||||
}
|
||||
|
||||
if m.Config.Entrypoint != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "entrypoint",
|
||||
Args: m.Config.Entrypoint,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
@@ -620,8 +627,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
_, err := downloadBlob(ctx, downloadOpts{
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
@@ -630,12 +638,41 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
// Note: Digest verification now happens inline during download in blobDownload.run()
|
||||
// via the orderedWriter, so no separate verification pass is needed.
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
if skipVerify[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract skill layers to the skills cache
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == MediaTypeSkill {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("extracting skill %s", layer.Digest)})
|
||||
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
|
||||
return fmt.Errorf("extracting skill layer %s: %w", layer.Digest, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
|
||||
52
server/internal/cache/blob/cache.go
vendored
52
server/internal/cache/blob/cache.go
vendored
@@ -10,6 +10,7 @@ import (
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -326,19 +327,21 @@ func (c *DiskCache) GetFile(d Digest) string {
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a slice of link names in lexical order.
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() ([]string, error) {
|
||||
paths, err := c.links()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
if !yield(pathToName(path), nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
names := make([]string, len(paths))
|
||||
for i, path := range paths {
|
||||
names[i] = pathToName(path)
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// pathToName converts a path to a name. It is the inverse of nameToPath. The
|
||||
@@ -369,11 +372,10 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
|
||||
}
|
||||
|
||||
maybe := filepath.Join("manifests", np)
|
||||
paths, err := c.links()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range paths {
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.EqualFold(maybe, l) {
|
||||
return filepath.Join(c.dir, l), nil
|
||||
}
|
||||
@@ -381,10 +383,22 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
|
||||
return filepath.Join(c.dir, maybe), nil
|
||||
}
|
||||
|
||||
// links returns a slice of link paths in the cache in lexical order.
|
||||
func (c *DiskCache) links() ([]string, error) {
|
||||
fsys := os.DirFS(c.dir)
|
||||
return fs.Glob(fsys, "manifests/*/*/*/*")
|
||||
// links returns a sequence of links in the cache in lexical order.
|
||||
func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
// TODO(bmizerany): reuse empty dirnames if exist
|
||||
return func(yield func(string, error) bool) {
|
||||
fsys := os.DirFS(c.dir)
|
||||
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
for _, manifest := range manifests {
|
||||
if !yield(manifest, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
|
||||
27
server/internal/cache/blob/cache_test.go
vendored
27
server/internal/cache/blob/cache_test.go
vendored
@@ -466,9 +466,12 @@ func testManifestNameReuse(t *testing.T) {
|
||||
t.Fatalf("g = %v, want %v", g, w)
|
||||
}
|
||||
|
||||
got, err := c.links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
var got []string
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
want := []string{"manifests/h/n/m/t"}
|
||||
if !slices.Equal(got, want) {
|
||||
@@ -484,9 +487,12 @@ func testManifestNameReuse(t *testing.T) {
|
||||
err = c.Link("h/n/m:T", d1)
|
||||
check(err)
|
||||
|
||||
got, err = c.links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
got = got[:0]
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
|
||||
// we should have only one link that is same case as the last link
|
||||
@@ -548,9 +554,12 @@ func TestNames(t *testing.T) {
|
||||
check(c.Link("h/n/m:t", mkdigest("1")))
|
||||
check(c.Link("h/n/m:u", mkdigest("2")))
|
||||
|
||||
got, err := c.Links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
var got []string
|
||||
for l, err := range c.Links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
want := []string{"h/n/m:t", "h/n/m:u"}
|
||||
if !slices.Equal(got, want) {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -545,7 +546,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
})
|
||||
}()
|
||||
|
||||
err = r.chunksums(ctx, name, l, func(cs chunksum) bool {
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Note the chunksum stream
|
||||
// interruption, but do not cancel
|
||||
// in-flight downloads. We can still
|
||||
// make progress on them. Once they are
|
||||
// done, ErrIncomplete will be returned
|
||||
// below.
|
||||
update(0, err)
|
||||
break
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
@@ -557,7 +569,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
update(cs.Chunk.Size(), ErrCached)
|
||||
return true // continue
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -608,13 +620,6 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
// Record the downloading of this chunk.
|
||||
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
})
|
||||
return true // continue processing chunks
|
||||
})
|
||||
if err != nil {
|
||||
// Note the chunksum stream interruption, but do not cancel
|
||||
// in-flight downloads. We can still make progress on them.
|
||||
// Once they are done, ErrIncomplete will be returned below.
|
||||
update(0, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -669,6 +674,19 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||
return func(yield func(*Layer) bool) {
|
||||
if !yield(m.Config) {
|
||||
return
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if !yield(l) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() int64 {
|
||||
var size int64
|
||||
if m.Config != nil {
|
||||
@@ -793,114 +811,125 @@ type chunksum struct {
|
||||
Digest blob.Digest
|
||||
}
|
||||
|
||||
// chunksums calls fn for each chunksum in the layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum covering the entire layer is passed to fn.
|
||||
// If the layer is over the chunking threshold, chunksums are read from the chunksums endpoint.
|
||||
// Returns an error if the chunksum stream fails, or nil if all chunksums were processed.
|
||||
// If fn returns false, iteration stops early and chunksums returns nil.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer, fn func(chunksum) bool) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
fn(cs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
return s.Err()
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
||||
return func(yield func(chunksum, error) bool) {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid digest: %q", s.Bytes())
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
return err
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())
|
||||
yield(cs, nil)
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
if !fn(cs) {
|
||||
return nil
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
}
|
||||
return
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
if !yield(cs, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1147,8 +1176,8 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// parseChunk parses a byte slice in the form "start-end" and returns the Chunk.
|
||||
func parseChunk(s []byte) (blob.Chunk, error) {
|
||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
||||
if !found {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
||||
|
||||
@@ -27,20 +27,46 @@ type Trace struct {
|
||||
}
|
||||
|
||||
func (t *Trace) update(l *Layer, n int64, err error) {
|
||||
if t != nil && t.Update != nil {
|
||||
if t.Update != nil {
|
||||
t.Update(l, n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type traceKey struct{}
|
||||
|
||||
// WithTrace attaches a Trace to the context for transfer progress reporting.
|
||||
// WithTrace adds a trace to the context for transfer progress reporting.
|
||||
func WithTrace(ctx context.Context, t *Trace) context.Context {
|
||||
return context.WithValue(ctx, traceKey{}, t)
|
||||
old := traceFromContext(ctx)
|
||||
if old == t {
|
||||
// No change, return the original context. This also prevents
|
||||
// infinite recursion below, if the caller passes the same
|
||||
// Trace.
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Create a new Trace that wraps the old one, if any. If we used the
|
||||
// same pointer t, we end up with a recursive structure.
|
||||
composed := &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if old != nil {
|
||||
old.update(l, n, err)
|
||||
}
|
||||
t.update(l, n, err)
|
||||
},
|
||||
}
|
||||
return context.WithValue(ctx, traceKey{}, composed)
|
||||
}
|
||||
|
||||
// traceFromContext returns the Trace associated with ctx, or nil if none.
|
||||
var emptyTrace = &Trace{}
|
||||
|
||||
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
|
||||
// none is found.
|
||||
//
|
||||
// It never returns nil.
|
||||
func traceFromContext(ctx context.Context) *Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(*Trace)
|
||||
if t == nil {
|
||||
return emptyTrace
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -2,46 +2,44 @@ package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Retry calls fn repeatedly with exponential backoff until it returns nil,
|
||||
// a non-retryable error (shouldRetry returns false), or the context is cancelled.
|
||||
// The shouldRetry function determines if an error is retryable.
|
||||
// Returns the last error encountered, or nil if fn succeeded.
|
||||
func Retry(ctx context.Context, maxBackoff time.Duration, shouldRetry func(error) bool, fn func() error) error {
|
||||
var t *time.Timer
|
||||
for n := 0; ; n++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
var t *time.Timer
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
err := fn()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !shouldRetry(err) {
|
||||
return err
|
||||
}
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
n++
|
||||
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,70 +10,31 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
func TestLoop(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
n := 0
|
||||
last := -1
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
err := Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
for n, err := range Loop(ctx, 100*time.Millisecond) {
|
||||
if !errors.Is(err, ctx.Err()) {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != last+1 {
|
||||
t.Errorf("n = %d, want %d", n, last+1)
|
||||
}
|
||||
last = n
|
||||
if n > 5 {
|
||||
cancel()
|
||||
}
|
||||
return errors.New("keep going")
|
||||
})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("err = %v, want context.Canceled", err)
|
||||
}
|
||||
|
||||
if n != 6 {
|
||||
t.Errorf("n = %d, want 6", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetrySuccess(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
n := 0
|
||||
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
if n >= 3 {
|
||||
return nil // success
|
||||
}
|
||||
return errors.New("retry")
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("n = %d, want 3", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryNonRetryable(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
permanent := errors.New("permanent error")
|
||||
n := 0
|
||||
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool {
|
||||
return !errors.Is(err, permanent)
|
||||
}, func() error {
|
||||
n++
|
||||
if n >= 2 {
|
||||
return permanent
|
||||
}
|
||||
return errors.New("retry")
|
||||
})
|
||||
|
||||
if !errors.Is(err, permanent) {
|
||||
t.Errorf("err = %v, want permanent", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("n = %d, want 2", n)
|
||||
if last != 6 {
|
||||
t.Errorf("last = %d, want 6", last)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,46 +3,37 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errRetry = errors.New("retry")
|
||||
|
||||
func TestRetryAllocs(t *testing.T) {
|
||||
func TestLoopAllocs(t *testing.T) {
|
||||
for i := range 3 {
|
||||
got := testing.AllocsPerRun(1000, func() {
|
||||
tick := 0
|
||||
Retry(t.Context(), 1, func(err error) bool { return true }, func() error {
|
||||
tick++
|
||||
for tick := range Loop(t.Context(), 1) {
|
||||
if tick >= i {
|
||||
return nil
|
||||
break
|
||||
}
|
||||
return errRetry
|
||||
})
|
||||
}
|
||||
})
|
||||
want := float64(0)
|
||||
if i > 0 {
|
||||
want = 3 // due to time.NewTimer
|
||||
}
|
||||
if got > want {
|
||||
t.Errorf("[%d ticks]: allocs = %v, want <= %v", i, got, want)
|
||||
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRetry(b *testing.B) {
|
||||
func BenchmarkLoop(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
synctest.Run(func() {
|
||||
n := 0
|
||||
Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
for n := range Loop(ctx, 100*time.Millisecond) {
|
||||
if n == b.N {
|
||||
return nil
|
||||
break
|
||||
}
|
||||
return errRetry
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
p, err := decodeParams(r.Body)
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
|
||||
p, err := decodeParams(r.Body)
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -293,14 +293,10 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
// ticker controls periodic progress flushing. It starts paused (very long
|
||||
// interval) and is activated by start() once all layers are registered,
|
||||
// so clients see a complete total before progress begins.
|
||||
ticker := time.NewTicker(1 << 62) // effectively paused until started
|
||||
defer ticker.Stop()
|
||||
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
|
||||
start := sync.OnceFunc(func() {
|
||||
flushProgress()
|
||||
ticker.Reset(100 * time.Millisecond)
|
||||
flushProgress() // flush initial state
|
||||
t.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
@@ -324,21 +320,36 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}()
|
||||
|
||||
// Block flushing progress updates until every
|
||||
// layer is accounted for. Clients depend on a
|
||||
// complete model size to calculate progress
|
||||
// correctly; if they use an incomplete total,
|
||||
// progress indicators would erratically jump
|
||||
// as new layers are registered.
|
||||
start()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- backoff.Retry(ctx, 3*time.Second, canRetry, func() error {
|
||||
return s.Client.Pull(ctx, p.model())
|
||||
})
|
||||
go func() (err error) {
|
||||
defer func() { done <- err }()
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := s.Client.Pull(ctx, p.model())
|
||||
if canRetry(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-t.C:
|
||||
flushProgress()
|
||||
case err := <-done:
|
||||
flushProgress()
|
||||
@@ -363,13 +374,20 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
func decodeParams(r io.Reader) (*params, error) {
|
||||
var p params
|
||||
err := json.NewDecoder(r).Decode(&p)
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
if err == nil {
|
||||
return &p, nil
|
||||
return v, nil
|
||||
}
|
||||
var zero T
|
||||
|
||||
// Not sure why, but I can't seem to be able to use:
|
||||
//
|
||||
// errors.As(err, &json.UnmarshalTypeError{})
|
||||
//
|
||||
// This is working fine in stdlib, so I'm not sure what rules changed
|
||||
// and why this no longer works here. So, we do it the verbose way.
|
||||
var a *json.UnmarshalTypeError
|
||||
var b *json.SyntaxError
|
||||
if errors.As(err, &a) || errors.As(err, &b) {
|
||||
@@ -378,7 +396,7 @@ func decodeParams(r io.Reader) (*params, error) {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||||
}
|
||||
return nil, err
|
||||
return zero, err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
@@ -390,8 +408,10 @@ func canRetry(err error) bool {
|
||||
return oe.Temporary()
|
||||
}
|
||||
s := err.Error()
|
||||
return errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(s, "unreachable") ||
|
||||
strings.Contains(s, "no route to host") ||
|
||||
strings.Contains(s, "connection reset by peer")
|
||||
return cmp.Or(
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
strings.Contains(s, "unreachable"),
|
||||
strings.Contains(s, "no route to host"),
|
||||
strings.Contains(s, "connection reset by peer"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -129,11 +129,30 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
// Find both 4-part (models) and 5-part (skills/agents) manifest paths
|
||||
matches4, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matches5, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Combine matches, filtering to only include files
|
||||
var matches []string
|
||||
for _, match := range matches4 {
|
||||
fi, err := os.Stat(match)
|
||||
if err == nil && !fi.IsDir() {
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
for _, match := range matches5 {
|
||||
fi, err := os.Stat(match)
|
||||
if err == nil && !fi.IsDir() {
|
||||
matches = append(matches, match)
|
||||
}
|
||||
}
|
||||
|
||||
ms := make(map[model.Name]*Manifest)
|
||||
for _, match := range matches {
|
||||
|
||||
315
server/mcp.go
Normal file
315
server/mcp.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MediaTypeMCP is the media type for MCP server layers in manifests.
|
||||
const MediaTypeMCP = "application/vnd.ollama.image.mcp"
|
||||
|
||||
// GetMCPsPath returns the path to the extracted MCPs cache directory.
|
||||
// If digest is empty, returns the mcps directory itself.
|
||||
// If digest is provided, returns the path to the extracted MCP for that digest.
|
||||
func GetMCPsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "mcps", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ExtractMCPBlob extracts an MCP tar.gz blob to the mcps cache.
|
||||
// The blob is expected to be at the blobs path for the given digest.
|
||||
// Returns the path to the extracted MCP directory.
|
||||
func ExtractMCPBlob(digest string) (string, error) {
|
||||
// Get the blob path
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting blob path: %w", err)
|
||||
}
|
||||
|
||||
// Get the extraction path
|
||||
mcpPath, err := GetMCPsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting mcp path: %w", err)
|
||||
}
|
||||
|
||||
// Check if already extracted (look for any file)
|
||||
entries, err := os.ReadDir(mcpPath)
|
||||
if err == nil && len(entries) > 0 {
|
||||
return mcpPath, nil
|
||||
}
|
||||
|
||||
// Open the blob
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("opening blob: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create gzip reader
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
// Create tar reader
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
// Create the mcp directory
|
||||
if err := os.MkdirAll(mcpPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating mcp directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract files
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading tar: %w", err)
|
||||
}
|
||||
|
||||
// Clean the name and ensure it doesn't escape the target directory
|
||||
name := filepath.Clean(header.Name)
|
||||
if strings.HasPrefix(name, "..") {
|
||||
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
|
||||
}
|
||||
|
||||
target := filepath.Join(mcpPath, name)
|
||||
|
||||
// Verify the target is within mcpPath
|
||||
if !strings.HasPrefix(target, filepath.Clean(mcpPath)+string(os.PathSeparator)) && target != filepath.Clean(mcpPath) {
|
||||
return "", fmt.Errorf("path escapes mcp directory: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(target, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating parent directory: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
outFile.Close()
|
||||
return "", fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return mcpPath, nil
|
||||
}
|
||||
|
||||
// CreateMCPLayer creates an MCP layer from a local directory.
|
||||
// The directory can optionally contain an mcp.json or package.json file.
|
||||
// Returns the created layer.
|
||||
func CreateMCPLayer(mcpDir string) (Layer, error) {
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(mcpDir)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("mcp directory not found: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return Layer{}, fmt.Errorf("mcp path is not a directory: %s", mcpDir)
|
||||
}
|
||||
|
||||
// Create a temporary file for the tar.gz
|
||||
blobsPath, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(blobsPath, "mcp-*.tar.gz")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
}()
|
||||
|
||||
// Create gzip writer
|
||||
gzw := gzip.NewWriter(tmpFile)
|
||||
defer gzw.Close()
|
||||
|
||||
// Create tar writer
|
||||
tw := tar.NewWriter(gzw)
|
||||
defer tw.Close()
|
||||
|
||||
// Walk the mcp directory and add files to tar
|
||||
err = filepath.Walk(mcpDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(mcpDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = relPath
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write file contents if it's a regular file
|
||||
if !info.IsDir() {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
|
||||
}
|
||||
|
||||
// Close writers to flush
|
||||
if err := tw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing temp file: %w", err)
|
||||
}
|
||||
|
||||
// Open the temp file for reading
|
||||
tmpFile, err = os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
// Create the layer (this will compute the digest and move to blobs)
|
||||
layer, err := NewLayer(tmpFile, MediaTypeMCP)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating layer: %w", err)
|
||||
}
|
||||
|
||||
// Extract the mcp to the cache so it's ready to use
|
||||
if _, err := ExtractMCPBlob(layer.Digest); err != nil {
|
||||
return Layer{}, fmt.Errorf("extracting mcp: %w", err)
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
// IsLocalMCPPath checks if an MCP reference looks like a local path.
|
||||
// Local paths are explicitly prefixed with /, ./, ../, or ~.
|
||||
func IsLocalMCPPath(name string) bool {
|
||||
return strings.HasPrefix(name, "/") ||
|
||||
strings.HasPrefix(name, "./") ||
|
||||
strings.HasPrefix(name, "../") ||
|
||||
strings.HasPrefix(name, "~")
|
||||
}
|
||||
|
||||
// MCPNamespace is the namespace used for standalone MCPs in the registry.
|
||||
const MCPNamespace = "mcp"
|
||||
|
||||
// IsMCPReference checks if a name refers to an MCP (has mcp/ prefix).
|
||||
func IsMCPReference(name string) bool {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
|
||||
// mcp/name or mcp/name:tag
|
||||
if len(parts) >= 1 && parts[0] == MCPNamespace {
|
||||
return true
|
||||
}
|
||||
// namespace/mcp/name (e.g., myuser/mcp/websearch)
|
||||
if len(parts) >= 2 && parts[1] == MCPNamespace {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseMCPName parses an MCP reference string into a model.Name.
|
||||
// The Kind field is set to "mcp".
|
||||
func ParseMCPName(name string) model.Name {
|
||||
n := model.ParseName(name)
|
||||
|
||||
// If Kind wasn't set (old format without mcp/), set it
|
||||
if n.Kind == "" {
|
||||
n.Kind = MCPNamespace
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// GetMCPManifestPath returns the path to the MCP manifest file.
|
||||
func GetMCPManifestPath(n model.Name) (string, error) {
|
||||
if n.Model == "" {
|
||||
return "", fmt.Errorf("mcp name is required")
|
||||
}
|
||||
|
||||
// Ensure Kind is set
|
||||
if n.Kind == "" {
|
||||
n.Kind = MCPNamespace
|
||||
}
|
||||
|
||||
path := filepath.Join(
|
||||
envconfig.Models(),
|
||||
"manifests",
|
||||
n.Filepath(),
|
||||
)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -18,6 +18,7 @@ type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Kind string // Optional: "skill", "agent", or empty for models
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
@@ -42,6 +43,7 @@ func ParseModelPath(name string) ModelPath {
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Kind: "",
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
@@ -55,13 +57,41 @@ func ParseModelPath(name string) ModelPath {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
case 4:
|
||||
// host/namespace/kind/model or host/namespace/model:tag with kind
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
if model.ValidKinds[parts[2]] {
|
||||
mp.Kind = parts[2]
|
||||
mp.Repository = parts[3]
|
||||
} else {
|
||||
// Not a valid kind, treat as old format with extra part
|
||||
mp.Repository = parts[3]
|
||||
}
|
||||
case 3:
|
||||
// Could be: host/namespace/model OR namespace/kind/model
|
||||
if model.ValidKinds[parts[1]] {
|
||||
// namespace/kind/model
|
||||
mp.Namespace = parts[0]
|
||||
mp.Kind = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
} else {
|
||||
// host/namespace/model
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
}
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
// Could be: namespace/model OR kind/model
|
||||
if model.ValidKinds[parts[0]] {
|
||||
// kind/model (library skill)
|
||||
mp.Kind = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
} else {
|
||||
// namespace/model
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
}
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
@@ -75,20 +105,35 @@ func ParseModelPath(name string) ModelPath {
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", mp.Namespace, mp.Kind, mp.Repository)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
if mp.Kind != "" {
|
||||
return fmt.Sprintf("%s/%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Kind, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
@@ -97,6 +142,7 @@ func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Kind: mp.Kind,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
|
||||
@@ -969,6 +969,9 @@ func getExistingName(n model.Name) (model.Name, error) {
|
||||
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
|
||||
n.Namespace = e.Namespace
|
||||
}
|
||||
if set.Kind == "" && strings.EqualFold(e.Kind, n.Kind) {
|
||||
n.Kind = e.Kind
|
||||
}
|
||||
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
|
||||
n.Model = e.Model
|
||||
}
|
||||
@@ -1107,6 +1110,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
Skills: m.Config.Skills,
|
||||
MCPs: m.Config.MCPs,
|
||||
AgentType: m.Config.AgentType,
|
||||
Entrypoint: m.Config.Entrypoint,
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
@@ -1161,11 +1168,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
fmt.Fprint(&sb, m.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
// skip loading tensor information if this is a remote model
|
||||
// skip loading tensor information if this is a remote model or a skill
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Skills don't have model weights, skip tensor loading
|
||||
if m.ModelPath == "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
326
server/skill.go
Normal file
326
server/skill.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MediaTypeSkill is the media type for skill layers in manifests.
|
||||
const MediaTypeSkill = "application/vnd.ollama.image.skill"
|
||||
|
||||
// GetSkillsPath returns the path to the extracted skills cache directory.
|
||||
// If digest is empty, returns the skills directory itself.
|
||||
// If digest is provided, returns the path to the extracted skill for that digest.
|
||||
func GetSkillsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "skills", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ExtractSkillBlob extracts a skill tar.gz blob to the skills cache.
|
||||
// The blob is expected to be at the blobs path for the given digest.
|
||||
// Returns the path to the extracted skill directory.
|
||||
func ExtractSkillBlob(digest string) (string, error) {
|
||||
// Get the blob path
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting blob path: %w", err)
|
||||
}
|
||||
|
||||
// Get the extraction path
|
||||
skillPath, err := GetSkillsPath(digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting skill path: %w", err)
|
||||
}
|
||||
|
||||
// Check if already extracted
|
||||
if _, err := os.Stat(filepath.Join(skillPath, "SKILL.md")); err == nil {
|
||||
return skillPath, nil
|
||||
}
|
||||
|
||||
// Open the blob
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("opening blob: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create gzip reader
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
// Create tar reader
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
// Create the skill directory
|
||||
if err := os.MkdirAll(skillPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating skill directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract files
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading tar: %w", err)
|
||||
}
|
||||
|
||||
// Clean the name and ensure it doesn't escape the target directory
|
||||
name := filepath.Clean(header.Name)
|
||||
if strings.HasPrefix(name, "..") {
|
||||
return "", fmt.Errorf("invalid path in archive: %s", header.Name)
|
||||
}
|
||||
|
||||
target := filepath.Join(skillPath, name)
|
||||
|
||||
// Verify the target is within skillPath
|
||||
if !strings.HasPrefix(target, filepath.Clean(skillPath)+string(os.PathSeparator)) && target != filepath.Clean(skillPath) {
|
||||
return "", fmt.Errorf("path escapes skill directory: %s", header.Name)
|
||||
}
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(target, 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return "", fmt.Errorf("creating parent directory: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
outFile.Close()
|
||||
return "", fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return skillPath, nil
|
||||
}
|
||||
|
||||
// CreateSkillLayer creates a skill layer from a local directory.
|
||||
// The directory must contain a SKILL.md file.
|
||||
// Returns the created layer.
|
||||
func CreateSkillLayer(skillDir string) (Layer, error) {
|
||||
// Verify SKILL.md exists
|
||||
skillMdPath := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
return Layer{}, fmt.Errorf("skill directory must contain SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Create a temporary file for the tar.gz
|
||||
blobsPath, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("getting blobs path: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(blobsPath, "skill-*.tar.gz")
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
}()
|
||||
|
||||
// Create gzip writer
|
||||
gzw := gzip.NewWriter(tmpFile)
|
||||
defer gzw.Close()
|
||||
|
||||
// Create tar writer
|
||||
tw := tar.NewWriter(gzw)
|
||||
defer tw.Close()
|
||||
|
||||
// Walk the skill directory and add files to tar
|
||||
err = filepath.Walk(skillDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(skillDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip the root directory itself
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create tar header
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = relPath
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write file contents if it's a regular file
|
||||
if !info.IsDir() {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(tw, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating tar archive: %w", err)
|
||||
}
|
||||
|
||||
// Close writers to flush
|
||||
if err := tw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing tar writer: %w", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing gzip writer: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return Layer{}, fmt.Errorf("closing temp file: %w", err)
|
||||
}
|
||||
|
||||
// Open the temp file for reading
|
||||
tmpFile, err = os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("reopening temp file: %w", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
// Create the layer (this will compute the digest and move to blobs)
|
||||
layer, err := NewLayer(tmpFile, MediaTypeSkill)
|
||||
if err != nil {
|
||||
return Layer{}, fmt.Errorf("creating layer: %w", err)
|
||||
}
|
||||
|
||||
// Extract the skill to the cache so it's ready to use
|
||||
if _, err := ExtractSkillBlob(layer.Digest); err != nil {
|
||||
return Layer{}, fmt.Errorf("extracting skill: %w", err)
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
// IsLocalSkillPath checks if a skill reference looks like a local path.
|
||||
// Local paths are explicitly prefixed with /, ./, ../, or ~.
|
||||
// Registry references like "skill/calculator:1.0.0" should NOT be treated as local paths.
|
||||
func IsLocalSkillPath(name string) bool {
|
||||
// Local paths are explicitly indicated by path prefixes
|
||||
return strings.HasPrefix(name, "/") ||
|
||||
strings.HasPrefix(name, "./") ||
|
||||
strings.HasPrefix(name, "../") ||
|
||||
strings.HasPrefix(name, "~")
|
||||
}
|
||||
|
||||
// SkillNamespace is the namespace used for standalone skills in the registry.
|
||||
const SkillNamespace = "skill"
|
||||
|
||||
// IsSkillReference checks if a name refers to a skill (has skill/ prefix).
|
||||
func IsSkillReference(name string) bool {
|
||||
// Check for skill/ prefix (handles both "skill/foo" and "registry/skill/foo")
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
|
||||
// skill/name or skill/name:tag
|
||||
if len(parts) >= 1 && parts[0] == SkillNamespace {
|
||||
return true
|
||||
}
|
||||
// namespace/skill/name (e.g., myuser/skill/calc) - not a skill ref
|
||||
// registry/skill/name (e.g., registry.ollama.ai/skill/calc)
|
||||
if len(parts) >= 2 && parts[1] == SkillNamespace {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseSkillName parses a skill reference string into a model.Name.
|
||||
// The Kind field is set to "skill".
|
||||
// Examples:
|
||||
// - "calculator" -> library/skill/calculator:latest
|
||||
// - "myname/calculator" -> myname/skill/calculator:latest
|
||||
// - "myname/skill/calculator:1.0.0" -> myname/skill/calculator:1.0.0
|
||||
func ParseSkillName(name string) model.Name {
|
||||
// Use the standard parser which now handles Kind
|
||||
n := model.ParseName(name)
|
||||
|
||||
// If Kind wasn't set (old format without skill/), set it
|
||||
if n.Kind == "" {
|
||||
n.Kind = SkillNamespace
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// SkillDisplayName returns a user-friendly display name for a skill.
|
||||
func SkillDisplayName(n model.Name) string {
|
||||
return n.DisplayShortest()
|
||||
}
|
||||
|
||||
// GetSkillManifestPath returns the path to the skill manifest file.
|
||||
// Uses the 5-part structure: host/namespace/kind/model/tag
|
||||
func GetSkillManifestPath(n model.Name) (string, error) {
|
||||
if n.Model == "" {
|
||||
return "", fmt.Errorf("skill name is required")
|
||||
}
|
||||
|
||||
// Ensure Kind is set
|
||||
if n.Kind == "" {
|
||||
n.Kind = SkillNamespace
|
||||
}
|
||||
|
||||
path := filepath.Join(
|
||||
envconfig.Models(),
|
||||
"manifests",
|
||||
n.Filepath(),
|
||||
)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
8
server/sparse_common.go
Normal file
8
server/sparse_common.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import "os"
|
||||
|
||||
func setSparse(*os.File) {
|
||||
}
|
||||
17
server/sparse_windows.go
Normal file
17
server/sparse_windows.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setSparse(file *os.File) {
|
||||
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||
windows.DeviceIoControl( //nolint:errcheck
|
||||
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
nil, nil,
|
||||
)
|
||||
}
|
||||
37
skills/calculator-skill/SKILL.md
Normal file
37
skills/calculator-skill/SKILL.md
Normal file
@@ -0,0 +1,37 @@
|
||||
---
|
||||
name: calculator-skill
|
||||
description: A skill for performing mathematical calculations using a Python script. Use when the user asks to calculate, compute, or do math operations.
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This skill performs mathematical calculations using a bundled Python script for accuracy.
|
||||
|
||||
## When to use
|
||||
|
||||
- The user asks to calculate something
|
||||
- The user wants to do math (add, subtract, multiply, divide)
|
||||
- The user asks about percentages or conversions
|
||||
- Any arithmetic or mathematical operation is needed
|
||||
|
||||
## Instructions
|
||||
|
||||
1. When the user asks for a calculation, use the `run_skill_script` tool to execute the calculation script.
|
||||
2. Call the script like this: `python3 scripts/calculate.py "<expression>"`
|
||||
3. Return the result from the script output to the user.
|
||||
|
||||
## Examples
|
||||
|
||||
For "What is 25 * 4?":
|
||||
- Call: `run_skill_script` with skill="calculator-skill" and command="python3 scripts/calculate.py '25 * 4'"
|
||||
- Output: "25 * 4 = 100"
|
||||
|
||||
For "Calculate 15% of 200":
|
||||
- Call: `run_skill_script` with skill="calculator-skill" and command="python3 scripts/calculate.py '15/100 * 200'"
|
||||
- Output: "15/100 * 200 = 30.0"
|
||||
|
||||
For "Add 123 and 456":
|
||||
- Call: `run_skill_script` with skill="calculator-skill" and command="python3 scripts/calculate.py '123 + 456'"
|
||||
- Output: "123 + 456 = 579"
|
||||
41
skills/calculator-skill/scripts/calculate.py
Executable file
41
skills/calculator-skill/scripts/calculate.py
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Calculator script for performing mathematical operations.
|
||||
Usage: python calculate.py <expression>
|
||||
Example: python calculate.py "25 * 4"
|
||||
"""
|
||||
import sys
|
||||
import re
|
||||
|
||||
def safe_eval(expression):
|
||||
"""Safely evaluate a mathematical expression."""
|
||||
# Only allow numbers, operators, parentheses, and whitespace
|
||||
if not re.match(r'^[\d\s\+\-\*\/\.\(\)\%]+$', expression):
|
||||
raise ValueError(f"Invalid expression: {expression}")
|
||||
|
||||
# Replace % with /100* for percentage calculations
|
||||
# e.g., "15% of 200" would be passed as "15/100*200"
|
||||
|
||||
try:
|
||||
result = eval(expression)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not evaluate: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python calculate.py <expression>")
|
||||
print("Example: python calculate.py '25 * 4'")
|
||||
sys.exit(1)
|
||||
|
||||
expression = ' '.join(sys.argv[1:])
|
||||
|
||||
try:
|
||||
result = safe_eval(expression)
|
||||
print(f"{expression} = {result}")
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
skills/demo-agent.Agentfile
Normal file
26
skills/demo-agent.Agentfile
Normal file
@@ -0,0 +1,26 @@
|
||||
FROM gpt-oss:20b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
|
||||
SKILL /Users/parth/Documents/repos/ollama/skills/calculator-skill
|
||||
SKILL /Users/parth/Documents/repos/ollama/skills/mock-logs-skill
|
||||
SKILL /Users/parth/Documents/repos/ollama/skills/ducky-skill
|
||||
|
||||
SYSTEM """You are a helpful assistant with access to specialized skills.
|
||||
|
||||
When asked to perform calculations, use the calculator skill's run_skill_script tool.
|
||||
When asked to generate logs or show sample log output, use the mock-logs skill's run_skill_script tool.
|
||||
When asked to run ducky or process directories with ducky, use the ducky skill's run_skill_script tool.
|
||||
|
||||
CRITICAL INSTRUCTION - YOU MUST FOLLOW THIS:
|
||||
After ANY tool call completes and returns output, you MUST write additional text analyzing, explaining, or summarizing the results. Your response is NOT complete until you have provided this analysis. Do NOT end your turn immediately after tool output appears.
|
||||
|
||||
Example workflow for mock logs:
|
||||
1. Call run_skill_script to generate logs
|
||||
2. Tool returns log output
|
||||
3. YOU MUST THEN WRITE: An analysis of the logs - identify patterns, note log levels, highlight any errors/warnings, and explain what the logs show
|
||||
|
||||
Never just show raw output and stop. Always add your analysis afterwards."""
|
||||
|
||||
PARAMETER temperature 0.3
|
||||
PARAMETER top_p 0.9
|
||||
38
skills/ducky-skill/SKILL.md
Normal file
38
skills/ducky-skill/SKILL.md
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: ducky
|
||||
description: Run DuckY CLI tool for processing directories with AI models
|
||||
---
|
||||
|
||||
# DuckY Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This skill provides access to the DuckY CLI tool, which processes directories using AI models.
|
||||
|
||||
## When to use
|
||||
|
||||
- User asks to run ducky on a directory
|
||||
- User wants to process files with ducky
|
||||
- User asks about ducky or wants to use ducky features
|
||||
- User wants to poll a crumb
|
||||
|
||||
## Instructions
|
||||
|
||||
1. When the user asks to run ducky, use the `run_skill_script` tool
|
||||
2. Call: `./scripts/run_ducky.sh [args]`
|
||||
- `-d <directory>` - Directory to process
|
||||
- `-m <model>` - Model to use
|
||||
- `-l` - Run locally with Ollama
|
||||
- `--poll <crumb>` - Poll a specific crumb
|
||||
- `-i <seconds>` - Polling interval
|
||||
|
||||
## Examples
|
||||
|
||||
For "Run ducky on the current directory":
|
||||
- Call: `run_skill_script` with skill="ducky" and command="./scripts/run_ducky.sh -d . -l"
|
||||
|
||||
For "Run ducky locally on src folder":
|
||||
- Call: `run_skill_script` with skill="ducky" and command="./scripts/run_ducky.sh -d src -l"
|
||||
|
||||
For "Poll the build crumb every 30 seconds":
|
||||
- Call: `run_skill_script` with skill="ducky" and command="./scripts/run_ducky.sh --poll build -i 30 -l"
|
||||
5
skills/ducky-skill/scripts/run_ducky.sh
Executable file
5
skills/ducky-skill/scripts/run_ducky.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Wrapper script for ducky CLI
|
||||
|
||||
# Pass all arguments to ducky
|
||||
exec ducky "$@"
|
||||
119
skills/excel-skill/SKILL.md
Normal file
119
skills/excel-skill/SKILL.md
Normal file
@@ -0,0 +1,119 @@
|
||||
---
|
||||
name: excel-skill
|
||||
description: Help non-technical users process Excel and CSV data - summarize spreadsheets, find duplicates, filter rows, calculate statistics, and clean up data. Use when the user mentions Excel, spreadsheet, CSV, or asks about their data.
|
||||
---
|
||||
|
||||
# Excel Data Processing Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This skill helps users work with Excel (.xlsx) and CSV files without needing technical knowledge. It can summarize data, find problems, answer questions about the data, and perform common cleanup tasks.
|
||||
|
||||
## When to use
|
||||
|
||||
- User uploads or mentions an Excel or CSV file
|
||||
- User wants to understand what's in their data
|
||||
- User asks about duplicates, missing values, or data quality
|
||||
- User wants to filter, sort, or summarize data
|
||||
- User asks questions like "how many", "what's the average", "show me the top 10"
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Understand the data first
|
||||
|
||||
When a user provides a file, ALWAYS start by running a summary to understand what you're working with:
|
||||
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" summary
|
||||
```
|
||||
|
||||
This shows:
|
||||
- Number of rows and columns
|
||||
- Column names and their data types
|
||||
- Sample of the data
|
||||
- Missing value counts
|
||||
|
||||
### Step 2: Answer their question
|
||||
|
||||
Based on what the user asks, use the appropriate command:
|
||||
|
||||
**Get statistics for a column:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" stats "<column_name>"
|
||||
```
|
||||
Shows count, average, min, max, and common values.
|
||||
|
||||
**Find duplicate rows:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" duplicates
|
||||
```
|
||||
Or check duplicates in specific columns:
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" duplicates "<column_name>"
|
||||
```
|
||||
|
||||
**Filter rows:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" filter "<column>" "<operator>" "<value>"
|
||||
```
|
||||
Operators: equals, contains, greater, less, not_equals
|
||||
Examples:
|
||||
- `filter "Status" "equals" "Active"`
|
||||
- `filter "Amount" "greater" "1000"`
|
||||
- `filter "Name" "contains" "Smith"`
|
||||
|
||||
**Sort data:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" sort "<column>" [asc|desc]
|
||||
```
|
||||
|
||||
**Count values in a column:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" count "<column_name>"
|
||||
```
|
||||
Shows how many times each value appears.
|
||||
|
||||
**Get top/bottom rows:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" top "<column>" <number>
|
||||
uv run scripts/process_data.py "<filepath>" bottom "<column>" <number>
|
||||
```
|
||||
|
||||
**Find missing values:**
|
||||
```
|
||||
uv run scripts/process_data.py "<filepath>" missing
|
||||
```
|
||||
|
||||
**Export filtered/processed data:**
|
||||
Add `--output "<new_filepath>"` to any command to save results.
|
||||
|
||||
## Examples
|
||||
|
||||
**User: "What's in this spreadsheet?"**
|
||||
Run: `uv run scripts/process_data.py "sales.xlsx" summary`
|
||||
|
||||
**User: "Are there any duplicate entries?"**
|
||||
Run: `uv run scripts/process_data.py "sales.xlsx" duplicates`
|
||||
|
||||
**User: "How many sales per region?"**
|
||||
Run: `uv run scripts/process_data.py "sales.xlsx" count "Region"`
|
||||
|
||||
**User: "Show me orders over $500"**
|
||||
Run: `uv run scripts/process_data.py "orders.csv" filter "Amount" "greater" "500"`
|
||||
|
||||
**User: "What's the average order value?"**
|
||||
Run: `uv run scripts/process_data.py "orders.csv" stats "Amount"`
|
||||
|
||||
**User: "Find all rows with missing email addresses"**
|
||||
Run: `uv run scripts/process_data.py "contacts.xlsx" filter "Email" "equals" ""`
|
||||
|
||||
**User: "Show me the top 10 customers by revenue"**
|
||||
Run: `uv run scripts/process_data.py "customers.csv" top "Revenue" 10`
|
||||
|
||||
## Tips for helping non-technical users
|
||||
|
||||
1. Always explain what you found in plain language
|
||||
2. If there are issues (duplicates, missing data), explain why it matters
|
||||
3. Offer to help fix problems you discover
|
||||
4. When showing numbers, provide context ("this is high/low compared to...")
|
||||
5. Ask clarifying questions if the column names are ambiguous
|
||||
11
skills/excel-skill/sample_data.csv
Normal file
11
skills/excel-skill/sample_data.csv
Normal file
@@ -0,0 +1,11 @@
|
||||
Name,Region,Amount,Status,Email
|
||||
Alice,North,1500,Active,alice@example.com
|
||||
Bob,South,2300,Active,bob@example.com
|
||||
Charlie,North,800,Inactive,charlie@example.com
|
||||
Diana,East,1500,Active,diana@example.com
|
||||
Eve,South,3200,Active,
|
||||
Frank,North,950,Inactive,frank@example.com
|
||||
Grace,West,2100,Active,grace@example.com
|
||||
Alice,North,1500,Active,alice@example.com
|
||||
Henry,East,1800,Active,henry@example.com
|
||||
Ivy,South,2300,Inactive,ivy@example.com
|
||||
|
395
skills/excel-skill/scripts/process_data.py
Normal file
395
skills/excel-skill/scripts/process_data.py
Normal file
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "pandas",
|
||||
# "openpyxl",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Excel/CSV Data Processing Script for non-technical users.
|
||||
Handles common data operations: summary, statistics, filtering, duplicates, etc.
|
||||
|
||||
Usage: uv run scripts/process_data.py <filepath> <command> [args...] [--output <output_path>]
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_file(filepath):
|
||||
"""Load Excel or CSV file into a DataFrame."""
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {filepath}")
|
||||
sys.exit(1)
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
try:
|
||||
if suffix in ['.xlsx', '.xls']:
|
||||
df = pd.read_excel(filepath)
|
||||
elif suffix == '.csv':
|
||||
df = pd.read_csv(filepath)
|
||||
else:
|
||||
# Try CSV as default
|
||||
df = pd.read_csv(filepath)
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"Error reading file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def save_output(df, output_path):
|
||||
"""Save DataFrame to file."""
|
||||
path = Path(output_path)
|
||||
suffix = path.suffix.lower()
|
||||
try:
|
||||
if suffix in ['.xlsx', '.xls']:
|
||||
df.to_excel(output_path, index=False)
|
||||
else:
|
||||
df.to_csv(output_path, index=False)
|
||||
print(f"\nSaved {len(df)} rows to: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving file: {e}")
|
||||
|
||||
|
||||
def cmd_summary(df, args):
|
||||
"""Show overview of the data."""
|
||||
print("=" * 60)
|
||||
print("DATA SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"\nRows: {len(df):,}")
|
||||
print(f"Columns: {len(df.columns)}")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("COLUMNS:")
|
||||
print("-" * 40)
|
||||
for col in df.columns:
|
||||
dtype = df[col].dtype
|
||||
non_null = df[col].notna().sum()
|
||||
null_count = df[col].isna().sum()
|
||||
|
||||
type_label = "text" if dtype == 'object' else ("number" if dtype in ['int64', 'float64'] else str(dtype))
|
||||
null_info = f" ({null_count} missing)" if null_count > 0 else ""
|
||||
print(f" - {col}: {type_label}{null_info}")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("SAMPLE DATA (first 5 rows):")
|
||||
print("-" * 40)
|
||||
print(df.head().to_string())
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def cmd_stats(df, args):
|
||||
"""Show statistics for a column."""
|
||||
if not args.column:
|
||||
print("Error: Please specify a column name")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\nSTATISTICS FOR: {col}")
|
||||
print("=" * 40)
|
||||
|
||||
series = df[col]
|
||||
print(f"Total values: {len(series):,}")
|
||||
print(f"Non-empty: {series.notna().sum():,}")
|
||||
print(f"Empty/missing: {series.isna().sum():,}")
|
||||
print(f"Unique values: {series.nunique():,}")
|
||||
|
||||
if pd.api.types.is_numeric_dtype(series):
|
||||
print(f"\nNumeric Statistics:")
|
||||
print(f" Sum: {series.sum():,.2f}")
|
||||
print(f" Average: {series.mean():,.2f}")
|
||||
print(f" Median: {series.median():,.2f}")
|
||||
print(f" Min: {series.min():,.2f}")
|
||||
print(f" Max: {series.max():,.2f}")
|
||||
print(f" Std Dev: {series.std():,.2f}")
|
||||
else:
|
||||
print(f"\nMost common values:")
|
||||
for val, count in series.value_counts().head(10).items():
|
||||
pct = count / len(series) * 100
|
||||
print(f" {val}: {count:,} ({pct:.1f}%)")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def cmd_duplicates(df, args):
|
||||
"""Find duplicate rows."""
|
||||
col = args.column
|
||||
|
||||
if col:
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
dups = df[df.duplicated(subset=[col], keep=False)]
|
||||
print(f"\nDUPLICATES IN COLUMN: {col}")
|
||||
else:
|
||||
dups = df[df.duplicated(keep=False)]
|
||||
print(f"\nDUPLICATE ROWS (all columns)")
|
||||
|
||||
print("=" * 40)
|
||||
|
||||
if len(dups) == 0:
|
||||
print("No duplicates found!")
|
||||
else:
|
||||
print(f"Found {len(dups):,} duplicate rows")
|
||||
print("\nDuplicate entries:")
|
||||
print(dups.to_string())
|
||||
|
||||
return dups
|
||||
|
||||
|
||||
def cmd_filter(df, args):
|
||||
"""Filter rows based on condition."""
|
||||
if not args.column or not args.operator or args.value is None:
|
||||
print("Error: Filter requires column, operator, and value")
|
||||
print("Usage: filter <column> <operator> <value>")
|
||||
print("Operators: equals, not_equals, contains, greater, less")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
op = args.operator.lower()
|
||||
val = args.value
|
||||
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
original_count = len(df)
|
||||
|
||||
if op == 'equals':
|
||||
if val == '':
|
||||
result = df[df[col].isna() | (df[col] == '')]
|
||||
else:
|
||||
# Try numeric comparison if possible
|
||||
try:
|
||||
result = df[df[col] == float(val)]
|
||||
except:
|
||||
result = df[df[col].astype(str).str.lower() == val.lower()]
|
||||
elif op == 'not_equals':
|
||||
try:
|
||||
result = df[df[col] != float(val)]
|
||||
except:
|
||||
result = df[df[col].astype(str).str.lower() != val.lower()]
|
||||
elif op == 'contains':
|
||||
result = df[df[col].astype(str).str.lower().str.contains(val.lower(), na=False)]
|
||||
elif op == 'greater':
|
||||
try:
|
||||
result = df[pd.to_numeric(df[col], errors='coerce') > float(val)]
|
||||
except:
|
||||
print(f"Error: Cannot compare '{col}' as numbers")
|
||||
sys.exit(1)
|
||||
elif op == 'less':
|
||||
try:
|
||||
result = df[pd.to_numeric(df[col], errors='coerce') < float(val)]
|
||||
except:
|
||||
print(f"Error: Cannot compare '{col}' as numbers")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"Error: Unknown operator '{op}'")
|
||||
print("Valid operators: equals, not_equals, contains, greater, less")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\nFILTER: {col} {op} '{val}'")
|
||||
print("=" * 40)
|
||||
print(f"Found {len(result):,} matching rows (out of {original_count:,})")
|
||||
|
||||
if len(result) > 0:
|
||||
print("\nResults:")
|
||||
if len(result) > 50:
|
||||
print(result.head(50).to_string())
|
||||
print(f"\n... and {len(result) - 50} more rows")
|
||||
else:
|
||||
print(result.to_string())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def cmd_sort(df, args):
|
||||
"""Sort data by column."""
|
||||
if not args.column:
|
||||
print("Error: Please specify a column to sort by")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
ascending = args.order != 'desc'
|
||||
result = df.sort_values(by=col, ascending=ascending)
|
||||
|
||||
order_label = "ascending" if ascending else "descending"
|
||||
print(f"\nSORTED BY: {col} ({order_label})")
|
||||
print("=" * 40)
|
||||
|
||||
if len(result) > 50:
|
||||
print(result.head(50).to_string())
|
||||
print(f"\n... and {len(result) - 50} more rows")
|
||||
else:
|
||||
print(result.to_string())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def cmd_count(df, args):
|
||||
"""Count values in a column."""
|
||||
if not args.column:
|
||||
print("Error: Please specify a column to count")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
counts = df[col].value_counts()
|
||||
|
||||
print(f"\nVALUE COUNTS FOR: {col}")
|
||||
print("=" * 40)
|
||||
print(f"Total unique values: {len(counts):,}")
|
||||
print()
|
||||
|
||||
for val, count in counts.items():
|
||||
pct = count / len(df) * 100
|
||||
print(f" {val}: {count:,} ({pct:.1f}%)")
|
||||
|
||||
# Return as DataFrame for potential export
|
||||
return counts.reset_index().rename(columns={'index': col, col: 'count'})
|
||||
|
||||
|
||||
def cmd_top(df, args):
|
||||
"""Get top N rows by column value."""
|
||||
if not args.column:
|
||||
print("Error: Please specify a column")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
# Number can be in args.operator position due to positional parsing
|
||||
n = int(args.number) if args.number else (int(args.operator) if args.operator and args.operator.isdigit() else 10)
|
||||
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
result = df.nlargest(n, col)
|
||||
|
||||
print(f"\nTOP {n} BY: {col}")
|
||||
print("=" * 40)
|
||||
print(result.to_string())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def cmd_bottom(df, args):
|
||||
"""Get bottom N rows by column value."""
|
||||
if not args.column:
|
||||
print("Error: Please specify a column")
|
||||
sys.exit(1)
|
||||
|
||||
col = args.column
|
||||
# Number can be in args.operator position due to positional parsing
|
||||
n = int(args.number) if args.number else (int(args.operator) if args.operator and args.operator.isdigit() else 10)
|
||||
|
||||
if col not in df.columns:
|
||||
print(f"Error: Column '{col}' not found")
|
||||
print(f"Available columns: {', '.join(df.columns)}")
|
||||
sys.exit(1)
|
||||
|
||||
result = df.nsmallest(n, col)
|
||||
|
||||
print(f"\nBOTTOM {n} BY: {col}")
|
||||
print("=" * 40)
|
||||
print(result.to_string())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def cmd_missing(df, args):
|
||||
"""Find rows with missing values."""
|
||||
print("\nMISSING VALUE ANALYSIS")
|
||||
print("=" * 40)
|
||||
|
||||
# Summary by column
|
||||
print("\nMissing values per column:")
|
||||
for col in df.columns:
|
||||
missing = df[col].isna().sum()
|
||||
if missing > 0:
|
||||
pct = missing / len(df) * 100
|
||||
print(f" {col}: {missing:,} ({pct:.1f}%)")
|
||||
|
||||
total_missing = df.isna().sum().sum()
|
||||
if total_missing == 0:
|
||||
print(" No missing values found!")
|
||||
return df
|
||||
|
||||
# Rows with any missing values
|
||||
rows_with_missing = df[df.isna().any(axis=1)]
|
||||
print(f"\nRows with missing values: {len(rows_with_missing):,}")
|
||||
|
||||
if len(rows_with_missing) > 0 and len(rows_with_missing) <= 50:
|
||||
print("\nRows with missing data:")
|
||||
print(rows_with_missing.to_string())
|
||||
elif len(rows_with_missing) > 50:
|
||||
print("\nFirst 50 rows with missing data:")
|
||||
print(rows_with_missing.head(50).to_string())
|
||||
print(f"\n... and {len(rows_with_missing) - 50} more rows")
|
||||
|
||||
return rows_with_missing
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process Excel/CSV data')
|
||||
parser.add_argument('filepath', help='Path to Excel or CSV file')
|
||||
parser.add_argument('command', choices=['summary', 'stats', 'duplicates', 'filter', 'sort', 'count', 'top', 'bottom', 'missing'],
|
||||
help='Command to run')
|
||||
parser.add_argument('column', nargs='?', help='Column name (for stats, filter, sort, count, top, bottom, duplicates)')
|
||||
parser.add_argument('operator', nargs='?', help='Operator for filter (equals, contains, greater, less, not_equals)')
|
||||
parser.add_argument('value', nargs='?', help='Value for filter')
|
||||
parser.add_argument('number', nargs='?', help='Number for top/bottom')
|
||||
parser.add_argument('--order', choices=['asc', 'desc'], default='asc', help='Sort order')
|
||||
parser.add_argument('--output', '-o', help='Output file path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the file
|
||||
df = load_file(args.filepath)
|
||||
|
||||
# Run the command
|
||||
commands = {
|
||||
'summary': cmd_summary,
|
||||
'stats': cmd_stats,
|
||||
'duplicates': cmd_duplicates,
|
||||
'filter': cmd_filter,
|
||||
'sort': cmd_sort,
|
||||
'count': cmd_count,
|
||||
'top': cmd_top,
|
||||
'bottom': cmd_bottom,
|
||||
'missing': cmd_missing,
|
||||
}
|
||||
|
||||
result = commands[args.command](df, args)
|
||||
|
||||
# Save output if requested
|
||||
if args.output and isinstance(result, pd.DataFrame):
|
||||
save_output(result, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
skills/hello-skill/SKILL.md
Normal file
25
skills/hello-skill/SKILL.md
Normal file
@@ -0,0 +1,25 @@
|
||||
---
|
||||
name: hello-skill
|
||||
description: Simple test skill for verifying Agent Skills integration in ollama run. Use when the user asks to test skills, sample skills, or wants a quick hello workflow.
|
||||
---
|
||||
|
||||
# Hello Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This is a minimal skill to validate that skills load correctly and that tool calls can read additional files.
|
||||
|
||||
## When to use
|
||||
|
||||
- The user asks to test skills integration.
|
||||
- The user wants a simple example skill.
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Reply with a short greeting that mentions the skill name.
|
||||
2. If you need a template greeting, read `references/GREETING.md` using the `read_skill_file` tool.
|
||||
|
||||
## Example
|
||||
|
||||
User: "Test the skills feature."
|
||||
Assistant: "Hello from hello-skill."
|
||||
2
skills/hello-skill/references/GREETING.md
Normal file
2
skills/hello-skill/references/GREETING.md
Normal file
@@ -0,0 +1,2 @@
|
||||
Template greeting:
|
||||
Hello from hello-skill. Skills are working.
|
||||
8
skills/math-agent.Agentfile
Normal file
8
skills/math-agent.Agentfile
Normal file
@@ -0,0 +1,8 @@
|
||||
FROM gpt-oss:20b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
SKILL /Users/parth/Documents/repos/ollama/skills
|
||||
SYSTEM You are a helpful math assistant. Follow the instructions from your loaded skills when performing tasks.
|
||||
|
||||
PARAMETER temperature 0.3
|
||||
PARAMETER top_p 0.9
|
||||
7
skills/mcp-agent.Agentfile
Normal file
7
skills/mcp-agent.Agentfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM gpt-oss:20b
|
||||
AGENT TYPE conversational
|
||||
SYSTEM You are a helpful assistant with MCP tools. You can echo text and add numbers using the mcp_test-mcp_echo and mcp_test-mcp_add tools.
|
||||
MCP test-mcp python3 ./test-mcp/server.py
|
||||
SKILL ./skills/excel-skill
|
||||
SKILL ./skills/pdf-skill
|
||||
|
||||
36
skills/mock-logs-skill/SKILL.md
Normal file
36
skills/mock-logs-skill/SKILL.md
Normal file
@@ -0,0 +1,36 @@
|
||||
---
|
||||
name: mock-logs
|
||||
description: Outputs mock log entries for testing and demonstration purposes
|
||||
---
|
||||
|
||||
# Mock Logs Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This skill generates mock log entries for testing, debugging, and demonstration purposes.
|
||||
|
||||
## When to use
|
||||
|
||||
- User asks to generate sample logs
|
||||
- User wants to see example log output
|
||||
- User needs test data for log parsing
|
||||
- User asks about log formats
|
||||
|
||||
## Instructions
|
||||
|
||||
1. When the user asks for mock logs, use the `run_skill_script` tool
|
||||
2. Call: `python3 scripts/generate_logs.py [count] [level]`
|
||||
- count: Number of log entries (default: 5)
|
||||
- level: Log level filter - info, warn, error, debug, or all (default: all)
|
||||
3. Return the generated logs to the user
|
||||
|
||||
## Examples
|
||||
|
||||
For "Generate some sample logs":
|
||||
- Call: `run_skill_script` with skill="mock-logs" and command="python3 scripts/generate_logs.py 5"
|
||||
|
||||
For "Show me 10 error logs":
|
||||
- Call: `run_skill_script` with skill="mock-logs" and command="python3 scripts/generate_logs.py 10 error"
|
||||
|
||||
For "Generate debug logs":
|
||||
- Call: `run_skill_script` with skill="mock-logs" and command="python3 scripts/generate_logs.py 5 debug"
|
||||
107
skills/mock-logs-skill/scripts/generate_logs.py
Normal file
107
skills/mock-logs-skill/scripts/generate_logs.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate mock log entries for testing."""
|
||||
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
LEVELS = ["INFO", "WARN", "ERROR", "DEBUG"]
|
||||
|
||||
SERVICES = [
|
||||
"api-gateway",
|
||||
"auth-service",
|
||||
"user-service",
|
||||
"payment-service",
|
||||
"notification-service",
|
||||
"cache-manager",
|
||||
"db-connector",
|
||||
"queue-worker",
|
||||
]
|
||||
|
||||
MESSAGES = {
|
||||
"INFO": [
|
||||
"Request processed successfully",
|
||||
"User session started",
|
||||
"Cache hit for key: user_{}",
|
||||
"Connection established to database",
|
||||
"Health check passed",
|
||||
"Configuration reloaded",
|
||||
"Scheduled task completed",
|
||||
"Message published to queue",
|
||||
],
|
||||
"WARN": [
|
||||
"High memory usage detected: {}%",
|
||||
"Slow query detected: {}ms",
|
||||
"Rate limit approaching for client {}",
|
||||
"Retry attempt {} of 3",
|
||||
"Connection pool running low",
|
||||
"Deprecated API endpoint called",
|
||||
"Certificate expires in {} days",
|
||||
],
|
||||
"ERROR": [
|
||||
"Failed to connect to database: timeout",
|
||||
"Authentication failed for user {}",
|
||||
"Payment processing error: insufficient funds",
|
||||
"Service unavailable: upstream timeout",
|
||||
"Invalid request payload",
|
||||
"Queue message processing failed",
|
||||
"Disk space critical: {}% used",
|
||||
],
|
||||
"DEBUG": [
|
||||
"Entering function: process_request",
|
||||
"Variable state: count={}",
|
||||
"SQL query: SELECT * FROM users WHERE id={}",
|
||||
"HTTP response: status={}, body_size={}",
|
||||
"Cache miss for key: session_{}",
|
||||
"Decoding JWT token",
|
||||
"Validating input parameters",
|
||||
],
|
||||
}
|
||||
|
||||
def generate_log_entry(level=None, base_time=None):
|
||||
if level is None:
|
||||
level = random.choice(LEVELS)
|
||||
|
||||
service = random.choice(SERVICES)
|
||||
message_template = random.choice(MESSAGES[level])
|
||||
|
||||
# Fill in placeholders with random values
|
||||
message = message_template
|
||||
while "{}" in message:
|
||||
placeholder_value = random.randint(1, 9999)
|
||||
message = message.replace("{}", str(placeholder_value), 1)
|
||||
|
||||
if base_time is None:
|
||||
base_time = datetime.now()
|
||||
|
||||
timestamp = base_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
return f"[{timestamp}] [{level:5}] [{service}] {message}"
|
||||
|
||||
def main():
|
||||
count = 5
|
||||
level_filter = None
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
count = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print(f"Error: Invalid count '{sys.argv[1]}'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) > 2:
|
||||
level_arg = sys.argv[2].upper()
|
||||
if level_arg != "ALL" and level_arg in LEVELS:
|
||||
level_filter = level_arg
|
||||
elif level_arg != "ALL":
|
||||
print(f"Error: Invalid level '{sys.argv[2]}'. Use: info, warn, error, debug, or all", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
base_time = datetime.now() - timedelta(seconds=count)
|
||||
|
||||
for i in range(count):
|
||||
log_time = base_time + timedelta(seconds=i, milliseconds=random.randint(0, 999))
|
||||
print(generate_log_entry(level=level_filter, base_time=log_time))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
skills/pdf-skill/SKILL.md
Normal file
109
skills/pdf-skill/SKILL.md
Normal file
@@ -0,0 +1,109 @@
|
||||
---
|
||||
name: pdf-skill
|
||||
description: Help users work with PDF files - extract text, get document info, search content, extract pages, and merge PDFs. Use when the user mentions PDF, document extraction, or wants to read/combine PDF files.
|
||||
---
|
||||
|
||||
# PDF Processing Skill
|
||||
|
||||
## Purpose
|
||||
|
||||
This skill helps users work with PDF files without needing technical knowledge. It can extract text, search for content, get document information, split and merge PDFs.
|
||||
|
||||
## When to use
|
||||
|
||||
- User uploads or mentions a PDF file
|
||||
- User wants to extract text from a document
|
||||
- User asks "what's in this PDF" or similar
|
||||
- User wants to search for something in a PDF
|
||||
- User wants to combine or split PDF files
|
||||
- User asks about page counts or document info
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Understand the document first
|
||||
|
||||
When a user provides a PDF, start by getting info about it:
|
||||
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" info
|
||||
```
|
||||
|
||||
This shows:
|
||||
- Number of pages
|
||||
- Document metadata (title, author, etc.)
|
||||
- File size
|
||||
|
||||
### Step 2: Perform the requested operation
|
||||
|
||||
Based on what the user asks, use the appropriate command:
|
||||
|
||||
**Extract all text:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" text
|
||||
```
|
||||
Extracts text from all pages.
|
||||
|
||||
**Extract text from specific pages:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" text --pages 1,2,3
|
||||
uv run scripts/process_pdf.py "<filepath>" text --pages 1-5
|
||||
```
|
||||
|
||||
**Search for text:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" search "<query>"
|
||||
```
|
||||
Finds all occurrences and shows surrounding context.
|
||||
|
||||
**Extract tables:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" tables
|
||||
```
|
||||
Attempts to extract tables from the PDF as CSV format.
|
||||
|
||||
**Extract specific pages to new PDF:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" split --pages 1-3 --output "extracted.pdf"
|
||||
```
|
||||
|
||||
**Merge multiple PDFs:**
|
||||
```
|
||||
uv run scripts/process_pdf.py merge "<file1.pdf>" "<file2.pdf>" --output "combined.pdf"
|
||||
```
|
||||
|
||||
**Get word/character count:**
|
||||
```
|
||||
uv run scripts/process_pdf.py "<filepath>" count
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
**User: "What's in this PDF?"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" info`
|
||||
Then: `uv run scripts/process_pdf.py "document.pdf" text --pages 1` (for first page preview)
|
||||
|
||||
**User: "Extract the text from this document"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" text`
|
||||
|
||||
**User: "Find all mentions of 'invoice' in this PDF"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" search "invoice"`
|
||||
|
||||
**User: "How many pages is this?"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" info`
|
||||
|
||||
**User: "Get me just pages 5-10"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" split --pages 5-10 --output "pages_5_10.pdf"`
|
||||
|
||||
**User: "Combine these two PDFs"**
|
||||
Run: `uv run scripts/process_pdf.py merge "doc1.pdf" "doc2.pdf" --output "combined.pdf"`
|
||||
|
||||
**User: "Are there any tables in this PDF?"**
|
||||
Run: `uv run scripts/process_pdf.py "document.pdf" tables`
|
||||
|
||||
## Tips for helping non-technical users
|
||||
|
||||
1. Always start with `info` to understand what you're working with
|
||||
2. For long documents, extract just the first page first to preview
|
||||
3. If text extraction looks garbled, the PDF might be scanned images (OCR needed)
|
||||
4. Explain what you found in plain language
|
||||
5. If tables don't extract well, mention that PDF tables can be tricky
|
||||
114
skills/pdf-skill/sample_invoice.pdf
Normal file
114
skills/pdf-skill/sample_invoice.pdf
Normal file
@@ -0,0 +1,114 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Count 2
|
||||
/Kids [3 0 R
|
||||
5 0 R]
|
||||
/MediaBox [0 0 595.28 841.89]
|
||||
/Type /Pages
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/OpenAction [3 0 R /FitH null]
|
||||
/PageLayout /OneColumn
|
||||
/Pages 1 0 R
|
||||
/Type /Catalog
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Contents 4 0 R
|
||||
/Parent 1 0 R
|
||||
/Resources 9 0 R
|
||||
/Type /Page
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Filter /FlateDecode
|
||||
/Length 442
|
||||
>>
|
||||
stream
|
||||
x<EFBFBD>}<7D><>N<EFBFBD>0<10><><<3C>H<EFBFBD><48>"-<2D>3<EFBFBD><33>vo<76>n%8<><38><EFBFBD>9<02><>h<EFBFBD><68><EFBFBD>o<><6F><EFBFBD>P<EFBFBD>j}<1D><>e<EFBFBD>yl<79><6C>)<29><><EFBFBD><EFBFBD><EFBFBD><1A>W<08><><EFBFBD>P?ßz(JA<06><><EFBFBD><EFBFBD>?<3F><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD>~<7E><>P<EFBFBD>N<EFBFBD><4E>+<02><><EFBFBD>B<EFBFBD>L`<15>S
|
||||
~ )<29>Ix<49><78>, ec<65><63><EFBFBD><EFBFBD>K<EFBFBD>i<EFBFBD>M<>ȿ<>$Hp<48>W<EFBFBD><57><EFBFBD><EFBFBD>߾<EFBFBD><DFBE>n <17>Wm<57><6D>vM<76>n<EFBFBD>{C<>ސd<DE90>b<EFBFBD><62>U<>8<1B><><EFBFBD>{<7B><>;3<14>O<EFBFBD>~3#<23>crX<1E><><EFBFBD>d<EFBFBD><64>t
|
||||
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>R`5<><11>P.hB<68><42>p<>%<25>5+<2B><><EFBFBD>`<60>B<EFBFBD>Q<>}<7D>V8<56>1<EFBFBD>6)<29><>ܼ<0F><><17><>E&g<><1C><><EFBFBD>P2<50>p<EFBFBD>fJQ<><51><EFBFBD>1<EFBFBD>/<2F> <09>27ES\ϔ<>:@<40><07>T<EFBFBD>U<><55>+NU\<5C>*FHQ<48><0E>rc<><63>b_<>C*J<>YE*L<>BI!mYE<59>%<25><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ί<EFBFBD><CEAF>m<EFBFBD><6D><EFBFBD><EFBFBD>sc<><63>
|
||||
6]<5D><><05><>/s%Ȭ<>(<28><><EFBFBD>O<>?<3F><>B<19><>l<EFBFBD>$5<><35><EFBFBD><EFBFBD>rA4<41>=<3D><01>z:<3A>
|
||||
endstream
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Contents 6 0 R
|
||||
/Parent 1 0 R
|
||||
/Resources 10 0 R
|
||||
/Type /Page
|
||||
>>
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/Filter /FlateDecode
|
||||
/Length 306
|
||||
>>
|
||||
stream
|
||||
x<EFBFBD>m<EFBFBD><EFBFBD>N<EFBFBD>0<14><>><3E>a@ <09>8<EFBFBD>M<EFBFBD><4D>#: <06><><02>s<EFBFBD>;<3B>%o<><6F>P<EFBFBD><50><EFBFBD><EFBFBD><EFBFBD>w<EFBFBD>w<EFBFBD>/<2F><><EFBFBD><12><>c<EFBFBD>嚃<EFBFBD>X<EFBFBD><58><EFBFBD><EFBFBD><EFBFBD>Rg<>BY<42><59><EFBFBD>B<EFBFBD><42><EFBFBD>!<21>H<><48><EFBFBD><EFBFBD>VG<56>l<EFBFBD>F<EFBFBD><46>\<5C><><EFBFBD><EFBFBD>DQ<44>zu<7A>x<EFBFBD>SO6B<07>#a<><61>N[Z9<05><0E>~;<3B>(ő^Ӊa<D389> <09><><EFBFBD><EFBFBD>j<EFBFBD><6A><EFBFBD>L\<5C>w6<77>̄<EFBFBD><CC84><EFBFBD><03>]Wu<57><75><19><>i(<28>{B<><42>4<0C>GoSJ)"<22>'~7<>B8<42><38><EFBFBD>vxR<78><52>x<EFBFBD>FTL<54>G<EFBFBD><47><EFBFBD><EFBFBD>)r<>ƈ$NK<4E>`0$A%n<>"6<>m<><6D><EFBFBD>#<23>Tb<54><62><EFBFBD>J&^!<21>m<EFBFBD><6D><EFBFBD><EFBFBD>#<23>X<EFBFBD><58>?<3F><><EFBFBD><13>ѥ<>t{lƠ-<2D>pq<70><71><05><1C>|
|
||||
endstream
|
||||
endobj
|
||||
7 0 obj
|
||||
<<
|
||||
/BaseFont /Helvetica-Bold
|
||||
/Encoding /WinAnsiEncoding
|
||||
/Subtype /Type1
|
||||
/Type /Font
|
||||
>>
|
||||
endobj
|
||||
8 0 obj
|
||||
<<
|
||||
/BaseFont /Helvetica
|
||||
/Encoding /WinAnsiEncoding
|
||||
/Subtype /Type1
|
||||
/Type /Font
|
||||
>>
|
||||
endobj
|
||||
9 0 obj
|
||||
<<
|
||||
/Font <</F1 7 0 R
|
||||
/F2 8 0 R>>
|
||||
/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]
|
||||
>>
|
||||
endobj
|
||||
10 0 obj
|
||||
<<
|
||||
/Font <</F1 7 0 R
|
||||
/F2 8 0 R>>
|
||||
/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]
|
||||
>>
|
||||
endobj
|
||||
11 0 obj
|
||||
<<
|
||||
/CreationDate (D:20251230034342Z)
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 12
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000108 00000 n
|
||||
0000000211 00000 n
|
||||
0000000291 00000 n
|
||||
0000000805 00000 n
|
||||
0000000886 00000 n
|
||||
0000001264 00000 n
|
||||
0000001366 00000 n
|
||||
0000001463 00000 n
|
||||
0000001560 00000 n
|
||||
0000001658 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 12
|
||||
/Root 2 0 R
|
||||
/Info 11 0 R
|
||||
/ID [<2B10F02FFCC93A7FD39B360714BACC88><2B10F02FFCC93A7FD39B360714BACC88>]
|
||||
>>
|
||||
startxref
|
||||
1714
|
||||
367
skills/pdf-skill/scripts/process_pdf.py
Normal file
367
skills/pdf-skill/scripts/process_pdf.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "pypdf",
|
||||
# "pdfplumber",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
PDF Processing Script for non-technical users.
|
||||
Handles common PDF operations: info, text extraction, search, split, merge.
|
||||
|
||||
Usage: uv run scripts/process_pdf.py <filepath> <command> [args...] [--output <output_path>]
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_pdf_pypdf(filepath):
|
||||
"""Load PDF using pypdf."""
|
||||
from pypdf import PdfReader
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {filepath}")
|
||||
sys.exit(1)
|
||||
try:
|
||||
return PdfReader(filepath)
|
||||
except Exception as e:
|
||||
print(f"Error reading PDF: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_pdf_plumber(filepath):
|
||||
"""Load PDF using pdfplumber (better for text/tables)."""
|
||||
import pdfplumber
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {filepath}")
|
||||
sys.exit(1)
|
||||
try:
|
||||
return pdfplumber.open(filepath)
|
||||
except Exception as e:
|
||||
print(f"Error reading PDF: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_page_range(pages_str, max_pages):
|
||||
"""Parse page range string like '1,2,3' or '1-5' or '1,3-5,7'."""
|
||||
if not pages_str:
|
||||
return list(range(1, max_pages + 1))
|
||||
|
||||
pages = set()
|
||||
parts = pages_str.split(',')
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if '-' in part:
|
||||
start, end = part.split('-', 1)
|
||||
start = int(start.strip())
|
||||
end = int(end.strip())
|
||||
pages.update(range(start, end + 1))
|
||||
else:
|
||||
pages.add(int(part))
|
||||
|
||||
# Filter to valid range and sort
|
||||
valid_pages = sorted([p for p in pages if 1 <= p <= max_pages])
|
||||
return valid_pages
|
||||
|
||||
|
||||
def cmd_info(args):
|
||||
"""Show PDF information."""
|
||||
reader = load_pdf_pypdf(args.filepath)
|
||||
|
||||
print("=" * 60)
|
||||
print("PDF INFORMATION")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nFile: {args.filepath}")
|
||||
print(f"Pages: {len(reader.pages)}")
|
||||
|
||||
# File size
|
||||
path = Path(args.filepath)
|
||||
size_bytes = path.stat().st_size
|
||||
if size_bytes < 1024:
|
||||
size_str = f"{size_bytes} bytes"
|
||||
elif size_bytes < 1024 * 1024:
|
||||
size_str = f"{size_bytes / 1024:.1f} KB"
|
||||
else:
|
||||
size_str = f"{size_bytes / (1024 * 1024):.1f} MB"
|
||||
print(f"Size: {size_str}")
|
||||
|
||||
# Metadata
|
||||
meta = reader.metadata
|
||||
if meta:
|
||||
print("\n" + "-" * 40)
|
||||
print("METADATA:")
|
||||
print("-" * 40)
|
||||
if meta.title:
|
||||
print(f" Title: {meta.title}")
|
||||
if meta.author:
|
||||
print(f" Author: {meta.author}")
|
||||
if meta.subject:
|
||||
print(f" Subject: {meta.subject}")
|
||||
if meta.creator:
|
||||
print(f" Creator: {meta.creator}")
|
||||
if meta.creation_date:
|
||||
print(f" Created: {meta.creation_date}")
|
||||
if meta.modification_date:
|
||||
print(f" Modified: {meta.modification_date}")
|
||||
|
||||
|
||||
def cmd_text(args):
|
||||
"""Extract text from PDF."""
|
||||
pdf = load_pdf_plumber(args.filepath)
|
||||
|
||||
pages = parse_page_range(args.pages, len(pdf.pages))
|
||||
|
||||
print("=" * 60)
|
||||
if args.pages:
|
||||
print(f"TEXT EXTRACTION (pages {args.pages})")
|
||||
else:
|
||||
print("TEXT EXTRACTION (all pages)")
|
||||
print("=" * 60)
|
||||
|
||||
for page_num in pages:
|
||||
page = pdf.pages[page_num - 1] # 0-indexed
|
||||
text = page.extract_text() or ""
|
||||
|
||||
print(f"\n--- Page {page_num} ---\n")
|
||||
if text.strip():
|
||||
print(text)
|
||||
else:
|
||||
print("(No text found on this page - may be an image or scan)")
|
||||
|
||||
pdf.close()
|
||||
|
||||
|
||||
def cmd_search(args):
|
||||
"""Search for text in PDF."""
|
||||
if not args.query:
|
||||
print("Error: Please provide a search query")
|
||||
sys.exit(1)
|
||||
|
||||
pdf = load_pdf_plumber(args.filepath)
|
||||
query = args.query.lower()
|
||||
|
||||
print("=" * 60)
|
||||
print(f"SEARCH RESULTS: '{args.query}'")
|
||||
print("=" * 60)
|
||||
|
||||
total_matches = 0
|
||||
|
||||
for i, page in enumerate(pdf.pages):
|
||||
page_num = i + 1
|
||||
text = page.extract_text() or ""
|
||||
|
||||
# Find matches with context
|
||||
text_lower = text.lower()
|
||||
if query in text_lower:
|
||||
# Count occurrences
|
||||
count = text_lower.count(query)
|
||||
total_matches += count
|
||||
|
||||
print(f"\n--- Page {page_num} ({count} match{'es' if count > 1 else ''}) ---")
|
||||
|
||||
# Show context around each match
|
||||
lines = text.split('\n')
|
||||
for j, line in enumerate(lines):
|
||||
if query in line.lower():
|
||||
# Highlight the match (uppercase)
|
||||
highlighted = re.sub(
|
||||
f'({re.escape(args.query)})',
|
||||
r'>>>\1<<<',
|
||||
line,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
print(f" {highlighted}")
|
||||
|
||||
print(f"\n{'=' * 40}")
|
||||
if total_matches == 0:
|
||||
print(f"No matches found for '{args.query}'")
|
||||
else:
|
||||
print(f"Total: {total_matches} match{'es' if total_matches > 1 else ''} found")
|
||||
|
||||
pdf.close()
|
||||
|
||||
|
||||
def cmd_tables(args):
|
||||
"""Extract tables from PDF."""
|
||||
pdf = load_pdf_plumber(args.filepath)
|
||||
|
||||
print("=" * 60)
|
||||
print("TABLE EXTRACTION")
|
||||
print("=" * 60)
|
||||
|
||||
table_count = 0
|
||||
|
||||
for i, page in enumerate(pdf.pages):
|
||||
page_num = i + 1
|
||||
tables = page.extract_tables()
|
||||
|
||||
if tables:
|
||||
for j, table in enumerate(tables):
|
||||
table_count += 1
|
||||
print(f"\n--- Table {table_count} (Page {page_num}) ---\n")
|
||||
|
||||
# Print as CSV-like format
|
||||
for row in table:
|
||||
# Clean up None values
|
||||
cleaned = [str(cell).strip() if cell else "" for cell in row]
|
||||
print(",".join(cleaned))
|
||||
|
||||
if table_count == 0:
|
||||
print("\nNo tables found in this PDF.")
|
||||
print("Note: Table extraction works best with clearly structured tables.")
|
||||
else:
|
||||
print(f"\n{'=' * 40}")
|
||||
print(f"Total: {table_count} table{'s' if table_count > 1 else ''} found")
|
||||
|
||||
pdf.close()
|
||||
|
||||
|
||||
def cmd_count(args):
|
||||
"""Count words and characters in PDF."""
|
||||
pdf = load_pdf_plumber(args.filepath)
|
||||
|
||||
total_chars = 0
|
||||
total_words = 0
|
||||
page_stats = []
|
||||
|
||||
for i, page in enumerate(pdf.pages):
|
||||
text = page.extract_text() or ""
|
||||
chars = len(text)
|
||||
words = len(text.split())
|
||||
total_chars += chars
|
||||
total_words += words
|
||||
page_stats.append((i + 1, words, chars))
|
||||
|
||||
print("=" * 60)
|
||||
print("DOCUMENT STATISTICS")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nTotal pages: {len(pdf.pages)}")
|
||||
print(f"Total words: {total_words:,}")
|
||||
print(f"Total characters: {total_chars:,}")
|
||||
|
||||
if len(pdf.pages) > 1:
|
||||
print(f"\nAverage words per page: {total_words // len(pdf.pages):,}")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("PER-PAGE BREAKDOWN:")
|
||||
print("-" * 40)
|
||||
for page_num, words, chars in page_stats:
|
||||
print(f" Page {page_num}: {words:,} words, {chars:,} chars")
|
||||
|
||||
pdf.close()
|
||||
|
||||
|
||||
def cmd_split(args):
|
||||
"""Extract specific pages to a new PDF."""
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
|
||||
if not args.output:
|
||||
print("Error: Please specify output file with --output")
|
||||
sys.exit(1)
|
||||
|
||||
reader = load_pdf_pypdf(args.filepath)
|
||||
pages = parse_page_range(args.pages, len(reader.pages))
|
||||
|
||||
if not pages:
|
||||
print("Error: No valid pages specified")
|
||||
sys.exit(1)
|
||||
|
||||
writer = PdfWriter()
|
||||
|
||||
for page_num in pages:
|
||||
writer.add_page(reader.pages[page_num - 1])
|
||||
|
||||
with open(args.output, 'wb') as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"Extracted {len(pages)} page(s) to: {args.output}")
|
||||
print(f"Pages included: {', '.join(map(str, pages))}")
|
||||
|
||||
|
||||
def cmd_merge(args):
|
||||
"""Merge multiple PDFs into one."""
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
|
||||
if not args.output:
|
||||
print("Error: Please specify output file with --output")
|
||||
sys.exit(1)
|
||||
|
||||
# Collect all input files
|
||||
files = [args.filepath]
|
||||
if args.query:
|
||||
files.append(args.query)
|
||||
if args.pages:
|
||||
files.append(args.pages)
|
||||
# Check for additional files in remaining args
|
||||
|
||||
# Validate all files exist
|
||||
for f in files:
|
||||
if not Path(f).exists():
|
||||
print(f"Error: File not found: {f}")
|
||||
sys.exit(1)
|
||||
|
||||
writer = PdfWriter()
|
||||
total_pages = 0
|
||||
|
||||
for filepath in files:
|
||||
reader = PdfReader(filepath)
|
||||
for page in reader.pages:
|
||||
writer.add_page(page)
|
||||
total_pages += 1
|
||||
print(f" Added: {filepath} ({len(reader.pages)} pages)")
|
||||
|
||||
with open(args.output, 'wb') as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"\nMerged {len(files)} files ({total_pages} total pages) to: {args.output}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process PDF files')
|
||||
parser.add_argument('filepath', help='Path to PDF file (or "merge" command)')
|
||||
parser.add_argument('command', nargs='?', default='info',
|
||||
help='Command: info, text, search, tables, count, split, merge')
|
||||
parser.add_argument('query', nargs='?', help='Search query or second file for merge')
|
||||
parser.add_argument('--pages', '-p', help='Page range (e.g., "1-3" or "1,2,5")')
|
||||
parser.add_argument('--output', '-o', help='Output file path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle merge as special case (first arg is "merge")
|
||||
if args.filepath == 'merge':
|
||||
if not args.command:
|
||||
print("Error: merge requires at least 2 PDF files")
|
||||
print("Usage: process_pdf.py merge file1.pdf file2.pdf --output combined.pdf")
|
||||
sys.exit(1)
|
||||
# Shift args for merge
|
||||
args.filepath = args.command
|
||||
args.command = 'merge'
|
||||
|
||||
# Run the command
|
||||
commands = {
|
||||
'info': cmd_info,
|
||||
'text': cmd_text,
|
||||
'search': cmd_search,
|
||||
'tables': cmd_tables,
|
||||
'count': cmd_count,
|
||||
'split': cmd_split,
|
||||
'merge': cmd_merge,
|
||||
}
|
||||
|
||||
if args.command not in commands:
|
||||
print(f"Error: Unknown command '{args.command}'")
|
||||
print(f"Available commands: {', '.join(commands.keys())}")
|
||||
sys.exit(1)
|
||||
|
||||
commands[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
skills/test-agent.Agentfile
Normal file
7
skills/test-agent.Agentfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM qwq
|
||||
|
||||
SYSTEM You are a test agent with calculator skills.
|
||||
|
||||
AGENT TYPE conversational
|
||||
|
||||
SKILL ./calculator-skill
|
||||
4
skills/test-mcp/mcp.json
Normal file
4
skills/test-mcp/mcp.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"name": "test-mcp",
|
||||
"description": "A test MCP server"
|
||||
}
|
||||
109
skills/test-mcp/server.py
Executable file
109
skills/test-mcp/server.py
Executable file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A simple test MCP server that exposes an echo tool.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
def handle_request(req):
|
||||
method = req.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-mcp", "version": "1.0.0"}
|
||||
}
|
||||
elif method == "notifications/initialized":
|
||||
# Notification, no response needed
|
||||
return None
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"tools": [
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echoes back the input text",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to echo"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "add",
|
||||
"description": "Adds two numbers together",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "number",
|
||||
"description": "First number"
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "Second number"
|
||||
}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
params = req.get("params", {})
|
||||
tool_name = params.get("name", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
if tool_name == "echo":
|
||||
text = args.get("text", "")
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Echo: {text}"}]
|
||||
}
|
||||
elif tool_name == "add":
|
||||
a = args.get("a", 0)
|
||||
b = args.get("b", 0)
|
||||
result = a + b
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Result: {a} + {b} = {result}"}]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||
"isError": True
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def main():
|
||||
for line in sys.stdin:
|
||||
try:
|
||||
req = json.loads(line.strip())
|
||||
result = handle_request(req)
|
||||
|
||||
# Only send response if there's an ID (not a notification)
|
||||
if "id" in req and result is not None:
|
||||
resp = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req["id"],
|
||||
"result": result
|
||||
}
|
||||
print(json.dumps(resp), flush=True)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "id" in req:
|
||||
resp = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.get("id"),
|
||||
"error": {"code": -32603, "message": str(e)}
|
||||
}
|
||||
print(json.dumps(resp), flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,5 +1,29 @@
|
||||
package model
|
||||
|
||||
// SkillRef represents a reference to a skill, either by local path or by registry digest.
|
||||
type SkillRef struct {
|
||||
// Name is the local path (for development) or registry name (e.g., "skill/calculator:1.0.0")
|
||||
Name string `json:"name,omitempty"`
|
||||
// Digest is the content-addressable digest of the skill blob (e.g., "sha256:abc123...")
|
||||
Digest string `json:"digest,omitempty"`
|
||||
}
|
||||
|
||||
// MCPRef represents a reference to an MCP (Model Context Protocol) server.
|
||||
type MCPRef struct {
|
||||
// Name is the identifier for the MCP server (used for tool namespacing)
|
||||
Name string `json:"name,omitempty"`
|
||||
// Digest is the content-addressable digest of the bundled MCP server blob
|
||||
Digest string `json:"digest,omitempty"`
|
||||
// Command is the executable to run (e.g., "uv", "node", "python3")
|
||||
Command string `json:"command,omitempty"`
|
||||
// Args are the arguments to pass to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Env is optional environment variables for the MCP server
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// Type is the transport type (currently only "stdio" is supported)
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// ConfigV2 represents the configuration metadata for a model.
|
||||
type ConfigV2 struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
@@ -20,6 +44,12 @@ type ConfigV2 struct {
|
||||
EmbedLen int `json:"embedding_length,omitempty"`
|
||||
BaseName string `json:"base_name,omitempty"`
|
||||
|
||||
// agent-specific fields
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
OS string `json:"os"`
|
||||
|
||||
@@ -59,6 +59,7 @@ type partKind int
|
||||
const (
|
||||
kindHost partKind = iota
|
||||
kindNamespace
|
||||
kindKind
|
||||
kindModel
|
||||
kindTag
|
||||
kindDigest
|
||||
@@ -70,6 +71,8 @@ func (k partKind) String() string {
|
||||
return "host"
|
||||
case kindNamespace:
|
||||
return "namespace"
|
||||
case kindKind:
|
||||
return "kind"
|
||||
case kindModel:
|
||||
return "model"
|
||||
case kindTag:
|
||||
@@ -89,6 +92,7 @@ func (k partKind) String() string {
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Kind string // Optional: "skill", "agent", or empty for models
|
||||
Model string
|
||||
Tag string
|
||||
}
|
||||
@@ -97,34 +101,27 @@ type Name struct {
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { kind } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { kind } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// kind:
|
||||
// pattern: "skill" | "agent" | "" (empty for models)
|
||||
// length: [0, 80]
|
||||
// model:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// Most users should use [ParseName] instead, unless need to support
|
||||
// different defaults than DefaultName.
|
||||
@@ -136,6 +133,13 @@ func ParseName(s string) Name {
|
||||
return Merge(ParseNameBare(s), DefaultName())
|
||||
}
|
||||
|
||||
// ValidKinds are the allowed values for the Kind field
|
||||
var ValidKinds = map[string]bool{
|
||||
"skill": true,
|
||||
"agent": true,
|
||||
"mcp": true,
|
||||
}
|
||||
|
||||
// ParseNameBare parses s as a name string and returns a Name. No merge with
|
||||
// [DefaultName] is performed.
|
||||
func ParseNameBare(s string) Name {
|
||||
@@ -153,6 +157,30 @@ func ParseNameBare(s string) Name {
|
||||
return n
|
||||
}
|
||||
|
||||
s, n.Kind, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
// Only 2 parts: namespace/model - what we parsed as Kind is actually Namespace
|
||||
n.Namespace = n.Kind
|
||||
n.Kind = ""
|
||||
return n
|
||||
}
|
||||
|
||||
// Check if what we parsed as Kind is actually a valid kind value
|
||||
if !ValidKinds[n.Kind] {
|
||||
// Not a valid kind - this is the old 3-part format: host/namespace/model
|
||||
// Shift: Kind -> Namespace, s -> Host
|
||||
n.Namespace = n.Kind
|
||||
n.Kind = ""
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
return n
|
||||
}
|
||||
|
||||
// Valid kind found - continue parsing for namespace and optional host
|
||||
s, n.Namespace, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Namespace = s
|
||||
@@ -168,20 +196,32 @@ func ParseNameBare(s string) Name {
|
||||
return n
|
||||
}
|
||||
|
||||
// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
|
||||
// ParseNameFromFilepath parses a 4 or 5-part filepath as a Name. The parts are
|
||||
// expected to be in the form:
|
||||
//
|
||||
// { host } "/" { namespace } "/" { model } "/" { tag }
|
||||
// { host } "/" { namespace } "/" { kind } "/" { model } "/" { tag }
|
||||
func ParseNameFromFilepath(s string) (n Name) {
|
||||
parts := strings.Split(s, string(filepath.Separator))
|
||||
if len(parts) != 4 {
|
||||
|
||||
switch len(parts) {
|
||||
case 4:
|
||||
// Old format: host/namespace/model/tag
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Model = parts[2]
|
||||
n.Tag = parts[3]
|
||||
case 5:
|
||||
// New format: host/namespace/kind/model/tag
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Kind = parts[2]
|
||||
n.Model = parts[3]
|
||||
n.Tag = parts[4]
|
||||
default:
|
||||
return Name{}
|
||||
}
|
||||
|
||||
n.Host = parts[0]
|
||||
n.Namespace = parts[1]
|
||||
n.Model = parts[2]
|
||||
n.Tag = parts[3]
|
||||
if !n.IsFullyQualified() {
|
||||
return Name{}
|
||||
}
|
||||
@@ -189,11 +229,12 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, kind, and tag parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Kind = cmp.Or(a.Kind, b.Kind)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
return a
|
||||
}
|
||||
@@ -211,6 +252,10 @@ func (n Name) String() string {
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
if n.Kind != "" {
|
||||
b.WriteString(n.Kind)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
b.WriteString(n.Model)
|
||||
if n.Tag != "" {
|
||||
b.WriteByte(':')
|
||||
@@ -233,6 +278,12 @@ func (n Name) DisplayShortest() string {
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
||||
// include kind if present
|
||||
if n.Kind != "" {
|
||||
sb.WriteString(n.Kind)
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
||||
// always include model and tag
|
||||
sb.WriteString(n.Model)
|
||||
sb.WriteString(":")
|
||||
@@ -256,18 +307,23 @@ func (n Name) IsValid() bool {
|
||||
}
|
||||
|
||||
// IsFullyQualified returns true if all parts of the name are present and
|
||||
// valid without the digest.
|
||||
// valid without the digest. Kind is optional and only validated if non-empty.
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
parts := []string{
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
if !isValidPart(kindHost, n.Host) {
|
||||
return false
|
||||
}
|
||||
for i, part := range parts {
|
||||
if !isValidPart(partKind(i), part) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindNamespace, n.Namespace) {
|
||||
return false
|
||||
}
|
||||
// Kind is optional - only validate if present
|
||||
if n.Kind != "" && !isValidPart(kindKind, n.Kind) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindModel, n.Model) {
|
||||
return false
|
||||
}
|
||||
if !isValidPart(kindTag, n.Tag) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -276,6 +332,7 @@ func (n Name) IsFullyQualified() bool {
|
||||
// host to tag as a directory in the form:
|
||||
//
|
||||
// {host}/{namespace}/{model}/{tag}
|
||||
// {host}/{namespace}/{kind}/{model}/{tag}
|
||||
//
|
||||
// It uses the system's filepath separator and ensures the path is clean.
|
||||
//
|
||||
@@ -285,6 +342,15 @@ func (n Name) Filepath() string {
|
||||
if !n.IsFullyQualified() {
|
||||
panic("illegal attempt to get filepath of invalid name")
|
||||
}
|
||||
if n.Kind != "" {
|
||||
return filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Kind,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
)
|
||||
}
|
||||
return filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
@@ -301,6 +367,7 @@ func (n Name) LogValue() slog.Value {
|
||||
func (n Name) EqualFold(o Name) bool {
|
||||
return strings.EqualFold(n.Host, o.Host) &&
|
||||
strings.EqualFold(n.Namespace, o.Namespace) &&
|
||||
strings.EqualFold(n.Kind, o.Kind) &&
|
||||
strings.EqualFold(n.Model, o.Model) &&
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
@@ -317,6 +384,11 @@ func isValidLen(kind partKind, s string) bool {
|
||||
}
|
||||
|
||||
func isValidPart(kind partKind, s string) bool {
|
||||
// Kind must be one of the valid values
|
||||
if kind == kindKind {
|
||||
return ValidKinds[s]
|
||||
}
|
||||
|
||||
if !isValidLen(kind, s) {
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user