mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-21 23:29:04 -04:00
feat(nemo): enable word-level timestamps for ASR models (#10297)
* feat(nemo): enable word-level timestamps for ASR models
The nemo backend ignored timestamp_granularities and always returned a
single segment with start=0 end=0, making word-level timestamps
impossible to obtain even though the NeMo models (parakeet-tdt, etc.)
fully support them.
Changes:
- Add _get_stride_seconds() to compute frame duration from the model's
preprocessor window_stride and encoder subsampling_factor.
- Add _build_segments_with_words() that extracts word offsets from the
NeMo Hypothesis.timestamp dict and converts frame indices to
nanosecond timestamps.
- Support 'word' granularity (one segment per word) and 'segment'
granularity (merge at time-gap boundaries using a dynamic threshold).
- Populate TranscriptSegment.words with TranscriptWord entries so
callers get both segment-level and word-level timing.
- Only request timestamps from NeMo when the caller actually asks for
them (timestamp_granularities is non-empty), keeping the fast path
unchanged for callers that don't need timestamps.
Tested with nvidia/parakeet-tdt-0.6b-v3 on the JFK "ask not" clip:
curl -X POST /v1/audio/transcriptions \
-F file=@jfk.wav -F model=nemo-parakeet-tdt-0.6b \
-F 'timestamp_granularities[]=word' -F response_format=verbose_json
→ each word has correct start/end times in seconds.
Signed-off-by: fqscfqj <fqscfqj@outlook.com>
* fix(nemo): address Copilot review feedback
- Narrow exception handling in _get_stride_seconds to catch only
AttributeError, KeyError, TypeError instead of bare Exception, and
emit a warning when falling back to the hardcoded stride.
- Remove explicit return_hypotheses=False when timestamps are requested;
timestamps=True already forces NeMo to return Hypothesis objects.
- Add a warning when NeMo does not return Hypothesis objects despite
timestamps being requested.
Signed-off-by: fqscfqj <fqscfqj@outlook.com>
---------
Signed-off-by: fqscfqj <fqscfqj@outlook.com>
This commit is contained in:
@@ -84,6 +84,135 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _get_stride_seconds(self):
|
||||
"""Compute the seconds-per-frame stride for the loaded model.
|
||||
|
||||
stride = preprocessor_window_stride * encoder_subsampling_factor
|
||||
"""
|
||||
try:
|
||||
preprocessor = self.model.preprocessor
|
||||
window_stride = preprocessor._cfg.get('window_stride', 0.01)
|
||||
subsampling_factor = getattr(self.model.encoder, 'subsampling_factor', 8)
|
||||
return window_stride * subsampling_factor
|
||||
except (AttributeError, KeyError, TypeError) as err:
|
||||
print(
|
||||
f"Warning: could not compute stride from model config ({err}), "
|
||||
f"falling back to 0.08s/frame",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 0.08
|
||||
|
||||
def _build_segments_with_words(self, hypothesis, stride, timestamp_granularities=None):
|
||||
"""Build TranscriptSegment list from a NeMo Hypothesis with timestamps.
|
||||
|
||||
Supports two granularity modes:
|
||||
- "word": one TranscriptSegment per word, each with a single TranscriptWord entry
|
||||
- "segment" (default): merge consecutive words into sentence-level segments,
|
||||
splitting at word-level time gaps that exceed a dynamic threshold.
|
||||
"""
|
||||
if not hypothesis or not isinstance(hypothesis.timestamp, dict):
|
||||
return []
|
||||
|
||||
word_offsets = hypothesis.timestamp.get('word', [])
|
||||
if not word_offsets:
|
||||
return []
|
||||
|
||||
granularities = list(timestamp_granularities) if timestamp_granularities else []
|
||||
granularity = "word" if "word" in granularities else "segment"
|
||||
|
||||
# Build a flat list of (text, start_ns, end_ns) from NeMo word offsets
|
||||
transcript_words = []
|
||||
for wo in word_offsets:
|
||||
word_text = wo.get('word', '')
|
||||
if not word_text:
|
||||
continue
|
||||
start_offset = wo.get('start_offset', 0)
|
||||
end_offset = wo.get('end_offset', start_offset)
|
||||
start_ns = int(start_offset * stride * 1_000_000_000)
|
||||
end_ns = int(end_offset * stride * 1_000_000_000)
|
||||
transcript_words.append({
|
||||
'text': word_text,
|
||||
'start': start_ns,
|
||||
'end': end_ns,
|
||||
})
|
||||
|
||||
if not transcript_words:
|
||||
return []
|
||||
|
||||
if granularity == "word":
|
||||
# One segment per word
|
||||
result = []
|
||||
for idx, tw in enumerate(transcript_words):
|
||||
word = backend_pb2.TranscriptWord(
|
||||
start=tw['start'], end=tw['end'], text=tw['text']
|
||||
)
|
||||
result.append(backend_pb2.TranscriptSegment(
|
||||
id=idx,
|
||||
start=tw['start'],
|
||||
end=tw['end'],
|
||||
text=tw['text'],
|
||||
words=[word],
|
||||
))
|
||||
return result
|
||||
|
||||
# segment mode — merge at word-level time-gap boundaries
|
||||
# Compute gap threshold: median inter-word gap * 3, clamped to [0.3, 2.0]s
|
||||
gaps = []
|
||||
for i in range(1, len(transcript_words)):
|
||||
gap = (transcript_words[i]['start'] - transcript_words[i - 1]['end']) / 1_000_000_000
|
||||
if gap > 0:
|
||||
gaps.append(gap)
|
||||
if gaps:
|
||||
gaps.sort()
|
||||
median_gap = gaps[len(gaps) // 2]
|
||||
threshold_ns = int(max(0.3, min(median_gap * 3, 2.0)) * 1_000_000_000)
|
||||
else:
|
||||
threshold_ns = int(0.5 * 1_000_000_000)
|
||||
|
||||
result = []
|
||||
buf_words = [] # list of TranscriptWord protobuf
|
||||
buf_start = None
|
||||
buf_end = 0
|
||||
buf_text = []
|
||||
prev_end = None
|
||||
|
||||
for tw in transcript_words:
|
||||
# Detect word-level time gap
|
||||
if prev_end is not None and (tw['start'] - prev_end) >= threshold_ns and buf_text:
|
||||
seg_text = ' '.join(buf_text)
|
||||
result.append(backend_pb2.TranscriptSegment(
|
||||
id=len(result),
|
||||
start=buf_start,
|
||||
end=buf_end,
|
||||
text=seg_text,
|
||||
words=list(buf_words),
|
||||
))
|
||||
buf_words = []
|
||||
buf_text = []
|
||||
buf_start = None
|
||||
|
||||
if buf_start is None:
|
||||
buf_start = tw['start']
|
||||
buf_end = tw['end']
|
||||
buf_text.append(tw['text'])
|
||||
buf_words.append(backend_pb2.TranscriptWord(
|
||||
start=tw['start'], end=tw['end'], text=tw['text']
|
||||
))
|
||||
prev_end = tw['end']
|
||||
|
||||
# flush remaining
|
||||
if buf_text and buf_start is not None:
|
||||
seg_text = ' '.join(buf_text)
|
||||
result.append(backend_pb2.TranscriptSegment(
|
||||
id=len(result),
|
||||
start=buf_start,
|
||||
end=buf_end,
|
||||
text=seg_text,
|
||||
words=list(buf_words),
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
result_segments = []
|
||||
text = ""
|
||||
@@ -93,26 +222,67 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
|
||||
results = self.model.transcribe([audio_path])
|
||||
# Determine requested timestamp granularity
|
||||
timestamp_granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
|
||||
want_timestamps = bool(timestamp_granularities)
|
||||
|
||||
if not results or len(results) == 0:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
if want_timestamps:
|
||||
# Request timestamps from NeMo.
|
||||
# timestamps=True forces NeMo to return Hypothesis objects with
|
||||
# the timestamp dict populated, so we omit return_hypotheses to
|
||||
# let NeMo choose the correct return type.
|
||||
results = self.model.transcribe([audio_path], timestamps=True)
|
||||
|
||||
# Get the transcript text from the first result.
|
||||
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
|
||||
# where the actual text lives in Hypothesis.text.
|
||||
result = results[0]
|
||||
if isinstance(result, str):
|
||||
text = result
|
||||
if results and len(results) > 0:
|
||||
hypotheses = results[0] if isinstance(results[0], list) else results
|
||||
if hypotheses and len(hypotheses) > 0:
|
||||
hypothesis = hypotheses[0]
|
||||
|
||||
# Hypothesis object should have .timestamp populated
|
||||
if not hasattr(hypothesis, 'timestamp') or not isinstance(hypothesis.timestamp, dict):
|
||||
print(
|
||||
"Warning: timestamps were requested but NeMo did not return "
|
||||
"Hypothesis objects; falling back to untimestamped output",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Extract text
|
||||
if hasattr(hypothesis, 'text'):
|
||||
text = hypothesis.text or ""
|
||||
elif isinstance(hypothesis, str):
|
||||
text = hypothesis
|
||||
|
||||
# Build segments with word-level timestamps
|
||||
stride = self._get_stride_seconds()
|
||||
result_segments = self._build_segments_with_words(
|
||||
hypothesis, stride, timestamp_granularities
|
||||
)
|
||||
|
||||
# If no word offsets but we have text, fall back to single segment
|
||||
if not result_segments and text:
|
||||
result_segments.append(backend_pb2.TranscriptSegment(
|
||||
id=0, start=0, end=0, text=text
|
||||
))
|
||||
else:
|
||||
text = getattr(result, 'text', None) or ""
|
||||
# Simple transcription without timestamps
|
||||
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
|
||||
results = self.model.transcribe([audio_path])
|
||||
|
||||
if text:
|
||||
# Create a single segment with the full transcription
|
||||
result_segments.append(backend_pb2.TranscriptSegment(
|
||||
id=0, start=0, end=0, text=text
|
||||
))
|
||||
if results and len(results) > 0:
|
||||
# Get the transcript text from the first result.
|
||||
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
|
||||
# where the actual text lives in Hypothesis.text.
|
||||
result = results[0]
|
||||
if isinstance(result, str):
|
||||
text = result
|
||||
else:
|
||||
text = getattr(result, 'text', None) or ""
|
||||
|
||||
if text:
|
||||
# Create a single segment with the full transcription
|
||||
result_segments.append(backend_pb2.TranscriptSegment(
|
||||
id=0, start=0, end=0, text=text
|
||||
))
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in AudioTranscription: {err}", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user