diff --git a/backend/backend.proto b/backend/backend.proto index 369218010..31e40c4f3 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -32,6 +32,8 @@ service Backend { rpc GetMetrics(MetricsRequest) returns (MetricsResponse); rpc VAD(VADRequest) returns (VADResponse) {} + + rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {} } // Define the empty request @@ -410,3 +412,8 @@ message Detection { message DetectResponse { repeated Detection Detections = 1; } + +message ModelMetadataResponse { + bool supports_thinking = 1; + string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable) +} diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 9ff62316e..0f29cc755 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -2476,6 +2476,47 @@ public: response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total); + return grpc::Status::OK; + } + + grpc::Status ModelMetadata(ServerContext* /*context*/, const backend::ModelOptions* /*request*/, backend::ModelMetadataResponse* response) override { + // Check if model is loaded + if (params_base.model.path.empty()) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); + } + + // Check if chat templates are initialized + if (ctx_server.impl->chat_params.tmpls == nullptr) { + // If templates are not initialized, we can't detect thinking support + // Return false as default + response->set_supports_thinking(false); + response->set_rendered_template(""); + return grpc::Status::OK; + } + + // Detect thinking support using llama.cpp's function + bool supports_thinking = common_chat_templates_support_enable_thinking(ctx_server.impl->chat_params.tmpls.get()); + response->set_supports_thinking(supports_thinking); + + // Render the template with enable_thinking=true so Go code can detect thinking tokens + // This allows reusing existing detection functions in Go + std::string rendered_template = ""; + if (params_base.use_jinja) { + // Render the template with enable_thinking=true to see what the actual prompt looks like + common_chat_templates_inputs dummy_inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + dummy_inputs.messages = {msg}; + dummy_inputs.enable_thinking = true; + dummy_inputs.use_jinja = params_base.use_jinja; + + const auto rendered = common_chat_templates_apply(ctx_server.impl->chat_params.tmpls.get(), dummy_inputs); + rendered_template = rendered.prompt; + } + + response->set_rendered_template(rendered_template); + return grpc::Status::OK; } }; diff --git a/core/backend/llm.go b/core/backend/llm.go index 06b9d2d44..18367cb2f 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -61,6 +61,18 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima return nil, err } + // Detect thinking support after model load (only if not already detected) + // This needs to happen after LoadModel succeeds so the backend can render templates + if (c.ReasoningConfig.DisableReasoning == nil && c.ReasoningConfig.DisableReasoningTagPrefill == nil) && c.TemplateConfig.UseTokenizerTemplate { + modelOpts := grpcModelOpts(*c, o.SystemState.Model.ModelsPath) + config.DetectThinkingSupportFromBackend(ctx, c, inferenceModel, modelOpts) + // Update the config in the loader so it persists for future requests + cl.UpdateModelConfig(c.Name, func(cfg *config.ModelConfig) { + cfg.ReasoningConfig.DisableReasoning = c.ReasoningConfig.DisableReasoning + cfg.ReasoningConfig.DisableReasoningTagPrefill = c.ReasoningConfig.DisableReasoningTagPrefill + }) + } + var protoMessages []*proto.Message // if we are using the tokenizer template, we need to convert the messages to proto messages // unless the prompt has already been tokenized (non-chat endpoints + functions) diff --git a/core/config/gguf.go b/core/config/gguf.go index 507466d60..0d788dad4 100644 --- a/core/config/gguf.go +++ b/core/config/gguf.go @@ -1,10 +1,16 @@ package config import ( + "context" + + "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" gguf "github.com/gpustack/gguf-parser-go" + "github.com/gpustack/gguf-parser-go/util/ptr" ) const ( @@ -71,6 +77,8 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { cfg.modelTemplate = chatTemplate.ValueString() } + // Thinking support detection is done after model load via DetectThinkingSupportFromBackend + // template estimations if cfg.HasTemplate() { // nothing to guess here @@ -92,3 +100,47 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT") } + +// DetectThinkingSupportFromBackend calls the ModelMetadata gRPC method to detect +// if the model supports thinking mode and if the template ends with a thinking start token. +// This should be called after the model is loaded. +// The results are stored in cfg.SupportsThinking and cfg.ThinkingForcedOpen. +func DetectThinkingSupportFromBackend(ctx context.Context, cfg *ModelConfig, backendClient grpc.Backend, modelOptions *pb.ModelOptions) { + if backendClient == nil { + xlog.Debug("[gguf] DetectThinkingSupportFromBackend: backend client is nil, skipping detection") + return + } + + if modelOptions == nil { + xlog.Debug("[gguf] DetectThinkingSupportFromBackend: model options is nil, skipping detection") + return + } + + // Only detect for llama-cpp backend when using tokenizer templates + if cfg.Backend != "llama-cpp" || !cfg.TemplateConfig.UseTokenizerTemplate { + xlog.Debug("[gguf] DetectThinkingSupportFromBackend: skipping detection", "backend", cfg.Backend, "useTokenizerTemplate", cfg.TemplateConfig.UseTokenizerTemplate) + return + } + + metadata, err := backendClient.ModelMetadata(ctx, modelOptions) + if err != nil { + xlog.Warn("[gguf] DetectThinkingSupportFromBackend: failed to get model metadata", "error", err) + return + } + + if metadata != nil { + cfg.ReasoningConfig.DisableReasoning = ptr.To(!metadata.SupportsThinking) + + // Use the rendered template to detect if thinking token is at the end + // This reuses the existing DetectThinkingStartToken function + if metadata.RenderedTemplate != "" { + thinkingStartToken := reasoning.DetectThinkingStartToken(metadata.RenderedTemplate, &cfg.ReasoningConfig) + thinkingForcedOpen := thinkingStartToken != "" + cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(!thinkingForcedOpen) + xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", thinkingForcedOpen, "thinking_start_token", thinkingStartToken) + } else { + cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(true) + xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", false) + } + } +} diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index 1a8c64230..02724a5d6 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -246,6 +246,17 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) { delete(bcl.configs, m) } +// UpdateModelConfig updates an existing model config in the loader. +// This is useful for updating runtime-detected properties like thinking support. +func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) { + bcl.Lock() + defer bcl.Unlock() + if cfg, exists := bcl.configs[m]; exists { + updater(&cfg) + bcl.configs[m] = cfg + } +} + // Preload prepare models if they are not local but url or huggingface repositories func (bcl *ModelConfigLoader) Preload(modelPath string) error { bcl.Lock() diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 846f7231d..8ecb818a0 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -57,4 +57,6 @@ type Backend interface { GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) + + ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) } diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index e59db2e15..2d0ebc555 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -77,6 +77,10 @@ func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationRespons return pb.TokenizationResponse{}, fmt.Errorf("unimplemented") } +func (llm *Base) ModelMetadata(opts *pb.ModelOptions) (*pb.ModelMetadataResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + // backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details? func (llm *Base) Status() (pb.StatusResponse, error) { return pb.StatusResponse{ diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index ff5dccb41..dbdeeab24 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -537,3 +537,25 @@ func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc. client := pb.NewBackendClient(conn) return client.Detect(ctx, in, opts...) } + +func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.ModelMetadata(ctx, in, opts...) +} diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 3369ce0fc..03cac344f 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -99,6 +99,10 @@ func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc. return e.s.VAD(ctx, in) } +func (e *embedBackend) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) { + return e.s.ModelMetadata(ctx, in) +} + func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { return e.s.GetMetrics(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 66c38f430..bb22af55c 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -28,6 +28,8 @@ type AIModel interface { StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) VAD(*pb.VADRequest) (pb.VADResponse, error) + + ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error) } func newReply(s string) *pb.Reply { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 30962e8c8..8cc6ee43e 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -263,6 +263,18 @@ func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, e return &res, nil } +func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + res, err := s.llm.ModelMetadata(in) + if err != nil { + return nil, err + } + return res, nil +} + func StartServer(address string, model AIModel) error { lis, err := net.Listen("tcp", address) if err != nil {