Files
LocalAI/backend/python/nemo/backend.py
番茄摔成番茄酱 01fa12e0de feat(nemo): enable word-level timestamps for ASR models (#10297)
* feat(nemo): enable word-level timestamps for ASR models

The nemo backend ignored timestamp_granularities and always returned a
single segment with start=0 end=0, making word-level timestamps
impossible to obtain even though the NeMo models (parakeet-tdt, etc.)
fully support them.

Changes:
- Add _get_stride_seconds() to compute frame duration from the model's
  preprocessor window_stride and encoder subsampling_factor.
- Add _build_segments_with_words() that extracts word offsets from the
  NeMo Hypothesis.timestamp dict and converts frame indices to
  nanosecond timestamps.
- Support 'word' granularity (one segment per word) and 'segment'
  granularity (merge at time-gap boundaries using a dynamic threshold).
- Populate TranscriptSegment.words with TranscriptWord entries so
  callers get both segment-level and word-level timing.
- Only request timestamps from NeMo when the caller actually asks for
  them (timestamp_granularities is non-empty), keeping the fast path
  unchanged for callers that don't need timestamps.

Tested with nvidia/parakeet-tdt-0.6b-v3 on the JFK "ask not" clip:
  curl -X POST /v1/audio/transcriptions \
    -F file=@jfk.wav -F model=nemo-parakeet-tdt-0.6b \
    -F 'timestamp_granularities[]=word' -F response_format=verbose_json
  → each word has correct start/end times in seconds.

Signed-off-by: fqscfqj <fqscfqj@outlook.com>

* fix(nemo): address Copilot review feedback

- Narrow exception handling in _get_stride_seconds to catch only
  AttributeError, KeyError, TypeError instead of bare Exception, and
  emit a warning when falling back to the hardcoded stride.
- Remove explicit return_hypotheses=False when timestamps are requested;
  timestamps=True already forces NeMo to return Hypothesis objects.
- Add a warning when NeMo does not return Hypothesis objects despite
  timestamps being requested.

Signed-off-by: fqscfqj <fqscfqj@outlook.com>

---------

Signed-off-by: fqscfqj <fqscfqj@outlook.com>
2026-06-21 17:04:19 +02:00

331 lines
12 KiB
Python

#!/usr/bin/env python3
"""
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc
import torch
import nemo.collections.asr as nemo_asr
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if mps_available:
device = "mps"
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
self.device = device
self.options = {}
for opt in request.Options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
model_name = request.Model or "nvidia/parakeet-tdt-0.6b-v3"
try:
print(f"Loading NEMO ASR model from {model_name}", file=sys.stderr)
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
print("NEMO ASR model loaded successfully", file=sys.stderr)
except Exception as err:
print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.Result(success=False, message=str(err))
return backend_pb2.Result(message="Model loaded successfully", success=True)
def _get_stride_seconds(self):
"""Compute the seconds-per-frame stride for the loaded model.
stride = preprocessor_window_stride * encoder_subsampling_factor
"""
try:
preprocessor = self.model.preprocessor
window_stride = preprocessor._cfg.get('window_stride', 0.01)
subsampling_factor = getattr(self.model.encoder, 'subsampling_factor', 8)
return window_stride * subsampling_factor
except (AttributeError, KeyError, TypeError) as err:
print(
f"Warning: could not compute stride from model config ({err}), "
f"falling back to 0.08s/frame",
file=sys.stderr,
)
return 0.08
def _build_segments_with_words(self, hypothesis, stride, timestamp_granularities=None):
"""Build TranscriptSegment list from a NeMo Hypothesis with timestamps.
Supports two granularity modes:
- "word": one TranscriptSegment per word, each with a single TranscriptWord entry
- "segment" (default): merge consecutive words into sentence-level segments,
splitting at word-level time gaps that exceed a dynamic threshold.
"""
if not hypothesis or not isinstance(hypothesis.timestamp, dict):
return []
word_offsets = hypothesis.timestamp.get('word', [])
if not word_offsets:
return []
granularities = list(timestamp_granularities) if timestamp_granularities else []
granularity = "word" if "word" in granularities else "segment"
# Build a flat list of (text, start_ns, end_ns) from NeMo word offsets
transcript_words = []
for wo in word_offsets:
word_text = wo.get('word', '')
if not word_text:
continue
start_offset = wo.get('start_offset', 0)
end_offset = wo.get('end_offset', start_offset)
start_ns = int(start_offset * stride * 1_000_000_000)
end_ns = int(end_offset * stride * 1_000_000_000)
transcript_words.append({
'text': word_text,
'start': start_ns,
'end': end_ns,
})
if not transcript_words:
return []
if granularity == "word":
# One segment per word
result = []
for idx, tw in enumerate(transcript_words):
word = backend_pb2.TranscriptWord(
start=tw['start'], end=tw['end'], text=tw['text']
)
result.append(backend_pb2.TranscriptSegment(
id=idx,
start=tw['start'],
end=tw['end'],
text=tw['text'],
words=[word],
))
return result
# segment mode — merge at word-level time-gap boundaries
# Compute gap threshold: median inter-word gap * 3, clamped to [0.3, 2.0]s
gaps = []
for i in range(1, len(transcript_words)):
gap = (transcript_words[i]['start'] - transcript_words[i - 1]['end']) / 1_000_000_000
if gap > 0:
gaps.append(gap)
if gaps:
gaps.sort()
median_gap = gaps[len(gaps) // 2]
threshold_ns = int(max(0.3, min(median_gap * 3, 2.0)) * 1_000_000_000)
else:
threshold_ns = int(0.5 * 1_000_000_000)
result = []
buf_words = [] # list of TranscriptWord protobuf
buf_start = None
buf_end = 0
buf_text = []
prev_end = None
for tw in transcript_words:
# Detect word-level time gap
if prev_end is not None and (tw['start'] - prev_end) >= threshold_ns and buf_text:
seg_text = ' '.join(buf_text)
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=buf_start,
end=buf_end,
text=seg_text,
words=list(buf_words),
))
buf_words = []
buf_text = []
buf_start = None
if buf_start is None:
buf_start = tw['start']
buf_end = tw['end']
buf_text.append(tw['text'])
buf_words.append(backend_pb2.TranscriptWord(
start=tw['start'], end=tw['end'], text=tw['text']
))
prev_end = tw['end']
# flush remaining
if buf_text and buf_start is not None:
seg_text = ' '.join(buf_text)
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=buf_start,
end=buf_end,
text=seg_text,
words=list(buf_words),
))
return result
def AudioTranscription(self, request, context):
result_segments = []
text = ""
try:
audio_path = request.dst
if not audio_path or not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")
# Determine requested timestamp granularity
timestamp_granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
want_timestamps = bool(timestamp_granularities)
if want_timestamps:
# Request timestamps from NeMo.
# timestamps=True forces NeMo to return Hypothesis objects with
# the timestamp dict populated, so we omit return_hypotheses to
# let NeMo choose the correct return type.
results = self.model.transcribe([audio_path], timestamps=True)
if results and len(results) > 0:
hypotheses = results[0] if isinstance(results[0], list) else results
if hypotheses and len(hypotheses) > 0:
hypothesis = hypotheses[0]
# Hypothesis object should have .timestamp populated
if not hasattr(hypothesis, 'timestamp') or not isinstance(hypothesis.timestamp, dict):
print(
"Warning: timestamps were requested but NeMo did not return "
"Hypothesis objects; falling back to untimestamped output",
file=sys.stderr,
)
# Extract text
if hasattr(hypothesis, 'text'):
text = hypothesis.text or ""
elif isinstance(hypothesis, str):
text = hypothesis
# Build segments with word-level timestamps
stride = self._get_stride_seconds()
result_segments = self._build_segments_with_words(
hypothesis, stride, timestamp_granularities
)
# If no word offsets but we have text, fall back to single segment
if not result_segments and text:
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
else:
# Simple transcription without timestamps
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
results = self.model.transcribe([audio_path])
if results and len(results) > 0:
# Get the transcript text from the first result.
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
# where the actual text lives in Hypothesis.text.
result = results[0]
if isinstance(result, str):
text = result
else:
text = getattr(result, 'text', None) or ""
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
except Exception as err:
print(f"Error in AudioTranscription: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")
return backend_pb2.TranscriptResult(segments=result_segments, text=text)
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
args = parser.parse_args()
serve(args.addr)