feat(nemo): enable word-level timestamps for ASR models (#10297)

* feat(nemo): enable word-level timestamps for ASR models

The nemo backend ignored timestamp_granularities and always returned a
single segment with start=0 end=0, making word-level timestamps
impossible to obtain even though the NeMo models (parakeet-tdt, etc.)
fully support them.

Changes:
- Add _get_stride_seconds() to compute frame duration from the model's
  preprocessor window_stride and encoder subsampling_factor.
- Add _build_segments_with_words() that extracts word offsets from the
  NeMo Hypothesis.timestamp dict and converts frame indices to
  nanosecond timestamps.
- Support 'word' granularity (one segment per word) and 'segment'
  granularity (merge at time-gap boundaries using a dynamic threshold).
- Populate TranscriptSegment.words with TranscriptWord entries so
  callers get both segment-level and word-level timing.
- Only request timestamps from NeMo when the caller actually asks for
  them (timestamp_granularities is non-empty), keeping the fast path
  unchanged for callers that don't need timestamps.

Tested with nvidia/parakeet-tdt-0.6b-v3 on the JFK "ask not" clip:
  curl -X POST /v1/audio/transcriptions \
    -F file=@jfk.wav -F model=nemo-parakeet-tdt-0.6b \
    -F 'timestamp_granularities[]=word' -F response_format=verbose_json
  → each word has correct start/end times in seconds.

Signed-off-by: fqscfqj <fqscfqj@outlook.com>

* fix(nemo): address Copilot review feedback

- Narrow exception handling in _get_stride_seconds to catch only
  AttributeError, KeyError, TypeError instead of bare Exception, and
  emit a warning when falling back to the hardcoded stride.
- Remove explicit return_hypotheses=False when timestamps are requested;
  timestamps=True already forces NeMo to return Hypothesis objects.
- Add a warning when NeMo does not return Hypothesis objects despite
  timestamps being requested.

Signed-off-by: fqscfqj <fqscfqj@outlook.com>

---------

Signed-off-by: fqscfqj <fqscfqj@outlook.com>
This commit is contained in:
番茄摔成番茄酱
2026-06-21 23:04:19 +08:00
committed by GitHub
parent cf7f9573a2
commit 01fa12e0de

View File

@@ -84,6 +84,135 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True)
def _get_stride_seconds(self):
"""Compute the seconds-per-frame stride for the loaded model.
stride = preprocessor_window_stride * encoder_subsampling_factor
"""
try:
preprocessor = self.model.preprocessor
window_stride = preprocessor._cfg.get('window_stride', 0.01)
subsampling_factor = getattr(self.model.encoder, 'subsampling_factor', 8)
return window_stride * subsampling_factor
except (AttributeError, KeyError, TypeError) as err:
print(
f"Warning: could not compute stride from model config ({err}), "
f"falling back to 0.08s/frame",
file=sys.stderr,
)
return 0.08
def _build_segments_with_words(self, hypothesis, stride, timestamp_granularities=None):
"""Build TranscriptSegment list from a NeMo Hypothesis with timestamps.
Supports two granularity modes:
- "word": one TranscriptSegment per word, each with a single TranscriptWord entry
- "segment" (default): merge consecutive words into sentence-level segments,
splitting at word-level time gaps that exceed a dynamic threshold.
"""
if not hypothesis or not isinstance(hypothesis.timestamp, dict):
return []
word_offsets = hypothesis.timestamp.get('word', [])
if not word_offsets:
return []
granularities = list(timestamp_granularities) if timestamp_granularities else []
granularity = "word" if "word" in granularities else "segment"
# Build a flat list of (text, start_ns, end_ns) from NeMo word offsets
transcript_words = []
for wo in word_offsets:
word_text = wo.get('word', '')
if not word_text:
continue
start_offset = wo.get('start_offset', 0)
end_offset = wo.get('end_offset', start_offset)
start_ns = int(start_offset * stride * 1_000_000_000)
end_ns = int(end_offset * stride * 1_000_000_000)
transcript_words.append({
'text': word_text,
'start': start_ns,
'end': end_ns,
})
if not transcript_words:
return []
if granularity == "word":
# One segment per word
result = []
for idx, tw in enumerate(transcript_words):
word = backend_pb2.TranscriptWord(
start=tw['start'], end=tw['end'], text=tw['text']
)
result.append(backend_pb2.TranscriptSegment(
id=idx,
start=tw['start'],
end=tw['end'],
text=tw['text'],
words=[word],
))
return result
# segment mode — merge at word-level time-gap boundaries
# Compute gap threshold: median inter-word gap * 3, clamped to [0.3, 2.0]s
gaps = []
for i in range(1, len(transcript_words)):
gap = (transcript_words[i]['start'] - transcript_words[i - 1]['end']) / 1_000_000_000
if gap > 0:
gaps.append(gap)
if gaps:
gaps.sort()
median_gap = gaps[len(gaps) // 2]
threshold_ns = int(max(0.3, min(median_gap * 3, 2.0)) * 1_000_000_000)
else:
threshold_ns = int(0.5 * 1_000_000_000)
result = []
buf_words = [] # list of TranscriptWord protobuf
buf_start = None
buf_end = 0
buf_text = []
prev_end = None
for tw in transcript_words:
# Detect word-level time gap
if prev_end is not None and (tw['start'] - prev_end) >= threshold_ns and buf_text:
seg_text = ' '.join(buf_text)
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=buf_start,
end=buf_end,
text=seg_text,
words=list(buf_words),
))
buf_words = []
buf_text = []
buf_start = None
if buf_start is None:
buf_start = tw['start']
buf_end = tw['end']
buf_text.append(tw['text'])
buf_words.append(backend_pb2.TranscriptWord(
start=tw['start'], end=tw['end'], text=tw['text']
))
prev_end = tw['end']
# flush remaining
if buf_text and buf_start is not None:
seg_text = ' '.join(buf_text)
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=buf_start,
end=buf_end,
text=seg_text,
words=list(buf_words),
))
return result
def AudioTranscription(self, request, context):
result_segments = []
text = ""
@@ -93,26 +222,67 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
results = self.model.transcribe([audio_path])
# Determine requested timestamp granularity
timestamp_granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
want_timestamps = bool(timestamp_granularities)
if not results or len(results) == 0:
return backend_pb2.TranscriptResult(segments=[], text="")
if want_timestamps:
# Request timestamps from NeMo.
# timestamps=True forces NeMo to return Hypothesis objects with
# the timestamp dict populated, so we omit return_hypotheses to
# let NeMo choose the correct return type.
results = self.model.transcribe([audio_path], timestamps=True)
# Get the transcript text from the first result.
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
# where the actual text lives in Hypothesis.text.
result = results[0]
if isinstance(result, str):
text = result
if results and len(results) > 0:
hypotheses = results[0] if isinstance(results[0], list) else results
if hypotheses and len(hypotheses) > 0:
hypothesis = hypotheses[0]
# Hypothesis object should have .timestamp populated
if not hasattr(hypothesis, 'timestamp') or not isinstance(hypothesis.timestamp, dict):
print(
"Warning: timestamps were requested but NeMo did not return "
"Hypothesis objects; falling back to untimestamped output",
file=sys.stderr,
)
# Extract text
if hasattr(hypothesis, 'text'):
text = hypothesis.text or ""
elif isinstance(hypothesis, str):
text = hypothesis
# Build segments with word-level timestamps
stride = self._get_stride_seconds()
result_segments = self._build_segments_with_words(
hypothesis, stride, timestamp_granularities
)
# If no word offsets but we have text, fall back to single segment
if not result_segments and text:
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
else:
text = getattr(result, 'text', None) or ""
# Simple transcription without timestamps
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
results = self.model.transcribe([audio_path])
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
if results and len(results) > 0:
# Get the transcript text from the first result.
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
# where the actual text lives in Hypothesis.text.
result = results[0]
if isinstance(result, str):
text = result
else:
text = getattr(result, 'text', None) or ""
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
except Exception as err:
print(f"Error in AudioTranscription: {err}", file=sys.stderr)