diff --git a/pkg/model/filters.go b/pkg/model/filters.go index 79b72d5bf..9981e2771 100644 --- a/pkg/model/filters.go +++ b/pkg/model/filters.go @@ -15,3 +15,9 @@ func allExcept(s string) GRPCProcessFilter { return id != s } } + +func only(s string) GRPCProcessFilter { + return func(id string, p *process.Process) bool { + return id == s + } +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 98e424885..9178a5265 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -173,6 +173,10 @@ func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err e model, err := ml.LoadModel(o.modelID, o.model, ml.grpcModel(backend, o)) if err != nil { + err := ml.StopGRPC(only(o.modelID)) + if err != nil { + log.Error().Err(err).Str("model", o.modelID).Msg("error stopping model") + } log.Error().Str("modelID", o.modelID).Err(err).Msgf("Failed to load model %s with backend %s", o.modelID, o.backendString) return nil, err } @@ -180,8 +184,8 @@ func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err e return model.GRPC(o.parallelRequests, ml.wd), nil } -func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bool) { - if !singleActiveBackend { +func (ml *ModelLoader) stopActiveBackends(modelID string) { + if !ml.singletonMode { return } @@ -218,15 +222,19 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { // (avoid looping through all the backends) if m := ml.CheckIsLoaded(o.modelID); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.modelID) - return m.GRPC(o.parallelRequests, ml.wd), nil } - ml.stopActiveBackends(o.modelID, ml.singletonMode) + ml.stopActiveBackends(o.modelID) // if a backend is defined, return the loader directly if o.backendString != "" { - return ml.backendLoader(opts...) + client, err := ml.backendLoader(opts...) + if err != nil { + ml.Close() + return nil, err + } + return client, nil } // Otherwise scan for backends in the asset directory @@ -242,6 +250,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { if len(autoLoadBackends) == 0 { log.Error().Msg("No backends found") + ml.Close() return nil, fmt.Errorf("no backends found") }