mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat(realtime): WebRTC support (#8790)
* feat(realtime): WebRTC support Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(tracing): Show full LLM opts and deltas Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
4e3bf2752d
commit
f9a850c02a
18
.github/workflows/backend.yml
vendored
18
.github/workflows/backend.yml
vendored
@@ -2014,6 +2014,20 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
#opus
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-opus'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "opus"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
#silero-vad
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2347,6 +2361,10 @@ jobs:
|
||||
tag-suffix: "-metal-darwin-arm64-piper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "opus"
|
||||
tag-suffix: "-metal-darwin-arm64-opus"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "silero-vad"
|
||||
tag-suffix: "-metal-darwin-arm64-silero-vad"
|
||||
build-type: "metal"
|
||||
|
||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install curl ffmpeg
|
||||
sudo apt-get install curl ffmpeg libopus-dev
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
@@ -195,7 +195,7 @@ jobs:
|
||||
run: go version
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
|
||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus
|
||||
pip install --user --no-cache-dir grpcio-tools grpcio
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
|
||||
2
.github/workflows/tests-e2e.yml
vendored
2
.github/workflows/tests-e2e.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
sudo apt-get install -y build-essential libopus-dev
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -38,6 +38,7 @@ test-models/
|
||||
test-dir/
|
||||
tests/e2e-aio/backends
|
||||
tests/e2e-aio/models
|
||||
mock-backend
|
||||
|
||||
release/
|
||||
|
||||
@@ -69,3 +70,6 @@ docs/static/gallery.html
|
||||
# React UI build artifacts (keep placeholder dist/index.html)
|
||||
core/http/react-ui/node_modules/
|
||||
core/http/react-ui/dist
|
||||
|
||||
# Extracted backend binaries for container-based testing
|
||||
local-backends/
|
||||
|
||||
@@ -10,7 +10,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl wget espeak-ng libgomp1 \
|
||||
ffmpeg libopenblas0 libopenblas-dev sox && \
|
||||
ffmpeg libopenblas0 libopenblas-dev libopus0 sox && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -190,6 +190,7 @@ RUN apt-get update && \
|
||||
curl libssl-dev \
|
||||
git \
|
||||
git-lfs \
|
||||
libopus-dev pkg-config \
|
||||
unzip upx-ucl python3 python-is-python3 && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
@@ -378,6 +379,9 @@ COPY ./entrypoint.sh .
|
||||
|
||||
# Copy the binary
|
||||
COPY --from=builder /build/local-ai ./
|
||||
# Copy the opus shim if it was built
|
||||
RUN --mount=from=builder,src=/build/,dst=/mnt/build \
|
||||
if [ -f /mnt/build/libopusshim.so ]; then cp /mnt/build/libopusshim.so ./; fi
|
||||
|
||||
# Make sure the models directory exists
|
||||
RUN mkdir -p /models /backends
|
||||
|
||||
88
Makefile
88
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -106,6 +106,7 @@ react-ui-docker:
|
||||
core/http/react-ui/dist: react-ui
|
||||
|
||||
## Build:
|
||||
|
||||
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project
|
||||
$(info ${GREEN}I local-ai build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
@@ -163,6 +164,7 @@ test: test-models/testmodel.ggml protogen-go
|
||||
@echo 'Running tests'
|
||||
export GO_TAGS="debug"
|
||||
$(MAKE) prepare-test
|
||||
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
|
||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/transformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||
$(MAKE) test-llama-gguf
|
||||
@@ -250,6 +252,88 @@ test-stablediffusion: prepare-test
|
||||
test-stores:
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration
|
||||
|
||||
test-opus:
|
||||
@echo 'Running opus backend tests'
|
||||
$(MAKE) -C backend/go/opus libopusshim.so
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/...
|
||||
|
||||
test-opus-docker:
|
||||
@echo 'Running opus backend tests in Docker'
|
||||
docker build --target builder \
|
||||
--build-arg BUILD_TYPE=$(or $(BUILD_TYPE),) \
|
||||
--build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \
|
||||
--build-arg BACKEND=opus \
|
||||
-t localai-opus-test -f backend/Dockerfile.golang .
|
||||
docker run --rm localai-opus-test \
|
||||
bash -c 'cd /LocalAI && go run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/...'
|
||||
|
||||
test-realtime: build-mock-backend
|
||||
@echo 'Running realtime e2e tests (mock backend)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
|
||||
# Real-model realtime tests. Set REALTIME_TEST_MODEL to use your own pipeline,
|
||||
# or leave unset to auto-build one from the component env vars below.
|
||||
REALTIME_VAD?=silero-vad-ggml
|
||||
REALTIME_STT?=whisper-1
|
||||
REALTIME_LLM?=qwen3-0.6b
|
||||
REALTIME_TTS?=tts-1
|
||||
REALTIME_BACKENDS_PATH?=$(abspath ./)/backends
|
||||
|
||||
test-realtime-models: build-mock-backend
|
||||
@echo 'Running realtime e2e tests (real models)'
|
||||
REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \
|
||||
REALTIME_VAD=$(REALTIME_VAD) \
|
||||
REALTIME_STT=$(REALTIME_STT) \
|
||||
REALTIME_LLM=$(REALTIME_LLM) \
|
||||
REALTIME_TTS=$(REALTIME_TTS) \
|
||||
REALTIME_BACKENDS_PATH=$(REALTIME_BACKENDS_PATH) \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
|
||||
# --- Container-based real-model testing ---
|
||||
|
||||
REALTIME_BACKEND_NAMES ?= silero-vad whisper llama-cpp kokoro
|
||||
REALTIME_MODELS_DIR ?= $(abspath ./models)
|
||||
REALTIME_BACKENDS_DIR ?= $(abspath ./local-backends)
|
||||
REALTIME_DOCKER_FLAGS ?= --gpus all
|
||||
|
||||
local-backends:
|
||||
mkdir -p local-backends
|
||||
|
||||
extract-backend-%: docker-build-% local-backends
|
||||
@echo "Extracting backend $*..."
|
||||
@CID=$$(docker create local-ai-backend:$*) && \
|
||||
rm -rf local-backends/$* && mkdir -p local-backends/$* && \
|
||||
docker cp $$CID:/ - | tar -xf - -C local-backends/$* && \
|
||||
docker rm $$CID > /dev/null
|
||||
|
||||
extract-realtime-backends: $(addprefix extract-backend-,$(REALTIME_BACKEND_NAMES))
|
||||
|
||||
test-realtime-models-docker: build-mock-backend
|
||||
docker build --target build-requirements \
|
||||
--build-arg BUILD_TYPE=$(or $(BUILD_TYPE),cublas) \
|
||||
--build-arg CUDA_MAJOR_VERSION=$(or $(CUDA_MAJOR_VERSION),13) \
|
||||
--build-arg CUDA_MINOR_VERSION=$(or $(CUDA_MINOR_VERSION),0) \
|
||||
-t localai-test-runner .
|
||||
docker run --rm \
|
||||
$(REALTIME_DOCKER_FLAGS) \
|
||||
-v $(abspath ./):/build \
|
||||
-v $(REALTIME_MODELS_DIR):/models:ro \
|
||||
-v $(REALTIME_BACKENDS_DIR):/backends \
|
||||
-v localai-go-cache:/root/go/pkg/mod \
|
||||
-v localai-go-build-cache:/root/.cache/go-build \
|
||||
-e REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \
|
||||
-e REALTIME_VAD=$(REALTIME_VAD) \
|
||||
-e REALTIME_STT=$(REALTIME_STT) \
|
||||
-e REALTIME_LLM=$(REALTIME_LLM) \
|
||||
-e REALTIME_TTS=$(REALTIME_TTS) \
|
||||
-e REALTIME_BACKENDS_PATH=/backends \
|
||||
-e REALTIME_MODELS_PATH=/models \
|
||||
-w /build \
|
||||
localai-test-runner \
|
||||
bash -c 'git config --global --add safe.directory /build && \
|
||||
make protogen-go && make build-mock-backend && \
|
||||
go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e'
|
||||
|
||||
test-container:
|
||||
docker build --target requirements -t local-ai-test-container .
|
||||
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
|
||||
@@ -477,6 +561,7 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_OPUS = opus|golang|.|false|true
|
||||
|
||||
# Python backends with root context
|
||||
BACKEND_RERANKERS = rerankers|python|.|false|true
|
||||
@@ -534,6 +619,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRANSFORMERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OUTETTS)))
|
||||
|
||||
@@ -180,6 +180,11 @@ RUN <<EOT bash
|
||||
fi
|
||||
EOT
|
||||
|
||||
RUN if [ "${BACKEND}" = "opus" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends libopus-dev pkg-config && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN git config --global --add safe.directory /LocalAI
|
||||
|
||||
@@ -35,6 +35,9 @@ service Backend {
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
|
||||
rpc AudioEncode(AudioEncodeRequest) returns (AudioEncodeResult) {}
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
}
|
||||
|
||||
@@ -496,6 +499,30 @@ message ToolFormatMarkers {
|
||||
string call_id_suffix = 31; // e.g., ""
|
||||
}
|
||||
|
||||
message AudioEncodeRequest {
|
||||
bytes pcm_data = 1;
|
||||
int32 sample_rate = 2;
|
||||
int32 channels = 3;
|
||||
map<string, string> options = 4;
|
||||
}
|
||||
|
||||
message AudioEncodeResult {
|
||||
repeated bytes frames = 1;
|
||||
int32 sample_rate = 2;
|
||||
int32 samples_per_frame = 3;
|
||||
}
|
||||
|
||||
message AudioDecodeRequest {
|
||||
repeated bytes frames = 1;
|
||||
map<string, string> options = 2;
|
||||
}
|
||||
|
||||
message AudioDecodeResult {
|
||||
bytes pcm_data = 1;
|
||||
int32 sample_rate = 2;
|
||||
int32 samples_per_frame = 3;
|
||||
}
|
||||
|
||||
message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
|
||||
19
backend/go/opus/Makefile
Normal file
19
backend/go/opus/Makefile
Normal file
@@ -0,0 +1,19 @@
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
|
||||
OPUS_CFLAGS := $(shell pkg-config --cflags opus)
|
||||
OPUS_LIBS := $(shell pkg-config --libs opus)
|
||||
|
||||
libopusshim.so: csrc/opus_shim.c
|
||||
$(CC) -shared -fPIC -o $@ $< $(OPUS_CFLAGS) $(OPUS_LIBS)
|
||||
|
||||
opus: libopusshim.so
|
||||
$(GOCMD) build -tags "$(GO_TAGS)" -o opus ./
|
||||
|
||||
package: opus
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean:
|
||||
rm -f opus libopusshim.so
|
||||
256
backend/go/opus/codec.go
Normal file
256
backend/go/opus/codec.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
const (
|
||||
ApplicationVoIP = 2048
|
||||
ApplicationAudio = 2049
|
||||
ApplicationRestrictedLowDelay = 2051
|
||||
)
|
||||
|
||||
var (
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
opusLib uintptr
|
||||
shimLib uintptr
|
||||
|
||||
// libopus functions
|
||||
cEncoderCreate func(fs int32, channels int32, application int32, errPtr *int32) uintptr
|
||||
cEncode func(st uintptr, pcm *int16, frameSize int32, data *byte, maxBytes int32) int32
|
||||
cEncoderDestroy func(st uintptr)
|
||||
|
||||
cDecoderCreate func(fs int32, channels int32, errPtr *int32) uintptr
|
||||
cDecode func(st uintptr, data *byte, dataLen int32, pcm *int16, frameSize int32, decodeFec int32) int32
|
||||
cDecoderDestroy func(st uintptr)
|
||||
|
||||
// shim functions (non-variadic wrappers for opus_encoder_ctl)
|
||||
cSetBitrate func(st uintptr, bitrate int32) int32
|
||||
cSetComplexity func(st uintptr, complexity int32) int32
|
||||
)
|
||||
|
||||
func loadLib(names []string) (uintptr, error) {
|
||||
var firstErr error
|
||||
for _, name := range names {
|
||||
h, err := purego.Dlopen(name, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return 0, firstErr
|
||||
}
|
||||
|
||||
func ensureInit() error {
|
||||
initOnce.Do(func() {
|
||||
initErr = doInit()
|
||||
})
|
||||
return initErr
|
||||
}
|
||||
|
||||
const shimHint = "ensure libopus-dev is installed and rebuild, or set OPUS_LIBRARY / OPUS_SHIM_LIBRARY env vars"
|
||||
|
||||
func doInit() error {
|
||||
opusNames := opusSearchPaths()
|
||||
var err error
|
||||
opusLib, err = loadLib(opusNames)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opus: failed to load libopus (%s): %w", shimHint, err)
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cEncoderCreate, opusLib, "opus_encoder_create")
|
||||
purego.RegisterLibFunc(&cEncode, opusLib, "opus_encode")
|
||||
purego.RegisterLibFunc(&cEncoderDestroy, opusLib, "opus_encoder_destroy")
|
||||
purego.RegisterLibFunc(&cDecoderCreate, opusLib, "opus_decoder_create")
|
||||
purego.RegisterLibFunc(&cDecode, opusLib, "opus_decode")
|
||||
purego.RegisterLibFunc(&cDecoderDestroy, opusLib, "opus_decoder_destroy")
|
||||
|
||||
shimNames := shimSearchPaths()
|
||||
shimLib, err = loadLib(shimNames)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opus: failed to load libopusshim (%s): %w", shimHint, err)
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cSetBitrate, shimLib, "opus_shim_encoder_set_bitrate")
|
||||
purego.RegisterLibFunc(&cSetComplexity, shimLib, "opus_shim_encoder_set_complexity")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func opusSearchPaths() []string {
|
||||
var paths []string
|
||||
|
||||
if env := os.Getenv("OPUS_LIBRARY"); env != "" {
|
||||
paths = append(paths, env)
|
||||
}
|
||||
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
dir := filepath.Dir(exe)
|
||||
paths = append(paths, filepath.Join(dir, "libopus.so.0"), filepath.Join(dir, "libopus.so"))
|
||||
if runtime.GOOS == "darwin" {
|
||||
paths = append(paths, filepath.Join(dir, "libopus.dylib"))
|
||||
}
|
||||
}
|
||||
|
||||
paths = append(paths, "libopus.so.0", "libopus.so", "libopus.dylib", "opus.dll")
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
paths = append(paths,
|
||||
"/opt/homebrew/lib/libopus.dylib",
|
||||
"/usr/local/lib/libopus.dylib",
|
||||
)
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
func shimSearchPaths() []string {
|
||||
var paths []string
|
||||
|
||||
if env := os.Getenv("OPUS_SHIM_LIBRARY"); env != "" {
|
||||
paths = append(paths, env)
|
||||
}
|
||||
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
dir := filepath.Dir(exe)
|
||||
paths = append(paths, filepath.Join(dir, "libopusshim.so"))
|
||||
if runtime.GOOS == "darwin" {
|
||||
paths = append(paths, filepath.Join(dir, "libopusshim.dylib"))
|
||||
}
|
||||
}
|
||||
|
||||
paths = append(paths, "./libopusshim.so", "libopusshim.so")
|
||||
if runtime.GOOS == "darwin" {
|
||||
paths = append(paths, "./libopusshim.dylib", "libopusshim.dylib")
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
// Encoder wraps a libopus OpusEncoder via purego.
|
||||
type Encoder struct {
|
||||
st uintptr
|
||||
}
|
||||
|
||||
func NewEncoder(sampleRate, channels, application int) (*Encoder, error) {
|
||||
if err := ensureInit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opusErr int32
|
||||
st := cEncoderCreate(int32(sampleRate), int32(channels), int32(application), &opusErr)
|
||||
if opusErr != 0 || st == 0 {
|
||||
return nil, fmt.Errorf("opus_encoder_create failed: error %d", opusErr)
|
||||
}
|
||||
return &Encoder{st: st}, nil
|
||||
}
|
||||
|
||||
// Encode encodes a frame of PCM int16 samples. It returns the number of bytes
|
||||
// written to out, or a negative error code.
|
||||
func (e *Encoder) Encode(pcm []int16, frameSize int, out []byte) (int, error) {
|
||||
if len(pcm) == 0 || len(out) == 0 {
|
||||
return 0, errors.New("opus encode: empty input or output buffer")
|
||||
}
|
||||
n := cEncode(e.st, &pcm[0], int32(frameSize), &out[0], int32(len(out)))
|
||||
if n < 0 {
|
||||
return 0, fmt.Errorf("opus_encode failed: error %d", n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
func (e *Encoder) SetBitrate(bitrate int) error {
|
||||
if ret := cSetBitrate(e.st, int32(bitrate)); ret != 0 {
|
||||
return fmt.Errorf("opus set bitrate: error %d", ret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Encoder) SetComplexity(complexity int) error {
|
||||
if ret := cSetComplexity(e.st, int32(complexity)); ret != 0 {
|
||||
return fmt.Errorf("opus set complexity: error %d", ret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Encoder) Close() {
|
||||
if e.st != 0 {
|
||||
cEncoderDestroy(e.st)
|
||||
e.st = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Decoder wraps a libopus OpusDecoder via purego.
|
||||
type Decoder struct {
|
||||
st uintptr
|
||||
}
|
||||
|
||||
func NewDecoder(sampleRate, channels int) (*Decoder, error) {
|
||||
if err := ensureInit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opusErr int32
|
||||
st := cDecoderCreate(int32(sampleRate), int32(channels), &opusErr)
|
||||
if opusErr != 0 || st == 0 {
|
||||
return nil, fmt.Errorf("opus_decoder_create failed: error %d", opusErr)
|
||||
}
|
||||
return &Decoder{st: st}, nil
|
||||
}
|
||||
|
||||
// Decode decodes an Opus packet into pcm. frameSize is the max number of
|
||||
// samples per channel that pcm can hold. Returns the number of decoded samples
|
||||
// per channel.
|
||||
func (d *Decoder) Decode(data []byte, pcm []int16, frameSize int, fec bool) (int, error) {
|
||||
if len(pcm) == 0 {
|
||||
return 0, errors.New("opus decode: empty output buffer")
|
||||
}
|
||||
|
||||
var dataPtr *byte
|
||||
var dataLen int32
|
||||
if len(data) > 0 {
|
||||
dataPtr = &data[0]
|
||||
dataLen = int32(len(data))
|
||||
}
|
||||
|
||||
decodeFec := int32(0)
|
||||
if fec {
|
||||
decodeFec = 1
|
||||
}
|
||||
|
||||
n := cDecode(d.st, dataPtr, dataLen, &pcm[0], int32(frameSize), decodeFec)
|
||||
if n < 0 {
|
||||
return 0, fmt.Errorf("opus_decode failed: error %d", n)
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
func (d *Decoder) Close() {
|
||||
if d.st != 0 {
|
||||
cDecoderDestroy(d.st)
|
||||
d.st = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Init eagerly loads the opus libraries, returning any error.
|
||||
// Calling this is optional; the libraries are loaded lazily on first use.
|
||||
func Init() error {
|
||||
return ensureInit()
|
||||
}
|
||||
|
||||
// Reset allows re-initialization (for testing).
|
||||
func Reset() {
|
||||
initOnce = sync.Once{}
|
||||
initErr = nil
|
||||
opusLib = 0
|
||||
shimLib = 0
|
||||
}
|
||||
9
backend/go/opus/csrc/opus_shim.c
Normal file
9
backend/go/opus/csrc/opus_shim.c
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <opus.h>
|
||||
|
||||
int opus_shim_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate) {
|
||||
return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate));
|
||||
}
|
||||
|
||||
int opus_shim_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity) {
|
||||
return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity));
|
||||
}
|
||||
16
backend/go/opus/main.go
Normal file
16
backend/go/opus/main.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, &Opus{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
184
backend/go/opus/opus.go
Normal file
184
backend/go/opus/opus.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
)
|
||||
|
||||
const (
|
||||
opusSampleRate = 48000
|
||||
opusChannels = 1
|
||||
opusFrameSize = 960 // 20ms at 48kHz
|
||||
opusMaxPacketSize = 4000
|
||||
opusMaxFrameSize = 5760 // 120ms at 48kHz
|
||||
|
||||
decoderIdleTTL = 60 * time.Second
|
||||
decoderEvictTick = 30 * time.Second
|
||||
)
|
||||
|
||||
type cachedDecoder struct {
|
||||
mu sync.Mutex
|
||||
dec *Decoder
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
type Opus struct {
|
||||
base.Base
|
||||
|
||||
decodersMu sync.Mutex
|
||||
decoders map[string]*cachedDecoder
|
||||
}
|
||||
|
||||
func (o *Opus) Load(opts *pb.ModelOptions) error {
|
||||
o.decoders = make(map[string]*cachedDecoder)
|
||||
go o.evictLoop()
|
||||
return Init()
|
||||
}
|
||||
|
||||
func (o *Opus) evictLoop() {
|
||||
ticker := time.NewTicker(decoderEvictTick)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
o.decodersMu.Lock()
|
||||
now := time.Now()
|
||||
for id, cd := range o.decoders {
|
||||
if now.Sub(cd.lastUsed) > decoderIdleTTL {
|
||||
cd.dec.Close()
|
||||
delete(o.decoders, id)
|
||||
}
|
||||
}
|
||||
o.decodersMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDecoder returns a cached decoder for the given session ID,
|
||||
// creating one if it doesn't exist yet.
|
||||
func (o *Opus) getOrCreateDecoder(sessionID string) (*cachedDecoder, error) {
|
||||
o.decodersMu.Lock()
|
||||
defer o.decodersMu.Unlock()
|
||||
|
||||
if cd, ok := o.decoders[sessionID]; ok {
|
||||
cd.lastUsed = time.Now()
|
||||
return cd, nil
|
||||
}
|
||||
|
||||
dec, err := NewDecoder(opusSampleRate, opusChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cd := &cachedDecoder{dec: dec, lastUsed: time.Now()}
|
||||
o.decoders[sessionID] = cd
|
||||
return cd, nil
|
||||
}
|
||||
|
||||
func (o *Opus) AudioEncode(req *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
enc, err := NewEncoder(opusSampleRate, opusChannels, ApplicationAudio)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opus encoder create: %w", err)
|
||||
}
|
||||
defer enc.Close()
|
||||
|
||||
if err := enc.SetBitrate(64000); err != nil {
|
||||
return nil, fmt.Errorf("opus set bitrate: %w", err)
|
||||
}
|
||||
if err := enc.SetComplexity(10); err != nil {
|
||||
return nil, fmt.Errorf("opus set complexity: %w", err)
|
||||
}
|
||||
|
||||
samples := sound.BytesToInt16sLE(req.PcmData)
|
||||
if len(samples) == 0 {
|
||||
return &pb.AudioEncodeResult{
|
||||
SampleRate: opusSampleRate,
|
||||
SamplesPerFrame: opusFrameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.SampleRate != 0 && int(req.SampleRate) != opusSampleRate {
|
||||
samples = sound.ResampleInt16(samples, int(req.SampleRate), opusSampleRate)
|
||||
}
|
||||
|
||||
var frames [][]byte
|
||||
packet := make([]byte, opusMaxPacketSize)
|
||||
|
||||
for offset := 0; offset+opusFrameSize <= len(samples); offset += opusFrameSize {
|
||||
frame := samples[offset : offset+opusFrameSize]
|
||||
n, err := enc.Encode(frame, opusFrameSize, packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opus encode: %w", err)
|
||||
}
|
||||
out := make([]byte, n)
|
||||
copy(out, packet[:n])
|
||||
frames = append(frames, out)
|
||||
}
|
||||
|
||||
return &pb.AudioEncodeResult{
|
||||
Frames: frames,
|
||||
SampleRate: opusSampleRate,
|
||||
SamplesPerFrame: opusFrameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *Opus) AudioDecode(req *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) {
|
||||
if len(req.Frames) == 0 {
|
||||
return &pb.AudioDecodeResult{
|
||||
SampleRate: opusSampleRate,
|
||||
SamplesPerFrame: opusFrameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Use a persistent decoder when a session ID is provided so that Opus
|
||||
// prediction state carries across batches. Fall back to a fresh decoder
|
||||
// for backward compatibility.
|
||||
sessionID := req.Options["session_id"]
|
||||
|
||||
var cd *cachedDecoder
|
||||
var ownedDec *Decoder
|
||||
|
||||
if sessionID != "" && o.decoders != nil {
|
||||
var err error
|
||||
cd, err = o.getOrCreateDecoder(sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opus decoder create: %w", err)
|
||||
}
|
||||
cd.mu.Lock()
|
||||
defer cd.mu.Unlock()
|
||||
} else {
|
||||
dec, err := NewDecoder(opusSampleRate, opusChannels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opus decoder create: %w", err)
|
||||
}
|
||||
ownedDec = dec
|
||||
defer ownedDec.Close()
|
||||
}
|
||||
|
||||
dec := ownedDec
|
||||
if cd != nil {
|
||||
dec = cd.dec
|
||||
}
|
||||
|
||||
var allSamples []int16
|
||||
var samplesPerFrame int32
|
||||
|
||||
pcm := make([]int16, opusMaxFrameSize)
|
||||
for _, frame := range req.Frames {
|
||||
n, err := dec.Decode(frame, pcm, opusMaxFrameSize, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opus decode: %w", err)
|
||||
}
|
||||
if samplesPerFrame == 0 {
|
||||
samplesPerFrame = int32(n)
|
||||
}
|
||||
allSamples = append(allSamples, pcm[:n]...)
|
||||
}
|
||||
|
||||
return &pb.AudioDecodeResult{
|
||||
PcmData: sound.Int16toBytesLE(allSamples),
|
||||
SampleRate: opusSampleRate,
|
||||
SamplesPerFrame: samplesPerFrame,
|
||||
}, nil
|
||||
}
|
||||
1346
backend/go/opus/opus_test.go
Normal file
1346
backend/go/opus/opus_test.go
Normal file
File diff suppressed because it is too large
Load Diff
47
backend/go/opus/package.sh
Normal file
47
backend/go/opus/package.sh
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/opus $CURDIR/package/
|
||||
cp -avf $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Copy the opus shim library
|
||||
cp -avf $CURDIR/libopusshim.so $CURDIR/package/lib/
|
||||
|
||||
# Copy system libopus
|
||||
if command -v pkg-config >/dev/null 2>&1 && pkg-config --exists opus; then
|
||||
LIBOPUS_DIR=$(pkg-config --variable=libdir opus)
|
||||
cp -avfL $LIBOPUS_DIR/libopus.so* $CURDIR/package/lib/ 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
else
|
||||
echo "Warning: Could not detect architecture for system library bundling"
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
15
backend/go/opus/run.sh
Normal file
15
backend/go/opus/run.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export OPUS_SHIM_LIBRARY=$CURDIR/lib/libopusshim.so
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/opus "$@"
|
||||
fi
|
||||
|
||||
exec $CURDIR/opus "$@"
|
||||
@@ -724,6 +724,23 @@
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- &opus
|
||||
name: "opus"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
|
||||
urls:
|
||||
- https://opus-codec.org/
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-opus
|
||||
license: BSD-3-Clause
|
||||
description: |
|
||||
Opus audio codec backend for encoding and decoding audio.
|
||||
Required for WebRTC transport in the Realtime API.
|
||||
tags:
|
||||
- audio-codec
|
||||
- opus
|
||||
- WebRTC
|
||||
- realtime
|
||||
- CPU
|
||||
- &silero-vad
|
||||
name: "silero-vad"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-silero-vad"
|
||||
@@ -1088,6 +1105,21 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-local-store"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-local-store
|
||||
- !!merge <<: *opus
|
||||
name: "opus-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-opus"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-opus
|
||||
- !!merge <<: *opus
|
||||
name: "metal-opus"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-opus"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-opus
|
||||
- !!merge <<: *opus
|
||||
name: "metal-opus-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-opus"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-opus
|
||||
- !!merge <<: *silero-vad
|
||||
name: "silero-vad-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-silero-vad"
|
||||
|
||||
@@ -84,6 +84,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
var capturedPredictOpts *proto.PredictOptions
|
||||
fn := func() (LLMResponse, error) {
|
||||
opts := gRPCPredictOpts(*c, loader.ModelPath)
|
||||
// Merge request-level metadata (overrides config defaults)
|
||||
@@ -111,6 +112,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
opts.LogitBias = string(logitBiasJSON)
|
||||
}
|
||||
}
|
||||
capturedPredictOpts = opts
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
@@ -245,16 +247,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": s,
|
||||
"use_tokenizer_template": c.TemplateConfig.UseTokenizerTemplate,
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"function_template": c.TemplateConfig.Functions,
|
||||
"grammar": c.Grammar,
|
||||
"stop_words": c.StopWords,
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"function_template": c.TemplateConfig.Functions,
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
@@ -262,12 +260,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
traceData["messages"] = string(msgJSON)
|
||||
}
|
||||
}
|
||||
if tools != "" {
|
||||
traceData["tools"] = tools
|
||||
}
|
||||
if toolChoice != "" {
|
||||
traceData["tool_choice"] = toolChoice
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
traceData["reasoning_config"] = string(reasoningJSON)
|
||||
}
|
||||
@@ -277,15 +269,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"mixed_mode": c.FunctionsConfig.GrammarConfig.MixedMode,
|
||||
"xml_format_preset": c.FunctionsConfig.XMLFormatPreset,
|
||||
}
|
||||
if c.Temperature != nil {
|
||||
traceData["temperature"] = *c.Temperature
|
||||
}
|
||||
if c.TopP != nil {
|
||||
traceData["top_p"] = *c.TopP
|
||||
}
|
||||
if c.Maxtokens != nil {
|
||||
traceData["max_tokens"] = *c.Maxtokens
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
originalFn := fn
|
||||
@@ -299,6 +282,42 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
"completion": resp.Usage.Completion,
|
||||
}
|
||||
|
||||
if len(resp.ChatDeltas) > 0 {
|
||||
chatDeltasInfo := map[string]any{
|
||||
"total_deltas": len(resp.ChatDeltas),
|
||||
}
|
||||
var contentParts, reasoningParts []string
|
||||
toolCallCount := 0
|
||||
for _, d := range resp.ChatDeltas {
|
||||
if d.Content != "" {
|
||||
contentParts = append(contentParts, d.Content)
|
||||
}
|
||||
if d.ReasoningContent != "" {
|
||||
reasoningParts = append(reasoningParts, d.ReasoningContent)
|
||||
}
|
||||
toolCallCount += len(d.ToolCalls)
|
||||
}
|
||||
if len(contentParts) > 0 {
|
||||
chatDeltasInfo["content"] = strings.Join(contentParts, "")
|
||||
}
|
||||
if len(reasoningParts) > 0 {
|
||||
chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||
}
|
||||
if toolCallCount > 0 {
|
||||
chatDeltasInfo["tool_call_count"] = toolCallCount
|
||||
}
|
||||
traceData["chat_deltas"] = chatDeltasInfo
|
||||
}
|
||||
|
||||
if capturedPredictOpts != nil {
|
||||
if optsJSON, err := json.Marshal(capturedPredictOpts); err == nil {
|
||||
var optsMap map[string]any
|
||||
if err := json.Unmarshal(optsJSON, &optsMap); err == nil {
|
||||
traceData["predict_options"] = optsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
|
||||
@@ -3,11 +3,12 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -30,9 +31,12 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var audioSnippet map[string]any
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
// Capture audio before the backend call — the backend may delete the file.
|
||||
audioSnippet = trace.AudioSnippet(audio)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
@@ -45,6 +49,16 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
})
|
||||
if err != nil {
|
||||
if appConfig.EnableTracing {
|
||||
errData := map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
}
|
||||
if audioSnippet != nil {
|
||||
maps.Copy(errData, audioSnippet)
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
@@ -53,13 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(audio, 200),
|
||||
Error: err.Error(),
|
||||
Data: map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
},
|
||||
Data: errData,
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
@@ -84,6 +92,18 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
data := map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
"result_text": tr.Text,
|
||||
"segments_count": len(tr.Segments),
|
||||
}
|
||||
if audioSnippet != nil {
|
||||
maps.Copy(data, audioSnippet)
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
@@ -91,15 +111,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(audio+" -> "+tr.Text, 200),
|
||||
Data: map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
"result_text": tr.Text,
|
||||
"segments_count": len(tr.Segments),
|
||||
},
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -84,6 +85,16 @@ func ModelTTS(
|
||||
errStr = fmt.Sprintf("TTS error: %s", res.Message)
|
||||
}
|
||||
|
||||
data := map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
}
|
||||
if err == nil && res.Success {
|
||||
if snippet := trace.AudioSnippet(filePath); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
@@ -92,11 +103,7 @@ func ModelTTS(
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
},
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -158,6 +165,11 @@ func ModelTTSStream(
|
||||
headerSent := false
|
||||
var callbackErr error
|
||||
|
||||
// Collect up to 30s of audio for tracing
|
||||
var snippetPCM []byte
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
@@ -166,7 +178,7 @@ func ModelTTSStream(
|
||||
}, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]interface{}
|
||||
var info map[string]any
|
||||
if json.Unmarshal(reply.Message, &info) == nil {
|
||||
if sr, ok := info["sample_rate"].(float64); ok {
|
||||
sampleRate = uint32(sr)
|
||||
@@ -207,6 +219,22 @@ func ModelTTSStream(
|
||||
if writeErr := audioCallback(reply.Audio); writeErr != nil {
|
||||
callbackErr = writeErr
|
||||
}
|
||||
// Accumulate PCM for tracing snippet
|
||||
totalPCMBytes += len(reply.Audio)
|
||||
if appConfig.EnableTracing && !snippetCapped {
|
||||
maxBytes := int(sampleRate) * 2 * trace.MaxSnippetSeconds // 16-bit mono
|
||||
if len(snippetPCM)+len(reply.Audio) <= maxBytes {
|
||||
snippetPCM = append(snippetPCM, reply.Audio...)
|
||||
} else {
|
||||
remaining := maxBytes - len(snippetPCM)
|
||||
if remaining > 0 {
|
||||
// Align to sample boundary (2 bytes per sample)
|
||||
remaining = remaining &^ 1
|
||||
snippetPCM = append(snippetPCM, reply.Audio[:remaining]...)
|
||||
}
|
||||
snippetCapped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -221,6 +249,17 @@ func ModelTTSStream(
|
||||
errStr = resultErr.Error()
|
||||
}
|
||||
|
||||
data := map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
"streaming": true,
|
||||
}
|
||||
if resultErr == nil && len(snippetPCM) > 0 {
|
||||
if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
@@ -229,12 +268,7 @@ func ModelTTSStream(
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
"streaming": true,
|
||||
},
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -7,17 +7,17 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
|
||||
func makeMultipartRequest(fields map[string]string, files map[string][]byte) (*http.Request, string) {
|
||||
b := &bytes.Buffer{}
|
||||
w := multipart.NewWriter(b)
|
||||
for k, v := range fields {
|
||||
@@ -25,83 +25,73 @@ func makeMultipartRequest(t *testing.T, fields map[string]string, files map[stri
|
||||
}
|
||||
for fname, content := range files {
|
||||
fw, err := w.CreateFormFile(fname, fname+".png")
|
||||
require.NoError(t, err)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = fw.Write(content)
|
||||
require.NoError(t, err)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
require.NoError(t, w.Close())
|
||||
Expect(w.Close()).To(Succeed())
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
return req, w.FormDataContentType()
|
||||
}
|
||||
|
||||
func TestInpainting_MissingFiles(t *testing.T) {
|
||||
e := echo.New()
|
||||
// handler requires cl, ml, appConfig but this test verifies missing files early
|
||||
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
|
||||
var _ = Describe("Inpainting", func() {
|
||||
It("returns error for missing files", func() {
|
||||
e := echo.New()
|
||||
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := h(c)
|
||||
require.Error(t, err)
|
||||
}
|
||||
err := h(c)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
func TestInpainting_HappyPath(t *testing.T) {
|
||||
// Setup temp generated content dir
|
||||
tmpDir, err := os.MkdirTemp("", "gencontent")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
It("handles the happy path", func() {
|
||||
tmpDir, err := os.MkdirTemp("", "gencontent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
DeferCleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
|
||||
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
|
||||
|
||||
// stub the backend.ImageGenerationFunc
|
||||
orig := backend.ImageGenerationFunc
|
||||
backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
fn := func() error {
|
||||
// write a fake png file to dst
|
||||
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
|
||||
orig := backend.ImageGenerationFunc
|
||||
backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
fn := func() error {
|
||||
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
|
||||
}
|
||||
return fn, nil
|
||||
}
|
||||
return fn, nil
|
||||
}
|
||||
defer func() { backend.ImageGenerationFunc = orig }()
|
||||
DeferCleanup(func() { backend.ImageGenerationFunc = orig })
|
||||
|
||||
// prepare multipart request with image and mask
|
||||
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
|
||||
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
|
||||
reqBuf, _ := makeMultipartRequest(t, fields, files)
|
||||
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
|
||||
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
|
||||
reqBuf, _ := makeMultipartRequest(fields, files)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
e := echo.New()
|
||||
c := e.NewContext(reqBuf, rec)
|
||||
rec := httptest.NewRecorder()
|
||||
e := echo.New()
|
||||
c := e.NewContext(reqBuf, rec)
|
||||
|
||||
// set a minimal model config in context as handler expects
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})
|
||||
|
||||
h := InpaintingEndpoint(nil, nil, appConf)
|
||||
h := InpaintingEndpoint(nil, nil, appConf)
|
||||
|
||||
// call handler
|
||||
err = h(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
err = h(c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// verify response body contains generated-images path
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "generated-images")
|
||||
body := rec.Body.String()
|
||||
Expect(body).To(ContainSubstring("generated-images"))
|
||||
|
||||
// confirm the file was created in tmpDir
|
||||
// parse out filename from response (naive search)
|
||||
// find "generated-images/" and extract until closing quote or brace
|
||||
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
|
||||
require.True(t, idx >= 0)
|
||||
rest := rec.Body.Bytes()[idx:]
|
||||
end := bytes.IndexAny(rest, "\",}\n")
|
||||
if end == -1 {
|
||||
end = len(rest)
|
||||
}
|
||||
fname := string(rest[len("generated-images/"):end])
|
||||
// ensure file exists
|
||||
_, err = os.Stat(filepath.Join(tmpDir, fname))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
|
||||
Expect(idx).To(BeNumerically(">=", 0))
|
||||
rest := rec.Body.Bytes()[idx:]
|
||||
end := bytes.IndexAny(rest, "\",}\n")
|
||||
if end == -1 {
|
||||
end = len(rest)
|
||||
}
|
||||
fname := string(rest[len("generated-images/"):end])
|
||||
_, err = os.Stat(filepath.Join(tmpDir, fname))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
23
core/http/endpoints/openai/realtime_transport.go
Normal file
23
core/http/endpoints/openai/realtime_transport.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
)
|
||||
|
||||
// Transport abstracts event and audio I/O so the same session logic
|
||||
// can serve both WebSocket and WebRTC connections.
|
||||
type Transport interface {
|
||||
// SendEvent marshals and sends a server event to the client.
|
||||
SendEvent(event types.ServerEvent) error
|
||||
// ReadEvent reads the next raw client event (JSON bytes).
|
||||
ReadEvent() ([]byte, error)
|
||||
// SendAudio sends raw PCM audio to the client at the given sample rate.
|
||||
// For WebSocket this is a no-op (audio is sent via JSON events).
|
||||
// For WebRTC this encodes to Opus and writes to the media track.
|
||||
// The context allows cancellation for barge-in support.
|
||||
SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error
|
||||
// Close tears down the underlying connection.
|
||||
Close() error
|
||||
}
|
||||
251
core/http/endpoints/openai/realtime_transport_webrtc.go
Normal file
251
core/http/endpoints/openai/realtime_transport_webrtc.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// WebRTCTransport implements Transport over a pion/webrtc PeerConnection.
|
||||
// Events travel via the "oai-events" DataChannel; audio goes over an RTP track.
|
||||
type WebRTCTransport struct {
|
||||
pc *webrtc.PeerConnection
|
||||
dc *webrtc.DataChannel
|
||||
audioTrack *webrtc.TrackLocalStaticRTP
|
||||
opusBackend grpc.Backend
|
||||
inEvents chan []byte
|
||||
outEvents chan []byte // buffered outbound event queue
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
flushed chan struct{} // closed when sender goroutine has drained outEvents
|
||||
dcReady chan struct{} // closed when data channel is open
|
||||
dcReadyOnce sync.Once
|
||||
sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack
|
||||
|
||||
// RTP state for outbound audio — protected by rtpMu
|
||||
rtpMu sync.Mutex
|
||||
rtpSeqNum uint16
|
||||
rtpTimestamp uint32
|
||||
rtpMarker bool // true → next packet gets marker bit set
|
||||
}
|
||||
|
||||
func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocalStaticRTP, opusBackend grpc.Backend) *WebRTCTransport {
|
||||
t := &WebRTCTransport{
|
||||
pc: pc,
|
||||
audioTrack: audioTrack,
|
||||
opusBackend: opusBackend,
|
||||
inEvents: make(chan []byte, 256),
|
||||
outEvents: make(chan []byte, 256),
|
||||
closed: make(chan struct{}),
|
||||
flushed: make(chan struct{}),
|
||||
dcReady: make(chan struct{}),
|
||||
sessionCh: make(chan *Session, 1),
|
||||
rtpSeqNum: uint16(rand.UintN(65536)),
|
||||
rtpTimestamp: rand.Uint32(),
|
||||
rtpMarker: true, // first packet of the stream gets marker
|
||||
}
|
||||
|
||||
// The client creates the "oai-events" data channel (so m=application is
|
||||
// included in the SDP offer). We receive it here via OnDataChannel.
|
||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
if dc.Label() != "oai-events" {
|
||||
return
|
||||
}
|
||||
t.dc = dc
|
||||
dc.OnOpen(func() {
|
||||
t.dcReadyOnce.Do(func() { close(t.dcReady) })
|
||||
})
|
||||
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
select {
|
||||
case t.inEvents <- msg.Data:
|
||||
case <-t.closed:
|
||||
}
|
||||
})
|
||||
// The channel may already be open by the time OnDataChannel fires
|
||||
if dc.ReadyState() == webrtc.DataChannelStateOpen {
|
||||
t.dcReadyOnce.Do(func() { close(t.dcReady) })
|
||||
}
|
||||
})
|
||||
|
||||
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
xlog.Debug("WebRTC connection state", "state", state.String())
|
||||
if state == webrtc.PeerConnectionStateFailed ||
|
||||
state == webrtc.PeerConnectionStateClosed ||
|
||||
state == webrtc.PeerConnectionStateDisconnected {
|
||||
t.closeOnce.Do(func() { close(t.closed) })
|
||||
}
|
||||
})
|
||||
|
||||
go t.sendLoop()
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// sendLoop is a dedicated goroutine that drains outEvents and sends them
|
||||
// over the data channel. It waits for the data channel to open before
|
||||
// sending, and drains any remaining events when closed is signalled.
|
||||
func (t *WebRTCTransport) sendLoop() {
|
||||
defer close(t.flushed)
|
||||
|
||||
// Wait for data channel to be ready
|
||||
select {
|
||||
case <-t.dcReady:
|
||||
case <-t.closed:
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-t.outEvents:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := t.dc.SendText(string(data)); err != nil {
|
||||
xlog.Error("data channel send failed", "error", err)
|
||||
return
|
||||
}
|
||||
case <-t.closed:
|
||||
// Drain any remaining queued events before exiting
|
||||
for {
|
||||
select {
|
||||
case data := <-t.outEvents:
|
||||
if err := t.dc.SendText(string(data)); err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) SendEvent(event types.ServerEvent) error {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal event: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case t.outEvents <- data:
|
||||
return nil
|
||||
case <-t.closed:
|
||||
return fmt.Errorf("transport closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) ReadEvent() ([]byte, error) {
|
||||
select {
|
||||
case msg := <-t.inEvents:
|
||||
return msg, nil
|
||||
case <-t.closed:
|
||||
return nil, fmt.Errorf("transport closed")
|
||||
}
|
||||
}
|
||||
|
||||
// SendAudio encodes raw PCM int16 LE to Opus and writes RTP packets to the
|
||||
// audio track. The encoder resamples from the given sampleRate to 48kHz
|
||||
// internally. Frames are paced at real-time intervals (20ms per frame) to
|
||||
// avoid overwhelming the browser's jitter buffer with a burst of packets.
|
||||
//
|
||||
// The context allows callers to cancel mid-stream for barge-in support.
|
||||
// When cancelled, the marker bit is set so the next audio segment starts
|
||||
// cleanly in the browser's jitter buffer.
|
||||
//
|
||||
// RTP packets are constructed manually (rather than via WriteSample) so we
|
||||
// can control the marker bit. pion's WriteSample sets the marker bit on
|
||||
// every Opus packet, which causes Chrome's NetEq jitter buffer to reset
|
||||
// its timing estimation for each frame, producing severe audio distortion.
|
||||
func (t *WebRTCTransport) SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error {
|
||||
result, err := t.opusBackend.AudioEncode(ctx, &pb.AudioEncodeRequest{
|
||||
PcmData: pcmData,
|
||||
SampleRate: int32(sampleRate),
|
||||
Channels: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("opus encode: %w", err)
|
||||
}
|
||||
|
||||
frames := result.Frames
|
||||
const frameDuration = 20 * time.Millisecond
|
||||
const samplesPerFrame = 960 // 20ms at 48kHz
|
||||
|
||||
ticker := time.NewTicker(frameDuration)
|
||||
defer ticker.Stop()
|
||||
|
||||
for i, frame := range frames {
|
||||
t.rtpMu.Lock()
|
||||
pkt := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: t.rtpMarker,
|
||||
SequenceNumber: t.rtpSeqNum,
|
||||
Timestamp: t.rtpTimestamp,
|
||||
// SSRC and PayloadType are overridden by pion's writeRTP
|
||||
},
|
||||
Payload: frame,
|
||||
}
|
||||
t.rtpSeqNum++
|
||||
t.rtpTimestamp += samplesPerFrame
|
||||
t.rtpMarker = false // only the first packet gets marker
|
||||
t.rtpMu.Unlock()
|
||||
|
||||
if err := t.audioTrack.WriteRTP(pkt); err != nil {
|
||||
return fmt.Errorf("write rtp: %w", err)
|
||||
}
|
||||
|
||||
// Pace output at ~real-time so the browser's jitter buffer
|
||||
// receives packets at the expected rate. Skip wait after last frame.
|
||||
if i < len(frames)-1 {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
// Barge-in: mark the next packet so the browser knows
|
||||
// a new audio segment is starting after the interruption.
|
||||
t.rtpMu.Lock()
|
||||
t.rtpMarker = true
|
||||
t.rtpMu.Unlock()
|
||||
return ctx.Err()
|
||||
case <-t.closed:
|
||||
return fmt.Errorf("transport closed during audio send")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSession delivers the session to any goroutine waiting in WaitForSession.
|
||||
func (t *WebRTCTransport) SetSession(s *Session) {
|
||||
select {
|
||||
case t.sessionCh <- s:
|
||||
case <-t.closed:
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForSession blocks until the session is available or the transport closes.
|
||||
func (t *WebRTCTransport) WaitForSession() *Session {
|
||||
select {
|
||||
case s := <-t.sessionCh:
|
||||
return s
|
||||
case <-t.closed:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebRTCTransport) Close() error {
|
||||
// Signal no more events and unblock the sender if it's waiting
|
||||
t.closeOnce.Do(func() { close(t.closed) })
|
||||
// Wait for the sender to drain any remaining queued events
|
||||
<-t.flushed
|
||||
return t.pc.Close()
|
||||
}
|
||||
47
core/http/endpoints/openai/realtime_transport_ws.go
Normal file
47
core/http/endpoints/openai/realtime_transport_ws.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// WebSocketTransport implements Transport over a gorilla/websocket connection.
|
||||
type WebSocketTransport struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewWebSocketTransport(conn *websocket.Conn) *WebSocketTransport {
|
||||
return &WebSocketTransport{conn: conn}
|
||||
}
|
||||
|
||||
func (t *WebSocketTransport) SendEvent(event types.ServerEvent) error {
|
||||
eventBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
xlog.Error("failed to marshal event", "error", err)
|
||||
return err
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.conn.WriteMessage(websocket.TextMessage, eventBytes)
|
||||
}
|
||||
|
||||
func (t *WebSocketTransport) ReadEvent() ([]byte, error) {
|
||||
_, msg, err := t.conn.ReadMessage()
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// SendAudio is a no-op for WebSocket — audio is delivered via JSON events
|
||||
// (base64-encoded in response.audio.delta).
|
||||
func (t *WebSocketTransport) SendAudio(_ context.Context, _ []byte, _ int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebSocketTransport) Close() error {
|
||||
return t.conn.Close()
|
||||
}
|
||||
206
core/http/endpoints/openai/realtime_webrtc.go
Normal file
206
core/http/endpoints/openai/realtime_webrtc.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// RealtimeCallRequest is the JSON body for POST /v1/realtime/calls.
|
||||
type RealtimeCallRequest struct {
|
||||
SDP string `json:"sdp"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RealtimeCallResponse is the JSON response for POST /v1/realtime/calls.
|
||||
type RealtimeCallResponse struct {
|
||||
SDP string `json:"sdp"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// RealtimeCalls handles POST /v1/realtime/calls for WebRTC signaling.
|
||||
func RealtimeCalls(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RealtimeCallRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
if req.SDP == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "sdp is required"})
|
||||
}
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "model is required"})
|
||||
}
|
||||
|
||||
// Create a MediaEngine with Opus support
|
||||
m := &webrtc.MediaEngine{}
|
||||
if err := m.RegisterDefaultCodecs(); err != nil {
|
||||
xlog.Error("failed to register codecs", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"})
|
||||
}
|
||||
|
||||
api := webrtc.NewAPI(webrtc.WithMediaEngine(m))
|
||||
|
||||
pc, err := api.NewPeerConnection(webrtc.Configuration{})
|
||||
if err != nil {
|
||||
xlog.Error("failed to create peer connection", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create peer connection"})
|
||||
}
|
||||
|
||||
// Create outbound audio track (Opus, 48kHz).
|
||||
// We use TrackLocalStaticRTP (not TrackLocalStaticSample) so that
|
||||
// SendAudio can construct RTP packets directly and control the marker
|
||||
// bit. pion's WriteSample sets the marker bit on every Opus packet,
|
||||
// which causes Chrome's NetEq jitter buffer to reset for each frame.
|
||||
audioTrack, err := webrtc.NewTrackLocalStaticRTP(
|
||||
webrtc.RTPCodecCapability{
|
||||
MimeType: webrtc.MimeTypeOpus,
|
||||
ClockRate: 48000,
|
||||
Channels: 2, // Opus in WebRTC is always signaled as 2 channels per RFC 7587
|
||||
},
|
||||
"audio",
|
||||
"localai",
|
||||
)
|
||||
if err != nil {
|
||||
pc.Close()
|
||||
xlog.Error("failed to create audio track", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create audio track"})
|
||||
}
|
||||
|
||||
rtpSender, err := pc.AddTrack(audioTrack)
|
||||
if err != nil {
|
||||
pc.Close()
|
||||
xlog.Error("failed to add audio track", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to add audio track"})
|
||||
}
|
||||
|
||||
// Drain RTCP (control protocol) packets we don't have anyting useful to do with
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
if _, _, err := rtpSender.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Load the Opus backend
|
||||
opusBackend, err := application.ModelLoader().Load(
|
||||
model.WithBackendString("opus"),
|
||||
model.WithModelID("__opus_codec__"),
|
||||
model.WithModel("opus"),
|
||||
)
|
||||
if err != nil {
|
||||
pc.Close()
|
||||
xlog.Error("failed to load opus backend", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "opus backend not available"})
|
||||
}
|
||||
|
||||
// Create the transport (the data channel is created by the client and
|
||||
// received via pc.OnDataChannel inside NewWebRTCTransport)
|
||||
transport := NewWebRTCTransport(pc, audioTrack, opusBackend)
|
||||
|
||||
// Handle incoming audio track from the client
|
||||
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
codec := track.Codec()
|
||||
if codec.MimeType != webrtc.MimeTypeOpus {
|
||||
xlog.Warn("unexpected track codec, ignoring", "mime", codec.MimeType)
|
||||
return
|
||||
}
|
||||
xlog.Debug("Received audio track from client",
|
||||
"codec", codec.MimeType,
|
||||
"clock_rate", codec.ClockRate,
|
||||
"channels", codec.Channels,
|
||||
"sdp_fmtp", codec.SDPFmtpLine,
|
||||
"payload_type", codec.PayloadType,
|
||||
)
|
||||
|
||||
handleIncomingAudioTrack(track, transport)
|
||||
})
|
||||
|
||||
// Set the remote SDP (client's offer)
|
||||
if err := pc.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: req.SDP,
|
||||
}); err != nil {
|
||||
transport.Close()
|
||||
xlog.Error("failed to set remote description", "error", err)
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid SDP offer"})
|
||||
}
|
||||
|
||||
// Create answer
|
||||
answer, err := pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
transport.Close()
|
||||
xlog.Error("failed to create answer", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create answer"})
|
||||
}
|
||||
|
||||
if err := pc.SetLocalDescription(answer); err != nil {
|
||||
transport.Close()
|
||||
xlog.Error("failed to set local description", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to set local description"})
|
||||
}
|
||||
|
||||
// Wait for ICE gathering to complete (with timeout)
|
||||
gatherDone := webrtc.GatheringCompletePromise(pc)
|
||||
select {
|
||||
case <-gatherDone:
|
||||
case <-time.After(10 * time.Second):
|
||||
xlog.Warn("ICE gathering timed out, using partial candidates")
|
||||
}
|
||||
|
||||
localDesc := pc.LocalDescription()
|
||||
if localDesc == nil {
|
||||
transport.Close()
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "no local description"})
|
||||
}
|
||||
|
||||
sessionID := generateSessionID()
|
||||
|
||||
// Start the realtime session in a goroutine
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
go func() {
|
||||
defer transport.Close()
|
||||
runRealtimeSession(application, transport, req.Model, evaluator)
|
||||
}()
|
||||
|
||||
return c.JSON(http.StatusCreated, RealtimeCallResponse{
|
||||
SDP: localDesc.SDP,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// handleIncomingAudioTrack reads RTP packets from a remote WebRTC track
|
||||
// and buffers the raw Opus payloads on the session. Decoding is done in
|
||||
// batches by decodeOpusLoop in realtime.go.
|
||||
func handleIncomingAudioTrack(track *webrtc.TrackRemote, transport *WebRTCTransport) {
|
||||
session := transport.WaitForSession()
|
||||
if session == nil {
|
||||
xlog.Error("could not find session for incoming audio track (transport closed)")
|
||||
sendError(transport, "session_error", "Session failed to start — check server logs", "", "")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
pkt, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
xlog.Debug("audio track read ended", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the payload — pion's ReadRTP may back it by a reusable buffer
|
||||
payload := make([]byte, len(pkt.Payload))
|
||||
copy(payload, pkt.Payload)
|
||||
|
||||
session.OpusFramesLock.Lock()
|
||||
session.OpusFrames = append(session.OpusFrames, payload)
|
||||
session.OpusFramesLock.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -712,17 +712,39 @@ type SessionAudioInput struct {
|
||||
// Configuration for input audio noise reduction. This can be set to null to turn off. Noise reduction filters audio added to the input audio buffer before it is sent to VAD and the model. Filtering the audio can improve VAD and turn detection accuracy (reducing false positives) and model performance by improving perception of the input audio.
|
||||
NoiseReduction *AudioNoiseReduction `json:"noise_reduction,omitempty"`
|
||||
|
||||
// Configuration for input audio transcription, defaults to off and can be set to null to turn off once on. Input audio transcription is not native to the model, since the model consumes audio directly. Transcription runs asynchronously through the /audio/transcriptions endpoint and should be treated as guidance of input audio content rather than precisely what the model heard. The client can optionally set the language and prompt for transcription, these offer additional guidance to the transcription service.
|
||||
// Configuration for turn detection: Server VAD or Semantic VAD. Set to null
|
||||
// to turn off, in which case the client must manually trigger model response.
|
||||
TurnDetection *TurnDetectionUnion `json:"turn_detection,omitempty"`
|
||||
|
||||
// Configuration for turn detection, ether Server VAD or Semantic VAD. This can be set to null to turn off, in which case the client must manually trigger model response.
|
||||
//
|
||||
// Server VAD means that the model will detect the start and end of speech based on audio volume and respond at the end of user speech.
|
||||
//
|
||||
// Semantic VAD is more advanced and uses a turn detection model (in conjunction with VAD) to semantically estimate whether the user has finished speaking, then dynamically sets a timeout based on this probability. For example, if user audio trails off with "uhhm", the model will score a low probability of turn end and wait longer for the user to continue speaking. This can be useful for more natural conversations, but may have a higher latency.
|
||||
// True when the JSON payload explicitly included "turn_detection" (even as null).
|
||||
// Standard Go JSON can't distinguish absent from null for pointer fields.
|
||||
TurnDetectionSet bool `json:"-"`
|
||||
|
||||
// Configuration for input audio transcription, defaults to off and can be
|
||||
// set to null to turn off once on.
|
||||
Transcription *AudioTranscription `json:"transcription,omitempty"`
|
||||
}
|
||||
|
||||
func (s *SessionAudioInput) UnmarshalJSON(data []byte) error {
|
||||
// Check whether turn_detection key exists in the raw JSON.
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
type alias SessionAudioInput
|
||||
var a alias
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
*s = SessionAudioInput(a)
|
||||
|
||||
if _, ok := raw["turn_detection"]; ok {
|
||||
s.TurnDetectionSet = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SessionAudioOutput struct {
|
||||
Format *AudioFormatUnion `json:"format,omitempty"`
|
||||
Speed float32 `json:"speed,omitempty"`
|
||||
@@ -1012,10 +1034,13 @@ func (r *SessionUnion) UnmarshalJSON(data []byte) error {
|
||||
return err
|
||||
}
|
||||
switch SessionType(t.Type) {
|
||||
case SessionTypeRealtime:
|
||||
return json.Unmarshal(data, &r.Realtime)
|
||||
case SessionTypeRealtime, "":
|
||||
// Default to realtime when no type field is present (e.g. session.update events).
|
||||
r.Realtime = &RealtimeSession{}
|
||||
return json.Unmarshal(data, r.Realtime)
|
||||
case SessionTypeTranscription:
|
||||
return json.Unmarshal(data, &r.Transcription)
|
||||
r.Transcription = &TranscriptionSession{}
|
||||
return json.Unmarshal(data, r.Transcription)
|
||||
default:
|
||||
return fmt.Errorf("unknown session type: %s", t.Type)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func GetTraces() []APIExchange {
|
||||
mu.Unlock()
|
||||
|
||||
sort.Slice(traces, func(i, j int) bool {
|
||||
return traces[i].Timestamp.Before(traces[j].Timestamp)
|
||||
return traces[i].Timestamp.After(traces[j].Timestamp)
|
||||
})
|
||||
|
||||
return traces
|
||||
|
||||
@@ -55,6 +55,7 @@ const SECTIONS = [
|
||||
{ id: 'memory', icon: 'fa-memory', color: 'var(--color-accent)', label: 'Memory' },
|
||||
{ id: 'backends', icon: 'fa-cogs', color: 'var(--color-accent)', label: 'Backends' },
|
||||
{ id: 'performance', icon: 'fa-gauge-high', color: 'var(--color-success)', label: 'Performance' },
|
||||
{ id: 'tracing', icon: 'fa-bug', color: 'var(--color-warning)', label: 'Tracing' },
|
||||
{ id: 'api', icon: 'fa-globe', color: 'var(--color-warning)', label: 'API & CORS' },
|
||||
{ id: 'p2p', icon: 'fa-network-wired', color: 'var(--color-accent)', label: 'P2P' },
|
||||
{ id: 'galleries', icon: 'fa-images', color: 'var(--color-accent)', label: 'Galleries' },
|
||||
@@ -327,10 +328,19 @@ export default function Settings() {
|
||||
<SettingRow label="Debug Mode" description="Enable verbose debug logging">
|
||||
<Toggle checked={settings.debug} onChange={(v) => update('debug', v)} />
|
||||
</SettingRow>
|
||||
<SettingRow label="Enable Tracing" description="Enable request/response tracing for debugging">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Tracing */}
|
||||
<div ref={el => sectionRefs.current.tracing = el} style={{ marginBottom: 'var(--spacing-xl)' }}>
|
||||
<h3 style={{ fontSize: '1rem', fontWeight: 700, display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<i className="fas fa-bug" style={{ color: 'var(--color-warning)' }} /> Tracing
|
||||
</h3>
|
||||
<div className="card">
|
||||
<SettingRow label="Enable Tracing" description="Record API requests, responses, and backend operations for debugging">
|
||||
<Toggle checked={settings.enable_tracing} onChange={(v) => update('enable_tracing', v)} />
|
||||
</SettingRow>
|
||||
<SettingRow label="Tracing Max Items" description="Maximum number of trace items to retain">
|
||||
<SettingRow label="Max Items" description="Maximum number of trace items to retain (0 = unlimited)">
|
||||
<input className="input" type="number" style={{ width: 120 }} value={settings.tracing_max_items ?? ''} onChange={(e) => update('tracing_max_items', parseInt(e.target.value) || 0)} placeholder="100" disabled={!settings.enable_tracing} />
|
||||
</SettingRow>
|
||||
</div>
|
||||
|
||||
@@ -1,196 +1,688 @@
|
||||
import { useState, useRef, useCallback } from 'react'
|
||||
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||
import { useOutletContext } from 'react-router-dom'
|
||||
import ModelSelector from '../components/ModelSelector'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import { chatApi, ttsApi, audioApi } from '../utils/api'
|
||||
import { realtimeApi } from '../utils/api'
|
||||
|
||||
const STATUS_STYLES = {
|
||||
disconnected: { icon: 'fa-solid fa-circle', color: 'var(--color-text-secondary)', bg: 'transparent' },
|
||||
connecting: { icon: 'fa-solid fa-spinner fa-spin', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' },
|
||||
connected: { icon: 'fa-solid fa-circle', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' },
|
||||
listening: { icon: 'fa-solid fa-microphone', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' },
|
||||
thinking: { icon: 'fa-solid fa-brain fa-beat', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' },
|
||||
speaking: { icon: 'fa-solid fa-volume-high fa-beat-fade', color: 'var(--color-accent)', bg: 'rgba(168,85,247,0.1)' },
|
||||
error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' },
|
||||
}
|
||||
|
||||
export default function Talk() {
|
||||
const { addToast } = useOutletContext()
|
||||
const [llmModel, setLlmModel] = useState('')
|
||||
const [whisperModel, setWhisperModel] = useState('')
|
||||
const [ttsModel, setTtsModel] = useState('')
|
||||
const [isRecording, setIsRecording] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [status, setStatus] = useState('Press the record button to start talking.')
|
||||
const [audioUrl, setAudioUrl] = useState(null)
|
||||
const [conversationHistory, setConversationHistory] = useState([])
|
||||
const mediaRecorderRef = useRef(null)
|
||||
const chunksRef = useRef([])
|
||||
const audioRef = useRef(null)
|
||||
|
||||
const startRecording = async () => {
|
||||
if (!navigator.mediaDevices) {
|
||||
addToast('MediaDevices API not supported', 'error')
|
||||
// Pipeline models
|
||||
const [pipelineModels, setPipelineModels] = useState([])
|
||||
const [selectedModel, setSelectedModel] = useState('')
|
||||
const [modelsLoading, setModelsLoading] = useState(true)
|
||||
|
||||
// Connection state
|
||||
const [status, setStatus] = useState('disconnected')
|
||||
const [statusText, setStatusText] = useState('Disconnected')
|
||||
const [isConnected, setIsConnected] = useState(false)
|
||||
|
||||
// Transcript
|
||||
const [transcript, setTranscript] = useState([])
|
||||
const streamingRef = useRef(null) // tracks the index of the in-progress assistant message
|
||||
|
||||
// Session settings
|
||||
const [instructions, setInstructions] = useState(
|
||||
'You are a helpful voice assistant. Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. Speak naturally as you would in a phone conversation.'
|
||||
)
|
||||
const [voice, setVoice] = useState('')
|
||||
const [voiceEdited, setVoiceEdited] = useState(false)
|
||||
const [language, setLanguage] = useState('')
|
||||
|
||||
// Diagnostics
|
||||
const [diagVisible, setDiagVisible] = useState(false)
|
||||
|
||||
// Refs for WebRTC / audio
|
||||
const pcRef = useRef(null)
|
||||
const dcRef = useRef(null)
|
||||
const localStreamRef = useRef(null)
|
||||
const audioRef = useRef(null)
|
||||
const hasErrorRef = useRef(false)
|
||||
|
||||
// Diagnostics refs
|
||||
const audioCtxRef = useRef(null)
|
||||
const analyserRef = useRef(null)
|
||||
const diagFrameRef = useRef(null)
|
||||
const statsIntervalRef = useRef(null)
|
||||
const waveCanvasRef = useRef(null)
|
||||
const specCanvasRef = useRef(null)
|
||||
const transcriptEndRef = useRef(null)
|
||||
|
||||
// Diagnostics stats (not worth re-rendering for every frame)
|
||||
const [diagStats, setDiagStats] = useState({
|
||||
peakFreq: '--', thd: '--', rms: '--', sampleRate: '--',
|
||||
packetsRecv: '--', packetsLost: '--', jitter: '--', concealed: '--', raw: '',
|
||||
})
|
||||
|
||||
// Fetch pipeline models on mount
|
||||
useEffect(() => {
|
||||
realtimeApi.pipelineModels()
|
||||
.then(models => {
|
||||
setPipelineModels(models || [])
|
||||
if (models?.length > 0) {
|
||||
setSelectedModel(models[0].name)
|
||||
if (!voiceEdited) setVoice(models[0].voice || '')
|
||||
}
|
||||
})
|
||||
.catch(err => addToast(`Failed to load pipeline models: ${err.message}`, 'error'))
|
||||
.finally(() => setModelsLoading(false))
|
||||
}, [])
|
||||
|
||||
// Auto-scroll transcript
|
||||
useEffect(() => {
|
||||
transcriptEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
}, [transcript])
|
||||
|
||||
const selectedModelInfo = pipelineModels.find(m => m.name === selectedModel)
|
||||
|
||||
// ── Status helper ──
|
||||
const updateStatus = useCallback((state, text) => {
|
||||
setStatus(state)
|
||||
setStatusText(text || state)
|
||||
}, [])
|
||||
|
||||
// ── Session update ──
|
||||
const sendSessionUpdate = useCallback(() => {
|
||||
const dc = dcRef.current
|
||||
if (!dc || dc.readyState !== 'open') return
|
||||
if (!instructions.trim() && !voice.trim() && !language.trim()) return
|
||||
|
||||
const session = {}
|
||||
if (instructions.trim()) session.instructions = instructions.trim()
|
||||
if (voice.trim() || language.trim()) {
|
||||
session.audio = {}
|
||||
if (voice.trim()) session.audio.output = { voice: voice.trim() }
|
||||
if (language.trim()) session.audio.input = { transcription: { language: language.trim() } }
|
||||
}
|
||||
|
||||
dc.send(JSON.stringify({ type: 'session.update', session }))
|
||||
}, [instructions, voice, language])
|
||||
|
||||
// ── Server event handler ──
|
||||
const handleServerEvent = useCallback((event) => {
|
||||
switch (event.type) {
|
||||
case 'session.created':
|
||||
sendSessionUpdate()
|
||||
updateStatus('listening', 'Listening...')
|
||||
break
|
||||
case 'session.updated':
|
||||
break
|
||||
case 'input_audio_buffer.speech_started':
|
||||
updateStatus('listening', 'Hearing you speak...')
|
||||
break
|
||||
case 'input_audio_buffer.speech_stopped':
|
||||
updateStatus('thinking', 'Processing...')
|
||||
break
|
||||
case 'conversation.item.input_audio_transcription.completed':
|
||||
if (event.transcript) {
|
||||
streamingRef.current = null
|
||||
setTranscript(prev => [...prev, { role: 'user', text: event.transcript }])
|
||||
}
|
||||
updateStatus('thinking', 'Generating response...')
|
||||
break
|
||||
case 'response.output_audio_transcript.delta':
|
||||
if (event.delta) {
|
||||
setTranscript(prev => {
|
||||
if (streamingRef.current !== null) {
|
||||
const updated = [...prev]
|
||||
updated[streamingRef.current] = {
|
||||
...updated[streamingRef.current],
|
||||
text: updated[streamingRef.current].text + event.delta,
|
||||
}
|
||||
return updated
|
||||
}
|
||||
streamingRef.current = prev.length
|
||||
return [...prev, { role: 'assistant', text: event.delta }]
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'response.output_audio_transcript.done':
|
||||
if (event.transcript) {
|
||||
setTranscript(prev => {
|
||||
if (streamingRef.current !== null) {
|
||||
const updated = [...prev]
|
||||
updated[streamingRef.current] = { ...updated[streamingRef.current], text: event.transcript }
|
||||
return updated
|
||||
}
|
||||
return [...prev, { role: 'assistant', text: event.transcript }]
|
||||
})
|
||||
}
|
||||
streamingRef.current = null
|
||||
break
|
||||
case 'response.output_audio.delta':
|
||||
updateStatus('speaking', 'Speaking...')
|
||||
break
|
||||
case 'response.done':
|
||||
updateStatus('listening', 'Listening...')
|
||||
break
|
||||
case 'error':
|
||||
hasErrorRef.current = true
|
||||
updateStatus('error', 'Error: ' + (event.error?.message || 'Unknown error'))
|
||||
break
|
||||
}
|
||||
}, [sendSessionUpdate, updateStatus])
|
||||
|
||||
// ── Connect ──
|
||||
const connect = useCallback(async () => {
|
||||
if (!selectedModel) {
|
||||
addToast('Please select a pipeline model first.', 'warning')
|
||||
return
|
||||
}
|
||||
if (!navigator.mediaDevices?.getUserMedia) {
|
||||
updateStatus('error', 'Microphone access requires HTTPS or localhost.')
|
||||
return
|
||||
}
|
||||
|
||||
updateStatus('connecting', 'Connecting...')
|
||||
setIsConnected(true)
|
||||
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true })
|
||||
const recorder = new MediaRecorder(stream)
|
||||
chunksRef.current = []
|
||||
recorder.ondataavailable = (e) => chunksRef.current.push(e.data)
|
||||
recorder.start()
|
||||
mediaRecorderRef.current = recorder
|
||||
setIsRecording(true)
|
||||
setStatus('Recording... Click to stop.')
|
||||
} catch (err) {
|
||||
addToast(`Microphone error: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
const localStream = await navigator.mediaDevices.getUserMedia({ audio: true })
|
||||
localStreamRef.current = localStream
|
||||
|
||||
const stopRecording = useCallback(() => {
|
||||
if (!mediaRecorderRef.current) return
|
||||
const pc = new RTCPeerConnection({})
|
||||
pcRef.current = pc
|
||||
|
||||
mediaRecorderRef.current.onstop = async () => {
|
||||
setIsRecording(false)
|
||||
setLoading(true)
|
||||
|
||||
const audioBlob = new Blob(chunksRef.current, { type: 'audio/webm' })
|
||||
|
||||
try {
|
||||
// 1. Transcribe
|
||||
setStatus('Transcribing audio...')
|
||||
const formData = new FormData()
|
||||
formData.append('file', audioBlob)
|
||||
formData.append('model', whisperModel)
|
||||
const transcription = await audioApi.transcribe(formData)
|
||||
const userText = transcription.text
|
||||
|
||||
setStatus(`You said: "${userText}". Generating response...`)
|
||||
|
||||
// 2. Chat completion
|
||||
const newHistory = [...conversationHistory, { role: 'user', content: userText }]
|
||||
const chatResponse = await chatApi.complete({
|
||||
model: llmModel,
|
||||
messages: newHistory,
|
||||
})
|
||||
const assistantText = chatResponse?.choices?.[0]?.message?.content || ''
|
||||
const updatedHistory = [...newHistory, { role: 'assistant', content: assistantText }]
|
||||
setConversationHistory(updatedHistory)
|
||||
|
||||
setStatus(`Response: "${assistantText}". Generating speech...`)
|
||||
|
||||
// 3. TTS
|
||||
const ttsBlob = await ttsApi.generateV1({ input: assistantText, model: ttsModel })
|
||||
const url = URL.createObjectURL(ttsBlob)
|
||||
setAudioUrl(url)
|
||||
setStatus('Press the record button to continue.')
|
||||
|
||||
// Auto-play
|
||||
setTimeout(() => audioRef.current?.play(), 100)
|
||||
} catch (err) {
|
||||
addToast(`Error: ${err.message}`, 'error')
|
||||
setStatus('Error occurred. Try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
for (const track of localStream.getAudioTracks()) {
|
||||
pc.addTrack(track, localStream)
|
||||
}
|
||||
|
||||
pc.ontrack = (event) => {
|
||||
if (audioRef.current) audioRef.current.srcObject = event.streams[0]
|
||||
if (diagVisible) startDiagnostics()
|
||||
}
|
||||
|
||||
const dc = pc.createDataChannel('oai-events')
|
||||
dcRef.current = dc
|
||||
dc.onmessage = (msg) => {
|
||||
try {
|
||||
const text = typeof msg.data === 'string' ? msg.data : new TextDecoder().decode(msg.data)
|
||||
handleServerEvent(JSON.parse(text))
|
||||
} catch (e) {
|
||||
console.error('Failed to parse server event:', e)
|
||||
}
|
||||
}
|
||||
dc.onclose = () => console.log('Data channel closed')
|
||||
|
||||
pc.onconnectionstatechange = () => {
|
||||
if (pc.connectionState === 'connected') {
|
||||
updateStatus('connected', 'Connected, waiting for session...')
|
||||
} else if (pc.connectionState === 'failed' || pc.connectionState === 'closed') {
|
||||
disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
const offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
await new Promise((resolve) => {
|
||||
if (pc.iceGatheringState === 'complete') return resolve()
|
||||
pc.onicegatheringstatechange = () => {
|
||||
if (pc.iceGatheringState === 'complete') resolve()
|
||||
}
|
||||
setTimeout(resolve, 5000)
|
||||
})
|
||||
|
||||
const data = await realtimeApi.call({
|
||||
sdp: pc.localDescription.sdp,
|
||||
model: selectedModel,
|
||||
})
|
||||
|
||||
await pc.setRemoteDescription({ type: 'answer', sdp: data.sdp })
|
||||
} catch (err) {
|
||||
hasErrorRef.current = true
|
||||
updateStatus('error', 'Connection failed: ' + err.message)
|
||||
disconnect()
|
||||
}
|
||||
}, [selectedModel, diagVisible, handleServerEvent, updateStatus, addToast])
|
||||
|
||||
// ── Disconnect ──
|
||||
const disconnect = useCallback(() => {
|
||||
stopDiagnostics()
|
||||
if (dcRef.current) { dcRef.current.close(); dcRef.current = null }
|
||||
if (pcRef.current) { pcRef.current.close(); pcRef.current = null }
|
||||
if (localStreamRef.current) {
|
||||
localStreamRef.current.getTracks().forEach(t => t.stop())
|
||||
localStreamRef.current = null
|
||||
}
|
||||
if (audioRef.current) audioRef.current.srcObject = null
|
||||
|
||||
if (!hasErrorRef.current) updateStatus('disconnected', 'Disconnected')
|
||||
hasErrorRef.current = false
|
||||
setIsConnected(false)
|
||||
}, [updateStatus])
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
stopDiagnostics()
|
||||
if (dcRef.current) dcRef.current.close()
|
||||
if (pcRef.current) pcRef.current.close()
|
||||
if (localStreamRef.current) localStreamRef.current.getTracks().forEach(t => t.stop())
|
||||
}
|
||||
}, [])
|
||||
|
||||
// ── Test tone ──
|
||||
const sendTestTone = useCallback(() => {
|
||||
const dc = dcRef.current
|
||||
if (!dc || dc.readyState !== 'open') return
|
||||
dc.send(JSON.stringify({ type: 'test_tone' }))
|
||||
setTranscript(prev => [...prev, { role: 'assistant', text: '(Test tone requested)' }])
|
||||
}, [])
|
||||
|
||||
// ── Diagnostics ──
|
||||
function startDiagnostics() {
|
||||
const audioEl = audioRef.current
|
||||
if (!audioEl?.srcObject) return
|
||||
|
||||
if (!audioCtxRef.current) {
|
||||
const ctx = new AudioContext()
|
||||
const source = ctx.createMediaStreamSource(audioEl.srcObject)
|
||||
const analyser = ctx.createAnalyser()
|
||||
analyser.fftSize = 8192
|
||||
analyser.smoothingTimeConstant = 0.3
|
||||
source.connect(analyser)
|
||||
audioCtxRef.current = ctx
|
||||
analyserRef.current = analyser
|
||||
setDiagStats(prev => ({ ...prev, sampleRate: ctx.sampleRate + ' Hz' }))
|
||||
}
|
||||
|
||||
mediaRecorderRef.current.stop()
|
||||
mediaRecorderRef.current.stream?.getTracks().forEach(t => t.stop())
|
||||
}, [whisperModel, llmModel, ttsModel, conversationHistory])
|
||||
|
||||
const resetConversation = () => {
|
||||
setConversationHistory([])
|
||||
setAudioUrl(null)
|
||||
setStatus('Conversation reset. Press record to start.')
|
||||
addToast('Conversation reset', 'info')
|
||||
if (!diagFrameRef.current) drawDiagnostics()
|
||||
if (!statsIntervalRef.current) {
|
||||
pollWebRTCStats()
|
||||
statsIntervalRef.current = setInterval(pollWebRTCStats, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
const allModelsSet = llmModel && whisperModel && ttsModel
|
||||
function stopDiagnostics() {
|
||||
if (diagFrameRef.current) { cancelAnimationFrame(diagFrameRef.current); diagFrameRef.current = null }
|
||||
if (statsIntervalRef.current) { clearInterval(statsIntervalRef.current); statsIntervalRef.current = null }
|
||||
if (audioCtxRef.current) { audioCtxRef.current.close(); audioCtxRef.current = null; analyserRef.current = null }
|
||||
}
|
||||
|
||||
function drawDiagnostics() {
|
||||
const analyser = analyserRef.current
|
||||
if (!analyser) { diagFrameRef.current = null; return }
|
||||
|
||||
diagFrameRef.current = requestAnimationFrame(drawDiagnostics)
|
||||
|
||||
// Waveform
|
||||
const waveCanvas = waveCanvasRef.current
|
||||
if (waveCanvas) {
|
||||
const wCtx = waveCanvas.getContext('2d')
|
||||
const timeData = new Float32Array(analyser.fftSize)
|
||||
analyser.getFloatTimeDomainData(timeData)
|
||||
const w = waveCanvas.width, h = waveCanvas.height
|
||||
wCtx.fillStyle = '#000'; wCtx.fillRect(0, 0, w, h)
|
||||
wCtx.strokeStyle = '#0f0'; wCtx.lineWidth = 1; wCtx.beginPath()
|
||||
const sliceWidth = w / timeData.length
|
||||
let x = 0
|
||||
for (let i = 0; i < timeData.length; i++) {
|
||||
const y = (1 - timeData[i]) * h / 2
|
||||
i === 0 ? wCtx.moveTo(x, y) : wCtx.lineTo(x, y)
|
||||
x += sliceWidth
|
||||
}
|
||||
wCtx.stroke()
|
||||
|
||||
let sumSq = 0
|
||||
for (let i = 0; i < timeData.length; i++) sumSq += timeData[i] * timeData[i]
|
||||
const rms = Math.sqrt(sumSq / timeData.length)
|
||||
const rmsDb = rms > 0 ? (20 * Math.log10(rms)).toFixed(1) : '-Inf'
|
||||
setDiagStats(prev => ({ ...prev, rms: rmsDb + ' dBFS' }))
|
||||
}
|
||||
|
||||
// Spectrum
|
||||
const specCanvas = specCanvasRef.current
|
||||
if (specCanvas && audioCtxRef.current) {
|
||||
const sCtx = specCanvas.getContext('2d')
|
||||
const freqData = new Float32Array(analyser.frequencyBinCount)
|
||||
analyser.getFloatFrequencyData(freqData)
|
||||
const sw = specCanvas.width, sh = specCanvas.height
|
||||
sCtx.fillStyle = '#000'; sCtx.fillRect(0, 0, sw, sh)
|
||||
|
||||
const sampleRate = audioCtxRef.current.sampleRate
|
||||
const binHz = sampleRate / analyser.fftSize
|
||||
const maxFreqDisplay = 4000
|
||||
const maxBin = Math.min(Math.ceil(maxFreqDisplay / binHz), freqData.length)
|
||||
const barWidth = sw / maxBin
|
||||
|
||||
sCtx.fillStyle = '#0cf'
|
||||
let peakBin = 0, peakVal = -Infinity
|
||||
for (let i = 0; i < maxBin; i++) {
|
||||
const db = freqData[i]
|
||||
if (db > peakVal) { peakVal = db; peakBin = i }
|
||||
const barH = Math.max(0, ((db + 100) / 100) * sh)
|
||||
sCtx.fillRect(i * barWidth, sh - barH, Math.max(1, barWidth - 0.5), barH)
|
||||
}
|
||||
|
||||
// Frequency labels
|
||||
sCtx.fillStyle = '#888'; sCtx.font = '10px monospace'
|
||||
for (let f = 500; f <= maxFreqDisplay; f += 500) {
|
||||
sCtx.fillText(f + '', (f / binHz) * barWidth - 10, sh - 2)
|
||||
}
|
||||
|
||||
// 440 Hz marker
|
||||
const bin440 = Math.round(440 / binHz)
|
||||
const x440 = bin440 * barWidth
|
||||
sCtx.strokeStyle = '#f00'; sCtx.lineWidth = 1
|
||||
sCtx.beginPath(); sCtx.moveTo(x440, 0); sCtx.lineTo(x440, sh); sCtx.stroke()
|
||||
sCtx.fillStyle = '#f00'; sCtx.fillText('440', x440 + 2, 10)
|
||||
|
||||
const peakFreq = peakBin * binHz
|
||||
const fundamentalBin = Math.round(440 / binHz)
|
||||
const fundamentalPower = Math.pow(10, freqData[fundamentalBin] / 10)
|
||||
let harmonicPower = 0
|
||||
for (let h = 2; h <= 10; h++) {
|
||||
const hBin = Math.round(440 * h / binHz)
|
||||
if (hBin < freqData.length) harmonicPower += Math.pow(10, freqData[hBin] / 10)
|
||||
}
|
||||
const thd = fundamentalPower > 0
|
||||
? (Math.sqrt(harmonicPower / fundamentalPower) * 100).toFixed(1) + '%'
|
||||
: '--%'
|
||||
|
||||
setDiagStats(prev => ({
|
||||
...prev,
|
||||
peakFreq: peakFreq.toFixed(0) + ' Hz (' + peakVal.toFixed(1) + ' dB)',
|
||||
thd,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
async function pollWebRTCStats() {
|
||||
const pc = pcRef.current
|
||||
if (!pc) return
|
||||
try {
|
||||
const stats = await pc.getStats()
|
||||
const raw = []
|
||||
stats.forEach((report) => {
|
||||
if (report.type === 'inbound-rtp' && report.kind === 'audio') {
|
||||
setDiagStats(prev => ({
|
||||
...prev,
|
||||
packetsRecv: report.packetsReceived ?? '--',
|
||||
packetsLost: report.packetsLost ?? '--',
|
||||
jitter: report.jitter !== undefined ? (report.jitter * 1000).toFixed(1) + ' ms' : '--',
|
||||
concealed: report.concealedSamples ?? '--',
|
||||
}))
|
||||
raw.push('-- inbound-rtp (audio) --')
|
||||
raw.push(' packetsReceived: ' + report.packetsReceived)
|
||||
raw.push(' packetsLost: ' + report.packetsLost)
|
||||
raw.push(' jitter: ' + (report.jitter !== undefined ? (report.jitter * 1000).toFixed(2) + ' ms' : 'N/A'))
|
||||
raw.push(' bytesReceived: ' + report.bytesReceived)
|
||||
raw.push(' concealedSamples: ' + report.concealedSamples)
|
||||
raw.push(' totalSamplesReceived: ' + report.totalSamplesReceived)
|
||||
}
|
||||
})
|
||||
setDiagStats(prev => ({ ...prev, raw: raw.join('\n') }))
|
||||
} catch (_e) { /* stats polling error */ }
|
||||
}
|
||||
|
||||
const toggleDiagnostics = useCallback(() => {
|
||||
setDiagVisible(prev => {
|
||||
const next = !prev
|
||||
if (next) {
|
||||
setTimeout(startDiagnostics, 0)
|
||||
} else {
|
||||
stopDiagnostics()
|
||||
}
|
||||
return next
|
||||
})
|
||||
}, [])
|
||||
|
||||
const statusStyle = STATUS_STYLES[status] || STATUS_STYLES.disconnected
|
||||
|
||||
// ── Render ──
|
||||
return (
|
||||
<div className="page" style={{ display: 'flex', flexDirection: 'column', alignItems: 'center' }}>
|
||||
<div style={{ width: '100%', maxWidth: '40rem' }}>
|
||||
<div style={{ width: '100%', maxWidth: '48rem' }}>
|
||||
<div style={{ textAlign: 'center', marginBottom: 'var(--spacing-lg)' }}>
|
||||
<h1 className="page-title">Talk</h1>
|
||||
<p className="page-subtitle">Voice conversation with AI</p>
|
||||
<p className="page-subtitle">Real-time voice conversation via WebRTC</p>
|
||||
</div>
|
||||
|
||||
{/* Main interaction area */}
|
||||
<div className="card" style={{ padding: 'var(--spacing-lg)', textAlign: 'center', marginBottom: 'var(--spacing-md)' }}>
|
||||
{/* Big record button */}
|
||||
<button
|
||||
onClick={isRecording ? stopRecording : startRecording}
|
||||
disabled={loading || !allModelsSet}
|
||||
style={{
|
||||
width: 96, height: 96, borderRadius: '50%', border: 'none', cursor: loading || !allModelsSet ? 'not-allowed' : 'pointer',
|
||||
background: isRecording ? 'var(--color-error)' : 'var(--color-primary)',
|
||||
color: '#fff', fontSize: '2rem', display: 'inline-flex', alignItems: 'center', justifyContent: 'center',
|
||||
boxShadow: isRecording ? '0 0 0 8px rgba(239,68,68,0.2)' : '0 0 0 8px var(--color-primary-light)',
|
||||
transition: 'all 200ms', opacity: loading || !allModelsSet ? 0.5 : 1,
|
||||
margin: '0 auto var(--spacing-md)',
|
||||
}}
|
||||
>
|
||||
<i className={`fas ${isRecording ? 'fa-stop' : 'fa-microphone'}`} />
|
||||
</button>
|
||||
<div className="card" style={{ padding: 'var(--spacing-lg)', marginBottom: 'var(--spacing-md)' }}>
|
||||
{/* Connection status */}
|
||||
<div style={{
|
||||
display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)',
|
||||
padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
background: statusStyle.bg,
|
||||
border: '1px solid color-mix(in srgb, ' + statusStyle.color + ' 30%, transparent)',
|
||||
marginBottom: 'var(--spacing-md)',
|
||||
}}>
|
||||
<i className={statusStyle.icon} style={{ color: statusStyle.color }} />
|
||||
<span style={{ fontWeight: 500, color: statusStyle.color }}>{statusText}</span>
|
||||
</div>
|
||||
|
||||
{/* Status */}
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginBottom: 'var(--spacing-md)' }}>
|
||||
{loading ? <LoadingSpinner size="sm" /> : null}
|
||||
{' '}{status}
|
||||
</p>
|
||||
{/* Info note */}
|
||||
<div style={{
|
||||
background: 'var(--color-primary-light)',
|
||||
border: '1px solid color-mix(in srgb, var(--color-primary) 20%, transparent)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
marginBottom: 'var(--spacing-md)',
|
||||
display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-sm)',
|
||||
}}>
|
||||
<i className="fas fa-info-circle" style={{ color: 'var(--color-primary)', marginTop: 2, flexShrink: 0 }} />
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.8125rem', margin: 0 }}>
|
||||
<strong style={{ color: 'var(--color-primary)' }}>Note:</strong> Select a pipeline model and click Connect.
|
||||
Your microphone streams continuously; the server detects speech and responds automatically.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Recording indicator */}
|
||||
{isRecording && (
|
||||
{/* Pipeline model selector */}
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<label className="form-label" style={{ fontSize: '0.8125rem' }}>
|
||||
<i className="fas fa-brain" style={{ color: 'var(--color-primary)', marginRight: 4 }} /> Pipeline Model
|
||||
</label>
|
||||
<select
|
||||
className="model-selector"
|
||||
value={selectedModel}
|
||||
onChange={(e) => {
|
||||
setSelectedModel(e.target.value)
|
||||
const m = pipelineModels.find(p => p.name === e.target.value)
|
||||
if (m && !voiceEdited) setVoice(m.voice || '')
|
||||
}}
|
||||
disabled={modelsLoading || isConnected}
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
{modelsLoading && <option>Loading models...</option>}
|
||||
{!modelsLoading && pipelineModels.length === 0 && <option>No pipeline models available</option>}
|
||||
{pipelineModels.map(m => (
|
||||
<option key={m.name} value={m.name}>{m.name}</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* Pipeline details */}
|
||||
{selectedModelInfo && (
|
||||
<div style={{
|
||||
background: 'rgba(239, 68, 68, 0.1)', border: '1px solid rgba(239, 68, 68, 0.3)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-xs) var(--spacing-sm)',
|
||||
display: 'inline-flex', alignItems: 'center', gap: 'var(--spacing-xs)',
|
||||
color: 'var(--color-error)', fontSize: '0.8125rem', marginBottom: 'var(--spacing-md)',
|
||||
display: 'grid', gridTemplateColumns: 'repeat(4, 1fr)', gap: 'var(--spacing-xs)',
|
||||
marginBottom: 'var(--spacing-md)', fontSize: '0.75rem',
|
||||
}}>
|
||||
<i className="fas fa-circle" style={{ fontSize: '0.5rem', animation: 'pulse 1s infinite' }} />
|
||||
Recording...
|
||||
{[
|
||||
{ label: 'VAD', value: selectedModelInfo.vad },
|
||||
{ label: 'Transcription', value: selectedModelInfo.transcription },
|
||||
{ label: 'LLM', value: selectedModelInfo.llm },
|
||||
{ label: 'TTS', value: selectedModelInfo.tts },
|
||||
].map(item => (
|
||||
<div key={item.label} style={{
|
||||
background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-sm)',
|
||||
padding: 'var(--spacing-xs)', border: '1px solid var(--color-border)',
|
||||
}}>
|
||||
<div style={{ color: 'var(--color-text-secondary)', marginBottom: 2 }}>{item.label}</div>
|
||||
<div style={{ fontFamily: 'monospace', overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>{item.value}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Audio playback */}
|
||||
{audioUrl && (
|
||||
<div style={{ marginTop: 'var(--spacing-sm)' }}>
|
||||
<audio ref={audioRef} controls src={audioUrl} style={{ width: '100%' }} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Model selectors */}
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-md)' }}>
|
||||
<h3 style={{ fontSize: '0.875rem', fontWeight: 600, color: 'var(--color-text-secondary)' }}>
|
||||
<i className="fas fa-sliders-h" style={{ marginRight: 'var(--spacing-xs)' }} /> Models
|
||||
</h3>
|
||||
<button className="btn btn-secondary btn-sm" onClick={resetConversation} style={{ fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-rotate-right" /> Reset
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-sm)' }}>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-brain" style={{ color: 'var(--color-primary)', marginRight: 4 }} /> LLM
|
||||
</label>
|
||||
<ModelSelector value={llmModel} onChange={setLlmModel} capability="FLAG_CHAT" />
|
||||
</div>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-ear-listen" style={{ color: 'var(--color-accent)', marginRight: 4 }} /> Speech-to-Text
|
||||
</label>
|
||||
<ModelSelector value={whisperModel} onChange={setWhisperModel} capability="FLAG_TRANSCRIPT" />
|
||||
</div>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-volume-high" style={{ color: 'var(--color-success)', marginRight: 4 }} /> Text-to-Speech
|
||||
</label>
|
||||
<ModelSelector value={ttsModel} onChange={setTtsModel} capability="FLAG_TTS" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{!allModelsSet && (
|
||||
<div style={{
|
||||
background: 'var(--color-info-light)', border: '1px solid rgba(56, 189, 248, 0.2)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-xs) var(--spacing-sm)',
|
||||
marginTop: 'var(--spacing-sm)', fontSize: '0.75rem', color: 'var(--color-text-secondary)',
|
||||
{/* Session settings */}
|
||||
<details style={{
|
||||
marginBottom: 'var(--spacing-md)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
}}>
|
||||
<summary style={{
|
||||
cursor: 'pointer', padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
fontWeight: 500, color: 'var(--color-text-secondary)', fontSize: '0.875rem',
|
||||
}}>
|
||||
<i className="fas fa-info-circle" style={{ color: 'var(--color-info)', marginRight: 4 }} />
|
||||
Select all three models to start talking.
|
||||
<i className="fas fa-sliders" style={{ color: 'var(--color-primary)', marginRight: 'var(--spacing-xs)' }} />
|
||||
Session Settings
|
||||
</summary>
|
||||
<div style={{ padding: 'var(--spacing-md)', paddingTop: 'var(--spacing-xs)', display: 'flex', flexDirection: 'column', gap: 'var(--spacing-sm)' }}>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>Instructions</label>
|
||||
<textarea
|
||||
className="textarea"
|
||||
rows={3}
|
||||
value={instructions}
|
||||
onChange={e => setInstructions(e.target.value)}
|
||||
placeholder="System instructions for the model"
|
||||
style={{ fontSize: '0.8125rem' }}
|
||||
/>
|
||||
</div>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>Voice</label>
|
||||
<input
|
||||
className="input"
|
||||
value={voice}
|
||||
onChange={e => { setVoice(e.target.value); setVoiceEdited(true) }}
|
||||
placeholder="Voice name (leave blank for model default)"
|
||||
style={{ fontSize: '0.8125rem' }}
|
||||
/>
|
||||
</div>
|
||||
<div className="form-group" style={{ margin: 0 }}>
|
||||
<label className="form-label" style={{ fontSize: '0.75rem' }}>Transcription Language</label>
|
||||
<input
|
||||
className="input"
|
||||
value={language}
|
||||
onChange={e => setLanguage(e.target.value)}
|
||||
placeholder="Language code (e.g. 'en') — leave blank for auto-detect"
|
||||
style={{ fontSize: '0.8125rem' }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
{/* Transcript */}
|
||||
<div style={{
|
||||
marginBottom: 'var(--spacing-md)',
|
||||
maxHeight: '24rem', overflowY: 'auto', minHeight: '6rem',
|
||||
padding: 'var(--spacing-sm)',
|
||||
background: 'var(--color-bg-secondary)',
|
||||
border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)',
|
||||
}}>
|
||||
{transcript.length === 0 && (
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontStyle: 'italic', margin: 0 }}>
|
||||
Conversation will appear here...
|
||||
</p>
|
||||
)}
|
||||
{transcript.map((entry, i) => (
|
||||
<div key={i} style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-xs)' }}>
|
||||
<i className={entry.role === 'user' ? 'fa-solid fa-user' : 'fa-solid fa-robot'}
|
||||
style={{
|
||||
color: entry.role === 'user' ? 'var(--color-primary)' : 'var(--color-accent)',
|
||||
marginTop: 3, flexShrink: 0, fontSize: '0.75rem',
|
||||
}} />
|
||||
<p style={{ margin: 0 }}>{entry.text}</p>
|
||||
</div>
|
||||
))}
|
||||
<div ref={transcriptEndRef} />
|
||||
</div>
|
||||
|
||||
{/* Buttons */}
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)' }}>
|
||||
{!isConnected ? (
|
||||
<button className="btn btn-primary" onClick={connect} disabled={modelsLoading || !selectedModel}>
|
||||
<i className="fas fa-plug" style={{ marginRight: 'var(--spacing-xs)' }} /> Connect
|
||||
</button>
|
||||
) : (
|
||||
<>
|
||||
<button className="btn" onClick={sendTestTone}
|
||||
style={{ background: 'var(--color-accent)', color: '#fff', border: 'none' }}>
|
||||
<i className="fas fa-wave-square" style={{ marginRight: 'var(--spacing-xs)' }} /> Test Tone
|
||||
</button>
|
||||
<button className="btn btn-secondary" onClick={toggleDiagnostics}>
|
||||
<i className="fas fa-chart-line" style={{ marginRight: 'var(--spacing-xs)' }} /> Diag
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{isConnected && (
|
||||
<button className="btn" onClick={disconnect}
|
||||
style={{ background: 'var(--color-error)', color: '#fff', border: 'none' }}>
|
||||
<i className="fas fa-plug-circle-xmark" style={{ marginRight: 'var(--spacing-xs)' }} /> Disconnect
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Hidden audio element for WebRTC playback */}
|
||||
<audio ref={audioRef} autoPlay style={{ display: 'none' }} />
|
||||
|
||||
{/* Diagnostics panel */}
|
||||
{diagVisible && (
|
||||
<div style={{
|
||||
marginTop: 'var(--spacing-md)',
|
||||
border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
padding: 'var(--spacing-md)',
|
||||
}}>
|
||||
<h3 style={{ fontSize: '0.875rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-chart-line" style={{ color: 'var(--color-primary)', marginRight: 'var(--spacing-xs)' }} />
|
||||
Audio Diagnostics
|
||||
</h3>
|
||||
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 'var(--spacing-sm)', marginBottom: 'var(--spacing-sm)' }}>
|
||||
<div>
|
||||
<p style={{ fontSize: '0.6875rem', color: 'var(--color-text-secondary)', marginBottom: 2 }}>Waveform</p>
|
||||
<canvas ref={waveCanvasRef} width={400} height={120}
|
||||
style={{ width: '100%', border: '1px solid var(--color-border)', borderRadius: 'var(--radius-sm)', background: '#000' }} />
|
||||
</div>
|
||||
<div>
|
||||
<p style={{ fontSize: '0.6875rem', color: 'var(--color-text-secondary)', marginBottom: 2 }}>Spectrum (FFT)</p>
|
||||
<canvas ref={specCanvasRef} width={400} height={120}
|
||||
style={{ width: '100%', border: '1px solid var(--color-border)', borderRadius: 'var(--radius-sm)', background: '#000' }} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(4, 1fr)', gap: 'var(--spacing-xs)', marginBottom: 'var(--spacing-sm)', fontSize: '0.75rem' }}>
|
||||
{[
|
||||
{ label: 'Peak Freq', value: diagStats.peakFreq },
|
||||
{ label: 'THD', value: diagStats.thd },
|
||||
{ label: 'RMS Level', value: diagStats.rms },
|
||||
{ label: 'Sample Rate', value: diagStats.sampleRate },
|
||||
{ label: 'Packets Recv', value: diagStats.packetsRecv },
|
||||
{ label: 'Packets Lost', value: diagStats.packetsLost },
|
||||
{ label: 'Jitter', value: diagStats.jitter },
|
||||
{ label: 'Concealed', value: diagStats.concealed },
|
||||
].map(item => (
|
||||
<div key={item.label} style={{
|
||||
background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-xs)',
|
||||
}}>
|
||||
<div style={{ color: 'var(--color-text-secondary)', fontSize: '0.6875rem' }}>{item.label}</div>
|
||||
<div style={{ fontFamily: 'monospace' }}>{item.value}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<pre style={{
|
||||
fontSize: '0.6875rem', color: 'var(--color-text-secondary)',
|
||||
background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-sm)',
|
||||
padding: 'var(--spacing-xs)', maxHeight: '8rem', overflowY: 'auto',
|
||||
fontFamily: 'monospace', whiteSpace: 'pre-wrap', margin: 0,
|
||||
}}>
|
||||
{diagStats.raw || 'Waiting for stats...'}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,30 +1,294 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react'
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { useOutletContext } from 'react-router-dom'
|
||||
import { tracesApi } from '../utils/api'
|
||||
import { Link } from 'react-router-dom'
|
||||
import { tracesApi, settingsApi } from '../utils/api'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
|
||||
const AUDIO_DATA_KEYS = new Set([
|
||||
'audio_wav_base64', 'audio_duration_s', 'audio_snippet_s',
|
||||
'audio_sample_rate', 'audio_samples', 'audio_rms_dbfs',
|
||||
'audio_peak_dbfs', 'audio_dc_offset',
|
||||
])
|
||||
|
||||
function formatDuration(ns) {
|
||||
if (!ns && ns !== 0) return '-'
|
||||
if (ns < 1000) return `${ns}ns`
|
||||
if (ns < 1_000_000) return `${(ns / 1000).toFixed(1)}µs`
|
||||
if (ns < 1_000_000) return `${(ns / 1000).toFixed(1)}\u00b5s`
|
||||
if (ns < 1_000_000_000) return `${(ns / 1_000_000).toFixed(1)}ms`
|
||||
return `${(ns / 1_000_000_000).toFixed(2)}s`
|
||||
}
|
||||
|
||||
function formatTimestamp(ts) {
|
||||
if (!ts) return '-'
|
||||
const d = new Date(ts)
|
||||
return d.toLocaleTimeString() + '.' + String(d.getMilliseconds()).padStart(3, '0')
|
||||
}
|
||||
|
||||
function decodeTraceBody(body) {
|
||||
if (!body) return ''
|
||||
try {
|
||||
const bin = atob(body)
|
||||
const bytes = new Uint8Array(bin.length)
|
||||
for (let i = 0; i < bin.length; i++) bytes[i] = bin.charCodeAt(i)
|
||||
const text = new TextDecoder().decode(bytes)
|
||||
try { return JSON.stringify(JSON.parse(text), null, 2) } catch { return text }
|
||||
} catch {
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
function formatValue(value) {
|
||||
if (value === null || value === undefined) return 'null'
|
||||
if (typeof value === 'boolean') return value ? 'true' : 'false'
|
||||
if (typeof value === 'object') return JSON.stringify(value)
|
||||
return String(value)
|
||||
}
|
||||
|
||||
function formatLargeValue(value) {
|
||||
if (typeof value === 'string') {
|
||||
try { return JSON.stringify(JSON.parse(value), null, 2) } catch { return value }
|
||||
}
|
||||
if (typeof value === 'object') return JSON.stringify(value, null, 2)
|
||||
return String(value)
|
||||
}
|
||||
|
||||
function isLargeValue(value) {
|
||||
if (typeof value === 'string') return value.length > 120
|
||||
if (typeof value === 'object') return JSON.stringify(value).length > 120
|
||||
return false
|
||||
}
|
||||
|
||||
function truncateValue(value, maxLen) {
|
||||
const str = typeof value === 'object' ? JSON.stringify(value) : String(value)
|
||||
if (str.length <= maxLen) return str
|
||||
return str.substring(0, maxLen) + '...'
|
||||
}
|
||||
|
||||
const TYPE_COLORS = {
|
||||
llm: { bg: 'rgba(59,130,246,0.15)', color: '#60a5fa' },
|
||||
embedding: { bg: 'rgba(168,85,247,0.15)', color: '#c084fc' },
|
||||
transcription: { bg: 'rgba(234,179,8,0.15)', color: '#facc15' },
|
||||
image_generation: { bg: 'rgba(34,197,94,0.15)', color: '#4ade80' },
|
||||
video_generation: { bg: 'rgba(236,72,153,0.15)', color: '#f472b6' },
|
||||
tts: { bg: 'rgba(249,115,22,0.15)', color: '#fb923c' },
|
||||
sound_generation: { bg: 'rgba(20,184,166,0.15)', color: '#2dd4bf' },
|
||||
rerank: { bg: 'rgba(99,102,241,0.15)', color: '#818cf8' },
|
||||
tokenize: { bg: 'rgba(107,114,128,0.15)', color: '#9ca3af' },
|
||||
}
|
||||
|
||||
function typeBadgeStyle(type) {
|
||||
const c = TYPE_COLORS[type] || TYPE_COLORS.tokenize
|
||||
return { background: c.bg, color: c.color, padding: '2px 8px', borderRadius: 'var(--radius-sm)', fontSize: '0.75rem', fontWeight: 500 }
|
||||
}
|
||||
|
||||
// Audio player + metrics for transcription traces
|
||||
function AudioSnippet({ data }) {
|
||||
if (!data?.audio_wav_base64) return null
|
||||
const metrics = [
|
||||
{ label: 'Duration', value: data.audio_duration_s + 's' },
|
||||
{ label: 'Sample Rate', value: data.audio_sample_rate + ' Hz' },
|
||||
{ label: 'RMS Level', value: data.audio_rms_dbfs + ' dBFS' },
|
||||
{ label: 'Peak Level', value: data.audio_peak_dbfs + ' dBFS' },
|
||||
{ label: 'Samples', value: data.audio_samples },
|
||||
{ label: 'Snippet', value: data.audio_snippet_s + 's' },
|
||||
{ label: 'DC Offset', value: data.audio_dc_offset },
|
||||
]
|
||||
return (
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-xs)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)' }}>
|
||||
<i className="fas fa-headphones" style={{ color: 'var(--color-primary)' }} /> Audio Snippet
|
||||
</h4>
|
||||
<div style={{ background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)', borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm)' }}>
|
||||
<audio controls style={{ width: '100%', marginBottom: 'var(--spacing-sm)' }} src={`data:audio/wav;base64,${data.audio_wav_base64}`} />
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fill, minmax(120px, 1fr))', gap: 'var(--spacing-xs)', fontSize: '0.75rem' }}>
|
||||
{metrics.map(m => (
|
||||
<div key={m.label} style={{ background: 'var(--color-bg-secondary)', borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-xs)' }}>
|
||||
<div style={{ color: 'var(--color-text-secondary)' }}>{m.label}</div>
|
||||
<div style={{ fontFamily: 'monospace' }}>{m.value}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function isPlainObject(value) {
|
||||
return value !== null && typeof value === 'object' && !Array.isArray(value)
|
||||
}
|
||||
|
||||
function fieldSummary(value) {
|
||||
const count = Object.keys(value).length
|
||||
return `{${count} field${count !== 1 ? 's' : ''}}`
|
||||
}
|
||||
|
||||
// Expandable data fields for backend traces (recursive for nested objects)
|
||||
function DataFields({ data, nested }) {
|
||||
const [expandedFields, setExpandedFields] = useState({})
|
||||
const filtered = Object.entries(data).filter(([key]) => !AUDIO_DATA_KEYS.has(key))
|
||||
if (filtered.length === 0) return null
|
||||
|
||||
const toggleField = (key) => {
|
||||
setExpandedFields(prev => ({ ...prev, [key]: !prev[key] }))
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{!nested && <h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-xs)' }}>Data Fields</h4>}
|
||||
<div style={{ border: '1px solid var(--color-border)', borderRadius: 'var(--radius-md)', overflow: 'hidden' }}>
|
||||
{filtered.map(([key, value]) => {
|
||||
const objValue = isPlainObject(value)
|
||||
const large = !objValue && isLargeValue(value)
|
||||
const expandable = objValue || large
|
||||
const expanded = expandedFields[key]
|
||||
return (
|
||||
<div key={key} style={{ borderBottom: '1px solid var(--color-border)' }}>
|
||||
<div
|
||||
onClick={expandable ? () => toggleField(key) : undefined}
|
||||
style={{
|
||||
display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)',
|
||||
padding: 'var(--spacing-xs) var(--spacing-sm)',
|
||||
cursor: expandable ? 'pointer' : 'default',
|
||||
fontSize: '0.8125rem',
|
||||
}}
|
||||
>
|
||||
{expandable ? (
|
||||
<i className={`fas fa-chevron-${expanded ? 'down' : 'right'}`} style={{ fontSize: '0.6rem', color: 'var(--color-text-secondary)', width: 12, flexShrink: 0 }} />
|
||||
) : (
|
||||
<span style={{ width: 12, flexShrink: 0 }} />
|
||||
)}
|
||||
<span style={{ fontFamily: 'monospace', color: 'var(--color-primary)', flexShrink: 0 }}>{key}</span>
|
||||
{objValue && !expanded && <span style={{ fontSize: '0.75rem', color: 'var(--color-text-secondary)' }}>{fieldSummary(value)}</span>}
|
||||
{!objValue && !large && <span style={{ fontFamily: 'monospace', fontSize: '0.75rem', color: 'var(--color-text-secondary)' }}>{formatValue(value)}</span>}
|
||||
{!objValue && large && !expanded && <span style={{ fontSize: '0.75rem', color: 'var(--color-text-secondary)', overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>{truncateValue(value, 120)}</span>}
|
||||
</div>
|
||||
{expanded && objValue && (
|
||||
<div style={{ padding: '0 0 var(--spacing-xs) var(--spacing-md)' }}>
|
||||
<DataFields data={value} nested />
|
||||
</div>
|
||||
)}
|
||||
{expanded && large && (
|
||||
<div style={{ padding: '0 var(--spacing-sm) var(--spacing-sm)' }}>
|
||||
<pre style={{
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-sm)',
|
||||
fontSize: '0.75rem', fontFamily: 'monospace', whiteSpace: 'pre-wrap', wordBreak: 'break-word',
|
||||
overflow: 'auto', maxHeight: '50vh', margin: 0,
|
||||
}}>
|
||||
{formatLargeValue(value)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Expanded detail for a backend trace row
|
||||
function BackendTraceDetail({ trace }) {
|
||||
const infoItems = [
|
||||
{ label: 'Type', value: trace.type },
|
||||
{ label: 'Model', value: trace.model_name || '-' },
|
||||
{ label: 'Backend', value: trace.backend || '-' },
|
||||
{ label: 'Duration', value: formatDuration(trace.duration) },
|
||||
]
|
||||
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-md)', background: 'var(--color-bg-secondary)', borderBottom: '1px solid var(--color-border)' }}>
|
||||
{/* Summary cards */}
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(4, 1fr)', gap: 'var(--spacing-xs)', marginBottom: 'var(--spacing-md)', fontSize: '0.75rem' }}>
|
||||
{infoItems.map(item => (
|
||||
<div key={item.label} style={{ background: 'var(--color-bg-primary)', borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-xs)', border: '1px solid var(--color-border)' }}>
|
||||
<div style={{ color: 'var(--color-text-secondary)' }}>{item.label}</div>
|
||||
<div style={{ fontWeight: 500 }}>{item.label === 'Type' ? <span style={typeBadgeStyle(item.value)}>{item.value}</span> : item.value}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Error banner */}
|
||||
{trace.error && (
|
||||
<div style={{
|
||||
background: 'rgba(239,68,68,0.1)', border: '1px solid rgba(239,68,68,0.3)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm)', marginBottom: 'var(--spacing-md)',
|
||||
display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)',
|
||||
}}>
|
||||
<i className="fas fa-exclamation-triangle" style={{ color: 'var(--color-error)' }} />
|
||||
<span style={{ color: 'var(--color-error)', fontSize: '0.8125rem' }}>{trace.error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Audio snippet */}
|
||||
{trace.data && <AudioSnippet data={trace.data} />}
|
||||
|
||||
{/* Data fields */}
|
||||
{trace.data && Object.keys(trace.data).length > 0 && <DataFields data={trace.data} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Expanded detail for an API trace row
|
||||
function ApiTraceDetail({ trace }) {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-md)', background: 'var(--color-bg-secondary)', borderBottom: '1px solid var(--color-border)' }}>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 'var(--spacing-md)' }}>
|
||||
<div>
|
||||
<h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-xs)' }}>Request Body</h4>
|
||||
<pre style={{
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-sm)',
|
||||
fontSize: '0.75rem', fontFamily: 'monospace', whiteSpace: 'pre-wrap', wordBreak: 'break-word',
|
||||
overflow: 'auto', maxHeight: '50vh', margin: 0,
|
||||
}}>
|
||||
{decodeTraceBody(trace.request?.body)}
|
||||
</pre>
|
||||
</div>
|
||||
<div>
|
||||
<h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-xs)' }}>Response Body</h4>
|
||||
<pre style={{
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border)',
|
||||
borderRadius: 'var(--radius-sm)', padding: 'var(--spacing-sm)',
|
||||
fontSize: '0.75rem', fontFamily: 'monospace', whiteSpace: 'pre-wrap', wordBreak: 'break-word',
|
||||
overflow: 'auto', maxHeight: '50vh', margin: 0,
|
||||
}}>
|
||||
{decodeTraceBody(trace.response?.body)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default function Traces() {
|
||||
const { addToast } = useOutletContext()
|
||||
const [activeTab, setActiveTab] = useState('api')
|
||||
const [traces, setTraces] = useState([])
|
||||
const [apiCount, setApiCount] = useState(0)
|
||||
const [backendCount, setBackendCount] = useState(0)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [expandedRow, setExpandedRow] = useState(null)
|
||||
const [tracingEnabled, setTracingEnabled] = useState(null)
|
||||
const refreshRef = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
settingsApi.get()
|
||||
.then(data => setTracingEnabled(!!data.enable_tracing))
|
||||
.catch(() => {})
|
||||
}, [])
|
||||
|
||||
const fetchTraces = useCallback(async () => {
|
||||
try {
|
||||
setLoading(true)
|
||||
const data = activeTab === 'api'
|
||||
? await tracesApi.get()
|
||||
: await tracesApi.getBackend()
|
||||
setTraces(Array.isArray(data) ? data : [])
|
||||
const [apiData, backendData] = await Promise.all([
|
||||
tracesApi.get(),
|
||||
tracesApi.getBackend(),
|
||||
])
|
||||
const api = Array.isArray(apiData) ? apiData : []
|
||||
const backend = Array.isArray(backendData) ? backendData : []
|
||||
setApiCount(api.length)
|
||||
setBackendCount(backend.length)
|
||||
setTraces(activeTab === 'api' ? api : backend)
|
||||
} catch (err) {
|
||||
addToast(`Failed to load traces: ${err.message}`, 'error')
|
||||
} finally {
|
||||
@@ -33,14 +297,23 @@ export default function Traces() {
|
||||
}, [activeTab, addToast])
|
||||
|
||||
useEffect(() => {
|
||||
setLoading(true)
|
||||
setExpandedRow(null)
|
||||
fetchTraces()
|
||||
}, [fetchTraces])
|
||||
|
||||
// Auto-refresh every 5 seconds
|
||||
useEffect(() => {
|
||||
refreshRef.current = setInterval(fetchTraces, 5000)
|
||||
return () => clearInterval(refreshRef.current)
|
||||
}, [fetchTraces])
|
||||
|
||||
const handleClear = async () => {
|
||||
try {
|
||||
if (activeTab === 'api') await tracesApi.clear()
|
||||
else await tracesApi.clearBackend()
|
||||
setTraces([])
|
||||
setExpandedRow(null)
|
||||
addToast('Traces cleared', 'success')
|
||||
} catch (err) {
|
||||
addToast(`Failed to clear: ${err.message}`, 'error')
|
||||
@@ -61,12 +334,20 @@ export default function Traces() {
|
||||
<div className="page">
|
||||
<div className="page-header">
|
||||
<h1 className="page-title">Traces</h1>
|
||||
<p className="page-subtitle">Debug API and backend traces</p>
|
||||
<p className="page-subtitle">View logged API requests, responses, and backend operations</p>
|
||||
</div>
|
||||
|
||||
<div className="tabs">
|
||||
<button className={`tab ${activeTab === 'api' ? 'tab-active' : ''}`} onClick={() => setActiveTab('api')}>API Traces</button>
|
||||
<button className={`tab ${activeTab === 'backend' ? 'tab-active' : ''}`} onClick={() => setActiveTab('backend')}>Backend Traces</button>
|
||||
<button className={`tab ${activeTab === 'api' ? 'tab-active' : ''}`} onClick={() => setActiveTab('api')}>
|
||||
<i className="fas fa-exchange-alt" style={{ marginRight: 'var(--spacing-xs)', fontSize: '0.75rem' }} />
|
||||
API Traces
|
||||
<span style={{ marginLeft: 'var(--spacing-xs)', opacity: 0.6, fontSize: '0.75rem' }}>({apiCount})</span>
|
||||
</button>
|
||||
<button className={`tab ${activeTab === 'backend' ? 'tab-active' : ''}`} onClick={() => setActiveTab('backend')}>
|
||||
<i className="fas fa-cogs" style={{ marginRight: 'var(--spacing-xs)', fontSize: '0.75rem' }} />
|
||||
Backend Traces
|
||||
<span style={{ marginLeft: 'var(--spacing-xs)', opacity: 0.6, fontSize: '0.75rem' }}>({backendCount})</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', marginBottom: 'var(--spacing-md)' }}>
|
||||
@@ -75,6 +356,33 @@ export default function Traces() {
|
||||
<button className="btn btn-secondary btn-sm" onClick={handleExport} disabled={traces.length === 0}><i className="fas fa-download" /> Export</button>
|
||||
</div>
|
||||
|
||||
{tracingEnabled === false && (
|
||||
<div style={{
|
||||
background: 'rgba(234,179,8,0.1)', border: '1px solid rgba(234,179,8,0.3)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
marginBottom: 'var(--spacing-md)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)',
|
||||
}}>
|
||||
<i className="fas fa-exclamation-triangle" style={{ color: '#facc15', flexShrink: 0 }} />
|
||||
<span style={{ fontSize: '0.8125rem' }}>
|
||||
Tracing is currently <strong>disabled</strong>. New requests will not be recorded.{' '}
|
||||
<Link to="/settings" style={{ color: 'var(--color-primary)' }}>Enable in Settings</Link>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{tracingEnabled === true && (
|
||||
<div style={{
|
||||
background: 'rgba(34,197,94,0.08)', border: '1px solid rgba(34,197,94,0.25)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
marginBottom: 'var(--spacing-md)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)',
|
||||
}}>
|
||||
<i className="fas fa-circle-check" style={{ color: 'var(--color-success)', flexShrink: 0 }} />
|
||||
<span style={{ fontSize: '0.8125rem' }}>
|
||||
Tracing is <strong>enabled</strong>. Requests are being recorded.{' '}
|
||||
<Link to="/settings" style={{ color: 'var(--color-primary)' }}>Manage in Settings</Link>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', padding: 'var(--spacing-xl)' }}><LoadingSpinner size="lg" /></div>
|
||||
) : traces.length === 0 ? (
|
||||
@@ -89,7 +397,6 @@ export default function Traces() {
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: '30px' }}></th>
|
||||
<th>Time</th>
|
||||
<th>Method</th>
|
||||
<th>Path</th>
|
||||
<th>Status</th>
|
||||
@@ -100,17 +407,14 @@ export default function Traces() {
|
||||
<React.Fragment key={i}>
|
||||
<tr onClick={() => setExpandedRow(expandedRow === i ? null : i)} style={{ cursor: 'pointer' }}>
|
||||
<td><i className={`fas fa-chevron-${expandedRow === i ? 'down' : 'right'}`} style={{ fontSize: '0.7rem' }} /></td>
|
||||
<td>{trace.timestamp ? new Date(trace.timestamp).toLocaleTimeString() : '-'}</td>
|
||||
<td><span className="badge badge-info">{trace.request?.method || '-'}</span></td>
|
||||
<td style={{ fontFamily: 'JetBrains Mono, monospace', fontSize: '0.8125rem' }}>{trace.request?.path || '-'}</td>
|
||||
<td><span className={`badge ${(trace.response?.status || 0) < 400 ? 'badge-success' : 'badge-error'}`}>{trace.response?.status || '-'}</span></td>
|
||||
</tr>
|
||||
{expandedRow === i && (
|
||||
<tr>
|
||||
<td colSpan="5">
|
||||
<pre style={{ background: 'var(--color-bg-primary)', padding: 'var(--spacing-sm)', borderRadius: 'var(--radius-md)', fontSize: '0.75rem', overflow: 'auto', maxHeight: '300px' }}>
|
||||
{JSON.stringify(trace, null, 2)}
|
||||
</pre>
|
||||
<td colSpan="4" style={{ padding: 0 }}>
|
||||
<ApiTraceDetail trace={trace} />
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
@@ -125,12 +429,12 @@ export default function Traces() {
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: '30px' }}></th>
|
||||
<th>Time</th>
|
||||
<th>Type</th>
|
||||
<th>Time</th>
|
||||
<th>Model</th>
|
||||
<th>Backend</th>
|
||||
<th>Duration</th>
|
||||
<th>Summary</th>
|
||||
<th>Duration</th>
|
||||
<th style={{ width: '40px' }}>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -138,21 +442,23 @@ export default function Traces() {
|
||||
<React.Fragment key={i}>
|
||||
<tr onClick={() => setExpandedRow(expandedRow === i ? null : i)} style={{ cursor: 'pointer' }}>
|
||||
<td><i className={`fas fa-chevron-${expandedRow === i ? 'down' : 'right'}`} style={{ fontSize: '0.7rem' }} /></td>
|
||||
<td>{trace.timestamp ? new Date(trace.timestamp).toLocaleTimeString() : '-'}</td>
|
||||
<td><span className="badge badge-info">{trace.type || '-'}</span></td>
|
||||
<td><span style={typeBadgeStyle(trace.type)}>{trace.type || '-'}</span></td>
|
||||
<td style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)' }}>{formatTimestamp(trace.timestamp)}</td>
|
||||
<td style={{ fontFamily: 'JetBrains Mono, monospace', fontSize: '0.8125rem' }}>{trace.model_name || '-'}</td>
|
||||
<td>{trace.backend || '-'}</td>
|
||||
<td>{formatDuration(trace.duration)}</td>
|
||||
<td style={{ maxWidth: '300px', overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>
|
||||
{trace.error ? <span style={{ color: 'var(--color-error)' }}>{trace.error}</span> : (trace.summary || '-')}
|
||||
{trace.summary || '-'}
|
||||
</td>
|
||||
<td style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)' }}>{formatDuration(trace.duration)}</td>
|
||||
<td style={{ textAlign: 'center' }}>
|
||||
{trace.error
|
||||
? <i className="fas fa-times-circle" style={{ color: 'var(--color-error)' }} title={trace.error} />
|
||||
: <i className="fas fa-check-circle" style={{ color: 'var(--color-success)' }} />}
|
||||
</td>
|
||||
</tr>
|
||||
{expandedRow === i && (
|
||||
<tr>
|
||||
<td colSpan="7">
|
||||
<pre style={{ background: 'var(--color-bg-primary)', padding: 'var(--spacing-sm)', borderRadius: 'var(--radius-md)', fontSize: '0.75rem', overflow: 'auto', maxHeight: '300px' }}>
|
||||
{JSON.stringify(trace, null, 2)}
|
||||
</pre>
|
||||
<td colSpan="7" style={{ padding: 0 }}>
|
||||
<BackendTraceDetail trace={trace} />
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
|
||||
6
core/http/react-ui/src/utils/api.js
vendored
6
core/http/react-ui/src/utils/api.js
vendored
@@ -238,6 +238,12 @@ export const audioApi = {
|
||||
},
|
||||
}
|
||||
|
||||
// Realtime / WebRTC
|
||||
export const realtimeApi = {
|
||||
call: (body) => postJSON(API_CONFIG.endpoints.realtimeCalls, body),
|
||||
pipelineModels: () => fetchJSON(API_CONFIG.endpoints.pipelineModels),
|
||||
}
|
||||
|
||||
// Backend control
|
||||
export const backendControlApi = {
|
||||
shutdown: (body) => postJSON(API_CONFIG.endpoints.backendShutdown, body),
|
||||
|
||||
4
core/http/react-ui/src/utils/config.js
vendored
4
core/http/react-ui/src/utils/config.js
vendored
@@ -65,6 +65,10 @@ export const API_CONFIG = {
|
||||
modelsList: '/v1/models',
|
||||
modelsCapabilities: '/api/models/capabilities',
|
||||
|
||||
// Realtime / WebRTC
|
||||
realtimeCalls: '/v1/realtime/calls',
|
||||
pipelineModels: '/api/pipeline-models',
|
||||
|
||||
// LocalAI-specific
|
||||
tts: '/tts',
|
||||
video: '/video',
|
||||
|
||||
@@ -21,6 +21,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
app.GET("/v1/realtime", openai.Realtime(application))
|
||||
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application), traceMiddleware)
|
||||
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application), traceMiddleware)
|
||||
app.POST("/v1/realtime/calls", openai.RealtimeCalls(application), traceMiddleware)
|
||||
|
||||
// chat
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -18,6 +21,41 @@ func RegisterUIRoutes(app *echo.Echo,
|
||||
// SPA routes are handled by the 404 fallback in app.go which serves
|
||||
// index.html for any unmatched HTML request, enabling client-side routing.
|
||||
|
||||
// Pipeline models API (for the Talk page WebRTC interface)
|
||||
app.GET("/api/pipeline-models", func(c echo.Context) error {
|
||||
type pipelineModelInfo struct {
|
||||
Name string `json:"name"`
|
||||
VAD string `json:"vad"`
|
||||
Transcription string `json:"transcription"`
|
||||
LLM string `json:"llm"`
|
||||
TTS string `json:"tts"`
|
||||
Voice string `json:"voice"`
|
||||
}
|
||||
|
||||
pipelineModels := cl.GetModelConfigsByFilter(func(_ string, cfg *config.ModelConfig) bool {
|
||||
p := cfg.Pipeline
|
||||
return p.VAD != "" && p.Transcription != "" && p.LLM != "" && p.TTS != ""
|
||||
})
|
||||
|
||||
slices.SortFunc(pipelineModels, func(a, b config.ModelConfig) int {
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var models []pipelineModelInfo
|
||||
for _, cfg := range pipelineModels {
|
||||
models = append(models, pipelineModelInfo{
|
||||
Name: cfg.Name,
|
||||
VAD: cfg.Pipeline.VAD,
|
||||
Transcription: cfg.Pipeline.Transcription,
|
||||
LLM: cfg.Pipeline.LLM,
|
||||
TTS: cfg.Pipeline.TTS,
|
||||
Voice: cfg.TTSConfig.Voice,
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(200, models)
|
||||
})
|
||||
|
||||
app.GET("/api/traces", func(c echo.Context) error {
|
||||
return c.JSON(200, middleware.GetTraces())
|
||||
})
|
||||
|
||||
@@ -1,159 +1,606 @@
|
||||
|
||||
const recordButton = document.getElementById('recordButton');
|
||||
const connectButton = document.getElementById('connectButton');
|
||||
const disconnectButton = document.getElementById('disconnectButton');
|
||||
const testToneButton = document.getElementById('testToneButton');
|
||||
const diagnosticsButton = document.getElementById('diagnosticsButton');
|
||||
const audioPlayback = document.getElementById('audioPlayback');
|
||||
const resetButton = document.getElementById('resetButton');
|
||||
const transcript = document.getElementById('transcript');
|
||||
const statusIcon = document.getElementById('statusIcon');
|
||||
const statusLabel = document.getElementById('statusLabel');
|
||||
const connectionStatus = document.getElementById('connectionStatus');
|
||||
const modelSelect = document.getElementById('modelSelect');
|
||||
|
||||
let mediaRecorder;
|
||||
let audioChunks = [];
|
||||
let isRecording = false;
|
||||
let conversationHistory = [];
|
||||
let resetTimer;
|
||||
let pc = null;
|
||||
let dc = null;
|
||||
let localStream = null;
|
||||
let hasError = false;
|
||||
|
||||
// Audio diagnostics state
|
||||
let audioCtx = null;
|
||||
let analyser = null;
|
||||
let diagAnimFrame = null;
|
||||
let statsInterval = null;
|
||||
let diagVisible = false;
|
||||
|
||||
connectButton.addEventListener('click', connect);
|
||||
disconnectButton.addEventListener('click', disconnect);
|
||||
testToneButton.addEventListener('click', sendTestTone);
|
||||
diagnosticsButton.addEventListener('click', toggleDiagnostics);
|
||||
|
||||
// Show pipeline details when a model is selected
|
||||
modelSelect.addEventListener('change', function() {
|
||||
const opt = this.options[this.selectedIndex];
|
||||
const details = document.getElementById('pipelineDetails');
|
||||
if (!opt || !opt.value) {
|
||||
details.classList.add('hidden');
|
||||
return;
|
||||
}
|
||||
document.getElementById('pipelineVAD').textContent = opt.dataset.vad || '--';
|
||||
document.getElementById('pipelineSTT').textContent = opt.dataset.stt || '--';
|
||||
document.getElementById('pipelineLLM').textContent = opt.dataset.llm || '--';
|
||||
document.getElementById('pipelineTTS').textContent = opt.dataset.tts || '--';
|
||||
details.classList.remove('hidden');
|
||||
|
||||
// Pre-fill voice from model default if the user hasn't typed anything
|
||||
const voiceInput = document.getElementById('voiceInput');
|
||||
if (!voiceInput.dataset.userEdited) {
|
||||
voiceInput.value = opt.dataset.voice || '';
|
||||
}
|
||||
});
|
||||
|
||||
// Track if user manually edited the voice field
|
||||
document.getElementById('voiceInput').addEventListener('input', function() {
|
||||
this.dataset.userEdited = 'true';
|
||||
});
|
||||
|
||||
// Auto-select first model on page load
|
||||
if (modelSelect.options.length > 1) {
|
||||
modelSelect.selectedIndex = 1;
|
||||
modelSelect.dispatchEvent(new Event('change'));
|
||||
}
|
||||
|
||||
function getModel() {
|
||||
return document.getElementById('modelSelect').value;
|
||||
return modelSelect.value;
|
||||
}
|
||||
|
||||
function getSTTModel() {
|
||||
return document.getElementById('sttModelSelect').value;
|
||||
function setStatus(state, text) {
|
||||
statusLabel.textContent = text || state;
|
||||
statusIcon.className = 'fa-solid fa-circle';
|
||||
connectionStatus.className = 'rounded-lg p-4 mb-4 flex items-center space-x-3';
|
||||
|
||||
switch (state) {
|
||||
case 'disconnected':
|
||||
statusIcon.classList.add('text-[var(--color-text-secondary)]');
|
||||
connectionStatus.classList.add('bg-[var(--color-bg-primary)]/50', 'border', 'border-[var(--color-border-subtle)]');
|
||||
statusLabel.classList.add('text-[var(--color-text-secondary)]');
|
||||
break;
|
||||
case 'connecting':
|
||||
statusIcon.className = 'fa-solid fa-spinner fa-spin text-[var(--color-primary)]';
|
||||
connectionStatus.classList.add('bg-[var(--color-primary-light)]', 'border', 'border-[var(--color-primary)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-primary)]';
|
||||
break;
|
||||
case 'connected':
|
||||
statusIcon.classList.add('text-[var(--color-success)]');
|
||||
connectionStatus.classList.add('bg-[var(--color-success)]/10', 'border', 'border-[var(--color-success)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-success)]';
|
||||
break;
|
||||
case 'listening':
|
||||
statusIcon.className = 'fa-solid fa-microphone text-[var(--color-success)]';
|
||||
connectionStatus.classList.add('bg-[var(--color-success)]/10', 'border', 'border-[var(--color-success)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-success)]';
|
||||
break;
|
||||
case 'thinking':
|
||||
statusIcon.className = 'fa-solid fa-brain fa-beat text-[var(--color-primary)]';
|
||||
connectionStatus.classList.add('bg-[var(--color-primary-light)]', 'border', 'border-[var(--color-primary)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-primary)]';
|
||||
break;
|
||||
case 'speaking':
|
||||
statusIcon.className = 'fa-solid fa-volume-high fa-beat-fade text-[var(--color-accent)]';
|
||||
connectionStatus.classList.add('bg-[var(--color-accent)]/10', 'border', 'border-[var(--color-accent)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-accent)]';
|
||||
break;
|
||||
case 'error':
|
||||
statusIcon.classList.add('text-[var(--color-error)]');
|
||||
connectionStatus.classList.add('bg-[var(--color-error-light)]', 'border', 'border-[var(--color-error)]/30');
|
||||
statusLabel.className = 'font-medium text-[var(--color-error)]';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
function getTTSModel() {
|
||||
return document.getElementById('ttsModelSelect').value;
|
||||
// Currently streaming assistant message element (for incremental updates)
|
||||
let streamingEntry = null;
|
||||
|
||||
function addTranscript(role, text) {
|
||||
// Remove the placeholder if present
|
||||
const placeholder = transcript.querySelector('.italic');
|
||||
if (placeholder) placeholder.remove();
|
||||
|
||||
const entry = document.createElement('div');
|
||||
entry.className = 'flex items-start space-x-2';
|
||||
|
||||
const icon = document.createElement('i');
|
||||
const msg = document.createElement('p');
|
||||
msg.className = 'text-[var(--color-text-primary)]';
|
||||
msg.textContent = text;
|
||||
|
||||
if (role === 'user') {
|
||||
icon.className = 'fa-solid fa-user text-[var(--color-primary)] mt-1 flex-shrink-0';
|
||||
} else {
|
||||
icon.className = 'fa-solid fa-robot text-[var(--color-accent)] mt-1 flex-shrink-0';
|
||||
}
|
||||
|
||||
entry.appendChild(icon);
|
||||
entry.appendChild(msg);
|
||||
transcript.appendChild(entry);
|
||||
transcript.scrollTop = transcript.scrollHeight;
|
||||
return entry;
|
||||
}
|
||||
|
||||
function resetConversation() {
|
||||
conversationHistory = [];
|
||||
console.log("Conversation has been reset.");
|
||||
clearTimeout(resetTimer);
|
||||
function updateStreamingTranscript(role, delta) {
|
||||
if (!streamingEntry) {
|
||||
streamingEntry = addTranscript(role, delta);
|
||||
} else {
|
||||
const msg = streamingEntry.querySelector('p');
|
||||
if (msg) msg.textContent += delta;
|
||||
transcript.scrollTop = transcript.scrollHeight;
|
||||
}
|
||||
}
|
||||
|
||||
function setResetTimer() {
|
||||
clearTimeout(resetTimer);
|
||||
resetTimer = setTimeout(resetConversation, 300000); // Reset after 5 minutes
|
||||
function finalizeStreamingTranscript(role, fullText) {
|
||||
if (streamingEntry) {
|
||||
const msg = streamingEntry.querySelector('p');
|
||||
if (msg) msg.textContent = fullText;
|
||||
streamingEntry = null;
|
||||
} else {
|
||||
addTranscript(role, fullText);
|
||||
}
|
||||
transcript.scrollTop = transcript.scrollHeight;
|
||||
}
|
||||
|
||||
recordButton.addEventListener('click', toggleRecording);
|
||||
resetButton.addEventListener('click', resetConversation);
|
||||
// Send a session.update event with the user's settings
|
||||
function sendSessionUpdate() {
|
||||
if (!dc || dc.readyState !== 'open') return;
|
||||
|
||||
function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
startRecording();
|
||||
} else {
|
||||
stopRecording();
|
||||
const instructions = document.getElementById('instructionsInput').value.trim();
|
||||
const voice = document.getElementById('voiceInput').value.trim();
|
||||
const language = document.getElementById('languageInput').value.trim();
|
||||
|
||||
// Only send if the user configured something
|
||||
if (!instructions && !voice && !language) return;
|
||||
|
||||
const session = {};
|
||||
|
||||
if (instructions) {
|
||||
session.instructions = instructions;
|
||||
}
|
||||
|
||||
if (voice || language) {
|
||||
session.audio = {};
|
||||
if (voice) {
|
||||
session.audio.output = { voice: voice };
|
||||
}
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
document.getElementById("recording").style.display = "block";
|
||||
document.getElementById("resetButton").style.display = "none";
|
||||
if (!navigator.mediaDevices) {
|
||||
alert('MediaDevices API not supported!');
|
||||
return;
|
||||
if (language) {
|
||||
session.audio.input = {
|
||||
transcription: { language: language }
|
||||
};
|
||||
}
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
audioChunks = [];
|
||||
mediaRecorder.ondataavailable = (event) => {
|
||||
audioChunks.push(event.data);
|
||||
}
|
||||
|
||||
const event = {
|
||||
type: 'session.update',
|
||||
session: session,
|
||||
};
|
||||
|
||||
console.log('[session.update]', event);
|
||||
dc.send(JSON.stringify(event));
|
||||
}
|
||||
|
||||
function handleServerEvent(event) {
|
||||
console.log('[event]', event.type, event);
|
||||
|
||||
switch (event.type) {
|
||||
case 'session.created':
|
||||
// Session is ready — send any user settings
|
||||
sendSessionUpdate();
|
||||
setStatus('listening', 'Listening...');
|
||||
break;
|
||||
|
||||
case 'session.updated':
|
||||
console.log('[session.updated] Session settings applied', event.session);
|
||||
break;
|
||||
|
||||
case 'input_audio_buffer.speech_started':
|
||||
setStatus('listening', 'Hearing you speak...');
|
||||
break;
|
||||
|
||||
case 'input_audio_buffer.speech_stopped':
|
||||
setStatus('thinking', 'Processing...');
|
||||
break;
|
||||
|
||||
case 'conversation.item.input_audio_transcription.completed':
|
||||
if (event.transcript) {
|
||||
addTranscript('user', event.transcript);
|
||||
}
|
||||
setStatus('thinking', 'Generating response...');
|
||||
break;
|
||||
|
||||
case 'response.output_audio_transcript.delta':
|
||||
// Incremental transcript — update the in-progress assistant message
|
||||
if (event.delta) {
|
||||
updateStreamingTranscript('assistant', event.delta);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'response.output_audio_transcript.done':
|
||||
if (event.transcript) {
|
||||
finalizeStreamingTranscript('assistant', event.transcript);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'response.output_audio.delta':
|
||||
setStatus('speaking', 'Speaking...');
|
||||
break;
|
||||
|
||||
case 'response.done':
|
||||
setStatus('listening', 'Listening...');
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
console.error('Server error:', event.error);
|
||||
hasError = true;
|
||||
setStatus('error', 'Error: ' + (event.error?.message || 'Unknown error'));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
async function connect() {
|
||||
const model = getModel();
|
||||
if (!model) {
|
||||
alert('Please select a pipeline model first.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
||||
setStatus('error', 'Microphone access requires HTTPS or localhost.');
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus('connecting', 'Connecting...');
|
||||
connectButton.style.display = 'none';
|
||||
disconnectButton.style.display = '';
|
||||
testToneButton.style.display = '';
|
||||
diagnosticsButton.style.display = '';
|
||||
|
||||
try {
|
||||
// Get microphone access
|
||||
localStream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
|
||||
// Create peer connection
|
||||
pc = new RTCPeerConnection({});
|
||||
|
||||
// Add local audio track
|
||||
for (const track of localStream.getAudioTracks()) {
|
||||
pc.addTrack(track, localStream);
|
||||
}
|
||||
|
||||
// Handle remote audio track (server's TTS output)
|
||||
pc.ontrack = (event) => {
|
||||
audioPlayback.srcObject = event.streams[0];
|
||||
// If diagnostics panel is open, start analyzing the new stream
|
||||
if (diagVisible) startDiagnostics();
|
||||
};
|
||||
mediaRecorder.start();
|
||||
recordButton.textContent = 'Stop Recording';
|
||||
// add class bg-red-500 to recordButton
|
||||
recordButton.classList.add("bg-gray-500");
|
||||
|
||||
isRecording = true;
|
||||
}
|
||||
|
||||
function stopRecording() {
|
||||
mediaRecorder.stop();
|
||||
mediaRecorder.onstop = async () => {
|
||||
document.getElementById("recording").style.display = "none";
|
||||
document.getElementById("recordButton").style.display = "none";
|
||||
|
||||
document.getElementById("loader").style.display = "block";
|
||||
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
|
||||
document.getElementById("statustext").textContent = "Processing audio...";
|
||||
const transcript = await sendAudioToSTT(audioBlob);
|
||||
console.log("Transcript:", transcript);
|
||||
document.getElementById("statustext").textContent = "Seems you said: " + transcript+ ". Generating response...";
|
||||
const responseText = await sendTextToChatGPT(transcript);
|
||||
|
||||
console.log("Response:", responseText);
|
||||
document.getElementById("statustext").textContent = "Response generated: '" + responseText + "'. Generating audio response...";
|
||||
|
||||
const ttsAudio = await getTextToSpeechAudio(responseText);
|
||||
playAudioResponse(ttsAudio);
|
||||
|
||||
recordButton.textContent = 'Record';
|
||||
// remove class bg-red-500 from recordButton
|
||||
recordButton.classList.remove("bg-gray-500");
|
||||
isRecording = false;
|
||||
document.getElementById("loader").style.display = "none";
|
||||
document.getElementById("recordButton").style.display = "block";
|
||||
document.getElementById("resetButton").style.display = "block";
|
||||
document.getElementById("statustext").textContent = "Press the record button to start recording.";
|
||||
// Create the events data channel (client must create it so m=application
|
||||
// is included in the SDP offer — the answerer cannot add new m-lines)
|
||||
dc = pc.createDataChannel('oai-events');
|
||||
dc.onmessage = (msg) => {
|
||||
try {
|
||||
const text = typeof msg.data === 'string'
|
||||
? msg.data
|
||||
: new TextDecoder().decode(msg.data);
|
||||
const event = JSON.parse(text);
|
||||
handleServerEvent(event);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse server event:', e);
|
||||
}
|
||||
};
|
||||
dc.onclose = () => {
|
||||
console.log('Data channel closed');
|
||||
};
|
||||
}
|
||||
|
||||
async function sendAudioToSTT(audioBlob) {
|
||||
const formData = new FormData();
|
||||
formData.append('file', audioBlob);
|
||||
formData.append('model', getSTTModel());
|
||||
pc.onconnectionstatechange = () => {
|
||||
console.log('Connection state:', pc.connectionState);
|
||||
if (pc.connectionState === 'connected') {
|
||||
setStatus('connected', 'Connected, waiting for session...');
|
||||
} else if (pc.connectionState === 'failed' || pc.connectionState === 'closed') {
|
||||
disconnect();
|
||||
}
|
||||
};
|
||||
|
||||
const response = await fetch('v1/audio/transcriptions', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
// Create offer
|
||||
const offer = await pc.createOffer();
|
||||
await pc.setLocalDescription(offer);
|
||||
|
||||
// Wait for ICE gathering
|
||||
await new Promise((resolve) => {
|
||||
if (pc.iceGatheringState === 'complete') {
|
||||
resolve();
|
||||
} else {
|
||||
pc.onicegatheringstatechange = () => {
|
||||
if (pc.iceGatheringState === 'complete') resolve();
|
||||
};
|
||||
// Timeout after 5s
|
||||
setTimeout(resolve, 5000);
|
||||
}
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
console.log("STT result:", result)
|
||||
return result.text;
|
||||
}
|
||||
|
||||
async function sendTextToChatGPT(text) {
|
||||
conversationHistory.push({ role: "user", content: text });
|
||||
|
||||
const response = await fetch('v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: getModel(),
|
||||
messages: conversationHistory
|
||||
})
|
||||
// Send offer to server
|
||||
const response = await fetch('v1/realtime/calls', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
sdp: pc.localDescription.sdp,
|
||||
model: model,
|
||||
}),
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
const responseText = result.choices[0].message.content;
|
||||
conversationHistory.push({ role: "assistant", content: responseText });
|
||||
if (!response.ok) {
|
||||
const err = await response.json().catch(() => ({ error: 'Unknown error' }));
|
||||
throw new Error(err.error || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
setResetTimer();
|
||||
const data = await response.json();
|
||||
|
||||
return responseText;
|
||||
}
|
||||
|
||||
async function getTextToSpeechAudio(text) {
|
||||
const response = await fetch('v1/audio/speech', {
|
||||
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
// "backend": "string",
|
||||
input: text,
|
||||
model: getTTSModel(),
|
||||
// "voice": "string"
|
||||
})
|
||||
// Set remote description (server's answer)
|
||||
await pc.setRemoteDescription({
|
||||
type: 'answer',
|
||||
sdp: data.sdp,
|
||||
});
|
||||
|
||||
const audioBlob = await response.blob();
|
||||
return audioBlob; // Return the blob directly
|
||||
console.log('WebRTC connection established, session:', data.session_id);
|
||||
} catch (err) {
|
||||
console.error('Connection failed:', err);
|
||||
hasError = true;
|
||||
setStatus('error', 'Connection failed: ' + err.message);
|
||||
disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
function playAudioResponse(audioBlob) {
|
||||
const audioUrl = URL.createObjectURL(audioBlob);
|
||||
audioPlayback.src = audioUrl;
|
||||
audioPlayback.hidden = false;
|
||||
audioPlayback.play();
|
||||
function sendTestTone() {
|
||||
if (!dc || dc.readyState !== 'open') {
|
||||
console.warn('Data channel not open');
|
||||
return;
|
||||
}
|
||||
console.log('[test-tone] Requesting server test tone...');
|
||||
dc.send(JSON.stringify({ type: 'test_tone' }));
|
||||
addTranscript('assistant', '(Test tone requested — you should hear a 440 Hz beep)');
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
stopDiagnostics();
|
||||
if (dc) {
|
||||
dc.close();
|
||||
dc = null;
|
||||
}
|
||||
if (pc) {
|
||||
pc.close();
|
||||
pc = null;
|
||||
}
|
||||
if (localStream) {
|
||||
localStream.getTracks().forEach(t => t.stop());
|
||||
localStream = null;
|
||||
}
|
||||
audioPlayback.srcObject = null;
|
||||
|
||||
if (!hasError) {
|
||||
setStatus('disconnected', 'Disconnected');
|
||||
}
|
||||
hasError = false;
|
||||
connectButton.style.display = '';
|
||||
disconnectButton.style.display = 'none';
|
||||
testToneButton.style.display = 'none';
|
||||
diagnosticsButton.style.display = 'none';
|
||||
}
|
||||
|
||||
// ── Audio Diagnostics ──
|
||||
|
||||
function toggleDiagnostics() {
|
||||
const panel = document.getElementById('diagnosticsPanel');
|
||||
diagVisible = !diagVisible;
|
||||
panel.style.display = diagVisible ? '' : 'none';
|
||||
if (diagVisible) {
|
||||
startDiagnostics();
|
||||
} else {
|
||||
stopDiagnostics();
|
||||
}
|
||||
}
|
||||
|
||||
function startDiagnostics() {
|
||||
if (!audioPlayback.srcObject) return;
|
||||
|
||||
// Create AudioContext and connect the remote stream to an AnalyserNode
|
||||
if (!audioCtx) {
|
||||
audioCtx = new AudioContext();
|
||||
const source = audioCtx.createMediaStreamSource(audioPlayback.srcObject);
|
||||
analyser = audioCtx.createAnalyser();
|
||||
analyser.fftSize = 8192;
|
||||
analyser.smoothingTimeConstant = 0.3;
|
||||
source.connect(analyser);
|
||||
|
||||
document.getElementById('statSampleRate').textContent = audioCtx.sampleRate + ' Hz';
|
||||
}
|
||||
|
||||
// Start rendering loop
|
||||
if (!diagAnimFrame) {
|
||||
drawDiagnostics();
|
||||
}
|
||||
|
||||
// Start WebRTC stats polling
|
||||
if (!statsInterval) {
|
||||
pollWebRTCStats();
|
||||
statsInterval = setInterval(pollWebRTCStats, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
function stopDiagnostics() {
|
||||
if (diagAnimFrame) {
|
||||
cancelAnimationFrame(diagAnimFrame);
|
||||
diagAnimFrame = null;
|
||||
}
|
||||
if (statsInterval) {
|
||||
clearInterval(statsInterval);
|
||||
statsInterval = null;
|
||||
}
|
||||
if (audioCtx) {
|
||||
audioCtx.close();
|
||||
audioCtx = null;
|
||||
analyser = null;
|
||||
}
|
||||
}
|
||||
|
||||
function drawDiagnostics() {
|
||||
if (!analyser || !diagVisible) {
|
||||
diagAnimFrame = null;
|
||||
return;
|
||||
}
|
||||
|
||||
diagAnimFrame = requestAnimationFrame(drawDiagnostics);
|
||||
|
||||
// ── Waveform ──
|
||||
const waveCanvas = document.getElementById('waveformCanvas');
|
||||
const wCtx = waveCanvas.getContext('2d');
|
||||
const timeData = new Float32Array(analyser.fftSize);
|
||||
analyser.getFloatTimeDomainData(timeData);
|
||||
|
||||
const w = waveCanvas.width;
|
||||
const h = waveCanvas.height;
|
||||
wCtx.fillStyle = '#000';
|
||||
wCtx.fillRect(0, 0, w, h);
|
||||
wCtx.strokeStyle = '#0f0';
|
||||
wCtx.lineWidth = 1;
|
||||
wCtx.beginPath();
|
||||
const sliceWidth = w / timeData.length;
|
||||
let x = 0;
|
||||
for (let i = 0; i < timeData.length; i++) {
|
||||
const y = (1 - timeData[i]) * h / 2;
|
||||
if (i === 0) wCtx.moveTo(x, y);
|
||||
else wCtx.lineTo(x, y);
|
||||
x += sliceWidth;
|
||||
}
|
||||
wCtx.stroke();
|
||||
|
||||
// Compute RMS
|
||||
let sumSq = 0;
|
||||
for (let i = 0; i < timeData.length; i++) sumSq += timeData[i] * timeData[i];
|
||||
const rms = Math.sqrt(sumSq / timeData.length);
|
||||
const rmsDb = rms > 0 ? (20 * Math.log10(rms)).toFixed(1) : '-Inf';
|
||||
document.getElementById('statRMS').textContent = rmsDb + ' dBFS';
|
||||
|
||||
// ── FFT Spectrum ──
|
||||
const specCanvas = document.getElementById('spectrumCanvas');
|
||||
const sCtx = specCanvas.getContext('2d');
|
||||
const freqData = new Float32Array(analyser.frequencyBinCount);
|
||||
analyser.getFloatFrequencyData(freqData);
|
||||
|
||||
const sw = specCanvas.width;
|
||||
const sh = specCanvas.height;
|
||||
sCtx.fillStyle = '#000';
|
||||
sCtx.fillRect(0, 0, sw, sh);
|
||||
|
||||
// Draw spectrum (0 to 4kHz range for speech/tone analysis)
|
||||
const sampleRate = audioCtx.sampleRate;
|
||||
const binHz = sampleRate / analyser.fftSize;
|
||||
const maxFreqDisplay = 4000;
|
||||
const maxBin = Math.min(Math.ceil(maxFreqDisplay / binHz), freqData.length);
|
||||
const barWidth = sw / maxBin;
|
||||
|
||||
sCtx.fillStyle = '#0cf';
|
||||
let peakBin = 0;
|
||||
let peakVal = -Infinity;
|
||||
for (let i = 0; i < maxBin; i++) {
|
||||
const db = freqData[i];
|
||||
if (db > peakVal) {
|
||||
peakVal = db;
|
||||
peakBin = i;
|
||||
}
|
||||
// Map dB (-100 to 0) to pixel height
|
||||
const barH = Math.max(0, ((db + 100) / 100) * sh);
|
||||
sCtx.fillRect(i * barWidth, sh - barH, Math.max(1, barWidth - 0.5), barH);
|
||||
}
|
||||
|
||||
// Draw frequency labels
|
||||
sCtx.fillStyle = '#888';
|
||||
sCtx.font = '10px monospace';
|
||||
for (let f = 500; f <= maxFreqDisplay; f += 500) {
|
||||
const xPos = (f / binHz) * barWidth;
|
||||
sCtx.fillText(f + '', xPos - 10, sh - 2);
|
||||
}
|
||||
|
||||
// Mark 440 Hz
|
||||
const bin440 = Math.round(440 / binHz);
|
||||
const x440 = bin440 * barWidth;
|
||||
sCtx.strokeStyle = '#f00';
|
||||
sCtx.lineWidth = 1;
|
||||
sCtx.beginPath();
|
||||
sCtx.moveTo(x440, 0);
|
||||
sCtx.lineTo(x440, sh);
|
||||
sCtx.stroke();
|
||||
sCtx.fillStyle = '#f00';
|
||||
sCtx.fillText('440', x440 + 2, 10);
|
||||
|
||||
const peakFreq = peakBin * binHz;
|
||||
document.getElementById('statPeakFreq').textContent =
|
||||
peakFreq.toFixed(0) + ' Hz (' + peakVal.toFixed(1) + ' dB)';
|
||||
|
||||
// Compute THD (Total Harmonic Distortion) relative to 440 Hz
|
||||
// THD = sqrt(sum of harmonic powers / fundamental power)
|
||||
const fundamentalBin = Math.round(440 / binHz);
|
||||
const fundamentalPower = Math.pow(10, freqData[fundamentalBin] / 10);
|
||||
let harmonicPower = 0;
|
||||
for (let h = 2; h <= 10; h++) {
|
||||
const hBin = Math.round(440 * h / binHz);
|
||||
if (hBin < freqData.length) {
|
||||
harmonicPower += Math.pow(10, freqData[hBin] / 10);
|
||||
}
|
||||
}
|
||||
const thd = fundamentalPower > 0
|
||||
? (Math.sqrt(harmonicPower / fundamentalPower) * 100).toFixed(1)
|
||||
: '--';
|
||||
document.getElementById('statTHD').textContent = thd + '%';
|
||||
}
|
||||
|
||||
async function pollWebRTCStats() {
|
||||
if (!pc) return;
|
||||
try {
|
||||
const stats = await pc.getStats();
|
||||
const raw = [];
|
||||
stats.forEach((report) => {
|
||||
if (report.type === 'inbound-rtp' && report.kind === 'audio') {
|
||||
document.getElementById('statPacketsRecv').textContent =
|
||||
report.packetsReceived ?? '--';
|
||||
document.getElementById('statPacketsLost').textContent =
|
||||
report.packetsLost ?? '--';
|
||||
document.getElementById('statJitter').textContent =
|
||||
report.jitter !== undefined ? (report.jitter * 1000).toFixed(1) + ' ms' : '--';
|
||||
document.getElementById('statConcealed').textContent =
|
||||
report.concealedSamples ?? '--';
|
||||
|
||||
raw.push('── inbound-rtp (audio) ──');
|
||||
raw.push(' packetsReceived: ' + report.packetsReceived);
|
||||
raw.push(' packetsLost: ' + report.packetsLost);
|
||||
raw.push(' jitter: ' + (report.jitter !== undefined ? (report.jitter * 1000).toFixed(2) + ' ms' : 'N/A'));
|
||||
raw.push(' bytesReceived: ' + report.bytesReceived);
|
||||
raw.push(' concealedSamples: ' + report.concealedSamples);
|
||||
raw.push(' silentConcealedSamples: ' + report.silentConcealedSamples);
|
||||
raw.push(' totalSamplesReceived: ' + report.totalSamplesReceived);
|
||||
raw.push(' insertedSamplesForDecel: ' + report.insertedSamplesForDeceleration);
|
||||
raw.push(' removedSamplesForAccel: ' + report.removedSamplesForAcceleration);
|
||||
raw.push(' jitterBufferDelay: ' + (report.jitterBufferDelay !== undefined ? report.jitterBufferDelay.toFixed(3) + ' s' : 'N/A'));
|
||||
raw.push(' jitterBufferTargetDelay: ' + (report.jitterBufferTargetDelay !== undefined ? report.jitterBufferTargetDelay.toFixed(3) + ' s' : 'N/A'));
|
||||
raw.push(' jitterBufferEmittedCount: ' + report.jitterBufferEmittedCount);
|
||||
}
|
||||
});
|
||||
document.getElementById('statsRaw').textContent = raw.join('\n');
|
||||
} catch (e) {
|
||||
console.warn('Stats polling error:', e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
<html lang="en">
|
||||
{{template "views/partials/head" .}}
|
||||
<script defer src="static/talk.js"></script>
|
||||
<body class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]" x-data="{ key: $store.chat.key }">
|
||||
<body class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]">
|
||||
<div class="app-layout">
|
||||
{{template "views/partials/navbar" .}}
|
||||
|
||||
|
||||
<main class="main-content">
|
||||
<div class="main-content-inner">
|
||||
|
||||
@@ -16,107 +16,206 @@
|
||||
<h1 class="hero-title">
|
||||
<i class="fas fa-comments mr-2"></i>Talk Interface
|
||||
</h1>
|
||||
<p class="hero-subtitle">Speak with your AI models using voice interaction</p>
|
||||
<p class="hero-subtitle">Real-time voice conversation with your AI models via WebRTC</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Talk Interface -->
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="card overflow-hidden">
|
||||
<!-- Talk Interface Body -->
|
||||
<div class="p-6">
|
||||
<!-- Recording Status -->
|
||||
<div id="recording" class="bg-[var(--color-error-light)] border border-[var(--color-error)]/30 rounded-lg p-4 mb-4 flex items-center space-x-3" style="display: none;">
|
||||
<i class="fa-solid fa-microphone text-2xl text-[var(--color-error)]"></i>
|
||||
<span class="text-[var(--color-error)] font-medium">Recording... press "Stop recording" to stop</span>
|
||||
<!-- Connection Status -->
|
||||
<div id="connectionStatus" class="rounded-lg p-4 mb-4 flex items-center space-x-3 bg-[var(--color-bg-primary)]/50 border border-[var(--color-border-subtle)]">
|
||||
<i id="statusIcon" class="fa-solid fa-circle text-[var(--color-text-secondary)]"></i>
|
||||
<span id="statusLabel" class="font-medium text-[var(--color-text-secondary)]">Disconnected</span>
|
||||
</div>
|
||||
|
||||
<!-- Loader -->
|
||||
<div id="loader" class="my-4 flex justify-center" style="display: none;">
|
||||
<div class="animate-spin rounded-full h-10 w-10 border-t-2 border-b-2 border-[var(--color-primary)]"></div>
|
||||
</div>
|
||||
|
||||
<!-- Status Text -->
|
||||
<div id="statustext" class="my-4 p-3 bg-[var(--color-bg-primary)]/50 border border-[var(--color-border-subtle)] rounded-lg text-[var(--color-text-primary)]" style="min-height: 3rem;">Press the record button to start recording.</div>
|
||||
|
||||
|
||||
<!-- Note -->
|
||||
<div class="bg-[var(--color-primary-light)] border border-[var(--color-primary)]/20 rounded-lg p-4 mb-6">
|
||||
<div class="flex items-start">
|
||||
<i class="fas fa-info-circle text-[var(--color-primary)] mt-1 mr-3 flex-shrink-0"></i>
|
||||
<p class="text-[var(--color-text-secondary)]">
|
||||
<strong class="text-[var(--color-primary)]">Note:</strong> You need an LLM, an audio-transcription (whisper), and a TTS model installed for this to work. Select the appropriate models below and click 'Talk' to start recording. The recording will continue until you click 'Stop recording'. Make sure your microphone is set up and enabled.
|
||||
<strong class="text-[var(--color-primary)]">Note:</strong> Select a pipeline model below and click 'Connect' to start a real-time voice conversation. The pipeline model includes VAD, transcription, LLM, and TTS components. Your microphone audio streams continuously; the server detects speech and responds automatically.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Selectors -->
|
||||
<div class="grid grid-cols-1 md:grid-cols-3 gap-6 mb-6">
|
||||
<!-- LLM Model -->
|
||||
<div class="space-y-2">
|
||||
<label for="modelSelect" class="flex items-center text-[var(--color-text-secondary)] font-medium">
|
||||
<i class="fas fa-brain text-[var(--color-primary)] mr-2"></i>LLM Model
|
||||
</label>
|
||||
<select id="modelSelect"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-primary)] focus:ring-2 focus:ring-[var(--color-primary)]/50 rounded-lg shadow-sm p-2.5 appearance-none">
|
||||
<option value="" disabled class="text-[var(--color-text-secondary)]">Select a model</option>
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.}}" class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- STT Model -->
|
||||
<div class="space-y-2">
|
||||
<label for="sttModelSelect" class="flex items-center text-[var(--color-text-secondary)] font-medium">
|
||||
<i class="fas fa-ear-listen text-[var(--color-accent)] mr-2"></i>STT Model
|
||||
</label>
|
||||
<select id="sttModelSelect"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-accent)] focus:ring-2 focus:ring-[var(--color-accent)]/50 rounded-lg shadow-sm p-2.5 appearance-none">
|
||||
<option value="" disabled class="text-[var(--color-text-secondary)]">Select a model</option>
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.}}" class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- TTS Model -->
|
||||
<div class="space-y-2">
|
||||
<label for="ttsModelSelect" class="flex items-center text-[var(--color-text-secondary)] font-medium">
|
||||
<i class="fas fa-volume-high text-[var(--color-success)] mr-2"></i>TTS Model
|
||||
</label>
|
||||
<select id="ttsModelSelect"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-success)] focus:ring-2 focus:ring-[var(--color-success)]/50 rounded-lg shadow-sm p-2.5 appearance-none">
|
||||
<option value="" disabled class="text-[var(--color-text-secondary)]">Select a model</option>
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.}}" class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
|
||||
<!-- Model Selector -->
|
||||
<div class="mb-4 space-y-2">
|
||||
<label for="modelSelect" class="flex items-center text-[var(--color-text-secondary)] font-medium">
|
||||
<i class="fas fa-brain text-[var(--color-primary)] mr-2"></i>Pipeline Model
|
||||
</label>
|
||||
<select id="modelSelect"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-primary)] focus:ring-2 focus:ring-[var(--color-primary)]/50 rounded-lg shadow-sm p-2.5 appearance-none">
|
||||
<option value="" disabled class="text-[var(--color-text-secondary)]">Select a pipeline model</option>
|
||||
{{ range .PipelineModels }}
|
||||
<option value="{{.Name}}"
|
||||
data-vad="{{.VAD}}"
|
||||
data-stt="{{.Transcription}}"
|
||||
data-llm="{{.LLM}}"
|
||||
data-tts="{{.TTS}}"
|
||||
data-voice="{{.Voice}}"
|
||||
class="bg-[var(--color-bg-primary)] text-[var(--color-text-primary)]">{{.Name}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Pipeline Details (shown when a model is selected) -->
|
||||
<div id="pipelineDetails" class="mb-6 hidden">
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2 border border-[var(--color-border-subtle)]">
|
||||
<p class="text-[var(--color-text-secondary)] mb-0.5">VAD</p>
|
||||
<p id="pipelineVAD" class="font-mono text-[var(--color-text-primary)] truncate"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2 border border-[var(--color-border-subtle)]">
|
||||
<p class="text-[var(--color-text-secondary)] mb-0.5">Transcription</p>
|
||||
<p id="pipelineSTT" class="font-mono text-[var(--color-text-primary)] truncate"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2 border border-[var(--color-border-subtle)]">
|
||||
<p class="text-[var(--color-text-secondary)] mb-0.5">LLM</p>
|
||||
<p id="pipelineLLM" class="font-mono text-[var(--color-text-primary)] truncate"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2 border border-[var(--color-border-subtle)]">
|
||||
<p class="text-[var(--color-text-secondary)] mb-0.5">TTS</p>
|
||||
<p id="pipelineTTS" class="font-mono text-[var(--color-text-primary)] truncate"></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Session Settings (collapsible) -->
|
||||
<details class="mb-6 border border-[var(--color-border-subtle)] rounded-lg">
|
||||
<summary class="cursor-pointer p-3 flex items-center text-[var(--color-text-secondary)] font-medium hover:bg-[var(--color-bg-primary)]/50 rounded-lg">
|
||||
<i class="fas fa-sliders text-[var(--color-primary)] mr-2"></i>Session Settings
|
||||
</summary>
|
||||
<div class="p-4 pt-2 space-y-4">
|
||||
<!-- Instructions -->
|
||||
<div class="space-y-1">
|
||||
<label for="instructionsInput" class="text-sm text-[var(--color-text-secondary)]">Instructions</label>
|
||||
<textarea id="instructionsInput" rows="3"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-primary)] focus:ring-2 focus:ring-[var(--color-primary)]/50 rounded-lg shadow-sm p-2.5 text-sm"
|
||||
placeholder="System instructions for the model (e.g. 'be extremely succinct', 'talk quickly')">You are a helpful voice assistant. Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. Speak naturally as you would in a phone conversation. Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized.</textarea>
|
||||
</div>
|
||||
|
||||
<!-- Voice -->
|
||||
<div class="space-y-1">
|
||||
<label for="voiceInput" class="text-sm text-[var(--color-text-secondary)]">Voice</label>
|
||||
<input id="voiceInput" type="text"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-primary)] focus:ring-2 focus:ring-[var(--color-primary)]/50 rounded-lg shadow-sm p-2.5 text-sm"
|
||||
placeholder="Voice name (leave blank for model default)">
|
||||
</div>
|
||||
|
||||
<!-- Language -->
|
||||
<div class="space-y-1">
|
||||
<label for="languageInput" class="text-sm text-[var(--color-text-secondary)]">Transcription Language</label>
|
||||
<input id="languageInput" type="text"
|
||||
class="w-full bg-[var(--color-bg-primary)] text-[var(--color-text-primary)] border border-[var(--color-border-subtle)] focus:border-[var(--color-primary)] focus:ring-2 focus:ring-[var(--color-primary)]/50 rounded-lg shadow-sm p-2.5 text-sm"
|
||||
placeholder="Language code (e.g. 'en', 'es') — leave blank for auto-detect">
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Conversation Transcript -->
|
||||
<div id="transcript" class="mb-6 space-y-3 max-h-96 overflow-y-auto p-3 bg-[var(--color-bg-primary)]/50 border border-[var(--color-border-subtle)] rounded-lg" style="min-height: 6rem;">
|
||||
<p class="text-[var(--color-text-secondary)] italic">Conversation will appear here...</p>
|
||||
</div>
|
||||
|
||||
<!-- Buttons -->
|
||||
<div class="flex items-center justify-between mt-8">
|
||||
<button id="recordButton"
|
||||
class="inline-flex items-center bg-[var(--color-error)] hover:bg-[var(--color-error)]/90 text-white font-semibold py-2 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-microphone mr-2"></i>
|
||||
<span>Talk</span>
|
||||
<div class="flex items-center space-x-3">
|
||||
<button id="connectButton"
|
||||
class="inline-flex items-center bg-[var(--color-success)] hover:bg-[var(--color-success)]/90 text-white font-semibold py-2 px-6 rounded-lg transition-colors">
|
||||
<i class="fas fa-plug mr-2"></i>
|
||||
<span>Connect</span>
|
||||
</button>
|
||||
|
||||
<button id="testToneButton"
|
||||
class="inline-flex items-center bg-[var(--color-accent)] hover:bg-[var(--color-accent)]/90 text-white font-semibold py-2 px-6 rounded-lg transition-colors"
|
||||
style="display: none;">
|
||||
<i class="fas fa-wave-square mr-2"></i>
|
||||
<span>Test Tone</span>
|
||||
</button>
|
||||
|
||||
<button id="diagnosticsButton"
|
||||
class="inline-flex items-center bg-[var(--color-bg-primary)] hover:bg-[var(--color-bg-primary)]/80 text-[var(--color-text-secondary)] font-semibold py-2 px-4 rounded-lg transition-colors border border-[var(--color-border-subtle)]"
|
||||
style="display: none;">
|
||||
<i class="fas fa-chart-line mr-2"></i>
|
||||
<span>Diag</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<button id="disconnectButton"
|
||||
class="inline-flex items-center bg-[var(--color-error)] hover:bg-[var(--color-error)]/90 text-white font-semibold py-2 px-6 rounded-lg transition-colors"
|
||||
style="display: none;">
|
||||
<i class="fas fa-plug-circle-xmark mr-2"></i>
|
||||
<span>Disconnect</span>
|
||||
</button>
|
||||
|
||||
<a id="resetButton"
|
||||
class="flex items-center text-[var(--color-primary)] hover:text-[var(--color-accent)] transition-colors"
|
||||
href="#">
|
||||
<i class="fas fa-rotate-left mr-2"></i>
|
||||
<span>Reset conversation</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<!-- Audio Playback -->
|
||||
<audio id="audioPlayback" controls hidden></audio>
|
||||
|
||||
<!-- Audio element for WebRTC playback -->
|
||||
<audio id="audioPlayback" autoplay style="display:none;"></audio>
|
||||
|
||||
<!-- Audio Diagnostics (toggled by button) -->
|
||||
<div id="diagnosticsPanel" style="display: none;" class="mt-6 border border-[var(--color-border-subtle)] rounded-lg p-4">
|
||||
<h3 class="font-semibold text-[var(--color-text-primary)] mb-3">
|
||||
<i class="fas fa-chart-line text-[var(--color-primary)] mr-2"></i>Audio Diagnostics
|
||||
</h3>
|
||||
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-4 mb-4">
|
||||
<div>
|
||||
<p class="text-xs text-[var(--color-text-secondary)] mb-1">Waveform (time domain)</p>
|
||||
<canvas id="waveformCanvas" width="400" height="120" class="w-full border border-[var(--color-border-subtle)] rounded bg-black"></canvas>
|
||||
</div>
|
||||
<div>
|
||||
<p class="text-xs text-[var(--color-text-secondary)] mb-1">Spectrum (FFT)</p>
|
||||
<canvas id="spectrumCanvas" width="400" height="120" class="w-full border border-[var(--color-border-subtle)] rounded bg-black"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-3 mb-3">
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Peak Freq</p>
|
||||
<p id="statPeakFreq" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">THD</p>
|
||||
<p id="statTHD" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">RMS Level</p>
|
||||
<p id="statRMS" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Sample Rate</p>
|
||||
<p id="statSampleRate" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-3 mb-3">
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Packets Recv</p>
|
||||
<p id="statPacketsRecv" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Packets Lost</p>
|
||||
<p id="statPacketsLost" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Jitter</p>
|
||||
<p id="statJitter" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-primary)]/50 rounded p-2">
|
||||
<p class="text-xs text-[var(--color-text-secondary)]">Concealed</p>
|
||||
<p id="statConcealed" class="font-mono text-sm text-[var(--color-text-primary)]">--</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<pre id="statsRaw" class="text-xs text-[var(--color-text-secondary)] bg-[var(--color-bg-primary)]/50 rounded p-2 max-h-32 overflow-y-auto font-mono" style="white-space: pre-wrap;"></pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
{{template "views/partials/footer" .}}
|
||||
</div>
|
||||
</main>
|
||||
|
||||
@@ -254,12 +254,54 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Audio Player & Metrics (transcription traces) -->
|
||||
<template x-if="trace.data && trace.data.audio_wav_base64">
|
||||
<div class="mb-4">
|
||||
<h4 class="text-sm font-semibold text-[var(--color-text-primary)] mb-2">
|
||||
<i class="fas fa-headphones text-[var(--color-primary)] mr-1.5"></i>Audio Snippet
|
||||
</h4>
|
||||
<div class="bg-[var(--color-bg-primary)] border border-[var(--color-border-subtle)] rounded-lg p-3">
|
||||
<audio controls class="w-full mb-3" :src="'data:audio/wav;base64,' + trace.data.audio_wav_base64"></audio>
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">Duration</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_duration_s + 's'"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">Sample Rate</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_sample_rate + ' Hz'"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">RMS Level</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_rms_dbfs + ' dBFS'"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">Peak Level</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_peak_dbfs + ' dBFS'"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">Samples</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_samples"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">Snippet</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_snippet_s + 's'"></p>
|
||||
</div>
|
||||
<div class="bg-[var(--color-bg-secondary)]/50 rounded p-2">
|
||||
<p class="text-[var(--color-text-secondary)]">DC Offset</p>
|
||||
<p class="font-mono text-[var(--color-text-primary)]" x-text="trace.data.audio_dc_offset"></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Data fields as nested accordions -->
|
||||
<template x-if="trace.data && Object.keys(trace.data).length > 0">
|
||||
<div>
|
||||
<h4 class="text-sm font-semibold text-[var(--color-text-primary)] mb-2">Data Fields</h4>
|
||||
<div class="border border-[var(--color-border-subtle)] rounded-lg overflow-hidden">
|
||||
<template x-for="[key, value] in Object.entries(trace.data)" :key="key">
|
||||
<template x-for="[key, value] in filterDataFields(trace.data)" :key="key">
|
||||
<div class="border-b border-[var(--color-border-subtle)] last:border-b-0">
|
||||
<!-- Field header row -->
|
||||
<div @click="isLargeValue(value) && toggleBackendField(index, key)"
|
||||
@@ -552,6 +594,15 @@ function tracesApp() {
|
||||
if (typeof value === 'boolean') return value ? 'true' : 'false';
|
||||
if (typeof value === 'object') return JSON.stringify(value);
|
||||
return String(value);
|
||||
},
|
||||
|
||||
filterDataFields(data) {
|
||||
const audioKeys = new Set([
|
||||
'audio_wav_base64', 'audio_duration_s', 'audio_snippet_s',
|
||||
'audio_sample_rate', 'audio_samples', 'audio_rms_dbfs',
|
||||
'audio_peak_dbfs', 'audio_dc_offset'
|
||||
]);
|
||||
return Object.entries(data).filter(([key]) => !audioKeys.has(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
102
core/trace/audio_snippet.go
Normal file
102
core/trace/audio_snippet.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// MaxSnippetSeconds is the maximum number of seconds of audio captured per trace.
|
||||
const MaxSnippetSeconds = 30
|
||||
|
||||
// AudioSnippet captures the first MaxSnippetSeconds of a WAV file and computes
|
||||
// quality metrics. The result is a map suitable for merging into a BackendTrace
|
||||
// Data field.
|
||||
func AudioSnippet(wavPath string) map[string]any {
|
||||
raw, err := os.ReadFile(wavPath)
|
||||
if err != nil {
|
||||
xlog.Warn("audio snippet: read failed", "path", wavPath, "error", err)
|
||||
return nil
|
||||
}
|
||||
// Only process WAV files (RIFF header)
|
||||
if len(raw) <= audio.WAVHeaderSize || string(raw[:4]) != "RIFF" {
|
||||
xlog.Debug("audio snippet: not a WAV file or too small", "path", wavPath, "bytes", len(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
pcm, sampleRate := audio.ParseWAV(raw)
|
||||
if sampleRate == 0 {
|
||||
sampleRate = 16000
|
||||
}
|
||||
|
||||
return AudioSnippetFromPCM(pcm, sampleRate, len(pcm))
|
||||
}
|
||||
|
||||
// AudioSnippetFromPCM builds an audio snippet from raw PCM bytes (int16 LE mono).
|
||||
// totalPCMBytes is the full audio size before truncation (used to compute total duration).
|
||||
func AudioSnippetFromPCM(pcm []byte, sampleRate int, totalPCMBytes int) map[string]any {
|
||||
if len(pcm) == 0 || len(pcm)%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
samples := sound.BytesToInt16sLE(pcm)
|
||||
totalSamples := totalPCMBytes / 2
|
||||
durationS := float64(totalSamples) / float64(sampleRate)
|
||||
|
||||
// Truncate to first MaxSnippetSeconds
|
||||
maxSamples := MaxSnippetSeconds * sampleRate
|
||||
if len(samples) > maxSamples {
|
||||
samples = samples[:maxSamples]
|
||||
}
|
||||
|
||||
snippetDuration := float64(len(samples)) / float64(sampleRate)
|
||||
|
||||
rms := sound.CalculateRMS16(samples)
|
||||
rmsDBFS := -math.Inf(1)
|
||||
if rms > 0 {
|
||||
rmsDBFS = 20 * math.Log10(rms/32768.0)
|
||||
}
|
||||
|
||||
var peak int16
|
||||
var dcSum int64
|
||||
for _, s := range samples {
|
||||
if s < 0 && -s > peak {
|
||||
peak = -s
|
||||
} else if s > peak {
|
||||
peak = s
|
||||
}
|
||||
dcSum += int64(s)
|
||||
}
|
||||
peakDBFS := -math.Inf(1)
|
||||
if peak > 0 {
|
||||
peakDBFS = 20 * math.Log10(float64(peak) / 32768.0)
|
||||
}
|
||||
dcOffset := float64(dcSum) / float64(len(samples)) / 32768.0
|
||||
|
||||
// Encode the snippet as WAV
|
||||
snippetPCM := sound.Int16toBytesLE(samples)
|
||||
hdr := audio.NewWAVHeaderWithRate(uint32(len(snippetPCM)), uint32(sampleRate))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(audio.WAVHeaderSize + len(snippetPCM))
|
||||
if err := hdr.Write(&buf); err != nil {
|
||||
xlog.Warn("audio snippet: write header failed", "error", err)
|
||||
return nil
|
||||
}
|
||||
buf.Write(snippetPCM)
|
||||
|
||||
return map[string]any{
|
||||
"audio_wav_base64": base64.StdEncoding.EncodeToString(buf.Bytes()),
|
||||
"audio_duration_s": math.Round(durationS*100) / 100,
|
||||
"audio_snippet_s": math.Round(snippetDuration*100) / 100,
|
||||
"audio_sample_rate": sampleRate,
|
||||
"audio_samples": totalSamples,
|
||||
"audio_rms_dbfs": math.Round(rmsDBFS*10) / 10,
|
||||
"audio_peak_dbfs": math.Round(peakDBFS*10) / 10,
|
||||
"audio_dc_offset": math.Round(dcOffset*10000) / 10000,
|
||||
}
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func GetBackendTraces() []BackendTrace {
|
||||
}
|
||||
|
||||
sort.Slice(traces, func(i, j int) bool {
|
||||
return traces[i].Timestamp.Before(traces[j].Timestamp)
|
||||
return traces[i].Timestamp.After(traces[j].Timestamp)
|
||||
})
|
||||
|
||||
return traces
|
||||
|
||||
@@ -31,12 +31,49 @@ This configuration links the following components:
|
||||
|
||||
Make sure all referenced models (`silero-vad-ggml`, `whisper-large-turbo`, `qwen3-4b`, `tts-1`) are also installed or defined in your LocalAI instance.
|
||||
|
||||
## Usage
|
||||
## Transports
|
||||
|
||||
Once configured, you can connect to the Realtime API endpoint via WebSocket:
|
||||
The Realtime API supports two transports: **WebSocket** and **WebRTC**.
|
||||
|
||||
### WebSocket
|
||||
|
||||
Connect to the WebSocket endpoint:
|
||||
|
||||
```
|
||||
ws://localhost:8080/v1/realtime?model=gpt-realtime
|
||||
```
|
||||
|
||||
Audio is sent and received as raw PCM in the WebSocket messages, following the OpenAI Realtime API protocol.
|
||||
|
||||
### WebRTC
|
||||
|
||||
The WebRTC transport enables browser-based voice conversations with lower latency. Connect by POSTing an SDP offer to the REST endpoint:
|
||||
|
||||
```
|
||||
POST http://localhost:8080/v1/realtime?model=gpt-realtime
|
||||
Content-Type: application/sdp
|
||||
|
||||
<SDP offer body>
|
||||
```
|
||||
|
||||
The response contains the SDP answer to complete the WebRTC handshake.
|
||||
|
||||
#### Opus backend requirement
|
||||
|
||||
WebRTC uses the Opus audio codec for encoding and decoding audio on RTP tracks. The **opus** backend must be installed for WebRTC to work. Install it from the model gallery:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d '{"id": "opus"}'
|
||||
```
|
||||
|
||||
Or set the `EXTERNAL_GRPC_BACKENDS` environment variable if running a local build:
|
||||
|
||||
```bash
|
||||
EXTERNAL_GRPC_BACKENDS=opus:/path/to/backend/go/opus/opus
|
||||
```
|
||||
|
||||
The opus backend is loaded automatically when a WebRTC session starts. It does not require any model configuration file — just the backend binary.
|
||||
|
||||
## Protocol
|
||||
|
||||
The API follows the OpenAI Realtime API protocol for handling sessions, audio buffers, and conversation items.
|
||||
|
||||
30
go.mod
30
go.mod
@@ -132,6 +132,7 @@ require (
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
||||
github.com/philippgille/chromem-go v0.7.0 // indirect
|
||||
github.com/pion/transport/v4 v4.0.1 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.2 // indirect
|
||||
github.com/rs/zerolog v1.31.0 // indirect
|
||||
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
|
||||
@@ -209,25 +210,24 @@ require (
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
|
||||
github.com/otiai10/mint v1.6.3 // indirect
|
||||
github.com/pion/datachannel v1.5.10 // indirect
|
||||
github.com/pion/datachannel v1.6.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.6 // indirect
|
||||
github.com/pion/ice/v4 v4.0.10 // indirect
|
||||
github.com/pion/interceptor v0.1.40 // indirect
|
||||
github.com/pion/logging v0.2.3 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/dtls/v3 v3.1.2 // indirect
|
||||
github.com/pion/ice/v4 v4.2.1 // indirect
|
||||
github.com/pion/interceptor v0.1.44 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pion/mdns/v2 v2.1.0 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/rtcp v1.2.15 // indirect
|
||||
github.com/pion/rtp v1.8.19 // indirect
|
||||
github.com/pion/sctp v1.8.39 // indirect
|
||||
github.com/pion/sdp/v3 v3.0.13 // indirect
|
||||
github.com/pion/srtp/v3 v3.0.6 // indirect
|
||||
github.com/pion/rtcp v1.2.16 // indirect
|
||||
github.com/pion/rtp v1.10.1
|
||||
github.com/pion/sctp v1.9.2 // indirect
|
||||
github.com/pion/sdp/v3 v3.0.18 // indirect
|
||||
github.com/pion/srtp/v3 v3.0.10 // indirect
|
||||
github.com/pion/stun v0.6.1 // indirect
|
||||
github.com/pion/stun/v3 v3.0.0 // indirect
|
||||
github.com/pion/stun/v3 v3.1.1 // indirect
|
||||
github.com/pion/transport/v2 v2.2.10 // indirect
|
||||
github.com/pion/transport/v3 v3.0.7 // indirect
|
||||
github.com/pion/turn/v4 v4.0.2 // indirect
|
||||
github.com/pion/webrtc/v4 v4.1.2 // indirect
|
||||
github.com/pion/turn/v4 v4.1.4 // indirect
|
||||
github.com/pion/webrtc/v4 v4.2.9
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/rymdport/portal v0.4.2 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
|
||||
62
go.sum
62
go.sum
@@ -748,48 +748,50 @@ github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxc
|
||||
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
||||
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o=
|
||||
github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M=
|
||||
github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0=
|
||||
github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk=
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk=
|
||||
github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
|
||||
github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
|
||||
github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4=
|
||||
github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
|
||||
github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4=
|
||||
github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic=
|
||||
github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc=
|
||||
github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo=
|
||||
github.com/pion/ice/v4 v4.2.1 h1:XPRYXaLiFq3LFDG7a7bMrmr3mFr27G/gtXN3v/TVfxY=
|
||||
github.com/pion/ice/v4 v4.2.1/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c=
|
||||
github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I=
|
||||
github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
|
||||
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
|
||||
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
|
||||
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY=
|
||||
github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A=
|
||||
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
|
||||
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
|
||||
github.com/pion/rtp v1.8.19 h1:jhdO/3XhL/aKm/wARFVmvTfq0lC/CvN1xwYKmduly3c=
|
||||
github.com/pion/rtp v1.8.19/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
|
||||
github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE=
|
||||
github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
|
||||
github.com/pion/sdp/v3 v3.0.13 h1:uN3SS2b+QDZnWXgdr69SM8KB4EbcnPnPf2Laxhty/l4=
|
||||
github.com/pion/sdp/v3 v3.0.13/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
|
||||
github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4=
|
||||
github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY=
|
||||
github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo=
|
||||
github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo=
|
||||
github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
|
||||
github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
|
||||
github.com/pion/sctp v1.9.2 h1:HxsOzEV9pWoeggv7T5kewVkstFNcGvhMPx0GvUOUQXo=
|
||||
github.com/pion/sctp v1.9.2/go.mod h1:OTOlsQ5EDQ6mQ0z4MUGXt2CgQmKyafBEXhUVqLRB6G8=
|
||||
github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI=
|
||||
github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8=
|
||||
github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ=
|
||||
github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M=
|
||||
github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4=
|
||||
github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw=
|
||||
github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM=
|
||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||
github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q=
|
||||
github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/turn/v4 v4.0.2 h1:ZqgQ3+MjP32ug30xAbD6Mn+/K4Sxi3SdNOTFf+7mpps=
|
||||
github.com/pion/turn/v4 v4.0.2/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs=
|
||||
github.com/pion/webrtc/v4 v4.1.2 h1:mpuUo/EJ1zMNKGE79fAdYNFZBX790KE7kQQpLMjjR54=
|
||||
github.com/pion/webrtc/v4 v4.1.2/go.mod h1:xsCXiNAmMEjIdFxAYU0MbB3RwRieJsegSB2JZsGN+8U=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o=
|
||||
github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM=
|
||||
github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ=
|
||||
github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ=
|
||||
github.com/pion/webrtc/v4 v4.2.9 h1:DZIh1HAhPIL3RvwEDFsmL5hfPSLEpxsQk9/Jir2vkJE=
|
||||
github.com/pion/webrtc/v4 v4.2.9/go.mod h1:9EmLZve0H76eTzf8v2FmchZ6tcBXtDgpfTEu+drW6SY=
|
||||
github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
|
||||
github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -53,3 +53,46 @@ func NewWAVHeader(pcmLen uint32) WAVHeader {
|
||||
func (h *WAVHeader) Write(writer io.Writer) error {
|
||||
return binary.Write(writer, binary.LittleEndian, h)
|
||||
}
|
||||
|
||||
// NewWAVHeaderWithRate creates a WAV header for mono 16-bit PCM at the given sample rate.
|
||||
func NewWAVHeaderWithRate(pcmLen, sampleRate uint32) WAVHeader {
|
||||
header := WAVHeader{
|
||||
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
|
||||
Format: [4]byte{'W', 'A', 'V', 'E'},
|
||||
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
|
||||
Subchunk1Size: 16,
|
||||
AudioFormat: 1,
|
||||
NumChannels: 1,
|
||||
SampleRate: sampleRate,
|
||||
ByteRate: sampleRate * 2,
|
||||
BlockAlign: 2,
|
||||
BitsPerSample: 16,
|
||||
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
|
||||
Subchunk2Size: pcmLen,
|
||||
}
|
||||
header.ChunkSize = 36 + header.Subchunk2Size
|
||||
return header
|
||||
}
|
||||
|
||||
// WAVHeaderSize is the size of a standard PCM WAV header in bytes.
|
||||
const WAVHeaderSize = 44
|
||||
|
||||
// StripWAVHeader removes a WAV header from audio data, returning raw PCM.
|
||||
// If the data is too short to contain a header, it is returned unchanged.
|
||||
func StripWAVHeader(data []byte) []byte {
|
||||
if len(data) > WAVHeaderSize {
|
||||
return data[WAVHeaderSize:]
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// ParseWAV strips the WAV header and returns the raw PCM along with the
|
||||
// sample rate read from the header. If the data is too short to contain a
|
||||
// valid header the PCM is returned as-is with sampleRate=0.
|
||||
func ParseWAV(data []byte) (pcm []byte, sampleRate int) {
|
||||
if len(data) <= WAVHeaderSize {
|
||||
return data, 0
|
||||
}
|
||||
sr := int(binary.LittleEndian.Uint32(data[24:28]))
|
||||
return data[WAVHeaderSize:], sr
|
||||
}
|
||||
|
||||
13
pkg/audio/audio_suite_test.go
Normal file
13
pkg/audio/audio_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestAudio(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Audio Suite")
|
||||
}
|
||||
99
pkg/audio/audio_test.go
Normal file
99
pkg/audio/audio_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("WAV utilities", func() {
|
||||
Describe("NewWAVHeader", func() {
|
||||
It("produces a valid 44-byte header", func() {
|
||||
hdr := NewWAVHeader(3200)
|
||||
var buf bytes.Buffer
|
||||
Expect(hdr.Write(&buf)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(WAVHeaderSize))
|
||||
|
||||
b := buf.Bytes()
|
||||
Expect(string(b[0:4])).To(Equal("RIFF"))
|
||||
Expect(string(b[8:12])).To(Equal("WAVE"))
|
||||
Expect(string(b[12:16])).To(Equal("fmt "))
|
||||
|
||||
Expect(binary.LittleEndian.Uint16(b[20:22])).To(Equal(uint16(1))) // PCM
|
||||
Expect(binary.LittleEndian.Uint16(b[22:24])).To(Equal(uint16(1))) // mono
|
||||
Expect(binary.LittleEndian.Uint32(b[24:28])).To(Equal(uint32(16000)))
|
||||
Expect(binary.LittleEndian.Uint32(b[28:32])).To(Equal(uint32(32000)))
|
||||
Expect(string(b[36:40])).To(Equal("data"))
|
||||
Expect(binary.LittleEndian.Uint32(b[40:44])).To(Equal(uint32(3200)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NewWAVHeaderWithRate", func() {
|
||||
It("uses the custom sample rate", func() {
|
||||
hdr := NewWAVHeaderWithRate(4800, 24000)
|
||||
var buf bytes.Buffer
|
||||
Expect(hdr.Write(&buf)).To(Succeed())
|
||||
b := buf.Bytes()
|
||||
|
||||
Expect(binary.LittleEndian.Uint32(b[24:28])).To(Equal(uint32(24000)))
|
||||
Expect(binary.LittleEndian.Uint32(b[28:32])).To(Equal(uint32(48000)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("StripWAVHeader", func() {
|
||||
It("strips the 44-byte header", func() {
|
||||
pcm := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
hdr := NewWAVHeader(uint32(len(pcm)))
|
||||
var buf bytes.Buffer
|
||||
Expect(hdr.Write(&buf)).To(Succeed())
|
||||
buf.Write(pcm)
|
||||
|
||||
got := StripWAVHeader(buf.Bytes())
|
||||
Expect(got).To(Equal(pcm))
|
||||
})
|
||||
|
||||
It("returns short data unchanged", func() {
|
||||
short := []byte{0x01, 0x02, 0x03}
|
||||
Expect(StripWAVHeader(short)).To(Equal(short))
|
||||
|
||||
exact := make([]byte, WAVHeaderSize)
|
||||
Expect(StripWAVHeader(exact)).To(Equal(exact))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ParseWAV", func() {
|
||||
It("returns sample rate and PCM data", func() {
|
||||
pcm := make([]byte, 100)
|
||||
for i := range pcm {
|
||||
pcm[i] = byte(i)
|
||||
}
|
||||
|
||||
hdr24 := NewWAVHeaderWithRate(uint32(len(pcm)), 24000)
|
||||
var buf24 bytes.Buffer
|
||||
hdr24.Write(&buf24)
|
||||
buf24.Write(pcm)
|
||||
|
||||
gotPCM, gotRate := ParseWAV(buf24.Bytes())
|
||||
Expect(gotRate).To(Equal(24000))
|
||||
Expect(gotPCM).To(Equal(pcm))
|
||||
|
||||
hdr16 := NewWAVHeader(uint32(len(pcm)))
|
||||
var buf16 bytes.Buffer
|
||||
hdr16.Write(&buf16)
|
||||
buf16.Write(pcm)
|
||||
|
||||
gotPCM, gotRate = ParseWAV(buf16.Bytes())
|
||||
Expect(gotRate).To(Equal(16000))
|
||||
Expect(gotPCM).To(Equal(pcm))
|
||||
})
|
||||
|
||||
It("returns zero rate for short data", func() {
|
||||
short := []byte{0x01, 0x02, 0x03}
|
||||
gotPCM, gotRate := ParseWAV(short)
|
||||
Expect(gotRate).To(Equal(0))
|
||||
Expect(gotPCM).To(Equal(short))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -59,5 +59,8 @@ type Backend interface {
|
||||
|
||||
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
|
||||
|
||||
AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error)
|
||||
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
|
||||
|
||||
ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error)
|
||||
}
|
||||
|
||||
@@ -112,6 +112,14 @@ func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) {
|
||||
return pb.VADResponse{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioEncode(*pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func memoryUsage() *pb.MemoryUsageData {
|
||||
mud := pb.MemoryUsageData{
|
||||
Breakdown: make(map[string]uint64),
|
||||
|
||||
@@ -588,6 +588,50 @@ func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.
|
||||
return client.Detect(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.AudioEncode(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.AudioDecode(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -107,6 +107,14 @@ func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.
|
||||
return e.s.VAD(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) {
|
||||
return e.s.AudioEncode(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) {
|
||||
return e.s.AudioDecode(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
||||
return e.s.ModelMetadata(ctx, in)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,9 @@ type AIModel interface {
|
||||
|
||||
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
||||
|
||||
AudioEncode(*pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error)
|
||||
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
|
||||
|
||||
ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -284,6 +284,30 @@ func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, e
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *server) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.AudioEncode(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.AudioDecode(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
@@ -25,17 +25,29 @@ func CalculateRMS16(buffer []int16) float64 {
|
||||
}
|
||||
|
||||
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
if inputRate == outputRate {
|
||||
out := make([]int16, len(input))
|
||||
copy(out, input)
|
||||
return out
|
||||
}
|
||||
|
||||
// Calculate the resampling ratio
|
||||
ratio := float64(inputRate) / float64(outputRate)
|
||||
|
||||
// Calculate the length of the resampled output
|
||||
outputLength := int(float64(len(input)) / ratio)
|
||||
if outputLength <= 0 {
|
||||
return []int16{input[0]}
|
||||
}
|
||||
|
||||
// Allocate a slice for the resampled output
|
||||
output := make([]int16, outputLength)
|
||||
|
||||
// Perform linear interpolation for resampling
|
||||
for i := 0; i < outputLength-1; i++ {
|
||||
for i := 0; i < outputLength; i++ {
|
||||
// Calculate the corresponding position in the input
|
||||
pos := float64(i) * ratio
|
||||
|
||||
@@ -53,9 +65,6 @@ func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
|
||||
output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter]))
|
||||
}
|
||||
|
||||
// Handle the last sample explicitly to avoid index out of range
|
||||
output[outputLength-1] = input[len(input)-1]
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
|
||||
213
pkg/sound/int16_test.go
Normal file
213
pkg/sound/int16_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package sound
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Int16 utilities", func() {
|
||||
Describe("BytesToInt16sLE / Int16toBytesLE", func() {
|
||||
It("round-trips correctly", func() {
|
||||
values := []int16{0, 1, -1, 32767, -32768}
|
||||
b := Int16toBytesLE(values)
|
||||
got := BytesToInt16sLE(b)
|
||||
|
||||
Expect(got).To(Equal(values))
|
||||
})
|
||||
|
||||
It("panics on odd-length input", func() {
|
||||
Expect(func() {
|
||||
BytesToInt16sLE([]byte{0x01, 0x02, 0x03})
|
||||
}).To(Panic())
|
||||
})
|
||||
|
||||
It("returns empty slice for empty bytes input", func() {
|
||||
got := BytesToInt16sLE([]byte{})
|
||||
Expect(got).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns empty slice for empty int16 input", func() {
|
||||
got := Int16toBytesLE([]int16{})
|
||||
Expect(got).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ResampleInt16", func() {
|
||||
It("returns identical output for same rate", func() {
|
||||
src := generateSineWave(440, 16000, 320)
|
||||
dst := ResampleInt16(src, 16000, 16000)
|
||||
|
||||
Expect(dst).To(Equal(src))
|
||||
})
|
||||
|
||||
It("downsamples 48k to 16k", func() {
|
||||
src := generateSineWave(440, 48000, 960)
|
||||
dst := ResampleInt16(src, 48000, 16000)
|
||||
|
||||
Expect(dst).To(HaveLen(320))
|
||||
|
||||
freq := estimateFrequency(dst, 16000)
|
||||
Expect(freq).To(BeNumerically("~", 440, 50))
|
||||
})
|
||||
|
||||
It("upsamples 16k to 48k", func() {
|
||||
src := generateSineWave(440, 16000, 320)
|
||||
dst := ResampleInt16(src, 16000, 48000)
|
||||
|
||||
Expect(dst).To(HaveLen(960))
|
||||
|
||||
freq := estimateFrequency(dst, 48000)
|
||||
Expect(freq).To(BeNumerically("~", 440, 50))
|
||||
})
|
||||
|
||||
It("preserves quality through double resampling", func() {
|
||||
src := generateSineWave(440, 48000, 4800) // 100ms
|
||||
|
||||
direct := ResampleInt16(src, 48000, 16000)
|
||||
|
||||
step1 := ResampleInt16(src, 48000, 24000)
|
||||
double := ResampleInt16(step1, 24000, 16000)
|
||||
|
||||
minLen := len(direct)
|
||||
if len(double) < minLen {
|
||||
minLen = len(double)
|
||||
}
|
||||
|
||||
corr := computeCorrelation(direct[:minLen], double[:minLen])
|
||||
Expect(corr).To(BeNumerically(">=", 0.95))
|
||||
})
|
||||
|
||||
It("handles single sample", func() {
|
||||
src := []int16{1000}
|
||||
got := ResampleInt16(src, 48000, 16000)
|
||||
Expect(got).NotTo(BeEmpty())
|
||||
Expect(got[0]).To(Equal(int16(1000)))
|
||||
})
|
||||
|
||||
It("returns nil for empty input", func() {
|
||||
got := ResampleInt16(nil, 48000, 16000)
|
||||
Expect(got).To(BeNil())
|
||||
})
|
||||
|
||||
It("produces no discontinuity at batch boundaries (48k->16k)", func() {
|
||||
// Generate 900ms of 440Hz sine at 48kHz (simulating 3 decode batches)
|
||||
fullSine := generateSineWave(440, 48000, 48000*900/1000) // 43200 samples
|
||||
|
||||
// One-shot resample (ground truth)
|
||||
oneShot := ResampleInt16(fullSine, 48000, 16000)
|
||||
|
||||
// Batched resample: split into 3 batches of 300ms (14400 samples each)
|
||||
batchSize := 48000 * 300 / 1000 // 14400
|
||||
var batched []int16
|
||||
for offset := 0; offset < len(fullSine); offset += batchSize {
|
||||
end := offset + batchSize
|
||||
if end > len(fullSine) {
|
||||
end = len(fullSine)
|
||||
}
|
||||
chunk := ResampleInt16(fullSine[offset:end], 48000, 16000)
|
||||
batched = append(batched, chunk...)
|
||||
}
|
||||
|
||||
// Lengths should match
|
||||
Expect(len(batched)).To(Equal(len(oneShot)))
|
||||
|
||||
// Check discontinuity at each batch boundary
|
||||
batchOutSize := len(ResampleInt16(fullSine[:batchSize], 48000, 16000))
|
||||
for b := 1; b < 3; b++ {
|
||||
boundaryIdx := b * batchOutSize
|
||||
if boundaryIdx >= len(batched) || boundaryIdx < 1 {
|
||||
continue
|
||||
}
|
||||
// The sample-to-sample delta at the boundary
|
||||
jump := math.Abs(float64(batched[boundaryIdx]) - float64(batched[boundaryIdx-1]))
|
||||
// Compare with the average delta in the interior (excluding boundary)
|
||||
var avgDelta float64
|
||||
count := 0
|
||||
start := boundaryIdx - 10
|
||||
if start < 1 {
|
||||
start = 1
|
||||
}
|
||||
stop := boundaryIdx + 10
|
||||
if stop >= len(batched) {
|
||||
stop = len(batched) - 1
|
||||
}
|
||||
for i := start; i < stop; i++ {
|
||||
if i == boundaryIdx-1 || i == boundaryIdx {
|
||||
continue
|
||||
}
|
||||
avgDelta += math.Abs(float64(batched[i+1]) - float64(batched[i]))
|
||||
count++
|
||||
}
|
||||
avgDelta /= float64(count)
|
||||
|
||||
GinkgoWriter.Printf("Batch boundary %d (idx %d): jump=%.0f, avg_delta=%.0f, ratio=%.1f\n",
|
||||
b, boundaryIdx, jump, avgDelta, jump/avgDelta)
|
||||
|
||||
// The boundary jump should not be more than 3x the average delta
|
||||
Expect(jump).To(BeNumerically("<=", avgDelta*3),
|
||||
fmt.Sprintf("discontinuity at batch boundary %d: jump=%.0f vs avg=%.0f", b, jump, avgDelta))
|
||||
}
|
||||
|
||||
// Overall correlation should be very high
|
||||
minLen := len(oneShot)
|
||||
if len(batched) < minLen {
|
||||
minLen = len(batched)
|
||||
}
|
||||
corr := computeCorrelation(oneShot[:minLen], batched[:minLen])
|
||||
Expect(corr).To(BeNumerically(">=", 0.999),
|
||||
"batched resample differs significantly from one-shot")
|
||||
})
|
||||
|
||||
It("interpolates the last sample instead of using raw input value", func() {
|
||||
// Create a ramp signal where each value is unique
|
||||
input := make([]int16, 14400) // 300ms at 48kHz
|
||||
for i := range input {
|
||||
input[i] = int16(i % 32000)
|
||||
}
|
||||
|
||||
output := ResampleInt16(input, 48000, 16000) // ratio 3.0
|
||||
|
||||
// The last output sample should be at interpolated position (len(output)-1)*3.0
|
||||
lastIdx := len(output) - 1
|
||||
expectedPos := float64(lastIdx) * 3.0
|
||||
expectedInputIdx := int(expectedPos)
|
||||
// At integer position with frac=0, the interpolated value equals input[expectedInputIdx]
|
||||
expectedVal := input[expectedInputIdx]
|
||||
|
||||
GinkgoWriter.Printf("Last output[%d]: %d, expected (interpolated at input[%d]): %d, raw last input[%d]: %d\n",
|
||||
lastIdx, output[lastIdx], expectedInputIdx, expectedVal, len(input)-1, input[len(input)-1])
|
||||
|
||||
Expect(output[lastIdx]).To(Equal(expectedVal),
|
||||
"last sample should be interpolated, not raw input[last]")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("CalculateRMS16", func() {
|
||||
It("computes correct RMS for constant signal", func() {
|
||||
buf := make([]int16, 1000)
|
||||
for i := range buf {
|
||||
buf[i] = 1000
|
||||
}
|
||||
rms := CalculateRMS16(buf)
|
||||
Expect(rms).To(BeNumerically("~", 1000, 0.01))
|
||||
})
|
||||
|
||||
It("returns zero for silence", func() {
|
||||
buf := make([]int16, 1000)
|
||||
rms := CalculateRMS16(buf)
|
||||
Expect(rms).To(BeZero())
|
||||
})
|
||||
|
||||
It("computes correct RMS for known sine wave", func() {
|
||||
amplitude := float64(math.MaxInt16 / 2)
|
||||
buf := generateSineWave(440, 16000, 16000) // 1 second
|
||||
rms := CalculateRMS16(buf)
|
||||
expectedRMS := amplitude / math.Sqrt(2)
|
||||
|
||||
Expect(rms).To(BeNumerically("~", expectedRMS, expectedRMS*0.02))
|
||||
})
|
||||
})
|
||||
})
|
||||
13
pkg/sound/sound_suite_test.go
Normal file
13
pkg/sound/sound_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package sound
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSound(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Sound Suite")
|
||||
}
|
||||
72
pkg/sound/testutil_test.go
Normal file
72
pkg/sound/testutil_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package sound
|
||||
|
||||
import "math"
|
||||
|
||||
// generateSineWave produces a sine wave of the given frequency at the given sample rate.
|
||||
func generateSineWave(freq float64, sampleRate, numSamples int) []int16 {
|
||||
out := make([]int16, numSamples)
|
||||
for i := range out {
|
||||
t := float64(i) / float64(sampleRate)
|
||||
out[i] = int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// computeCorrelation returns the normalised Pearson correlation between two
|
||||
// equal-length int16 slices. Returns 0 when either signal has zero energy.
|
||||
func computeCorrelation(a, b []int16) float64 {
|
||||
n := len(a)
|
||||
if n == 0 || n != len(b) {
|
||||
return 0
|
||||
}
|
||||
var sumAB, sumA2, sumB2 float64
|
||||
for i := 0; i < n; i++ {
|
||||
fa, fb := float64(a[i]), float64(b[i])
|
||||
sumAB += fa * fb
|
||||
sumA2 += fa * fa
|
||||
sumB2 += fb * fb
|
||||
}
|
||||
denom := math.Sqrt(sumA2 * sumB2)
|
||||
if denom == 0 {
|
||||
return 0
|
||||
}
|
||||
return sumAB / denom
|
||||
}
|
||||
|
||||
// estimateFrequency estimates the dominant frequency of a mono int16 signal
|
||||
// using zero-crossing count.
|
||||
func estimateFrequency(samples []int16, sampleRate int) float64 {
|
||||
if len(samples) < 2 {
|
||||
return 0
|
||||
}
|
||||
crossings := 0
|
||||
for i := 1; i < len(samples); i++ {
|
||||
if (samples[i-1] >= 0 && samples[i] < 0) || (samples[i-1] < 0 && samples[i] >= 0) {
|
||||
crossings++
|
||||
}
|
||||
}
|
||||
duration := float64(len(samples)) / float64(sampleRate)
|
||||
// Each full cycle has 2 zero crossings.
|
||||
return float64(crossings) / (2 * duration)
|
||||
}
|
||||
|
||||
// computeRMS returns the root-mean-square of an int16 slice.
|
||||
func computeRMS(samples []int16) float64 {
|
||||
if len(samples) == 0 {
|
||||
return 0
|
||||
}
|
||||
var sum float64
|
||||
for _, s := range samples {
|
||||
v := float64(s)
|
||||
sum += v * v
|
||||
}
|
||||
return math.Sqrt(sum / float64(len(samples)))
|
||||
}
|
||||
|
||||
// generatePCMBytes creates a little-endian int16 PCM byte slice containing a
|
||||
// sine wave of the given frequency at the given sample rate and duration.
|
||||
func generatePCMBytes(freq float64, sampleRate, durationMs int) []byte {
|
||||
numSamples := sampleRate * durationMs / 1000
|
||||
samples := generateSineWave(freq, sampleRate, numSamples)
|
||||
return Int16toBytesLE(samples)
|
||||
}
|
||||
@@ -89,10 +89,10 @@ var _ = BeforeSuite(func() {
|
||||
Expect(os.Chmod(mockBackendPath, 0755)).To(Succeed())
|
||||
|
||||
// Create model config YAML
|
||||
modelConfig := map[string]interface{}{
|
||||
modelConfig := map[string]any{
|
||||
"name": "mock-model",
|
||||
"backend": "mock-backend",
|
||||
"parameters": map[string]interface{}{
|
||||
"parameters": map[string]any{
|
||||
"model": "mock-model.bin",
|
||||
},
|
||||
}
|
||||
@@ -109,11 +109,92 @@ var _ = BeforeSuite(func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(mcpConfigPath, mcpConfigYAML, 0644)).To(Succeed())
|
||||
|
||||
// Set up system state
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
// Create pipeline model configs for realtime API tests.
|
||||
// Each component model uses the same mock-backend binary.
|
||||
for _, name := range []string{"mock-vad", "mock-stt", "mock-llm", "mock-tts"} {
|
||||
cfg := map[string]any{
|
||||
"name": name,
|
||||
"backend": "mock-backend",
|
||||
"parameters": map[string]any{
|
||||
"model": name + ".bin",
|
||||
},
|
||||
}
|
||||
data, err := yaml.Marshal(cfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelsPath, name+".yaml"), data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
// Pipeline model that wires the component models together.
|
||||
pipelineCfg := map[string]any{
|
||||
"name": "realtime-pipeline",
|
||||
"pipeline": map[string]any{
|
||||
"vad": "mock-vad",
|
||||
"transcription": "mock-stt",
|
||||
"llm": "mock-llm",
|
||||
"tts": "mock-tts",
|
||||
},
|
||||
}
|
||||
pipelineData, err := yaml.Marshal(pipelineCfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelsPath, "realtime-pipeline.yaml"), pipelineData, 0644)).To(Succeed())
|
||||
|
||||
// If REALTIME_TEST_MODEL=realtime-test-pipeline, auto-create a pipeline
|
||||
// config from the REALTIME_VAD/STT/LLM/TTS env vars so real-model tests
|
||||
// can run without the user having to write a YAML file manually.
|
||||
if os.Getenv("REALTIME_TEST_MODEL") == "realtime-test-pipeline" {
|
||||
rtVAD := os.Getenv("REALTIME_VAD")
|
||||
rtSTT := os.Getenv("REALTIME_STT")
|
||||
rtLLM := os.Getenv("REALTIME_LLM")
|
||||
rtTTS := os.Getenv("REALTIME_TTS")
|
||||
|
||||
if rtVAD != "" && rtSTT != "" && rtLLM != "" && rtTTS != "" {
|
||||
testPipeline := map[string]any{
|
||||
"name": "realtime-test-pipeline",
|
||||
"pipeline": map[string]any{
|
||||
"vad": rtVAD,
|
||||
"transcription": rtSTT,
|
||||
"llm": rtLLM,
|
||||
"tts": rtTTS,
|
||||
},
|
||||
}
|
||||
data, writeErr := yaml.Marshal(testPipeline)
|
||||
Expect(writeErr).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelsPath, "realtime-test-pipeline.yaml"), data, 0644)).To(Succeed())
|
||||
xlog.Info("created realtime-test-pipeline",
|
||||
"vad", rtVAD, "stt", rtSTT, "llm", rtLLM, "tts", rtTTS)
|
||||
}
|
||||
}
|
||||
|
||||
// Import model configs from an external directory (e.g. real model YAMLs
|
||||
// and weights mounted into a container). Symlinks avoid copying large files.
|
||||
if rtModels := os.Getenv("REALTIME_MODELS_PATH"); rtModels != "" {
|
||||
entries, err := os.ReadDir(rtModels)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for _, entry := range entries {
|
||||
src := filepath.Join(rtModels, entry.Name())
|
||||
dst := filepath.Join(modelsPath, entry.Name())
|
||||
if _, err := os.Stat(dst); err == nil {
|
||||
continue // don't overwrite mock configs
|
||||
}
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
Expect(os.Symlink(src, dst)).To(Succeed())
|
||||
}
|
||||
}
|
||||
|
||||
// Set up system state. When REALTIME_BACKENDS_PATH is set, use it so the
|
||||
// application can discover real backend binaries for real-model tests.
|
||||
systemOpts := []system.SystemStateOptions{
|
||||
system.WithModelPath(modelsPath),
|
||||
)
|
||||
}
|
||||
if realBackends := os.Getenv("REALTIME_BACKENDS_PATH"); realBackends != "" {
|
||||
systemOpts = append(systemOpts, system.WithBackendPath(realBackends))
|
||||
} else {
|
||||
systemOpts = append(systemOpts, system.WithBackendPath(backendPath))
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(systemOpts...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create application
|
||||
@@ -130,8 +211,9 @@ var _ = BeforeSuite(func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Register backend with application's model loader
|
||||
// Register mock backend (always available for non-realtime tests).
|
||||
application.ModelLoader().SetExternalBackend("mock-backend", mockBackendPath)
|
||||
application.ModelLoader().SetExternalBackend("opus", mockBackendPath)
|
||||
|
||||
// Create HTTP app
|
||||
app, err = httpapi.API(application)
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -193,12 +195,28 @@ func (m *MockBackend) SoundGeneration(ctx context.Context, in *pb.SoundGeneratio
|
||||
}, nil
|
||||
}
|
||||
|
||||
// writeMinimalWAV writes a minimal valid WAV file (short silence) so the HTTP handler can send it.
|
||||
// ttsSampleRate returns the sample rate to use for TTS output, configurable
|
||||
// via the MOCK_TTS_SAMPLE_RATE environment variable (default 16000).
|
||||
func ttsSampleRate() int {
|
||||
if s := os.Getenv("MOCK_TTS_SAMPLE_RATE"); s != "" {
|
||||
if v, err := strconv.Atoi(s); err == nil && v > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 16000
|
||||
}
|
||||
|
||||
// writeMinimalWAV writes a WAV file containing a 440Hz sine wave (0.5s)
|
||||
// so that tests can verify audio integrity end-to-end. The sample rate
|
||||
// is configurable via MOCK_TTS_SAMPLE_RATE to test rate mismatch bugs.
|
||||
func writeMinimalWAV(path string) error {
|
||||
const sampleRate = 16000
|
||||
sampleRate := ttsSampleRate()
|
||||
const numChannels = 1
|
||||
const bitsPerSample = 16
|
||||
const numSamples = 1600 // 0.1s
|
||||
const freq = 440.0
|
||||
const durationSec = 0.5
|
||||
numSamples := int(float64(sampleRate) * durationSec)
|
||||
|
||||
dataSize := numSamples * numChannels * (bitsPerSample / 8)
|
||||
const headerLen = 44
|
||||
f, err := os.Create(path)
|
||||
@@ -219,23 +237,56 @@ func writeMinimalWAV(path string) error {
|
||||
_ = binary.Write(f, binary.LittleEndian, uint32(sampleRate*numChannels*(bitsPerSample/8)))
|
||||
_ = binary.Write(f, binary.LittleEndian, uint16(numChannels*(bitsPerSample/8)))
|
||||
_ = binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
|
||||
// data chunk
|
||||
// data chunk — 440Hz sine wave
|
||||
_, _ = f.Write([]byte("data"))
|
||||
_ = binary.Write(f, binary.LittleEndian, uint32(dataSize))
|
||||
_, _ = f.Write(make([]byte, dataSize))
|
||||
for i := range numSamples {
|
||||
t := float64(i) / float64(sampleRate)
|
||||
sample := int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
|
||||
_ = binary.Write(f, binary.LittleEndian, sample)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
|
||||
xlog.Debug("AudioTranscription called")
|
||||
dst := in.GetDst()
|
||||
wavSR := 0
|
||||
dataLen := 0
|
||||
rms := 0.0
|
||||
|
||||
if dst != "" {
|
||||
if data, err := os.ReadFile(dst); err == nil {
|
||||
if len(data) >= 44 {
|
||||
wavSR = int(binary.LittleEndian.Uint32(data[24:28]))
|
||||
dataLen = int(binary.LittleEndian.Uint32(data[40:44]))
|
||||
|
||||
// Compute RMS of the PCM payload (16-bit LE samples)
|
||||
pcm := data[44:]
|
||||
var sumSq float64
|
||||
nSamples := len(pcm) / 2
|
||||
for i := range nSamples {
|
||||
s := int16(pcm[2*i]) | int16(pcm[2*i+1])<<8
|
||||
v := float64(s)
|
||||
sumSq += v * v
|
||||
}
|
||||
if nSamples > 0 {
|
||||
rms = math.Sqrt(sumSq / float64(nSamples))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Debug("AudioTranscription called", "dst", dst, "wav_sample_rate", wavSR, "data_len", dataLen, "rms", rms)
|
||||
|
||||
text := fmt.Sprintf("transcribed: rms=%.1f samples=%d sr=%d", rms, dataLen/2, wavSR)
|
||||
return &pb.TranscriptResult{
|
||||
Text: "This is a mocked transcription.",
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{
|
||||
{
|
||||
Id: 0,
|
||||
Start: 0,
|
||||
End: 3000,
|
||||
Text: "This is a mocked transcription.",
|
||||
Text: text,
|
||||
Tokens: []int32{1, 2, 3, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
@@ -365,21 +416,65 @@ func (m *MockBackend) GetMetrics(ctx context.Context, in *pb.MetricsRequest) (*p
|
||||
}
|
||||
|
||||
func (m *MockBackend) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
|
||||
xlog.Debug("VAD called", "audio_length", len(in.Audio))
|
||||
// Compute RMS of the received float32 audio to decide whether speech is present.
|
||||
var sumSq float64
|
||||
for _, s := range in.Audio {
|
||||
v := float64(s)
|
||||
sumSq += v * v
|
||||
}
|
||||
rms := 0.0
|
||||
if len(in.Audio) > 0 {
|
||||
rms = math.Sqrt(sumSq / float64(len(in.Audio)))
|
||||
}
|
||||
xlog.Debug("VAD called", "audio_length", len(in.Audio), "rms", rms)
|
||||
|
||||
// If audio is near-silence, return no segments (no speech detected).
|
||||
if rms < 0.001 {
|
||||
return &pb.VADResponse{}, nil
|
||||
}
|
||||
|
||||
// Audio has signal — return a single segment covering the duration.
|
||||
duration := float64(len(in.Audio)) / 16000.0
|
||||
return &pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{
|
||||
{
|
||||
Start: 0.0,
|
||||
End: 1.5,
|
||||
},
|
||||
{
|
||||
Start: 2.0,
|
||||
End: 3.5,
|
||||
End: float32(duration),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
xlog.Debug("AudioEncode called", "pcm_len", len(in.PcmData), "sample_rate", in.SampleRate)
|
||||
// Return a single mock Opus frame per 960-sample chunk (20ms at 48kHz).
|
||||
numSamples := len(in.PcmData) / 2 // 16-bit samples
|
||||
frameSize := 960
|
||||
var frames [][]byte
|
||||
for offset := 0; offset+frameSize <= numSamples; offset += frameSize {
|
||||
// Minimal mock frame — just enough bytes to be non-empty.
|
||||
frames = append(frames, []byte{0xFC, 0xFF, 0xFE})
|
||||
}
|
||||
return &pb.AudioEncodeResult{
|
||||
Frames: frames,
|
||||
SampleRate: 48000,
|
||||
SamplesPerFrame: int32(frameSize),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockBackend) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) {
|
||||
xlog.Debug("AudioDecode called", "frames", len(in.Frames))
|
||||
// Return silent PCM (960 samples per frame at 48kHz, 16-bit LE).
|
||||
samplesPerFrame := 960
|
||||
totalSamples := len(in.Frames) * samplesPerFrame
|
||||
pcm := make([]byte, totalSamples*2)
|
||||
return &pb.AudioDecodeResult{
|
||||
PcmData: pcm,
|
||||
SampleRate: 48000,
|
||||
SamplesPerFrame: int32(samplesPerFrame),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockBackend) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
||||
xlog.Debug("ModelMetadata called", "model", in.Model)
|
||||
return &pb.ModelMetadataResponse{
|
||||
|
||||
459
tests/e2e/realtime_webrtc_test.go
Normal file
459
tests/e2e/realtime_webrtc_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package e2e_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"github.com/pion/webrtc/v4/pkg/media"
|
||||
)
|
||||
|
||||
// --- WebRTC test client ---
|
||||
|
||||
type webrtcTestClient struct {
|
||||
pc *webrtc.PeerConnection
|
||||
dc *webrtc.DataChannel
|
||||
sendTrack *webrtc.TrackLocalStaticSample
|
||||
|
||||
events chan map[string]any
|
||||
audioData chan []byte // raw Opus frames received
|
||||
|
||||
dcOpen chan struct{} // closed when data channel opens
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newWebRTCTestClient() *webrtcTestClient {
|
||||
m := &webrtc.MediaEngine{}
|
||||
Expect(m.RegisterDefaultCodecs()).To(Succeed())
|
||||
|
||||
api := webrtc.NewAPI(webrtc.WithMediaEngine(m))
|
||||
|
||||
pc, err := api.NewPeerConnection(webrtc.Configuration{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create outbound audio track (Opus)
|
||||
sendTrack, err := webrtc.NewTrackLocalStaticSample(
|
||||
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus},
|
||||
"audio-client",
|
||||
"test-client",
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rtpSender, err := pc.AddTrack(sendTrack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Drain RTCP
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
if _, _, err := rtpSender.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create the "oai-events" data channel (must be created by client)
|
||||
dc, err := pc.CreateDataChannel("oai-events", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
c := &webrtcTestClient{
|
||||
pc: pc,
|
||||
dc: dc,
|
||||
sendTrack: sendTrack,
|
||||
events: make(chan map[string]any, 256),
|
||||
audioData: make(chan []byte, 4096),
|
||||
dcOpen: make(chan struct{}),
|
||||
}
|
||||
|
||||
dc.OnOpen(func() {
|
||||
close(c.dcOpen)
|
||||
})
|
||||
|
||||
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Data, &evt); err == nil {
|
||||
c.events <- evt
|
||||
}
|
||||
})
|
||||
|
||||
// Collect incoming audio tracks
|
||||
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
for {
|
||||
pkt, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.audioData <- pkt.Payload
|
||||
}
|
||||
})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// connect performs SDP exchange with the server and waits for the data channel to open.
|
||||
func (c *webrtcTestClient) connect(model string) {
|
||||
offer, err := c.pc.CreateOffer(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.pc.SetLocalDescription(offer)).To(Succeed())
|
||||
|
||||
// Wait for ICE gathering
|
||||
gatherDone := webrtc.GatheringCompletePromise(c.pc)
|
||||
select {
|
||||
case <-gatherDone:
|
||||
case <-time.After(10 * time.Second):
|
||||
Fail("ICE gathering timed out")
|
||||
}
|
||||
|
||||
localDesc := c.pc.LocalDescription()
|
||||
Expect(localDesc).ToNot(BeNil())
|
||||
|
||||
// POST to /v1/realtime/calls
|
||||
reqBody, err := json.Marshal(map[string]string{
|
||||
"sdp": localDesc.SDP,
|
||||
"model": model,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("http://127.0.0.1:%d/v1/realtime/calls", apiPort),
|
||||
"application/json",
|
||||
bytes.NewReader(reqBody),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusCreated),
|
||||
"expected 201, got %d: %s", resp.StatusCode, string(body))
|
||||
|
||||
var callResp struct {
|
||||
SDP string `json:"sdp"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
Expect(json.Unmarshal(body, &callResp)).To(Succeed())
|
||||
Expect(callResp.SDP).ToNot(BeEmpty())
|
||||
|
||||
// Set the answer
|
||||
Expect(c.pc.SetRemoteDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: callResp.SDP,
|
||||
})).To(Succeed())
|
||||
|
||||
// Wait for data channel to open
|
||||
Eventually(c.dcOpen, 15*time.Second).Should(BeClosed())
|
||||
}
|
||||
|
||||
// sendEvent sends a JSON event via the data channel.
|
||||
func (c *webrtcTestClient) sendEvent(event any) {
|
||||
data, err := json.Marshal(event)
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, c.dc.Send(data)).To(Succeed())
|
||||
}
|
||||
|
||||
// readEvent reads the next event from the data channel with timeout.
|
||||
func (c *webrtcTestClient) readEvent(timeout time.Duration) map[string]any {
|
||||
select {
|
||||
case evt := <-c.events:
|
||||
return evt
|
||||
case <-time.After(timeout):
|
||||
Fail("timed out reading event from data channel")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// drainUntilEvent reads events until one with the given type appears.
|
||||
func (c *webrtcTestClient) drainUntilEvent(eventType string, timeout time.Duration) map[string]any {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
evt := c.readEvent(remaining)
|
||||
if evt["type"] == eventType {
|
||||
return evt
|
||||
}
|
||||
}
|
||||
Fail("timed out waiting for event: " + eventType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendSineWave encodes a sine wave to Opus and sends it over the audio track.
|
||||
// This is a simplified version that sends raw PCM wrapped as Opus-compatible
|
||||
// media samples. In a real client the Opus encoder would be used.
|
||||
func (c *webrtcTestClient) sendSilence(durationMs int) {
|
||||
// Send silence as zero-filled PCM samples via track.
|
||||
// We use 20ms Opus frames at 48kHz.
|
||||
framesNeeded := durationMs / 20
|
||||
// Minimal valid Opus silence frame (Opus DTX/silence)
|
||||
silenceFrame := make([]byte, 3)
|
||||
silenceFrame[0] = 0xF8 // Config: CELT-only, no VAD, 20ms frame
|
||||
silenceFrame[1] = 0xFF
|
||||
silenceFrame[2] = 0xFE
|
||||
|
||||
for range framesNeeded {
|
||||
_ = c.sendTrack.WriteSample(media.Sample{
|
||||
Data: silenceFrame,
|
||||
Duration: 20 * time.Millisecond,
|
||||
})
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *webrtcTestClient) close() {
|
||||
if c.pc != nil {
|
||||
c.pc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
var _ = Describe("Realtime WebRTC API", Label("Realtime"), func() {
|
||||
Context("Signaling", func() {
|
||||
It("should complete SDP exchange and receive session.created", func() {
|
||||
client := newWebRTCTestClient()
|
||||
defer client.close()
|
||||
|
||||
client.connect(pipelineModel())
|
||||
|
||||
evt := client.readEvent(30 * time.Second)
|
||||
Expect(evt["type"]).To(Equal("session.created"))
|
||||
|
||||
session, ok := evt["session"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(session["id"]).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Event exchange via DataChannel", func() {
|
||||
It("should handle session.update", func() {
|
||||
client := newWebRTCTestClient()
|
||||
defer client.close()
|
||||
|
||||
client.connect(pipelineModel())
|
||||
|
||||
// Read session.created
|
||||
created := client.readEvent(30 * time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
client.sendEvent(disableVADEvent())
|
||||
|
||||
updated := client.drainUntilEvent("session.updated", 10*time.Second)
|
||||
Expect(updated).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("should handle conversation.item.create and response.create", func() {
|
||||
client := newWebRTCTestClient()
|
||||
defer client.close()
|
||||
|
||||
client.connect(pipelineModel())
|
||||
|
||||
created := client.readEvent(30 * time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
client.sendEvent(disableVADEvent())
|
||||
client.drainUntilEvent("session.updated", 10*time.Second)
|
||||
|
||||
// Create text item
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Hello from WebRTC",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
added := client.drainUntilEvent("conversation.item.added", 10*time.Second)
|
||||
Expect(added).ToNot(BeNil())
|
||||
|
||||
// Trigger response
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "response.create",
|
||||
})
|
||||
|
||||
done := client.drainUntilEvent("response.done", 60*time.Second)
|
||||
Expect(done).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Audio track", func() {
|
||||
It("should receive audio on the incoming track after TTS", Label("real-models"), func() {
|
||||
if os.Getenv("REALTIME_TEST_MODEL") == "" {
|
||||
Skip("REALTIME_TEST_MODEL not set")
|
||||
}
|
||||
|
||||
client := newWebRTCTestClient()
|
||||
defer client.close()
|
||||
|
||||
client.connect(pipelineModel())
|
||||
|
||||
created := client.readEvent(30 * time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
client.sendEvent(disableVADEvent())
|
||||
client.drainUntilEvent("session.updated", 10*time.Second)
|
||||
|
||||
// Send text and trigger response with TTS
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
client.drainUntilEvent("conversation.item.added", 10*time.Second)
|
||||
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "response.create",
|
||||
})
|
||||
|
||||
// Collect audio frames while waiting for response.done
|
||||
var audioFrames [][]byte
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
loop:
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case frame := <-client.audioData:
|
||||
audioFrames = append(audioFrames, frame)
|
||||
case evt := <-client.events:
|
||||
if evt["type"] == "response.done" {
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Until(deadline)):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
// We should have received some audio frames
|
||||
Expect(len(audioFrames)).To(BeNumerically(">", 0),
|
||||
"expected to receive audio frames on the WebRTC track")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Disconnect cleanup", func() {
|
||||
It("should handle repeated connect/disconnect cycles", func() {
|
||||
for i := range 3 {
|
||||
By(fmt.Sprintf("Cycle %d", i+1))
|
||||
client := newWebRTCTestClient()
|
||||
client.connect(pipelineModel())
|
||||
|
||||
evt := client.readEvent(30 * time.Second)
|
||||
Expect(evt["type"]).To(Equal("session.created"))
|
||||
|
||||
client.close()
|
||||
// Brief pause to let server clean up
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("Audio integrity", Label("real-models"), func() {
|
||||
It("should receive recognizable audio from TTS through WebRTC", func() {
|
||||
if os.Getenv("REALTIME_TEST_MODEL") == "" {
|
||||
Skip("REALTIME_TEST_MODEL not set")
|
||||
}
|
||||
|
||||
client := newWebRTCTestClient()
|
||||
defer client.close()
|
||||
|
||||
client.connect(pipelineModel())
|
||||
|
||||
created := client.readEvent(30 * time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
client.sendEvent(disableVADEvent())
|
||||
client.drainUntilEvent("session.updated", 10*time.Second)
|
||||
|
||||
// Create text item and trigger response
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
client.drainUntilEvent("conversation.item.added", 10*time.Second)
|
||||
|
||||
client.sendEvent(map[string]any{
|
||||
"type": "response.create",
|
||||
})
|
||||
|
||||
// Collect Opus frames and decode them
|
||||
var totalBytes int
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
loop:
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case frame := <-client.audioData:
|
||||
totalBytes += len(frame)
|
||||
case evt := <-client.events:
|
||||
if evt["type"] == "response.done" {
|
||||
// Drain any remaining audio
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
drainAudio:
|
||||
for {
|
||||
select {
|
||||
case frame := <-client.audioData:
|
||||
totalBytes += len(frame)
|
||||
default:
|
||||
break drainAudio
|
||||
}
|
||||
}
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Until(deadline)):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received meaningful audio data
|
||||
Expect(totalBytes).To(BeNumerically(">", 100),
|
||||
"expected to receive meaningful audio data")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// computeRMSInt16 computes RMS of int16 samples (used by audio integrity tests).
|
||||
func computeRMSInt16(samples []int16) float64 {
|
||||
if len(samples) == 0 {
|
||||
return 0
|
||||
}
|
||||
var sum float64
|
||||
for _, s := range samples {
|
||||
v := float64(s)
|
||||
sum += v * v
|
||||
}
|
||||
return math.Sqrt(sum / float64(len(samples)))
|
||||
}
|
||||
269
tests/e2e/realtime_ws_test.go
Normal file
269
tests/e2e/realtime_ws_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package e2e_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// --- WebSocket test helpers ---
|
||||
|
||||
func connectWS(model string) *websocket.Conn {
|
||||
u := url.URL{
|
||||
Scheme: "ws",
|
||||
Host: fmt.Sprintf("127.0.0.1:%d", apiPort),
|
||||
Path: "/v1/realtime",
|
||||
RawQuery: "model=" + url.QueryEscape(model),
|
||||
}
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred(), "websocket dial failed")
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func readServerEvent(conn *websocket.Conn, timeout time.Duration) map[string]any {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred(), "read server event")
|
||||
var evt map[string]any
|
||||
ExpectWithOffset(1, json.Unmarshal(msg, &evt)).To(Succeed())
|
||||
return evt
|
||||
}
|
||||
|
||||
func sendClientEvent(conn *websocket.Conn, event any) {
|
||||
data, err := json.Marshal(event)
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, conn.WriteMessage(websocket.TextMessage, data)).To(Succeed())
|
||||
}
|
||||
|
||||
// drainUntil reads events until it finds one with the given type, or times out.
|
||||
func drainUntil(conn *websocket.Conn, eventType string, timeout time.Duration) map[string]any {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
evt := readServerEvent(conn, time.Until(deadline))
|
||||
if evt["type"] == eventType {
|
||||
return evt
|
||||
}
|
||||
}
|
||||
Fail("timed out waiting for event: " + eventType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generatePCMBase64 creates base64-encoded 16-bit LE PCM of a sine wave.
|
||||
func generatePCMBase64(freq float64, sampleRate, durationMs int) string {
|
||||
numSamples := sampleRate * durationMs / 1000
|
||||
pcm := make([]byte, numSamples*2)
|
||||
for i := range numSamples {
|
||||
t := float64(i) / float64(sampleRate)
|
||||
sample := int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
|
||||
pcm[2*i] = byte(sample)
|
||||
pcm[2*i+1] = byte(sample >> 8)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(pcm)
|
||||
}
|
||||
|
||||
// pipelineModel returns the model name to use for realtime tests.
|
||||
func pipelineModel() string {
|
||||
if m := os.Getenv("REALTIME_TEST_MODEL"); m != "" {
|
||||
return m
|
||||
}
|
||||
return "realtime-pipeline"
|
||||
}
|
||||
|
||||
// disableVADEvent returns a session.update event that disables server VAD.
|
||||
func disableVADEvent() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "session.update",
|
||||
"session": map[string]any{
|
||||
"audio": map[string]any{
|
||||
"input": map[string]any{
|
||||
"turn_detection": nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
var _ = Describe("Realtime WebSocket API", Label("Realtime"), func() {
|
||||
Context("Session management", func() {
|
||||
It("should return session.created on connect", func() {
|
||||
conn := connectWS(pipelineModel())
|
||||
defer conn.Close()
|
||||
|
||||
evt := readServerEvent(conn, 30*time.Second)
|
||||
Expect(evt["type"]).To(Equal("session.created"))
|
||||
|
||||
session, ok := evt["session"].(map[string]any)
|
||||
Expect(ok).To(BeTrue(), "session field should be an object")
|
||||
Expect(session["id"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("should return session.updated after session.update", func() {
|
||||
conn := connectWS(pipelineModel())
|
||||
defer conn.Close()
|
||||
|
||||
// Read session.created
|
||||
created := readServerEvent(conn, 30*time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Send session.update to disable VAD
|
||||
sendClientEvent(conn, disableVADEvent())
|
||||
|
||||
evt := drainUntil(conn, "session.updated", 10*time.Second)
|
||||
Expect(evt["type"]).To(Equal("session.updated"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Manual audio commit", func() {
|
||||
It("should produce a response with audio when audio is committed", func() {
|
||||
conn := connectWS(pipelineModel())
|
||||
defer conn.Close()
|
||||
|
||||
// Read session.created
|
||||
created := readServerEvent(conn, 30*time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable server VAD so we can manually commit
|
||||
sendClientEvent(conn, disableVADEvent())
|
||||
drainUntil(conn, "session.updated", 10*time.Second)
|
||||
|
||||
// Append 1 second of 440Hz sine wave at 24kHz (the default remote sample rate)
|
||||
audio := generatePCMBase64(440, 24000, 1000)
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": audio,
|
||||
})
|
||||
|
||||
// Commit the audio buffer
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "input_audio_buffer.commit",
|
||||
})
|
||||
|
||||
// We should receive the response event sequence.
|
||||
// The exact events depend on the pipeline, but we expect at least:
|
||||
// - input_audio_buffer.committed
|
||||
// - conversation.item.input_audio_transcription.completed
|
||||
// - response.output_audio.delta (with base64 audio)
|
||||
// - response.done
|
||||
|
||||
committed := drainUntil(conn, "input_audio_buffer.committed", 30*time.Second)
|
||||
Expect(committed).ToNot(BeNil())
|
||||
|
||||
// Wait for the full response cycle to complete
|
||||
done := drainUntil(conn, "response.done", 60*time.Second)
|
||||
Expect(done).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Text conversation item", func() {
|
||||
It("should create a text item and trigger a response", func() {
|
||||
conn := connectWS(pipelineModel())
|
||||
defer conn.Close()
|
||||
|
||||
// Read session.created
|
||||
created := readServerEvent(conn, 30*time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
sendClientEvent(conn, disableVADEvent())
|
||||
drainUntil(conn, "session.updated", 10*time.Second)
|
||||
|
||||
// Create a text conversation item
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Hello, how are you?",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Wait for item to be added
|
||||
added := drainUntil(conn, "conversation.item.added", 10*time.Second)
|
||||
Expect(added).ToNot(BeNil())
|
||||
|
||||
// Trigger a response
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "response.create",
|
||||
})
|
||||
|
||||
// Wait for response to complete
|
||||
done := drainUntil(conn, "response.done", 60*time.Second)
|
||||
Expect(done).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Audio integrity", func() {
|
||||
It("should return non-empty audio data in response.output_audio.delta", Label("real-models"), func() {
|
||||
if os.Getenv("REALTIME_TEST_MODEL") == "" {
|
||||
Skip("REALTIME_TEST_MODEL not set")
|
||||
}
|
||||
|
||||
conn := connectWS(pipelineModel())
|
||||
defer conn.Close()
|
||||
|
||||
created := readServerEvent(conn, 30*time.Second)
|
||||
Expect(created["type"]).To(Equal("session.created"))
|
||||
|
||||
// Disable VAD
|
||||
sendClientEvent(conn, disableVADEvent())
|
||||
drainUntil(conn, "session.updated", 10*time.Second)
|
||||
|
||||
// Create a text item and trigger response
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]any{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
drainUntil(conn, "conversation.item.added", 10*time.Second)
|
||||
|
||||
sendClientEvent(conn, map[string]any{
|
||||
"type": "response.create",
|
||||
})
|
||||
|
||||
// Collect audio deltas
|
||||
var totalAudioBytes int
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
evt := readServerEvent(conn, time.Until(deadline))
|
||||
if evt["type"] == "response.output_audio.delta" {
|
||||
if delta, ok := evt["delta"].(string); ok {
|
||||
decoded, err := base64.StdEncoding.DecodeString(delta)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
totalAudioBytes += len(decoded)
|
||||
}
|
||||
}
|
||||
if evt["type"] == "response.done" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(totalAudioBytes).To(BeNumerically(">", 0), "expected non-empty audio in response")
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user