mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-04 03:32:40 -05:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5d82ba344 | ||
|
|
efe2883c5d | ||
|
|
47237c7c3c | ||
|
|
697c769b64 | ||
|
|
94261b1717 | ||
|
|
eaf85a30f9 | ||
|
|
6a88b030ea | ||
|
|
f538416fb3 |
2
Makefile
2
Makefile
@@ -8,7 +8,7 @@ GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0
|
||||
|
||||
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
||||
|
||||
CPPLLAMA_VERSION?=381ee195721d8e747ee31a60c0751822b3072f02
|
||||
CPPLLAMA_VERSION?=6f9939d119b2d004c264952eb510bd106455531e
|
||||
|
||||
# gpt4all version
|
||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||
|
||||
@@ -41,7 +41,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
|
||||
var fn func() ([]float32, error)
|
||||
switch model := inferenceModel.(type) {
|
||||
case *grpc.Client:
|
||||
case grpc.Backend:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
|
||||
@@ -31,7 +31,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel *grpc.Client
|
||||
var inferenceModel grpc.Backend
|
||||
var err error
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
|
||||
@@ -158,8 +158,8 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
|
||||
//
|
||||
|
||||
enum task_type {
|
||||
COMPLETION_TASK,
|
||||
CANCEL_TASK
|
||||
TASK_TYPE_COMPLETION,
|
||||
TASK_TYPE_CANCEL,
|
||||
};
|
||||
|
||||
struct task_server {
|
||||
@@ -458,8 +458,12 @@ struct llama_client_slot
|
||||
}
|
||||
|
||||
bool has_budget(gpt_params &global_params) {
|
||||
if (params.n_predict == -1 && global_params.n_predict == -1)
|
||||
{
|
||||
return true; // limitless
|
||||
}
|
||||
n_remaining = -1;
|
||||
if(params.n_predict != -1)
|
||||
if (params.n_predict != -1)
|
||||
{
|
||||
n_remaining = params.n_predict - n_decoded;
|
||||
}
|
||||
@@ -467,7 +471,7 @@ struct llama_client_slot
|
||||
{
|
||||
n_remaining = global_params.n_predict - n_decoded;
|
||||
}
|
||||
return n_remaining > 0 || n_remaining == -1; // no budget || limitless
|
||||
return n_remaining > 0; // no budget
|
||||
}
|
||||
|
||||
bool available() const {
|
||||
@@ -1113,7 +1117,7 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
// check the limits
|
||||
if (slot.n_decoded > 2 && slot.has_next_token && !slot.has_budget(params))
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params))
|
||||
{
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
@@ -1177,8 +1181,9 @@ struct llama_server_context
|
||||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
void send_error(task_server& task, std::string error)
|
||||
void send_error(task_server& task, const std::string &error)
|
||||
{
|
||||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
task_result res;
|
||||
res.id = task.id;
|
||||
@@ -1276,7 +1281,7 @@ struct llama_server_context
|
||||
{
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
||||
size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size());
|
||||
size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size());
|
||||
if (probs_pos < probs_stop_pos)
|
||||
{
|
||||
@@ -1336,7 +1341,7 @@ struct llama_server_context
|
||||
{
|
||||
probs = std::vector<completion_token_output>(
|
||||
slot.generated_token_probs.begin(),
|
||||
slot.generated_token_probs.begin() + slot.sent_token_probs_index);
|
||||
slot.generated_token_probs.end());
|
||||
}
|
||||
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
||||
}
|
||||
@@ -1346,6 +1351,11 @@ struct llama_server_context
|
||||
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.result_json["model"] = slot.oaicompat_model;
|
||||
}
|
||||
queue_results.push_back(res);
|
||||
condition_results.notify_all();
|
||||
|
||||
// done with results, unlock
|
||||
lock.unlock();
|
||||
|
||||
// parent multitask, if any, needs to be updated
|
||||
if (slot.multitask_id != -1)
|
||||
@@ -1353,8 +1363,6 @@ struct llama_server_context
|
||||
update_multi_task(slot.multitask_id, slot.task_id, res);
|
||||
}
|
||||
|
||||
queue_results.push_back(res);
|
||||
condition_results.notify_all();
|
||||
}
|
||||
|
||||
void send_embedding(llama_client_slot &slot)
|
||||
@@ -1399,11 +1407,11 @@ struct llama_server_context
|
||||
task.data = std::move(data);
|
||||
task.infill_mode = infill;
|
||||
task.embedding_mode = embedding;
|
||||
task.type = COMPLETION_TASK;
|
||||
task.type = TASK_TYPE_COMPLETION;
|
||||
task.multitask_id = multitask_id;
|
||||
|
||||
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||
if (task.data.at("prompt").size() > 1)
|
||||
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
|
||||
{
|
||||
lock.unlock(); // entering new func scope
|
||||
return split_multiprompt_task(task);
|
||||
@@ -1521,7 +1529,7 @@ struct llama_server_context
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
task_server task;
|
||||
task.id = id_gen++;
|
||||
task.type = CANCEL_TASK;
|
||||
task.type = TASK_TYPE_CANCEL;
|
||||
task.target_id = task_id;
|
||||
queue_tasks.push_back(task);
|
||||
condition_tasks.notify_one();
|
||||
@@ -1551,32 +1559,41 @@ struct llama_server_context
|
||||
void process_tasks()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
std::vector<task_server> deferred_tasks;
|
||||
while (!queue_tasks.empty())
|
||||
{
|
||||
task_server task = queue_tasks.front();
|
||||
queue_tasks.erase(queue_tasks.begin());
|
||||
switch (task.type)
|
||||
{
|
||||
case COMPLETION_TASK: {
|
||||
case TASK_TYPE_COMPLETION: {
|
||||
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
|
||||
if (slot == nullptr)
|
||||
{
|
||||
LOG_TEE("slot unavailable\n");
|
||||
// send error result
|
||||
send_error(task, "slot unavailable");
|
||||
return;
|
||||
// if no slot is available, we defer this task for processing later
|
||||
deferred_tasks.push_back(task);
|
||||
break;
|
||||
}
|
||||
|
||||
if (task.data.contains("system_prompt"))
|
||||
{
|
||||
if (!all_slots_are_idle) {
|
||||
send_error(task, "system prompt can only be updated when all slots are idle");
|
||||
break;
|
||||
}
|
||||
process_system_prompt_data(task.data["system_prompt"]);
|
||||
// reset cache_tokens for all slots
|
||||
for (llama_client_slot &slot : slots)
|
||||
{
|
||||
slot.cache_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
slot->reset();
|
||||
|
||||
slot->infill = task.infill_mode;
|
||||
slot->embedding = task.embedding_mode;
|
||||
slot->task_id = task.id;
|
||||
slot->infill = task.infill_mode;
|
||||
slot->embedding = task.embedding_mode;
|
||||
slot->task_id = task.id;
|
||||
slot->multitask_id = task.multitask_id;
|
||||
|
||||
if (!launch_slot_with_data(slot, task.data))
|
||||
@@ -1586,7 +1603,7 @@ struct llama_server_context
|
||||
break;
|
||||
}
|
||||
} break;
|
||||
case CANCEL_TASK: { // release slot linked with the task id
|
||||
case TASK_TYPE_CANCEL: { // release slot linked with the task id
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
if (slot.task_id == task.target_id)
|
||||
@@ -1599,7 +1616,14 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
// add all the deferred tasks back the the queue
|
||||
for (task_server &task : deferred_tasks)
|
||||
{
|
||||
queue_tasks.push_back(task);
|
||||
}
|
||||
|
||||
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
||||
std::vector<task_result> agg_results;
|
||||
auto queue_iterator = queue_multitasks.begin();
|
||||
while (queue_iterator != queue_multitasks.end())
|
||||
{
|
||||
@@ -1620,8 +1644,7 @@ struct llama_server_context
|
||||
}
|
||||
aggregate_result.result_json = json{ "results", result_jsons };
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_results);
|
||||
queue_results.push_back(aggregate_result);
|
||||
agg_results.push_back(aggregate_result);
|
||||
condition_results.notify_all();
|
||||
|
||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||
@@ -1631,14 +1654,19 @@ struct llama_server_context
|
||||
++queue_iterator;
|
||||
}
|
||||
}
|
||||
// done with tasks, unlock
|
||||
lock.unlock();
|
||||
|
||||
// copy aggregate results of complete multi-tasks to the results queue
|
||||
std::lock_guard<std::mutex> lock_results(mutex_results);
|
||||
queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
|
||||
}
|
||||
|
||||
bool update_slots() {
|
||||
// attend tasks
|
||||
process_tasks();
|
||||
|
||||
// update the system prompt wait until all slots are idle state
|
||||
if (system_need_update && all_slots_are_idle)
|
||||
if (system_need_update)
|
||||
{
|
||||
LOG_TEE("updating system prompt\n");
|
||||
update_system_prompt();
|
||||
@@ -1714,7 +1742,6 @@ struct llama_server_context
|
||||
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
slot.n_past += 1;
|
||||
}
|
||||
|
||||
@@ -1729,7 +1756,8 @@ struct llama_server_context
|
||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
|
||||
// note: infill mode allows empty prompt
|
||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
|
||||
{
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
@@ -1832,7 +1860,7 @@ struct llama_server_context
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.num_prompt_tokens)
|
||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
||||
@@ -1932,6 +1960,7 @@ struct llama_server_context
|
||||
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1)
|
||||
{
|
||||
slot.t_start_genereration = ggml_time_us();
|
||||
@@ -2023,7 +2052,7 @@ json oaicompat_completion_params_parse(
|
||||
//
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
llama_sampling_params default_sparams;
|
||||
llama_params["model"] = json_value(body, "model", std::string("uknown"));
|
||||
llama_params["model"] = json_value(body, "model", std::string("unknown"));
|
||||
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
||||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
@@ -2095,8 +2124,8 @@ static json format_final_response_oaicompat(const json &request, const task_resu
|
||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
||||
{"usage",
|
||||
json{{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
|
||||
{"id", gen_chatcmplid()}};
|
||||
|
||||
if (server_verbose) {
|
||||
@@ -2436,10 +2465,10 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
||||
if (env_parallel != NULL) {
|
||||
params.n_parallel = std::stoi(env_parallel);
|
||||
params.cont_batching = true;
|
||||
} else {
|
||||
params.n_parallel = 1;
|
||||
}
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
if (!request->tensorsplit().empty()) {
|
||||
|
||||
@@ -6,10 +6,6 @@ weight = 14
|
||||
url = "/features/gpt-vision/"
|
||||
+++
|
||||
|
||||
{{% alert note %}}
|
||||
Available only on `master` builds
|
||||
{{% /alert %}}
|
||||
|
||||
LocalAI supports understanding images by using [LLaVA](https://llava.hliu.cc/), and implements the [GPT Vision API](https://platform.openai.com/docs/guides/vision) from OpenAI.
|
||||
|
||||

|
||||
@@ -28,4 +24,4 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso
|
||||
|
||||
### Setup
|
||||
|
||||
To setup the LLaVa models, follow the full example in the [configuration examples](https://github.com/mudler/LocalAI/blob/master/examples/configurations/README.md#llava).
|
||||
To setup the LLaVa models, follow the full example in the [configuration examples](https://github.com/mudler/LocalAI/blob/master/examples/configurations/README.md#llava).
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "v2.5.1"
|
||||
"version": "v2.6.0"
|
||||
}
|
||||
|
||||
46
pkg/grpc/backend.go
Normal file
46
pkg/grpc/backend.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var embeds = map[string]*embedBackend{}
|
||||
|
||||
func Provide(addr string, llm LLM) {
|
||||
embeds[addr] = &embedBackend{s: &server{llm: llm}}
|
||||
}
|
||||
|
||||
func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
|
||||
if bc, ok := embeds[address]; ok {
|
||||
return bc
|
||||
}
|
||||
return NewGrpcClient(address, parallel, wd, enableWatchDog)
|
||||
}
|
||||
|
||||
func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
|
||||
if !enableWatchDog {
|
||||
wd = nil
|
||||
}
|
||||
return &Client{
|
||||
address: address,
|
||||
parallel: parallel,
|
||||
wd: wd,
|
||||
}
|
||||
}
|
||||
|
||||
type Backend interface {
|
||||
IsBusy() bool
|
||||
HealthCheck(ctx context.Context) (bool, error)
|
||||
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
|
||||
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
||||
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
|
||||
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||
}
|
||||
@@ -27,17 +27,6 @@ type WatchDog interface {
|
||||
UnMark(address string)
|
||||
}
|
||||
|
||||
func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
|
||||
if !enableWatchDog {
|
||||
wd = nil
|
||||
}
|
||||
return &Client{
|
||||
address: address,
|
||||
parallel: parallel,
|
||||
wd: wd,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) IsBusy() bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
121
pkg/grpc/embed.go
Normal file
121
pkg/grpc/embed.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ Backend = new(embedBackend)
|
||||
var _ pb.Backend_PredictStreamServer = new(embedBackendServerStream)
|
||||
|
||||
type embedBackend struct {
|
||||
s *server
|
||||
}
|
||||
|
||||
func (e *embedBackend) IsBusy() bool {
|
||||
return e.s.llm.Busy()
|
||||
}
|
||||
|
||||
func (e *embedBackend) HealthCheck(ctx context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (e *embedBackend) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
return e.s.Embedding(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
|
||||
return e.s.Predict(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.LoadModel(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendServerStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
}
|
||||
return e.s.PredictStream(in, bs)
|
||||
}
|
||||
|
||||
func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.GenerateImage(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.TTS(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
||||
r, err := e.s.AudioTranscription(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := &schema.Result{}
|
||||
for _, s := range r.Segments {
|
||||
var tks []int
|
||||
for _, t := range s.Tokens {
|
||||
tks = append(tks, int(t))
|
||||
}
|
||||
tr.Segments = append(tr.Segments,
|
||||
schema.Segment{
|
||||
Text: s.Text,
|
||||
Id: int(s.Id),
|
||||
Start: time.Duration(s.Start),
|
||||
End: time.Duration(s.End),
|
||||
Tokens: tks,
|
||||
})
|
||||
}
|
||||
tr.Text = r.Text
|
||||
return tr, err
|
||||
}
|
||||
|
||||
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
|
||||
return e.s.TokenizeString(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) {
|
||||
return e.s.Status(ctx, &pb.HealthMessage{})
|
||||
}
|
||||
|
||||
type embedBackendServerStream struct {
|
||||
ctx context.Context
|
||||
fn func(s []byte)
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
|
||||
e.fn(reply.GetMessage())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) Context() context.Context {
|
||||
return e.ctx
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) SendMsg(m any) error {
|
||||
if x, ok := m.(*pb.Reply); ok {
|
||||
return e.Send(x)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendServerStream) RecvMsg(m any) error {
|
||||
return nil
|
||||
}
|
||||
@@ -181,3 +181,23 @@ func StartServer(address string, model LLM) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunServer(address string, model LLM) (func() error, error) {
|
||||
lis, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
pb.RegisterBackendServer(s, &server{llm: model})
|
||||
log.Printf("gRPC Server listening at %v", lis.Addr())
|
||||
if err = s.Serve(lis); err != nil {
|
||||
return func() error {
|
||||
return lis.Close()
|
||||
}, err
|
||||
}
|
||||
|
||||
return func() error {
|
||||
s.GracefulStop()
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
|
||||
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
|
||||
if parallel {
|
||||
return addr.GRPC(parallel, ml.wd), nil
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.C
|
||||
return ml.grpcClients[string(addr)], nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err error) {
|
||||
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
if o.model != "" {
|
||||
@@ -220,7 +220,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client *grpc.Client, err e
|
||||
return ml.resolveAddress(addr, o.parallelRequests)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
ml.mu.Lock()
|
||||
|
||||
@@ -59,7 +59,7 @@ type ModelLoader struct {
|
||||
ModelPath string
|
||||
mu sync.Mutex
|
||||
// TODO: this needs generics
|
||||
grpcClients map[string]*grpc.Client
|
||||
grpcClients map[string]grpc.Backend
|
||||
models map[string]ModelAddress
|
||||
grpcProcesses map[string]*process.Process
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
@@ -68,7 +68,7 @@ type ModelLoader struct {
|
||||
|
||||
type ModelAddress string
|
||||
|
||||
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
|
||||
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
|
||||
enableWD := false
|
||||
if wd != nil {
|
||||
enableWD = true
|
||||
@@ -79,7 +79,7 @@ func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
|
||||
func NewModelLoader(modelPath string) *ModelLoader {
|
||||
nml := &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
grpcClients: make(map[string]*grpc.Client),
|
||||
grpcClients: make(map[string]grpc.Backend),
|
||||
models: make(map[string]ModelAddress),
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
grpcProcesses: make(map[string]*process.Process),
|
||||
@@ -163,7 +163,7 @@ func (ml *ModelLoader) StopModel(modelName string) error {
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
|
||||
var client *grpc.Client
|
||||
var client grpc.Backend
|
||||
if m, ok := ml.models[s]; ok {
|
||||
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
||||
if c, ok := ml.grpcClients[s]; ok {
|
||||
|
||||
Reference in New Issue
Block a user