mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 05:03:13 -04:00
chore: update cogito and simplify MCP logics (#6413)
* chore: update cogito and simplify MCP logics Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Refine signal handling Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
459b6ab86d
commit
27c4161401
@@ -2,14 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/driver/desktop"
|
||||
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -42,7 +40,12 @@ func main() {
|
||||
}
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
setupSignalHandling(launcher)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
// Perform cleanup
|
||||
if err := launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Initialize the launcher state
|
||||
go func() {
|
||||
@@ -67,26 +70,3 @@ func main() {
|
||||
// Run the application in background (window only shown when "Settings" is clicked)
|
||||
myApp.Run()
|
||||
}
|
||||
|
||||
// setupSignalHandling sets up signal handlers for graceful shutdown
|
||||
func setupSignalHandling(launcher *coreLauncher.Launcher) {
|
||||
// Create a channel to receive OS signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
// Register for interrupt and terminate signals
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Handle signals in a separate goroutine
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
// Perform cleanup
|
||||
if err := launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
|
||||
// Exit the application
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ExplorerCMD struct {
|
||||
@@ -46,7 +47,11 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.Handler(nil)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
)
|
||||
|
||||
type FederatedCLI struct {
|
||||
@@ -20,7 +20,11 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
||||
|
||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
||||
|
||||
signals.Handler(nil)
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return fs.Start(context.Background())
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
return fs.Start(c)
|
||||
}
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -226,8 +226,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Catch signals from the OS requesting us to exit, and stop all backends
|
||||
signals.Handler(app.ModelLoader())
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package signals
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Handler(m *model.ModelLoader) {
|
||||
// Catch signals from the OS requesting us to exit, and stop all backends
|
||||
go func(m *model.ModelLoader) {
|
||||
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-c
|
||||
if m != nil {
|
||||
if err := m.StopAllGRPC(); err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}(m)
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -85,8 +84,6 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
||||
|
||||
args = append([]string{grpcProcess}, args...)
|
||||
|
||||
signals.Handler(nil)
|
||||
|
||||
return syscall.Exec(
|
||||
grpcProcess,
|
||||
args,
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -48,6 +48,9 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
|
||||
address := "127.0.0.1"
|
||||
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if r.NoRunner {
|
||||
// Let override which port and address to bind if the user
|
||||
// configure the llama-cpp service on its own
|
||||
@@ -59,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
p = r.RunnerPort
|
||||
}
|
||||
|
||||
_, err = p2p.ExposeService(context.Background(), address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,13 +104,15 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
signals.Handler(nil)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
for {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -2,43 +2,66 @@ package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/jsonschema"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func ToolsFromMCPConfig(ctx context.Context, remote config.MCPGenericConfig[config.MCPRemoteServers], stdio config.MCPGenericConfig[config.MCPSTDIOServers]) ([]*MCPTool, error) {
|
||||
allTools := []*MCPTool{}
|
||||
type sessionCache struct {
|
||||
mu sync.Mutex
|
||||
cache map[string][]*mcp.ClientSession
|
||||
}
|
||||
|
||||
var (
|
||||
cache = sessionCache{
|
||||
cache: make(map[string][]*mcp.ClientSession),
|
||||
}
|
||||
|
||||
client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
|
||||
)
|
||||
|
||||
func SessionsFromMCPConfig(
|
||||
name string,
|
||||
remote config.MCPGenericConfig[config.MCPRemoteServers],
|
||||
stdio config.MCPGenericConfig[config.MCPSTDIOServers],
|
||||
) ([]*mcp.ClientSession, error) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
sessions, exists := cache.cache[name]
|
||||
if exists {
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
allSessions := []*mcp.ClientSession{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Get the list of all the tools that the Agent will be esposed to
|
||||
for _, server := range remote.Servers {
|
||||
|
||||
log.Debug().Msgf("[MCP remote server] Configuration : %+v", server)
|
||||
// Create HTTP client with custom roundtripper for bearer token injection
|
||||
client := &http.Client{
|
||||
httpClient := &http.Client{
|
||||
Timeout: 360 * time.Second,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
|
||||
tools, err := mcpToolsFromTransport(ctx,
|
||||
&mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: client},
|
||||
)
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Error().Err(err).Msgf("Failed to connect to MCP server %s", server.URL)
|
||||
continue
|
||||
}
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
log.Debug().Msgf("[MCP remote server] Connected to MCP server %s", server.URL)
|
||||
cache.cache[name] = append(cache.cache[name], mcpSession)
|
||||
}
|
||||
|
||||
for _, server := range stdio.Servers {
|
||||
@@ -48,18 +71,24 @@ func ToolsFromMCPConfig(ctx context.Context, remote config.MCPGenericConfig[conf
|
||||
for key, value := range server.Env {
|
||||
command.Env = append(command.Env, key+"="+value)
|
||||
}
|
||||
tools, err := mcpToolsFromTransport(ctx,
|
||||
&mcp.CommandTransport{
|
||||
Command: command},
|
||||
)
|
||||
transport := &mcp.CommandTransport{Command: command}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Error().Err(err).Msgf("Failed to start MCP server %s", command)
|
||||
continue
|
||||
}
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
log.Debug().Msgf("[MCP stdio server] Connected to MCP server %s", command)
|
||||
cache.cache[name] = append(cache.cache[name], mcpSession)
|
||||
}
|
||||
|
||||
return allTools, nil
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
for _, session := range allSessions {
|
||||
session.Close()
|
||||
}
|
||||
cancel()
|
||||
})
|
||||
|
||||
return allSessions, nil
|
||||
}
|
||||
|
||||
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
|
||||
@@ -87,146 +116,3 @@ func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.Round
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
type MCPTool struct {
|
||||
name, description string
|
||||
inputSchema ToolInputSchema
|
||||
session *mcp.ClientSession
|
||||
ctx context.Context
|
||||
props map[string]jsonschema.Definition
|
||||
}
|
||||
|
||||
func (t *MCPTool) Run(args map[string]any) (string, error) {
|
||||
|
||||
// Call a tool on the server.
|
||||
params := &mcp.CallToolParams{
|
||||
Name: t.name,
|
||||
Arguments: args,
|
||||
}
|
||||
res, err := t.session.CallTool(t.ctx, params)
|
||||
if err != nil {
|
||||
log.Error().Msgf("CallTool failed: %v", err)
|
||||
return "", err
|
||||
}
|
||||
if res.IsError {
|
||||
log.Error().Msgf("tool failed")
|
||||
return "", errors.New("tool failed")
|
||||
}
|
||||
|
||||
result := ""
|
||||
for _, c := range res.Content {
|
||||
result += c.(*mcp.TextContent).Text
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *MCPTool) Tool() openai.Tool {
|
||||
|
||||
return openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: t.name,
|
||||
Description: t.description,
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: t.props,
|
||||
Required: t.inputSchema.Required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) Close() {
|
||||
t.session.Close()
|
||||
}
|
||||
|
||||
type ToolInputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]interface{} `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// probe the MCP remote and generate tools that are compliant with cogito
|
||||
// TODO: Maybe move this to cogito?
|
||||
func mcpToolsFromTransport(ctx context.Context, transport mcp.Transport) ([]*MCPTool, error) {
|
||||
allTools := []*MCPTool{}
|
||||
|
||||
// Create a new client, with no features.
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error connecting to MCP server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tools, err := session.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error listing tools: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tool := range tools.Tools {
|
||||
dat, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error marshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// XXX: This is a wild guess, to verify (data types might be incompatible)
|
||||
var inputSchema ToolInputSchema
|
||||
err = json.Unmarshal(dat, &inputSchema)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error unmarshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
props := map[string]jsonschema.Definition{}
|
||||
dat, err = json.Marshal(inputSchema.Properties)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error marshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(dat, &props)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error unmarshalling input schema properties: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
allTools = append(allTools, &MCPTool{
|
||||
name: tool.Name,
|
||||
description: tool.Description,
|
||||
session: session,
|
||||
ctx: ctx,
|
||||
props: props,
|
||||
inputSchema: inputSchema,
|
||||
})
|
||||
}
|
||||
|
||||
// We make sure we run Close on signal
|
||||
handleSignal(allTools)
|
||||
|
||||
return allTools, nil
|
||||
}
|
||||
|
||||
func handleSignal(tools []*MCPTool) {
|
||||
|
||||
// Create a channel to receive OS signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
// Register for interrupt and terminate signals
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Handle signals in a separate goroutine
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
for _, t := range tools {
|
||||
t.Close()
|
||||
}
|
||||
|
||||
// Exit the application
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -27,10 +26,6 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
toolsCache := map[string][]*mcp.MCPTool{}
|
||||
mu := sync.Mutex{}
|
||||
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -54,37 +49,17 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
|
||||
allTools := []*mcp.MCPTool{}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio := config.MCP.MCPConfigFromYAML()
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
mu.Lock()
|
||||
tools, exists := toolsCache[config.Name]
|
||||
if exists {
|
||||
allTools = append(allTools, tools...)
|
||||
} else {
|
||||
tools, err := mcp.ToolsFromMCPConfig(ctx, remote, stdio)
|
||||
if err != nil {
|
||||
mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
toolsCache[config.Name] = tools
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
cogitoTools := []cogito.Tool{}
|
||||
for _, tool := range allTools {
|
||||
cogitoTools = append(cogitoTools, tool)
|
||||
// defer tool.Close()
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(cogitoTools) == 0 {
|
||||
return fmt.Errorf("no tools found in the specified MCP servers")
|
||||
if len(sessions) == 0 {
|
||||
return fmt.Errorf("no working MCP servers found")
|
||||
}
|
||||
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
@@ -109,7 +84,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithContext(ctx),
|
||||
cogito.WithTools(cogitoTools...),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -34,7 +34,7 @@ require (
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.27
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0
|
||||
github.com/mudler/cogito v0.1.0
|
||||
github.com/mudler/cogito v0.2.0
|
||||
github.com/mudler/edgevpn v0.31.0
|
||||
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82
|
||||
github.com/nikolalohinski/gonja/v2 v2.4.1
|
||||
@@ -60,6 +60,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
google.golang.org/grpc v1.67.1
|
||||
google.golang.org/protobuf v1.36.8
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
oras.land/oras-go/v2 v2.6.0
|
||||
@@ -149,7 +150,6 @@ require (
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/image v0.25.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
4
go.sum
4
go.sum
@@ -510,8 +510,8 @@ github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7P
|
||||
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/mudler/cogito v0.1.0 h1:RybskLSPuLkBlR9Z+y4LJgIU5wVscYoHuF9+ubXsHgM=
|
||||
github.com/mudler/cogito v0.1.0/go.mod h1:MiipcWbTr+fcW3HiirQRrYYjEIamZFCLkpqvdgk/Nfw=
|
||||
github.com/mudler/cogito v0.2.0 h1:UzowMlP6kiDLnuwQikac9yUOhI6Qe2tW1jZP5gHQvaY=
|
||||
github.com/mudler/cogito v0.2.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
|
||||
github.com/mudler/edgevpn v0.31.0 h1:CXwxQ2ZygzE7iKGl1J+vq9pL5PvsW2uc3qI/zgpNpp4=
|
||||
github.com/mudler/edgevpn v0.31.0/go.mod h1:DKgh9Wu/NM3UbZoPyheMXFvpu1dSLkXrqAOy3oKJN3I=
|
||||
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA=
|
||||
|
||||
@@ -4,14 +4,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/hpcloud/tail"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -130,16 +129,13 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
err := grpcControlProcess.Stop()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while shutting down grpc process")
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||
|
||||
40
pkg/signals/handler.go
Normal file
40
pkg/signals/handler.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package signals
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
signalHandlers []func()
|
||||
signalHandlersMutex sync.Mutex
|
||||
signalHandlersOnce sync.Once
|
||||
)
|
||||
|
||||
func RegisterGracefulTerminationHandler(fn func()) {
|
||||
signalHandlersMutex.Lock()
|
||||
defer signalHandlersMutex.Unlock()
|
||||
signalHandlers = append(signalHandlers, fn)
|
||||
}
|
||||
|
||||
func init() {
|
||||
signalHandlersOnce.Do(func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
||||
go signalHandler(c)
|
||||
})
|
||||
}
|
||||
|
||||
func signalHandler(c chan os.Signal) {
|
||||
<-c
|
||||
|
||||
signalHandlersMutex.Lock()
|
||||
defer signalHandlersMutex.Unlock()
|
||||
for _, fn := range signalHandlers {
|
||||
fn()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
Reference in New Issue
Block a user