diff --git a/.gitignore b/.gitignore index 2ee2ab858..f9f861c22 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,8 @@ LocalAI models/* test-models/ test-dir/ +tests/e2e-aio/backends +tests/e2e-aio/models release/ diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index 047f0ab88..4a6ab6162 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -130,8 +130,9 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR segments := []*pb.TranscriptSegment{} text := "" for i := range int(segsLen) { - s := CppGetSegmentStart(i) - t := CppGetSegmentEnd(i) + // segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895 + s := CppGetSegmentStart(i) * (10000000) + t := CppGetSegmentEnd(i) * (10000000) txt := strings.Clone(CppGetSegmentText(i)) tokens := make([]int32, CppNTokens(i)) diff --git a/backend/python/faster-whisper/backend.py b/backend/python/faster-whisper/backend.py index df259420c..c94665b2b 100755 --- a/backend/python/faster-whisper/backend.py +++ b/backend/python/faster-whisper/backend.py @@ -40,7 +40,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): device = "mps" try: print("Preparing models, please wait", file=sys.stderr) - self.model = WhisperModel(request.Model, device=device, compute_type="float16") + self.model = WhisperModel(request.Model, device=device, compute_type="default") except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service @@ -55,11 +55,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): id = 0 for segment in segments: print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) - resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=segment.start, end=segment.end, text=segment.text)) + resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text)) text += segment.text - id += 1 + id += 1 except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) + raise err return backend_pb2.TranscriptResult(segments=resultSegments, text=text) diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 66e687813..62b04874c 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -12,8 +12,7 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { - +func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { if modelConfig.Backend == "" { modelConfig.Backend = model.WhisperBackend } diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 07da19893..78c232558 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -2,32 +2,42 @@ package cli import ( "context" + "encoding/json" "errors" "fmt" + "strings" "github.com/mudler/LocalAI/core/backend" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/format" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type TranscriptCMD struct { - Filename string `arg:""` + Filename string `arg:"" name:"file" help:"Audio file to transcribe" type:"path"` - Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"` - Model string `short:"m" required:"" help:"Model name to run the TTS"` - Language string `short:"l" help:"Language of the audio file"` - Translate bool `short:"c" help:"Translate the transcription to english"` - Diarize bool `short:"d" help:"Mark speaker turns"` - Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` - ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` - Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"` + Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"` + Model string `short:"m" required:"" help:"Model name to run the TTS"` + Language string `short:"l" help:"Language of the audio file"` + Translate bool `short:"c" help:"Translate the transcription to English"` + Diarize bool `short:"d" help:"Mark speaker turns"` + Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` + Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"` + ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, json_verbose)"` + PrettyPrint bool `help:"Used with response_format json or json_verbose for pretty printing"` } func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( + system.WithBackendPath(t.BackendsPath), system.WithModelPath(t.ModelsPath), ) if err != nil { @@ -40,6 +50,11 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { cl := config.NewModelConfigLoader(t.ModelsPath) ml := model.NewModelLoader(systemState) + + if err := gallery.RegisterBackends(systemState, ml); err != nil { + xlog.Error("error registering external backends", "error", err) + } + if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil { return err } @@ -62,8 +77,29 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { if err != nil { return err } - for _, segment := range tr.Segments { - fmt.Println(segment.Start.String(), "-", segment.Text) + + switch t.ResponseFormat { + case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText: + fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat)) + case schema.TranscriptionResponseFormatJson: + tr.Segments = nil + fallthrough + case schema.TranscriptionResponseFormatJsonVerbose: + var mtr []byte + var err error + if t.PrettyPrint { + mtr, err = json.MarshalIndent(tr, "", " ") + } else { + mtr, err = json.Marshal(tr) + } + if err != nil { + return err + } + fmt.Println(string(mtr)) + default: + for _, segment := range tr.Segments { + fmt.Println(segment.Start.String(), "-", strings.TrimSpace(segment.Text)) + } } return nil } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 2c5f98d5c..c52fe1914 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -1,6 +1,7 @@ package openai import ( + "errors" "io" "net/http" "os" @@ -12,6 +13,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/format" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" @@ -38,6 +40,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app diarize := c.FormValue("diarize") != "false" prompt := c.FormValue("prompt") + responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format")) // retrieve the file data from the request file, err := c.FormFile("file") @@ -76,7 +79,17 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app } xlog.Debug("Transcribed", "transcription", tr) - // TODO: handle different outputs here - return c.JSON(http.StatusOK, tr) + + switch responseFormat { + case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt: + return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat)) + case schema.TranscriptionResponseFormatJson: + tr.Segments = nil + fallthrough + case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility + return c.JSON(http.StatusOK, tr) + default: + return errors.New("invalid response_format") + } } } diff --git a/core/schema/openai.go b/core/schema/openai.go index 74ed2859e..e94b99ba7 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -107,6 +107,17 @@ type ImageGenerationResponseFormat string type ChatCompletionResponseFormatType string +type TranscriptionResponseFormatType string + +const ( + TranscriptionResponseFormatText = TranscriptionResponseFormatType("txt") + TranscriptionResponseFormatSrt = TranscriptionResponseFormatType("srt") + TranscriptionResponseFormatVtt = TranscriptionResponseFormatType("vtt") + TranscriptionResponseFormatLrc = TranscriptionResponseFormatType("lrc") + TranscriptionResponseFormatJson = TranscriptionResponseFormatType("json") + TranscriptionResponseFormatJsonVerbose = TranscriptionResponseFormatType("json_verbose") +) + type ChatCompletionResponseFormat struct { Type ChatCompletionResponseFormatType `json:"type,omitempty"` } diff --git a/core/schema/transcription.go b/core/schema/transcription.go index 492030e56..d843a9d98 100644 --- a/core/schema/transcription.go +++ b/core/schema/transcription.go @@ -11,6 +11,6 @@ type TranscriptionSegment struct { } type TranscriptionResult struct { - Segments []TranscriptionSegment `json:"segments"` + Segments []TranscriptionSegment `json:"segments,omitempty"` Text string `json:"text"` } diff --git a/core/startup/model_preload.go b/core/startup/model_preload.go index 985276cf4..933114518 100644 --- a/core/startup/model_preload.go +++ b/core/startup/model_preload.go @@ -18,10 +18,6 @@ import ( "github.com/mudler/xlog" ) -const ( - YAML_EXTENSION = ".yaml" -) - // InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration diff --git a/docs/content/features/audio-to-text.md b/docs/content/features/audio-to-text.md index 2b91a8071..2b5e9b2cb 100644 --- a/docs/content/features/audio-to-text.md +++ b/docs/content/features/audio-to-text.md @@ -7,7 +7,9 @@ url = "/features/audio-to-text/" Audio to text models are models that can generate text from an audio file. -The transcription endpoint allows to convert audio files to text. The endpoint is based on [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ library for audio transcription. The endpoint input supports all the audio formats supported by `ffmpeg`. +The transcription endpoint allows to convert audio files to text. The endpoint is based +on [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ library for audio transcription. The endpoint input +supports all the audio formats supported by `ffmpeg`. ## Usage @@ -21,7 +23,8 @@ curl http://localhost:8080/v1/audio/transcriptions -H "Content-Type: multipart/f ## Example -Download one of the models from [here](https://huggingface.co/ggerganov/whisper.cpp/tree/main) in the `models` folder, and create a YAML file for your model: +Download one of the models from [here](https://huggingface.co/ggerganov/whisper.cpp/tree/main) in the `models` folder, +and create a YAML file for your model: ```yaml name: whisper-1 @@ -38,7 +41,48 @@ wget --quiet --show-progress -O gb1.ogg https://upload.wikimedia.org/wikipedia/c ## Send the example audio file to the transcriptions endpoint curl http://localhost:8080/v1/audio/transcriptions -H "Content-Type: multipart/form-data" -F file="@$PWD/gb1.ogg" -F model="whisper-1" +``` -## Result -{"text":"My fellow Americans, this day has brought terrible news and great sadness to our country.At nine o'clock this morning, Mission Control in Houston lost contact with our Space ShuttleColumbia.A short time later, debris was seen falling from the skies above Texas.The Columbia's lost.There are no survivors.One board was a crew of seven.Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain DavidBrown, Commander William McCool, Dr. Kultna Shavla, and Elon Ramon, a colonel in the IsraeliAir Force.These men and women assumed great risk in the service to all humanity.In an age when spaceflight has come to seem almost routine, it is easy to overlook thedangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere ofthe Earth.These astronauts knew the dangers, and they faced them willingly, knowing they had a highand noble purpose in life.Because of their courage and daring and idealism, we will miss them all the more.All Americans today are thinking as well of the families of these men and women who havebeen given this sudden shock and grief.You're not alone.Our entire nation agrees with you, and those you loved will always have the respect andgratitude of this country.The cause in which they died will continue.Mankind has led into the darkness beyond our world by the inspiration of discovery andthe longing to understand.Our journey into space will go on.In the skies today, we saw destruction and tragedy.As farther than we can see, there is comfort and hope.In the words of the prophet Isaiah, \"Lift your eyes and look to the heavens who createdall these, he who brings out the starry hosts one by one and calls them each by name.\"Because of his great power and mighty strength, not one of them is missing.The same creator who names the stars also knows the names of the seven souls we mourntoday.The crew of the shuttle Columbia did not return safely to Earth yet we can pray that all aresafely home.May God bless the grieving families and may God continue to bless America.[BLANK_AUDIO]"} -``` \ No newline at end of file +Result: + +```json +{ + "segments":[{"id":0,"start":0,"end":9640000000,"text":" My fellow Americans, this day has brought terrible news and great sadness to our country.","tokens":[50364,1222,7177,6280,11,341,786,575,3038,6237,2583,293,869,22462,281,527,1941,13,50846]},{"id":1,"start":9640000000,"end":15960000000,"text":" At 9 o'clock this morning, Mission Control and Houston lost contact with our Space Shuttle","tokens":[1711,1722,277,6,9023,341,2446,11,20170,12912,293,18717,2731,3385,365,527,8705,13870,10972,51162]},{"id":2,"start":15960000000,"end":16960000000,"text":" Columbia.","tokens":[17339,13,51212]},{"id":3,"start":16960000000,"end":24640000000,"text":" A short time later, debris was seen falling from the skies above Texas.","tokens":[316,2099,565,1780,11,21942,390,1612,7440,490,264,25861,3673,7885,13,51596]},{"id":4,"start":24640000000,"end":27200000000,"text":" The Columbia's lost.","tokens":[440,17339,311,2731,13,51724]},{"id":5,"start":27200000000,"end":29920000000,"text":" There are no survivors.","tokens":[821,366,572,18369,13,51860]},{"id":6,"start":29920000000,"end":32920000000,"text":" And board was a crew of seven.","tokens":[50364,400,3150,390,257,7260,295,3407,13,50514]},{"id":7,"start":32920000000,"end":39780000000,"text":" Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain","tokens":[28478,11224,21282,4235,11,28412,28478,5116,18768,11,20857,27270,75,18572,11,10873,50857]},{"id":8,"start":39780000000,"end":50020000000,"text":" David Brown, Commander William McCool, Dr. Cooltna Chavla, and Elon Ramon, a Colonel","tokens":[4389,8030,11,20857,6740,4050,34,1092,11,2491,13,8561,83,629,761,706,875,11,293,28498,9078,266,11,257,28478,51369]},{"id":9,"start":50020000000,"end":52800000000,"text":" in the Israeli Air Force.","tokens":[294,264,19974,5774,10580,13,51508]},{"id":10,"start":52800000000,"end":58480000000,"text":" These men and women assumed great risk in the service to all humanity.","tokens":[1981,1706,293,2266,15895,869,3148,294,264,2643,281,439,10243,13,51792]},{"id":11,"start":58480000000,"end":63120000000,"text":" And an age when Space Flight has come to seem almost routine.","tokens":[50364,400,364,3205,562,8705,28954,575,808,281,1643,1920,9927,13,50596]},{"id":12,"start":63120000000,"end":68800000000,"text":" It is easy to overlook the dangers of travel by rocket and the difficulties of navigating","tokens":[467,307,1858,281,37826,264,27701,295,3147,538,13012,293,264,14399,295,32054,50880]},{"id":13,"start":68800000000,"end":72640000000,"text":" the fierce outer atmosphere of the Earth.","tokens":[264,25341,10847,8018,295,264,4755,13,51072]},{"id":14,"start":72640000000,"end":78040000000,"text":" These astronauts knew the dangers and they faced them willingly.","tokens":[1981,28273,2586,264,27701,293,436,11446,552,44675,13,51342]},{"id":15,"start":78040000000,"end":83040000000,"text":" Knowing they had a high and noble purpose in life.","tokens":[25499,436,632,257,1090,293,20171,4334,294,993,13,51592]},{"id":16,"start":83040000000,"end":90800000000,"text":" Because of their courage and daring and idealism, we will miss them all the more.","tokens":[50364,1436,295,641,9892,293,43128,293,7157,1434,11,321,486,1713,552,439,264,544,13,50752]},{"id":17,"start":90800000000,"end":96560000000,"text":" All Americans today are thinking as well of the families of these men and women who have","tokens":[1057,6280,965,366,1953,382,731,295,264,4466,295,613,1706,293,2266,567,362,51040]},{"id":18,"start":96560000000,"end":100440000000,"text":" been given this sudden shock in grief.","tokens":[668,2212,341,3990,5588,294,18998,13,51234]},{"id":19,"start":100440000000,"end":102400000000,"text":" You're not alone.","tokens":[509,434,406,3312,13,51332]},{"id":20,"start":102400000000,"end":105440000000,"text":" Our entire nation agrees with you.","tokens":[2621,2302,4790,26383,365,291,13,51484]},{"id":21,"start":105440000000,"end":112360000000,"text":" And those you loved will always have the respect and gratitude of this country.","tokens":[400,729,291,4333,486,1009,362,264,3104,293,16935,295,341,1941,13,51830]},{"id":22,"start":112360000000,"end":116600000000,"text":" The cause in which they died will continue.","tokens":[50364,440,3082,294,597,436,4539,486,2354,13,50576]},{"id":23,"start":116600000000,"end":124240000000,"text":" Man kind is led into the darkness beyond our world by the inspiration of discovery and the","tokens":[2458,733,307,4684,666,264,11262,4399,527,1002,538,264,10249,295,12114,293,264,50958]},{"id":24,"start":124240000000,"end":127000000000,"text":" longing to understand.","tokens":[35050,281,1223,13,51096]},{"id":25,"start":127000000000,"end":131160000000,"text":" Our journey into space will go on.","tokens":[2621,4671,666,1901,486,352,322,13,51304]},{"id":26,"start":131160000000,"end":136480000000,"text":" In the skies today, we saw destruction and tragedy.","tokens":[682,264,25861,965,11,321,1866,13563,293,18563,13,51570]},{"id":27,"start":136480000000,"end":142080000000,"text":" As farther than we can see, there is comfort and hope.","tokens":[1018,20344,813,321,393,536,11,456,307,3400,293,1454,13,51850]},{"id":28,"start":142080000000,"end":149800000000,"text":" In the words of the prophet Isaiah, lift your eyes and look to the heavens who created","tokens":[50364,682,264,2283,295,264,18566,27263,11,5533,428,2575,293,574,281,264,26011,567,2942,50750]},{"id":29,"start":149800000000,"end":151640000000,"text":" all these.","tokens":[439,613,13,50842]},{"id":30,"start":151640000000,"end":159960000000,"text":" He who brings out the story hosts one by one and calls them each by name because of his great","tokens":[634,567,5607,484,264,1657,21573,472,538,472,293,5498,552,1184,538,1315,570,295,702,869,51258]},{"id":31,"start":159960000000,"end":163400000000,"text":" power and mighty strength.","tokens":[1347,293,21556,3800,13,51430]},{"id":32,"start":163400000000,"end":166400000000,"text":" Not one of them is missing.","tokens":[1726,472,295,552,307,5361,13,51580]},{"id":33,"start":166400000000,"end":173600000000,"text":" The same creator who names the stars also knows the names of the seven souls we mourn","tokens":[50364,440,912,14181,567,5288,264,6105,611,3255,264,5288,295,264,3407,16588,321,22235,77,50724]},{"id":34,"start":173600000000,"end":175600000000,"text":" today.","tokens":[965,13,50824]},{"id":35,"start":175600000000,"end":183160000000,"text":" The crew of the shuttle Columbia did not return safely to earth yet we can pray that all","tokens":[440,7260,295,264,26728,17339,630,406,2736,11750,281,4120,1939,321,393,3690,300,439,51202]},{"id":36,"start":183160000000,"end":185840000000,"text":" are safely home.","tokens":[366,11750,1280,13,51336]},{"id":37,"start":185840000000,"end":192600000000,"text":" May God bless the grieving families and may God continue to bless America.","tokens":[1891,1265,5227,264,48454,4466,293,815,1265,2354,281,5227,3374,13,51674]},{"id":38,"start":196400000000,"end":206400000000,"text":" [BLANK_AUDIO]","tokens":[50364,542,37592,62,29937,60,50864]}], + "text":"My fellow Americans, this day has brought terrible news and great sadness to our country. At 9 o'clock this morning, Mission Control and Houston lost contact with our Space Shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. And board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Cooltna Chavla, and Elon Ramon, a Colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. And an age when Space Flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the Earth. These astronauts knew the dangers and they faced them willingly. Knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. All Americans today are thinking as well of the families of these men and women who have been given this sudden shock in grief. You're not alone. Our entire nation agrees with you. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Man kind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. As farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens who created all these. He who brings out the story hosts one by one and calls them each by name because of his great power and mighty strength. Not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to earth yet we can pray that all are safely home. May God bless the grieving families and may God continue to bless America. [BLANK_AUDIO]" +} +``` + +--- + +You can also specify the `response_format` parameter to be one of `lrc`, `srt`, `vtt`, `text`, `json` or `json_verbose` (default): +```bash +## Send the example audio file to the transcriptions endpoint +curl http://localhost:8080/v1/audio/transcriptions -H "Content-Type: multipart/form-data" -F file="@$PWD/gb1.ogg" -F model="whisper-1" -F response_format="srt" +``` + +Result (first few lines): +```text +1 +00:00:00,000 --> 00:00:09,640 +My fellow Americans, this day has brought terrible news and great sadness to our country. + +2 +00:00:09,640 --> 00:00:15,960 +At 9 o'clock this morning, Mission Control and Houston lost contact with our Space Shuttle + +3 +00:00:15,960 --> 00:00:16,960 +Columbia. + +4 +00:00:16,960 --> 00:00:24,640 +A short time later, debris was seen falling from the skies above Texas. + +5 +00:00:24,640 --> 00:00:27,200 +The Columbia's lost. + +6 +00:00:27,200 --> 00:00:29,920 +There are no survivors. +``` diff --git a/go.mod b/go.mod index c94f80b7d..24529b74b 100644 --- a/go.mod +++ b/go.mod @@ -67,9 +67,10 @@ require ( require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/labstack/gommon v0.4.2 // indirect + github.com/openai/openai-go/v3 v3.17.0 // indirect github.com/swaggo/files/v2 v2.0.2 // indirect github.com/tidwall/gjson v1.18.0 // indirect - github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect diff --git a/go.sum b/go.sum index cf3ed33bf..142b8af9e 100644 --- a/go.sum +++ b/go.sum @@ -565,6 +565,8 @@ github.com/onsi/ginkgo/v2 v2.27.5 h1:ZeVgZMx2PDMdJm/+w5fE/OyG6ILo1Y3e+QX4zSR0zTE github.com/onsi/ginkgo/v2 v2.27.5/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q= github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4= +github.com/openai/openai-go/v3 v3.17.0 h1:CfTkmQoItolSyW+bHOUF190KuX5+1Zv6MC0Gb4wAwy8= +github.com/openai/openai-go/v3 v3.17.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -769,6 +771,8 @@ github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= diff --git a/pkg/format/transcription.go b/pkg/format/transcription.go new file mode 100644 index 000000000..1625d02bb --- /dev/null +++ b/pkg/format/transcription.go @@ -0,0 +1,41 @@ +package format + +import ( + "fmt" + "strings" + "time" + + "github.com/mudler/LocalAI/core/schema" +) + +func TranscriptionResponse(tr *schema.TranscriptionResult, resFmt schema.TranscriptionResponseFormatType) string { + var out string + if resFmt == schema.TranscriptionResponseFormatLrc { + out = "[by:LocalAI]\n[re:LocalAI]\n" + } else if resFmt == schema.TranscriptionResponseFormatVtt { + out = "WEBVTT" + } + + for i, s := range tr.Segments { + switch resFmt { + case schema.TranscriptionResponseFormatLrc: + m := s.Start.Milliseconds() + out += fmt.Sprintf("\n[%02d:%02d:%02d] %s", m/60000, (m/1000)%60, (m%1000)/10, strings.TrimSpace(s.Text)) + case schema.TranscriptionResponseFormatSrt: + out += fmt.Sprintf("\n\n%d\n%s --> %s\n%s", i+1, durationStr(s.Start, ','), durationStr(s.End, ','), strings.TrimSpace(s.Text)) + case schema.TranscriptionResponseFormatVtt: + out += fmt.Sprintf("\n\n%s --> %s\n%s\n", durationStr(s.Start, '.'), durationStr(s.End, '.'), strings.TrimSpace(s.Text)) + case schema.TranscriptionResponseFormatText: + fallthrough + default: + out += fmt.Sprintf("\n%s", strings.TrimSpace(s.Text)) + } + } + + return out +} + +func durationStr(d time.Duration, millisSeparator rune) string { + m := d.Milliseconds() + return fmt.Sprintf("%02d:%02d:%02d%c%03d", m/3600000, m/60000, int(d.Seconds())%60, millisSeparator, m%1000) +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 78910f637..45bd3b6af 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -17,7 +17,7 @@ const ( LLamaCPP = "llama-cpp" ) -var Aliases map[string]string = map[string]string{ +var Aliases = map[string]string{ "go-llama": LLamaCPP, "llama": LLamaCPP, "embedded-store": LocalStoreBackend, @@ -29,7 +29,7 @@ var Aliases map[string]string = map[string]string{ "stablediffusion": StableDiffusionGGMLBackend, } -var TypeAlias map[string]string = map[string]string{ +var TypeAlias = map[string]string{ "sentencetransformers": "SentenceTransformer", "huggingface-embeddings": "SentenceTransformer", "mamba": "Mamba", @@ -75,7 +75,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Check if the backend is provided as external if uri, ok := ml.GetAllExternalBackends(o)[backend]; ok { xlog.Debug("Loading external backend", "uri", uri) - // check if uri is a file or a address + // check if uri is a file or an address if fi, err := os.Stat(uri); err == nil { xlog.Debug("external backend is file", "file", fi) serverAddress, err := getFreeAddress() diff --git a/tests/e2e-aio/e2e_suite_test.go b/tests/e2e-aio/e2e_suite_test.go index 0aee7b81f..85c8b0e3e 100644 --- a/tests/e2e-aio/e2e_suite_test.go +++ b/tests/e2e-aio/e2e_suite_test.go @@ -11,13 +11,14 @@ import ( "github.com/docker/go-connections/nat" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/sashabaranov/go-openai" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) var container testcontainers.Container -var client *openai.Client +var client openai.Client var containerImage = os.Getenv("LOCALAI_IMAGE") var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG") @@ -37,26 +38,22 @@ func TestLocalAI(t *testing.T) { var _ = BeforeSuite(func() { - var defaultConfig openai.ClientConfig if apiEndpoint == "" { startDockerImage() - apiPort, err := container.MappedPort(context.Background(), nat.Port(defaultApiPort)) + apiPort, err := container.MappedPort(context.Background(), defaultApiPort) Expect(err).To(Not(HaveOccurred())) - defaultConfig = openai.DefaultConfig(apiKey) apiEndpoint = "http://localhost:" + apiPort.Port() + "/v1" // So that other tests can reference this value safely. - defaultConfig.BaseURL = apiEndpoint } else { GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint) - defaultConfig = openai.DefaultConfig(apiKey) - defaultConfig.BaseURL = apiEndpoint } + opts := []option.RequestOption{option.WithAPIKey(apiKey), option.WithBaseURL(apiEndpoint)} // Wait for API to be ready - client = openai.NewClientWithConfig(defaultConfig) + client = openai.NewClient(opts...) Eventually(func() error { - _, err := client.ListModels(context.TODO()) + _, err := client.Models.List(context.TODO()) return err }, "50m").ShouldNot(HaveOccurred()) }) diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index 8421772a9..c3113a9a6 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -12,8 +12,8 @@ import ( "github.com/mudler/LocalAI/core/schema" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" ) var _ = Describe("E2E test", func() { @@ -30,14 +30,13 @@ var _ = Describe("E2E test", func() { Context("text", func() { It("correctly", func() { model := "gpt-4" - resp, err := client.CreateChatCompletion(context.TODO(), - openai.ChatCompletionRequest{ - Model: model, Messages: []openai.ChatCompletionMessage{ - { - Role: "user", - Content: "How much is 2+2?", - }, - }}) + resp, err := client.Chat.Completions.New(context.TODO(), + openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("How much is 2+2?"), + }, + }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content)) @@ -46,39 +45,36 @@ var _ = Describe("E2E test", func() { Context("function calls", func() { It("correctly invoke", func() { - params := jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "location": { - Type: jsonschema.String, - Description: "The city and state, e.g. San Francisco, CA", + params := openai.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "location": map[string]string{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", }, - "unit": { - Type: jsonschema.String, - Enum: []string{"celsius", "fahrenheit"}, + "unit": map[string]any{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, }, }, - Required: []string{"location"}, + "required": []string{"location"}, } - f := openai.FunctionDefinition{ - Name: "get_current_weather", - Description: "Get the current weather in a given location", - Parameters: params, - } - t := openai.Tool{ - Type: openai.ToolTypeFunction, - Function: &f, + tool := openai.ChatCompletionToolUnionParam{ + OfFunction: &openai.ChatCompletionFunctionToolParam{ + Function: openai.FunctionDefinitionParam{ + Name: "get_current_weather", + Description: openai.String("Get the current weather in a given location"), + Parameters: params, + }, + }, } - dialogue := []openai.ChatCompletionMessage{ - {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, - } - resp, err := client.CreateChatCompletion(context.TODO(), - openai.ChatCompletionRequest{ - Model: openai.GPT4, - Messages: dialogue, - Tools: []openai.Tool{t}, + resp, err := client.Chat.Completions.New(context.TODO(), + openai.ChatCompletionNewParams{ + Model: openai.ChatModelGPT4, + Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage("What is the weather in Boston today?")}, + Tools: []openai.ChatCompletionToolUnionParam{tool}, }, ) Expect(err).ToNot(HaveOccurred()) @@ -90,23 +86,21 @@ var _ = Describe("E2E test", func() { Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments)) }) }) + Context("json", func() { It("correctly", func() { model := "gpt-4" - req := openai.ChatCompletionRequest{ - ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, - Model: model, - Messages: []openai.ChatCompletionMessage{ - { - - Role: "user", - Content: "Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields", + resp, err := client.Chat.Completions.New(context.TODO(), + openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields"), }, - }, - } - - resp, err := client.CreateChatCompletion(context.TODO(), req) + ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{ + OfJSONObject: &openai.ResponseFormatJSONObjectParam{}, + }, + }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) @@ -121,23 +115,23 @@ var _ = Describe("E2E test", func() { Context("images", func() { It("correctly", func() { - req := openai.ImageRequest{ - Prompt: "test", - Quality: "1", - Size: openai.CreateImageSize256x256, - } - resp, err := client.CreateImage(context.TODO(), req) - Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req)) + resp, err := client.Images.Generate(context.TODO(), + openai.ImageGenerateParams{ + Prompt: "test", + Size: openai.ImageGenerateParamsSize256x256, + Quality: openai.ImageGenerateParamsQualityLow, + }) + Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request")) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) }) It("correctly changes the response format to url", func() { - resp, err := client.CreateImage(context.TODO(), - openai.ImageRequest{ + resp, err := client.Images.Generate(context.TODO(), + openai.ImageGenerateParams{ Prompt: "test", - Size: openai.CreateImageSize256x256, - Quality: "1", - ResponseFormat: openai.CreateImageResponseFormatURL, + Size: openai.ImageGenerateParamsSize256x256, + ResponseFormat: openai.ImageGenerateParamsResponseFormatURL, + Quality: openai.ImageGenerateParamsQualityLow, }, ) Expect(err).ToNot(HaveOccurred()) @@ -145,12 +139,11 @@ var _ = Describe("E2E test", func() { Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) }) It("correctly changes the response format to base64", func() { - resp, err := client.CreateImage(context.TODO(), - openai.ImageRequest{ + resp, err := client.Images.Generate(context.TODO(), + openai.ImageGenerateParams{ Prompt: "test", - Size: openai.CreateImageSize256x256, - Quality: "1", - ResponseFormat: openai.CreateImageResponseFormatB64JSON, + Size: openai.ImageGenerateParamsSize256x256, + ResponseFormat: openai.ImageGenerateParamsResponseFormatB64JSON, }, ) Expect(err).ToNot(HaveOccurred()) @@ -158,22 +151,27 @@ var _ = Describe("E2E test", func() { Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON)) }) }) + Context("embeddings", func() { It("correctly", func() { - resp, err := client.CreateEmbeddings(context.TODO(), - openai.EmbeddingRequestStrings{ - Input: []string{"doc"}, - Model: openai.AdaEmbeddingV2, + resp, err := client.Embeddings.New(context.TODO(), + openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"doc"}, + }, + Model: openai.EmbeddingModelTextEmbeddingAda002, }, ) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Data[0].Embedding).ToNot(BeEmpty()) - resp2, err := client.CreateEmbeddings(context.TODO(), - openai.EmbeddingRequestStrings{ - Input: []string{"cat"}, - Model: openai.AdaEmbeddingV2, + resp2, err := client.Embeddings.New(context.TODO(), + openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"cat"}, + }, + Model: openai.EmbeddingModelTextEmbeddingAda002, }, ) Expect(err).ToNot(HaveOccurred()) @@ -181,10 +179,12 @@ var _ = Describe("E2E test", func() { Expect(resp2.Data[0].Embedding).ToNot(BeEmpty()) Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding)) - resp3, err := client.CreateEmbeddings(context.TODO(), - openai.EmbeddingRequestStrings{ - Input: []string{"doc", "cat"}, - Model: openai.AdaEmbeddingV2, + resp3, err := client.Embeddings.New(context.TODO(), + openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"doc", "cat"}, + }, + Model: openai.EmbeddingModelTextEmbeddingAda002, }, ) Expect(err).ToNot(HaveOccurred()) @@ -195,66 +195,101 @@ var _ = Describe("E2E test", func() { Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding)) }) }) + Context("vision", func() { It("correctly", func() { model := "gpt-4o" - resp, err := client.CreateChatCompletion(context.TODO(), - openai.ChatCompletionRequest{ - Model: model, Messages: []openai.ChatCompletionMessage{ + resp, err := client.Chat.Completions.New(context.TODO(), + openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ { - - Role: "user", - MultiContent: []openai.ChatMessagePart{ - { - Type: openai.ChatMessagePartTypeText, - Text: "What is in the image?", - }, - { - Type: openai.ChatMessagePartTypeImageURL, - ImageURL: &openai.ChatMessageImageURL{ - URL: "https://picsum.photos/id/22/4434/3729", - Detail: openai.ImageURLDetailLow, + OfUser: &openai.ChatCompletionUserMessageParam{ + Role: "user", + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfArrayOfContentParts: []openai.ChatCompletionContentPartUnionParam{ + { + OfText: &openai.ChatCompletionContentPartTextParam{ + Type: "text", + Text: "What is in the image?", + }, + }, + { + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://picsum.photos/id/22/4434/3729", + Detail: "low", + }, + }, + }, }, }, }, }, - }}) + }, + }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("man"), ContainSubstring("road")), fmt.Sprint(resp.Choices[0].Message.Content)) }) }) + Context("text to audio", func() { It("correctly", func() { - res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, + res, err := client.Audio.Speech.New(context.Background(), openai.AudioSpeechNewParams{ + Model: openai.SpeechModelTTS1, Input: "Hello!", - Voice: openai.VoiceAlloy, + Voice: openai.AudioSpeechNewParamsVoiceAlloy, }) Expect(err).ToNot(HaveOccurred()) - defer res.Close() + defer res.Body.Close() - _, err = io.ReadAll(res) + _, err = io.ReadAll(res.Body) Expect(err).ToNot(HaveOccurred()) - }) }) + Context("audio to text", func() { It("correctly", func() { - downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" file, err := downloadHttpFile(downloadURL) Expect(err).ToNot(HaveOccurred()) - req := openai.AudioRequest{ - Model: openai.Whisper1, - FilePath: file, - } - resp, err := client.CreateTranscription(context.Background(), req) + fileHandle, err := os.Open(file) Expect(err).ToNot(HaveOccurred()) + defer fileHandle.Close() + + transcriptionResp, err := client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{ + Model: openai.AudioModelWhisper1, + File: fileHandle, + }) + Expect(err).ToNot(HaveOccurred()) + resp := transcriptionResp.AsTranscription() Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text)) }) + + It("with VTT format", func() { + downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" + file, err := downloadHttpFile(downloadURL) + Expect(err).ToNot(HaveOccurred()) + + fileHandle, err := os.Open(file) + Expect(err).ToNot(HaveOccurred()) + defer fileHandle.Close() + + var resp string + _, err = client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{ + Model: openai.AudioModelWhisper1, + File: fileHandle, + ResponseFormat: openai.AudioResponseFormatVTT, + }, option.WithResponseBodyInto(&resp)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp).To(ContainSubstring("This is the"), resp) + Expect(resp).To(ContainSubstring("WEBVTT"), resp) + Expect(resp).To(ContainSubstring("00:00:00.000 -->"), resp) + }) }) + Context("vad", func() { It("correctly", func() { modelName := "silero-vad" @@ -283,6 +318,7 @@ var _ = Describe("E2E test", func() { Expect(deserializedResponse.Segments).ToNot(BeZero()) }) }) + Context("reranker", func() { It("correctly", func() { modelName := "jina-reranker-v1-base-en" @@ -317,7 +353,6 @@ var _ = Describe("E2E test", func() { Expect(err).To(BeNil()) Expect(deserializedResponse).ToNot(BeZero()) Expect(deserializedResponse.Model).To(Equal(modelName)) - //Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0)) Expect(len(deserializedResponse.Results)).To(Equal(expectResults)) // Assert that relevance scores are in decreasing order for i := 1; i < len(deserializedResponse.Results); i++ { diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index 85d136d3e..375d4d7c0 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -17,14 +17,14 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/phayes/freeport" - "github.com/sashabaranov/go-openai" "gopkg.in/yaml.v3" "github.com/mudler/xlog" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" ) var ( - localAIURL string anthropicBaseURL string tmpDir string backendPath string @@ -33,7 +33,7 @@ var ( app *echo.Echo appCtx context.Context appCancel context.CancelFunc - client *openai.Client + client openai.Client apiPort int apiURL string mockBackendPath string @@ -129,7 +129,6 @@ var _ = BeforeSuite(func() { Expect(err).ToNot(HaveOccurred()) apiPort = port apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort) - localAIURL = apiURL // Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort) @@ -141,12 +140,10 @@ var _ = BeforeSuite(func() { }() // Wait for server to be ready - defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = apiURL - client = openai.NewClientWithConfig(defaultConfig) + client = openai.NewClient(option.WithBaseURL(apiURL)) Eventually(func() error { - _, err := client.ListModels(context.TODO()) + _, err := client.Models.List(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go index 0585209dd..1e2f739b8 100644 --- a/tests/e2e/mock_backend_test.go +++ b/tests/e2e/mock_backend_test.go @@ -9,22 +9,19 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/sashabaranov/go-openai" + "github.com/openai/openai-go/v3" ) var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { Describe("Text Generation APIs", func() { Context("Predict (Chat Completions)", func() { It("should return mocked response", func() { - resp, err := client.CreateChatCompletion( + resp, err := client.Chat.Completions.New( context.TODO(), - openai.ChatCompletionRequest{ + openai.ChatCompletionNewParams{ Model: "mock-model", - Messages: []openai.ChatCompletionMessage{ - { - Role: "user", - Content: "Hello", - }, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello"), }, }, ) @@ -36,31 +33,23 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { Context("PredictStream (Streaming Chat Completions)", func() { It("should stream mocked tokens", func() { - stream, err := client.CreateChatCompletionStream( + stream := client.Chat.Completions.NewStreaming( context.TODO(), - openai.ChatCompletionRequest{ + openai.ChatCompletionNewParams{ Model: "mock-model", - Messages: []openai.ChatCompletionMessage{ - { - Role: "user", - Content: "Hello", - }, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello"), }, }, ) - Expect(err).ToNot(HaveOccurred()) - defer stream.Close() - hasContent := false - for { - response, err := stream.Recv() - if err != nil { - break - } + for stream.Next() { + response := stream.Current() if len(response.Choices) > 0 && response.Choices[0].Delta.Content != "" { hasContent = true } } + Expect(stream.Err()).ToNot(HaveOccurred()) Expect(hasContent).To(BeTrue()) }) }) @@ -68,11 +57,13 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { Describe("Embeddings API", func() { It("should return mocked embeddings", func() { - resp, err := client.CreateEmbeddings( + resp, err := client.Embeddings.New( context.TODO(), - openai.EmbeddingRequest{ + openai.EmbeddingNewParams{ Model: "mock-model", - Input: []string{"test"}, + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test"}, + }, }, ) Expect(err).ToNot(HaveOccurred())