mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-20 07:46:46 -05:00
Compare commits
25 Commits
llama.cpp-
...
ai-chat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aec7c15b6f | ||
|
|
a5755c806b | ||
|
|
ce33ae0bbc | ||
|
|
361dcc94c8 | ||
|
|
d4630c62ca | ||
|
|
994a5acc52 | ||
|
|
2ddd55c470 | ||
|
|
7f69233f71 | ||
|
|
7551332c01 | ||
|
|
76409f79e0 | ||
|
|
730bf3c0b7 | ||
|
|
0371f55321 | ||
|
|
858367c98a | ||
|
|
e09e9a0b7a | ||
|
|
3b3edc481b | ||
|
|
b0bcf45245 | ||
|
|
910122913a | ||
|
|
99e97850c9 | ||
|
|
59a38aa67c | ||
|
|
9125eff794 | ||
|
|
e934910616 | ||
|
|
a21cabae2d | ||
|
|
79d7d20866 | ||
|
|
e7e806b135 | ||
|
|
f29adca9ce |
@@ -76,40 +76,6 @@ Switching between V1 and V2 requires reindexing your embeddings. The embeddings
|
||||
|
||||
:::
|
||||
|
||||
### GenAI Provider (llama.cpp)
|
||||
|
||||
Frigate can use a GenAI provider for semantic search embeddings when that provider has the `embeddings` role. Currently, only **llama.cpp** supports multimodal embeddings (both text and images).
|
||||
|
||||
To use llama.cpp for semantic search:
|
||||
|
||||
1. Configure a GenAI provider in your config with `embeddings` in its `roles`.
|
||||
2. Set `semantic_search.model` to the GenAI config key (e.g. `default`).
|
||||
3. Start the llama.cpp server with `--embeddings` and `--mmproj` for image support:
|
||||
|
||||
```yaml
|
||||
genai:
|
||||
default:
|
||||
provider: llamacpp
|
||||
base_url: http://localhost:8080
|
||||
model: your-model-name
|
||||
roles:
|
||||
- embeddings
|
||||
- vision
|
||||
- tools
|
||||
|
||||
semantic_search:
|
||||
enabled: True
|
||||
model: default
|
||||
```
|
||||
|
||||
The llama.cpp server must be started with `--embeddings` for the embeddings API, and `--mmproj <mmproj.gguf>` when using image embeddings. See the [llama.cpp server documentation](https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md) for details.
|
||||
|
||||
:::note
|
||||
|
||||
Switching between Jina models and a GenAI provider requires reindexing. Embeddings from different backends are incompatible.
|
||||
|
||||
:::
|
||||
|
||||
### GPU Acceleration
|
||||
|
||||
The CLIP models are downloaded in ONNX format, and the `large` model can be accelerated using GPU hardware, when available. This depends on the Docker build that is used. You can also target a specific device in a multi-GPU installation.
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import cv2
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from frigate.api.auth import (
|
||||
@@ -20,15 +21,60 @@ from frigate.api.defs.request.chat_body import ChatCompletionRequest
|
||||
from frigate.api.defs.response.chat_response import (
|
||||
ChatCompletionResponse,
|
||||
ChatMessageResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from frigate.api.defs.tags import Tags
|
||||
from frigate.api.event import events
|
||||
from frigate.genai.utils import build_assistant_message_for_conversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=[Tags.chat])
|
||||
|
||||
|
||||
def _chunk_content(content: str, chunk_size: int = 80) -> Generator[str, None, None]:
|
||||
"""Yield content in word-aware chunks for streaming."""
|
||||
if not content:
|
||||
return
|
||||
words = content.split(" ")
|
||||
current: List[str] = []
|
||||
current_len = 0
|
||||
for w in words:
|
||||
current.append(w)
|
||||
current_len += len(w) + 1
|
||||
if current_len >= chunk_size:
|
||||
yield " ".join(current) + " "
|
||||
current = []
|
||||
current_len = 0
|
||||
if current:
|
||||
yield " ".join(current)
|
||||
|
||||
|
||||
def _format_events_with_local_time(
|
||||
events_list: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Add human-readable local start/end times to each event for the LLM."""
|
||||
result = []
|
||||
for evt in events_list:
|
||||
if not isinstance(evt, dict):
|
||||
result.append(evt)
|
||||
continue
|
||||
copy_evt = dict(evt)
|
||||
try:
|
||||
start_ts = evt.get("start_time")
|
||||
end_ts = evt.get("end_time")
|
||||
if start_ts is not None:
|
||||
dt_start = datetime.fromtimestamp(start_ts)
|
||||
copy_evt["start_time_local"] = dt_start.strftime("%Y-%m-%d %I:%M:%S %p")
|
||||
if end_ts is not None:
|
||||
dt_end = datetime.fromtimestamp(end_ts)
|
||||
copy_evt["end_time_local"] = dt_end.strftime("%Y-%m-%d %I:%M:%S %p")
|
||||
except (TypeError, ValueError, OSError):
|
||||
pass
|
||||
result.append(copy_evt)
|
||||
return result
|
||||
|
||||
|
||||
class ToolExecuteRequest(BaseModel):
|
||||
"""Request model for tool execution."""
|
||||
|
||||
@@ -52,19 +98,25 @@ def get_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"Search for detected objects in Frigate by camera, object label, time range, "
|
||||
"zones, and other filters. Use this to answer questions about when "
|
||||
"objects were detected, what objects appeared, or to find specific object detections. "
|
||||
"An 'object' in Frigate represents a tracked detection (e.g., a person, package, car)."
|
||||
"An 'object' in Frigate represents a tracked detection (e.g., a person, package, car). "
|
||||
"When the user asks about a specific name (person, delivery company, animal, etc.), "
|
||||
"filter by sub_label only and do not set label."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"camera": {
|
||||
"type": "string",
|
||||
"description": "Camera name to filter by (optional). Use 'all' for all cameras.",
|
||||
"description": "Camera name to filter by (optional).",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Object label to filter by (e.g., 'person', 'package', 'car').",
|
||||
},
|
||||
"sub_label": {
|
||||
"type": "string",
|
||||
"description": "Name of a person, delivery company, animal, etc. When filtering by a specific name, use only sub_label; do not set label.",
|
||||
},
|
||||
"after": {
|
||||
"type": "string",
|
||||
"description": "Start time in ISO 8601 format (e.g., '2024-01-01T00:00:00Z').",
|
||||
@@ -80,8 +132,8 @@ def get_tool_definitions() -> List[Dict[str, Any]]:
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of objects to return (default: 10).",
|
||||
"default": 10,
|
||||
"description": "Maximum number of objects to return (default: 25).",
|
||||
"default": 25,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -119,14 +171,13 @@ def get_tool_definitions() -> List[Dict[str, Any]]:
|
||||
summary="Get available tools",
|
||||
description="Returns OpenAI-compatible tool definitions for function calling.",
|
||||
)
|
||||
def get_tools(request: Request) -> JSONResponse:
|
||||
def get_tools() -> JSONResponse:
|
||||
"""Get list of available tools for LLM function calling."""
|
||||
tools = get_tool_definitions()
|
||||
return JSONResponse(content={"tools": tools})
|
||||
|
||||
|
||||
async def _execute_search_objects(
|
||||
request: Request,
|
||||
arguments: Dict[str, Any],
|
||||
allowed_cameras: List[str],
|
||||
) -> JSONResponse:
|
||||
@@ -136,23 +187,26 @@ async def _execute_search_objects(
|
||||
This searches for detected objects (events) in Frigate using the same
|
||||
logic as the events API endpoint.
|
||||
"""
|
||||
# Parse ISO 8601 timestamps to Unix timestamps if provided
|
||||
# Parse after/before as server local time; convert to Unix timestamp
|
||||
after = arguments.get("after")
|
||||
before = arguments.get("before")
|
||||
|
||||
def _parse_as_local_timestamp(s: str):
|
||||
s = s.replace("Z", "").strip()[:19]
|
||||
dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S")
|
||||
return time.mktime(dt.timetuple())
|
||||
|
||||
if after:
|
||||
try:
|
||||
after_dt = datetime.fromisoformat(after.replace("Z", "+00:00"))
|
||||
after = after_dt.timestamp()
|
||||
except (ValueError, AttributeError):
|
||||
after = _parse_as_local_timestamp(after)
|
||||
except (ValueError, AttributeError, TypeError):
|
||||
logger.warning(f"Invalid 'after' timestamp format: {after}")
|
||||
after = None
|
||||
|
||||
if before:
|
||||
try:
|
||||
before_dt = datetime.fromisoformat(before.replace("Z", "+00:00"))
|
||||
before = before_dt.timestamp()
|
||||
except (ValueError, AttributeError):
|
||||
before = _parse_as_local_timestamp(before)
|
||||
except (ValueError, AttributeError, TypeError):
|
||||
logger.warning(f"Invalid 'before' timestamp format: {before}")
|
||||
before = None
|
||||
|
||||
@@ -165,15 +219,14 @@ async def _execute_search_objects(
|
||||
|
||||
# Build query parameters compatible with EventsQueryParams
|
||||
query_params = EventsQueryParams(
|
||||
camera=arguments.get("camera", "all"),
|
||||
cameras=arguments.get("camera", "all"),
|
||||
label=arguments.get("label", "all"),
|
||||
labels=arguments.get("label", "all"),
|
||||
sub_labels=arguments.get("sub_label", "all").lower(),
|
||||
zones=zones,
|
||||
zone=zones,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=arguments.get("limit", 10),
|
||||
limit=arguments.get("limit", 25),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -202,7 +255,6 @@ async def _execute_search_objects(
|
||||
description="Execute a tool function call from an LLM.",
|
||||
)
|
||||
async def execute_tool(
|
||||
request: Request,
|
||||
body: ToolExecuteRequest = Body(...),
|
||||
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
|
||||
) -> JSONResponse:
|
||||
@@ -218,7 +270,7 @@ async def execute_tool(
|
||||
logger.debug(f"Executing tool: {tool_name} with arguments: {arguments}")
|
||||
|
||||
if tool_name == "search_objects":
|
||||
return await _execute_search_objects(request, arguments, allowed_cameras)
|
||||
return await _execute_search_objects(arguments, allowed_cameras)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -334,7 +386,7 @@ async def _execute_tool_internal(
|
||||
This is used by the chat completion endpoint to execute tools.
|
||||
"""
|
||||
if tool_name == "search_objects":
|
||||
response = await _execute_search_objects(request, arguments, allowed_cameras)
|
||||
response = await _execute_search_objects(arguments, allowed_cameras)
|
||||
try:
|
||||
if hasattr(response, "body"):
|
||||
body_str = response.body.decode("utf-8")
|
||||
@@ -349,15 +401,109 @@ async def _execute_tool_internal(
|
||||
elif tool_name == "get_live_context":
|
||||
camera = arguments.get("camera")
|
||||
if not camera:
|
||||
logger.error(
|
||||
"Tool get_live_context failed: camera parameter is required. "
|
||||
"Arguments: %s",
|
||||
json.dumps(arguments),
|
||||
)
|
||||
return {"error": "Camera parameter is required"}
|
||||
return await _execute_get_live_context(request, camera, allowed_cameras)
|
||||
else:
|
||||
logger.error(
|
||||
"Tool call failed: unknown tool %r. Expected one of: search_objects, get_live_context. "
|
||||
"Arguments received: %s",
|
||||
tool_name,
|
||||
json.dumps(arguments),
|
||||
)
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
|
||||
async def _execute_pending_tools(
|
||||
pending_tool_calls: List[Dict[str, Any]],
|
||||
request: Request,
|
||||
allowed_cameras: List[str],
|
||||
) -> tuple[List[ToolCall], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute a list of tool calls; return (ToolCall list for API response, tool result dicts for conversation).
|
||||
"""
|
||||
tool_calls_out: List[ToolCall] = []
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
for tool_call in pending_tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call.get("arguments") or {}
|
||||
tool_call_id = tool_call["id"]
|
||||
logger.debug(
|
||||
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
|
||||
)
|
||||
try:
|
||||
tool_result = await _execute_tool_internal(
|
||||
tool_name, tool_args, request, allowed_cameras
|
||||
)
|
||||
if isinstance(tool_result, dict) and tool_result.get("error"):
|
||||
logger.error(
|
||||
"Tool call %s (id: %s) returned error: %s. Arguments: %s",
|
||||
tool_name,
|
||||
tool_call_id,
|
||||
tool_result.get("error"),
|
||||
json.dumps(tool_args),
|
||||
)
|
||||
if tool_name == "search_objects" and isinstance(tool_result, list):
|
||||
tool_result = _format_events_with_local_time(tool_result)
|
||||
_keys = {
|
||||
"id",
|
||||
"camera",
|
||||
"label",
|
||||
"zones",
|
||||
"start_time_local",
|
||||
"end_time_local",
|
||||
"sub_label",
|
||||
"event_count",
|
||||
}
|
||||
tool_result = [
|
||||
{k: evt[k] for k in _keys if k in evt}
|
||||
for evt in tool_result
|
||||
if isinstance(evt, dict)
|
||||
]
|
||||
result_content = (
|
||||
json.dumps(tool_result)
|
||||
if isinstance(tool_result, (dict, list))
|
||||
else (tool_result if isinstance(tool_result, str) else str(tool_result))
|
||||
)
|
||||
tool_calls_out.append(
|
||||
ToolCall(name=tool_name, arguments=tool_args, response=result_content)
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s (id: %s): %s. Arguments: %s",
|
||||
tool_name,
|
||||
tool_call_id,
|
||||
e,
|
||||
json.dumps(tool_args),
|
||||
exc_info=True,
|
||||
)
|
||||
error_content = json.dumps({"error": f"Tool execution failed: {str(e)}"})
|
||||
tool_calls_out.append(
|
||||
ToolCall(name=tool_name, arguments=tool_args, response=error_content)
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": error_content,
|
||||
}
|
||||
)
|
||||
return (tool_calls_out, tool_results)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completion",
|
||||
response_model=ChatCompletionResponse,
|
||||
dependencies=[Depends(allow_any_authenticated())],
|
||||
summary="Chat completion with tool calling",
|
||||
description=(
|
||||
@@ -369,7 +515,7 @@ async def chat_completion(
|
||||
request: Request,
|
||||
body: ChatCompletionRequest = Body(...),
|
||||
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
|
||||
) -> JSONResponse:
|
||||
):
|
||||
"""
|
||||
Chat completion endpoint with tool calling support.
|
||||
|
||||
@@ -394,9 +540,9 @@ async def chat_completion(
|
||||
tools = get_tool_definitions()
|
||||
conversation = []
|
||||
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
current_datetime = datetime.now()
|
||||
current_date_str = current_datetime.strftime("%Y-%m-%d")
|
||||
current_time_str = current_datetime.strftime("%H:%M:%S %Z")
|
||||
current_time_str = current_datetime.strftime("%I:%M:%S %p")
|
||||
|
||||
cameras_info = []
|
||||
config = request.app.frigate_config
|
||||
@@ -429,9 +575,12 @@ async def chat_completion(
|
||||
|
||||
system_prompt = f"""You are a helpful assistant for Frigate, a security camera NVR system. You help users answer questions about their cameras, detected objects, and events.
|
||||
|
||||
Current date and time: {current_date_str} at {current_time_str} (UTC)
|
||||
Current server local date and time: {current_date_str} at {current_time_str}
|
||||
|
||||
When users ask questions about "today", "yesterday", "this week", etc., use the current date above as reference.
|
||||
Do not start your response with phrases like "I will check...", "Let me see...", or "Let me look...". Answer directly.
|
||||
|
||||
Always present times to the user in the server's local timezone. When tool results include start_time_local and end_time_local, use those exact strings when listing or describing detection times—do not convert or invent timestamps. Do not use UTC or ISO format with Z for the user-facing answer unless the tool result only provides Unix timestamps without local time fields.
|
||||
When users ask about "today", "yesterday", "this week", etc., use the current date above as reference.
|
||||
When searching for objects or events, use ISO 8601 format for dates (e.g., {current_date_str}T00:00:00Z for the start of today).
|
||||
Always be accurate with time calculations based on the current date provided.{cameras_section}{live_image_note}"""
|
||||
|
||||
@@ -471,6 +620,7 @@ Always be accurate with time calculations based on the current date provided.{ca
|
||||
conversation.append(msg_dict)
|
||||
|
||||
tool_iterations = 0
|
||||
tool_calls: List[ToolCall] = []
|
||||
max_iterations = body.max_tool_iterations
|
||||
|
||||
logger.debug(
|
||||
@@ -478,6 +628,81 @@ Always be accurate with time calculations based on the current date provided.{ca
|
||||
f"{len(tools)} tool(s) available, max_iterations={max_iterations}"
|
||||
)
|
||||
|
||||
# True LLM streaming when client supports it and stream requested
|
||||
if body.stream and hasattr(genai_client, "chat_with_tools_stream"):
|
||||
stream_tool_calls: List[ToolCall] = []
|
||||
stream_iterations = 0
|
||||
|
||||
async def stream_body_llm():
|
||||
nonlocal conversation, stream_tool_calls, stream_iterations
|
||||
while stream_iterations < max_iterations:
|
||||
logger.debug(
|
||||
f"Streaming LLM (iteration {stream_iterations + 1}/{max_iterations}) "
|
||||
f"with {len(conversation)} message(s)"
|
||||
)
|
||||
async for event in genai_client.chat_with_tools_stream(
|
||||
messages=conversation,
|
||||
tools=tools if tools else None,
|
||||
tool_choice="auto",
|
||||
):
|
||||
kind, value = event
|
||||
if kind == "content_delta":
|
||||
yield (
|
||||
json.dumps({"type": "content", "delta": value}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
+ b"\n"
|
||||
)
|
||||
elif kind == "message":
|
||||
msg = value
|
||||
if msg.get("finish_reason") == "error":
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "An error occurred while processing your request.",
|
||||
}
|
||||
).encode("utf-8")
|
||||
+ b"\n"
|
||||
)
|
||||
return
|
||||
pending = msg.get("tool_calls")
|
||||
if pending:
|
||||
stream_iterations += 1
|
||||
conversation.append(
|
||||
build_assistant_message_for_conversation(
|
||||
msg.get("content"), pending
|
||||
)
|
||||
)
|
||||
executed_calls, tool_results = await _execute_pending_tools(
|
||||
pending, request, allowed_cameras
|
||||
)
|
||||
stream_tool_calls.extend(executed_calls)
|
||||
conversation.extend(tool_results)
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": [
|
||||
tc.model_dump() for tc in stream_tool_calls
|
||||
],
|
||||
}
|
||||
).encode("utf-8")
|
||||
+ b"\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
|
||||
return
|
||||
else:
|
||||
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_body_llm(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
try:
|
||||
while tool_iterations < max_iterations:
|
||||
logger.debug(
|
||||
@@ -499,119 +724,71 @@ Always be accurate with time calculations based on the current date provided.{ca
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": response.get("content"),
|
||||
}
|
||||
if response.get("tool_calls"):
|
||||
assistant_message["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["arguments"]),
|
||||
},
|
||||
}
|
||||
for tc in response["tool_calls"]
|
||||
]
|
||||
conversation.append(assistant_message)
|
||||
conversation.append(
|
||||
build_assistant_message_for_conversation(
|
||||
response.get("content"), response.get("tool_calls")
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls = response.get("tool_calls")
|
||||
if not tool_calls:
|
||||
pending_tool_calls = response.get("tool_calls")
|
||||
if not pending_tool_calls:
|
||||
logger.debug(
|
||||
f"Chat completion finished with final answer (iterations: {tool_iterations})"
|
||||
)
|
||||
final_content = response.get("content") or ""
|
||||
|
||||
if body.stream:
|
||||
|
||||
async def stream_body() -> Any:
|
||||
if tool_calls:
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": [
|
||||
tc.model_dump() for tc in tool_calls
|
||||
],
|
||||
}
|
||||
).encode("utf-8")
|
||||
+ b"\n"
|
||||
)
|
||||
# Stream content in word-sized chunks for smooth UX
|
||||
for part in _chunk_content(final_content):
|
||||
yield (
|
||||
json.dumps({"type": "content", "delta": part}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
+ b"\n"
|
||||
)
|
||||
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_body(),
|
||||
media_type="application/x-ndjson",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=ChatCompletionResponse(
|
||||
message=ChatMessageResponse(
|
||||
role="assistant",
|
||||
content=response.get("content"),
|
||||
content=final_content,
|
||||
tool_calls=None,
|
||||
),
|
||||
finish_reason=response.get("finish_reason", "stop"),
|
||||
tool_iterations=tool_iterations,
|
||||
tool_calls=tool_calls,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
tool_iterations += 1
|
||||
logger.debug(
|
||||
f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): "
|
||||
f"{len(tool_calls)} tool(s) to execute"
|
||||
f"{len(pending_tool_calls)} tool(s) to execute"
|
||||
)
|
||||
tool_results = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["arguments"]
|
||||
tool_call_id = tool_call["id"]
|
||||
|
||||
logger.debug(
|
||||
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await _execute_tool_internal(
|
||||
tool_name, tool_args, request, allowed_cameras
|
||||
)
|
||||
|
||||
if isinstance(tool_result, dict):
|
||||
result_content = json.dumps(tool_result)
|
||||
result_summary = tool_result
|
||||
if isinstance(tool_result, dict) and isinstance(
|
||||
tool_result.get("content"), list
|
||||
):
|
||||
result_count = len(tool_result.get("content", []))
|
||||
result_summary = {
|
||||
"count": result_count,
|
||||
"sample": tool_result.get("content", [])[:2]
|
||||
if result_count > 0
|
||||
else [],
|
||||
}
|
||||
logger.debug(
|
||||
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
|
||||
f"Result: {json.dumps(result_summary, indent=2)}"
|
||||
)
|
||||
elif isinstance(tool_result, str):
|
||||
result_content = tool_result
|
||||
logger.debug(
|
||||
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
|
||||
f"Result length: {len(result_content)} characters"
|
||||
)
|
||||
else:
|
||||
result_content = str(tool_result)
|
||||
logger.debug(
|
||||
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
|
||||
f"Result type: {type(tool_result).__name__}"
|
||||
)
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error executing tool {tool_name} (id: {tool_call_id}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_content = json.dumps(
|
||||
{"error": f"Tool execution failed: {str(e)}"}
|
||||
)
|
||||
tool_results.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": error_content,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation."
|
||||
)
|
||||
|
||||
executed_calls, tool_results = await _execute_pending_tools(
|
||||
pending_tool_calls, request, allowed_cameras
|
||||
)
|
||||
tool_calls.extend(executed_calls)
|
||||
conversation.extend(tool_results)
|
||||
logger.debug(
|
||||
f"Added {len(tool_results)} tool result(s) to conversation. "
|
||||
@@ -630,6 +807,7 @@ Always be accurate with time calculations based on the current date provided.{ca
|
||||
),
|
||||
finish_reason="length",
|
||||
tool_iterations=tool_iterations,
|
||||
tool_calls=tool_calls,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -39,3 +39,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"user message as multimodal content. Use with get_live_context for detection info."
|
||||
),
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=False,
|
||||
description="If true, stream the final assistant response in the body as newline-delimited JSON.",
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool call from the LLM."""
|
||||
class ToolCallInvocation(BaseModel):
|
||||
"""A tool call requested by the LLM (before execution)."""
|
||||
|
||||
id: str = Field(description="Unique identifier for this tool call")
|
||||
name: str = Field(description="Tool name to call")
|
||||
@@ -20,11 +20,24 @@ class ChatMessageResponse(BaseModel):
|
||||
content: Optional[str] = Field(
|
||||
default=None, description="Message content (None if tool calls present)"
|
||||
)
|
||||
tool_calls: Optional[list[ToolCall]] = Field(
|
||||
tool_calls: Optional[list[ToolCallInvocation]] = Field(
|
||||
default=None, description="Tool calls if LLM wants to call tools"
|
||||
)
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool that was executed during the completion, with its response."""
|
||||
|
||||
name: str = Field(description="Tool name that was called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Arguments passed to the tool"
|
||||
)
|
||||
response: str = Field(
|
||||
default="",
|
||||
description="The response or result returned from the tool execution",
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""Response from chat completion."""
|
||||
|
||||
@@ -35,3 +48,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
tool_iterations: int = Field(
|
||||
default=0, description="Number of tool call iterations performed"
|
||||
)
|
||||
tool_calls: list[ToolCall] = Field(
|
||||
default_factory=list,
|
||||
description="List of tool calls that were executed during this completion",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
@@ -128,10 +128,9 @@ class SemanticSearchConfig(FrigateBaseModel):
|
||||
reindex: Optional[bool] = Field(
|
||||
default=False, title="Reindex all tracked objects on startup."
|
||||
)
|
||||
model: Optional[Union[SemanticSearchModelEnum, str]] = Field(
|
||||
model: Optional[SemanticSearchModelEnum] = Field(
|
||||
default=SemanticSearchModelEnum.jinav1,
|
||||
title="The CLIP model or GenAI provider name for semantic search.",
|
||||
description="Use 'jinav1', 'jinav2' for ONNX models, or a GenAI config key (e.g. 'default') when that provider has the embeddings role.",
|
||||
title="The CLIP model to use for semantic search.",
|
||||
)
|
||||
model_size: str = Field(
|
||||
default="small", title="The size of the embeddings model used."
|
||||
|
||||
@@ -443,22 +443,6 @@ class FrigateConfig(FrigateBaseModel):
|
||||
)
|
||||
role_to_name[role] = name
|
||||
|
||||
# validate semantic_search.model when it is a GenAI provider name
|
||||
if self.semantic_search.enabled and isinstance(
|
||||
self.semantic_search.model, str
|
||||
):
|
||||
if self.semantic_search.model not in self.genai:
|
||||
raise ValueError(
|
||||
f"semantic_search.model '{self.semantic_search.model}' is not a "
|
||||
"valid GenAI config key. Must match a key in genai config."
|
||||
)
|
||||
genai_cfg = self.genai[self.semantic_search.model]
|
||||
if GenAIRoleEnum.embeddings not in genai_cfg.roles:
|
||||
raise ValueError(
|
||||
f"GenAI provider '{self.semantic_search.model}' must have "
|
||||
"'embeddings' in its roles for semantic search."
|
||||
)
|
||||
|
||||
# set default min_score for object attributes
|
||||
for attribute in self.model.all_attributes:
|
||||
if not self.objects.filters.get(attribute):
|
||||
|
||||
@@ -28,7 +28,6 @@ from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, serialize
|
||||
from frigate.util.file import get_event_thumbnail_bytes
|
||||
|
||||
from .genai_embedding import GenAIEmbedding
|
||||
from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding
|
||||
from .onnx.jina_v2_embedding import JinaV2Embedding
|
||||
|
||||
@@ -74,13 +73,11 @@ class Embeddings:
|
||||
config: FrigateConfig,
|
||||
db: SqliteVecQueueDatabase,
|
||||
metrics: DataProcessorMetrics,
|
||||
genai_manager=None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.db = db
|
||||
self.metrics = metrics
|
||||
self.requestor = InterProcessRequestor()
|
||||
self.genai_manager = genai_manager
|
||||
|
||||
self.image_inference_speed = InferenceSpeed(self.metrics.image_embeddings_speed)
|
||||
self.image_eps = EventsPerSecond()
|
||||
@@ -107,27 +104,7 @@ class Embeddings:
|
||||
},
|
||||
)
|
||||
|
||||
model_cfg = self.config.semantic_search.model
|
||||
is_genai_model = isinstance(model_cfg, str)
|
||||
|
||||
if is_genai_model:
|
||||
embeddings_client = (
|
||||
genai_manager.embeddings_client if genai_manager else None
|
||||
)
|
||||
if not embeddings_client:
|
||||
raise ValueError(
|
||||
f"semantic_search.model is '{model_cfg}' (GenAI provider) but "
|
||||
"no embeddings client is configured. Ensure the GenAI provider "
|
||||
"has 'embeddings' in its roles."
|
||||
)
|
||||
self.embedding = GenAIEmbedding(embeddings_client)
|
||||
self.text_embedding = lambda input_data: self.embedding(
|
||||
input_data, embedding_type="text"
|
||||
)
|
||||
self.vision_embedding = lambda input_data: self.embedding(
|
||||
input_data, embedding_type="vision"
|
||||
)
|
||||
elif model_cfg == SemanticSearchModelEnum.jinav2:
|
||||
if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2:
|
||||
# Single JinaV2Embedding instance for both text and vision
|
||||
self.embedding = JinaV2Embedding(
|
||||
model_size=self.config.semantic_search.model_size,
|
||||
@@ -141,8 +118,7 @@ class Embeddings:
|
||||
self.vision_embedding = lambda input_data: self.embedding(
|
||||
input_data, embedding_type="vision"
|
||||
)
|
||||
else:
|
||||
# Default to jinav1
|
||||
else: # Default to jinav1
|
||||
self.text_embedding = JinaV1TextEmbedding(
|
||||
model_size=config.semantic_search.model_size,
|
||||
requestor=self.requestor,
|
||||
@@ -160,11 +136,8 @@ class Embeddings:
|
||||
self.metrics.text_embeddings_eps.value = self.text_eps.eps()
|
||||
|
||||
def get_model_definitions(self):
|
||||
model_cfg = self.config.semantic_search.model
|
||||
if isinstance(model_cfg, str):
|
||||
# GenAI provider: no ONNX models to download
|
||||
models = []
|
||||
elif model_cfg == SemanticSearchModelEnum.jinav2:
|
||||
# Version-specific models
|
||||
if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2:
|
||||
models = [
|
||||
"jinaai/jina-clip-v2-tokenizer",
|
||||
"jinaai/jina-clip-v2-model_fp16.onnx"
|
||||
@@ -251,14 +224,6 @@ class Embeddings:
|
||||
|
||||
embeddings = self.vision_embedding(valid_thumbs)
|
||||
|
||||
if len(embeddings) != len(valid_ids):
|
||||
logger.warning(
|
||||
"Batch embed returned %d embeddings for %d thumbnails; skipping batch",
|
||||
len(embeddings),
|
||||
len(valid_ids),
|
||||
)
|
||||
return []
|
||||
|
||||
if upsert:
|
||||
items = []
|
||||
for i in range(len(valid_ids)):
|
||||
@@ -281,15 +246,9 @@ class Embeddings:
|
||||
|
||||
def embed_description(
|
||||
self, event_id: str, description: str, upsert: bool = True
|
||||
) -> np.ndarray | None:
|
||||
) -> np.ndarray:
|
||||
start = datetime.datetime.now().timestamp()
|
||||
embeddings = self.text_embedding([description])
|
||||
if not embeddings:
|
||||
logger.warning(
|
||||
"Failed to generate description embedding for event %s", event_id
|
||||
)
|
||||
return None
|
||||
embedding = embeddings[0]
|
||||
embedding = self.text_embedding([description])[0]
|
||||
|
||||
if upsert:
|
||||
self.db.execute_sql(
|
||||
@@ -312,32 +271,8 @@ class Embeddings:
|
||||
# upsert embeddings one by one to avoid token limit
|
||||
embeddings = []
|
||||
|
||||
for eid, desc in event_descriptions.items():
|
||||
result = self.text_embedding([desc])
|
||||
if not result:
|
||||
logger.warning(
|
||||
"Failed to generate description embedding for event %s", eid
|
||||
)
|
||||
continue
|
||||
embeddings.append(result[0])
|
||||
|
||||
if not embeddings:
|
||||
logger.warning("No description embeddings generated in batch")
|
||||
return np.array([])
|
||||
|
||||
# Build ids list for only successful embeddings - we need to track which succeeded
|
||||
ids = list(event_descriptions.keys())
|
||||
if len(embeddings) != len(ids):
|
||||
# Rebuild ids/embeddings for only successful ones (match by order)
|
||||
ids = []
|
||||
embeddings_filtered = []
|
||||
for eid, desc in event_descriptions.items():
|
||||
result = self.text_embedding([desc])
|
||||
if result:
|
||||
ids.append(eid)
|
||||
embeddings_filtered.append(result[0])
|
||||
ids = ids
|
||||
embeddings = embeddings_filtered
|
||||
for desc in event_descriptions.values():
|
||||
embeddings.append(self.text_embedding([desc])[0])
|
||||
|
||||
if upsert:
|
||||
ids = list(event_descriptions.keys())
|
||||
@@ -379,10 +314,7 @@ class Embeddings:
|
||||
|
||||
batch_size = (
|
||||
4
|
||||
if (
|
||||
isinstance(self.config.semantic_search.model, str)
|
||||
or self.config.semantic_search.model == SemanticSearchModelEnum.jinav2
|
||||
)
|
||||
if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2
|
||||
else 32
|
||||
)
|
||||
current_page = 1
|
||||
@@ -669,8 +601,6 @@ class Embeddings:
|
||||
if trigger.type == "description":
|
||||
logger.debug(f"Generating embedding for trigger description {trigger_name}")
|
||||
embedding = self.embed_description(None, trigger.data, upsert=False)
|
||||
if embedding is None:
|
||||
return b""
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
elif trigger.type == "thumbnail":
|
||||
@@ -706,8 +636,6 @@ class Embeddings:
|
||||
embedding = self.embed_thumbnail(
|
||||
str(trigger.data), thumbnail, upsert=False
|
||||
)
|
||||
if embedding is None:
|
||||
return b""
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
else:
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
"""GenAI-backed embeddings for semantic search."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from frigate.genai import GenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_DIM = 768
|
||||
|
||||
|
||||
class GenAIEmbedding:
|
||||
"""Embedding adapter that delegates to a GenAI provider's embed API.
|
||||
|
||||
Provides the same interface as JinaV2Embedding for semantic search:
|
||||
__call__(inputs, embedding_type) -> list[np.ndarray]. Output embeddings are
|
||||
normalized to 768 dimensions for Frigate's sqlite-vec schema.
|
||||
"""
|
||||
|
||||
def __init__(self, client: "GenAIClient") -> None:
|
||||
self.client = client
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: list[str] | list[bytes] | list[Image.Image],
|
||||
embedding_type: str = "text",
|
||||
) -> list[np.ndarray]:
|
||||
"""Generate embeddings for text or images.
|
||||
|
||||
Args:
|
||||
inputs: List of strings (text) or bytes/PIL images (vision).
|
||||
embedding_type: "text" or "vision".
|
||||
|
||||
Returns:
|
||||
List of 768-dim numpy float32 arrays.
|
||||
"""
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
if embedding_type == "text":
|
||||
texts = [str(x) for x in inputs]
|
||||
embeddings = self.client.embed(texts=texts)
|
||||
elif embedding_type == "vision":
|
||||
images: list[bytes] = []
|
||||
for inp in inputs:
|
||||
if isinstance(inp, bytes):
|
||||
images.append(inp)
|
||||
elif isinstance(inp, Image.Image):
|
||||
buf = io.BytesIO()
|
||||
inp.convert("RGB").save(buf, format="JPEG")
|
||||
images.append(buf.getvalue())
|
||||
else:
|
||||
logger.warning(
|
||||
"GenAIEmbedding: skipping unsupported vision input type %s",
|
||||
type(inp).__name__,
|
||||
)
|
||||
if not images:
|
||||
return []
|
||||
embeddings = self.client.embed(images=images)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid embedding_type '{embedding_type}'. Must be 'text' or 'vision'."
|
||||
)
|
||||
|
||||
result = []
|
||||
for emb in embeddings:
|
||||
arr = np.asarray(emb, dtype=np.float32).flatten()
|
||||
if arr.size != EMBEDDING_DIM:
|
||||
if arr.size > EMBEDDING_DIM:
|
||||
arr = arr[:EMBEDDING_DIM]
|
||||
else:
|
||||
arr = np.pad(
|
||||
arr,
|
||||
(0, EMBEDDING_DIM - arr.size),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
result.append(arr)
|
||||
return result
|
||||
@@ -116,10 +116,8 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
models = [Event, Recordings, ReviewSegment, Trigger]
|
||||
db.bind(models)
|
||||
|
||||
self.genai_manager = GenAIClientManager(config)
|
||||
|
||||
if config.semantic_search.enabled:
|
||||
self.embeddings = Embeddings(config, db, metrics, self.genai_manager)
|
||||
self.embeddings = Embeddings(config, db, metrics)
|
||||
|
||||
# Check if we need to re-index events
|
||||
if config.semantic_search.reindex:
|
||||
@@ -146,6 +144,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.frame_manager = SharedMemoryFrameManager()
|
||||
|
||||
self.detected_license_plates: dict[str, dict[str, Any]] = {}
|
||||
self.genai_manager = GenAIClientManager(config)
|
||||
|
||||
# model runners to share between realtime and post processors
|
||||
if self.config.lpr.enabled:
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
from frigate.config import CameraConfig, FrigateConfig, GenAIConfig, GenAIProviderEnum
|
||||
@@ -305,25 +304,6 @@ Guidelines:
|
||||
"""Get the context window size for this provider in tokens."""
|
||||
return 4096
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: list[str] | None = None,
|
||||
images: list[bytes] | None = None,
|
||||
) -> list[np.ndarray]:
|
||||
"""Generate embeddings for text and/or images.
|
||||
|
||||
Returns list of numpy arrays (one per input). Expected dimension is 768
|
||||
for Frigate semantic search compatibility.
|
||||
|
||||
Providers that support embeddings should override this method.
|
||||
"""
|
||||
logger.warning(
|
||||
"%s does not support embeddings. "
|
||||
"This method should be overridden by the provider implementation.",
|
||||
self.__class__.__name__,
|
||||
)
|
||||
return []
|
||||
|
||||
def chat_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -1,35 +1,20 @@
|
||||
"""llama.cpp Provider for Frigate AI."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import httpx
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
from frigate.genai.utils import parse_tool_calls_from_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_jpeg(img_bytes: bytes) -> bytes | None:
|
||||
"""Convert image bytes to JPEG. llama.cpp/STB does not support WebP."""
|
||||
try:
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return buf.getvalue()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert image to JPEG: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.llamacpp)
|
||||
class LlamaCppClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using llama.cpp server."""
|
||||
@@ -116,105 +101,76 @@ class LlamaCppClient(GenAIClient):
|
||||
|
||||
def get_context_size(self) -> int:
|
||||
"""Get the context window size for llama.cpp."""
|
||||
return self.genai_config.provider_options.get("context_size", 4096)
|
||||
return self.provider_options.get("context_size", 4096)
|
||||
|
||||
def embed(
|
||||
def _build_payload(
|
||||
self,
|
||||
texts: list[str] | None = None,
|
||||
images: list[bytes] | None = None,
|
||||
) -> list[np.ndarray]:
|
||||
"""Generate embeddings via llama.cpp /embeddings endpoint.
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
tool_choice: Optional[str],
|
||||
stream: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request payload for chat completions (sync or stream)."""
|
||||
openai_tool_choice = None
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
openai_tool_choice = "none"
|
||||
elif tool_choice == "auto":
|
||||
openai_tool_choice = "auto"
|
||||
elif tool_choice == "required":
|
||||
openai_tool_choice = "required"
|
||||
|
||||
Supports batch requests. Uses content format with prompt_string and
|
||||
multimodal_data for images (PR #15108). Server must be started with
|
||||
--embeddings and --mmproj for multimodal support.
|
||||
"""
|
||||
if self.provider is None:
|
||||
logger.warning(
|
||||
"llama.cpp provider has not been initialized. Check your llama.cpp configuration."
|
||||
payload: dict[str, Any] = {"messages": messages}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
if openai_tool_choice is not None:
|
||||
payload["tool_choice"] = openai_tool_choice
|
||||
provider_opts = {
|
||||
k: v for k, v in self.provider_options.items() if k != "context_size"
|
||||
}
|
||||
payload.update(provider_opts)
|
||||
return payload
|
||||
|
||||
def _message_from_choice(self, choice: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parse OpenAI-style choice into {content, tool_calls, finish_reason}."""
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
content = content.strip() if content else None
|
||||
tool_calls = parse_tool_calls_from_message(message)
|
||||
finish_reason = choice.get("finish_reason") or (
|
||||
"tool_calls" if tool_calls else "stop" if content else "error"
|
||||
)
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _streamed_tool_calls_to_list(
|
||||
tool_calls_by_index: dict[int, dict[str, Any]],
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""Convert streamed tool_calls index map to list of {id, name, arguments}."""
|
||||
if not tool_calls_by_index:
|
||||
return None
|
||||
result = []
|
||||
for idx in sorted(tool_calls_by_index.keys()):
|
||||
t = tool_calls_by_index[idx]
|
||||
args_str = t.get("arguments") or "{}"
|
||||
try:
|
||||
arguments = json.loads(args_str)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
result.append(
|
||||
{
|
||||
"id": t.get("id", ""),
|
||||
"name": t.get("name", ""),
|
||||
"arguments": arguments,
|
||||
}
|
||||
)
|
||||
return []
|
||||
|
||||
texts = texts or []
|
||||
images = images or []
|
||||
if not texts and not images:
|
||||
return []
|
||||
|
||||
EMBEDDING_DIM = 768
|
||||
|
||||
content = []
|
||||
for text in texts:
|
||||
content.append({"prompt_string": text})
|
||||
for img in images:
|
||||
# llama.cpp uses STB which does not support WebP; convert to JPEG
|
||||
jpeg_bytes = _to_jpeg(img)
|
||||
to_encode = jpeg_bytes if jpeg_bytes is not None else img
|
||||
encoded = base64.b64encode(to_encode).decode("utf-8")
|
||||
# prompt_string must contain <__media__> placeholder for image tokenization
|
||||
content.append({
|
||||
"prompt_string": "<__media__>\n",
|
||||
"multimodal_data": [encoded],
|
||||
})
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.provider}/embeddings",
|
||||
json={"content": content},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
items = result.get("data", result) if isinstance(result, dict) else result
|
||||
if not isinstance(items, list):
|
||||
logger.warning("llama.cpp embeddings returned unexpected format")
|
||||
return []
|
||||
|
||||
embeddings = []
|
||||
for item in items:
|
||||
emb = item.get("embedding") if isinstance(item, dict) else None
|
||||
if emb is None:
|
||||
logger.warning("llama.cpp embeddings item missing embedding field")
|
||||
continue
|
||||
arr = np.array(emb, dtype=np.float32)
|
||||
orig_dim = arr.size
|
||||
if orig_dim != EMBEDDING_DIM:
|
||||
if orig_dim > EMBEDDING_DIM:
|
||||
arr = arr[:EMBEDDING_DIM]
|
||||
logger.debug(
|
||||
"Truncated llama.cpp embedding from %d to %d dimensions",
|
||||
orig_dim,
|
||||
EMBEDDING_DIM,
|
||||
)
|
||||
else:
|
||||
arr = np.pad(
|
||||
arr,
|
||||
(0, EMBEDDING_DIM - orig_dim),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
logger.debug(
|
||||
"Padded llama.cpp embedding from %d to %d dimensions",
|
||||
orig_dim,
|
||||
EMBEDDING_DIM,
|
||||
)
|
||||
embeddings.append(arr)
|
||||
return embeddings
|
||||
except requests.exceptions.Timeout:
|
||||
logger.warning("llama.cpp embeddings request timed out")
|
||||
return []
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_detail = str(e)
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
try:
|
||||
error_detail = f"{str(e)} - Response: {e.response.text[:500]}"
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("llama.cpp embeddings error: %s", error_detail)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning("Unexpected error in llama.cpp embeddings: %s", str(e))
|
||||
return []
|
||||
return result if result else None
|
||||
|
||||
def chat_with_tools(
|
||||
self,
|
||||
@@ -237,31 +193,8 @@ class LlamaCppClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
try:
|
||||
openai_tool_choice = None
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
openai_tool_choice = "none"
|
||||
elif tool_choice == "auto":
|
||||
openai_tool_choice = "auto"
|
||||
elif tool_choice == "required":
|
||||
openai_tool_choice = "required"
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
if openai_tool_choice is not None:
|
||||
payload["tool_choice"] = openai_tool_choice
|
||||
|
||||
provider_opts = {
|
||||
k: v for k, v in self.provider_options.items() if k != "context_size"
|
||||
}
|
||||
payload.update(provider_opts)
|
||||
|
||||
payload = self._build_payload(messages, tools, tool_choice, stream=False)
|
||||
response = requests.post(
|
||||
f"{self.provider}/v1/chat/completions",
|
||||
json=payload,
|
||||
@@ -269,60 +202,13 @@ class LlamaCppClient(GenAIClient):
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result is None or "choices" not in result or len(result["choices"]) == 0:
|
||||
return {
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
choice = result["choices"][0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
content = message.get("content")
|
||||
if content:
|
||||
content = content.strip()
|
||||
else:
|
||||
content = None
|
||||
|
||||
tool_calls = None
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
tool_calls = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
try:
|
||||
function_data = tool_call.get("function", {})
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
arguments = json.loads(arguments_str)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool call arguments: {e}, "
|
||||
f"tool: {function_data.get('name', 'unknown')}"
|
||||
)
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function_data.get("name", ""),
|
||||
"arguments": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
finish_reason = "error"
|
||||
if "finish_reason" in choice and choice["finish_reason"]:
|
||||
finish_reason = choice["finish_reason"]
|
||||
elif tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
elif content:
|
||||
finish_reason = "stop"
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return self._message_from_choice(result["choices"][0])
|
||||
except requests.exceptions.Timeout as e:
|
||||
logger.warning("llama.cpp request timed out: %s", str(e))
|
||||
return {
|
||||
@@ -334,8 +220,7 @@ class LlamaCppClient(GenAIClient):
|
||||
error_detail = str(e)
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
try:
|
||||
error_body = e.response.text
|
||||
error_detail = f"{str(e)} - Response: {error_body[:500]}"
|
||||
error_detail = f"{str(e)} - Response: {e.response.text[:500]}"
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("llama.cpp returned an error: %s", error_detail)
|
||||
@@ -351,3 +236,111 @@ class LlamaCppClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
async def chat_with_tools_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
"""Stream chat with tools via OpenAI-compatible streaming API."""
|
||||
if self.provider is None:
|
||||
logger.warning(
|
||||
"llama.cpp provider has not been initialized. Check your llama.cpp configuration."
|
||||
)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
return
|
||||
try:
|
||||
payload = self._build_payload(messages, tools, tool_choice, stream=True)
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async with httpx.AsyncClient(timeout=float(self.timeout)) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.provider}/v1/chat/completions",
|
||||
json=payload,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
choices = data.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
if choices[0].get("finish_reason"):
|
||||
finish_reason = choices[0]["finish_reason"]
|
||||
if delta.get("content"):
|
||||
content_parts.append(delta["content"])
|
||||
yield ("content_delta", delta["content"])
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
fn = tc.get("function") or {}
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": tc.get("id", ""),
|
||||
"name": tc.get("name") or fn.get("name", ""),
|
||||
"arguments": "",
|
||||
}
|
||||
t = tool_calls_by_index[idx]
|
||||
if tc.get("id"):
|
||||
t["id"] = tc["id"]
|
||||
name = tc.get("name") or fn.get("name")
|
||||
if name:
|
||||
t["name"] = name
|
||||
arg = tc.get("arguments") or fn.get("arguments")
|
||||
if arg is not None:
|
||||
t["arguments"] += (
|
||||
arg if isinstance(arg, str) else json.dumps(arg)
|
||||
)
|
||||
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
tool_calls_list = self._streamed_tool_calls_to_list(tool_calls_by_index)
|
||||
if tool_calls_list:
|
||||
finish_reason = "tool_calls"
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls_list,
|
||||
"finish_reason": finish_reason,
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning("llama.cpp streaming HTTP error: %s", e)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unexpected error in llama.cpp chat_with_tools_stream: %s", str(e)
|
||||
)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
"""Ollama Provider for Frigate AI."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from httpx import RemoteProtocolError, TimeoutException
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from ollama import Client as ApiClient
|
||||
from ollama import ResponseError
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
from frigate.genai.utils import parse_tool_calls_from_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -88,6 +89,73 @@ class OllamaClient(GenAIClient):
|
||||
"num_ctx", 4096
|
||||
)
|
||||
|
||||
def _build_request_params(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
tool_choice: Optional[str],
|
||||
stream: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request_messages and params for chat (sync or stream)."""
|
||||
request_messages = []
|
||||
for msg in messages:
|
||||
msg_dict = {
|
||||
"role": msg.get("role"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
if msg.get("tool_call_id"):
|
||||
msg_dict["tool_call_id"] = msg["tool_call_id"]
|
||||
if msg.get("name"):
|
||||
msg_dict["name"] = msg["name"]
|
||||
if msg.get("tool_calls"):
|
||||
msg_dict["tool_calls"] = msg["tool_calls"]
|
||||
request_messages.append(msg_dict)
|
||||
|
||||
request_params: dict[str, Any] = {
|
||||
"model": self.genai_config.model,
|
||||
"messages": request_messages,
|
||||
**self.provider_options,
|
||||
}
|
||||
if stream:
|
||||
request_params["stream"] = True
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if tool_choice:
|
||||
request_params["tool_choice"] = (
|
||||
"none"
|
||||
if tool_choice == "none"
|
||||
else "required"
|
||||
if tool_choice == "required"
|
||||
else "auto"
|
||||
)
|
||||
return request_params
|
||||
|
||||
def _message_from_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parse Ollama chat response into {content, tool_calls, finish_reason}."""
|
||||
if not response or "message" not in response:
|
||||
return {
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
message = response["message"]
|
||||
content = message.get("content", "").strip() if message.get("content") else None
|
||||
tool_calls = parse_tool_calls_from_message(message)
|
||||
finish_reason = "error"
|
||||
if response.get("done"):
|
||||
finish_reason = (
|
||||
"tool_calls" if tool_calls else "stop" if content else "error"
|
||||
)
|
||||
elif tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
elif content:
|
||||
finish_reason = "stop"
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def chat_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -103,93 +171,12 @@ class OllamaClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
try:
|
||||
request_messages = []
|
||||
for msg in messages:
|
||||
msg_dict = {
|
||||
"role": msg.get("role"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
if msg.get("tool_call_id"):
|
||||
msg_dict["tool_call_id"] = msg["tool_call_id"]
|
||||
if msg.get("name"):
|
||||
msg_dict["name"] = msg["name"]
|
||||
if msg.get("tool_calls"):
|
||||
msg_dict["tool_calls"] = msg["tool_calls"]
|
||||
request_messages.append(msg_dict)
|
||||
|
||||
request_params = {
|
||||
"model": self.genai_config.model,
|
||||
"messages": request_messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
request_params["tool_choice"] = "none"
|
||||
elif tool_choice == "required":
|
||||
request_params["tool_choice"] = "required"
|
||||
elif tool_choice == "auto":
|
||||
request_params["tool_choice"] = "auto"
|
||||
|
||||
request_params.update(self.provider_options)
|
||||
|
||||
response = self.provider.chat(**request_params)
|
||||
|
||||
if not response or "message" not in response:
|
||||
return {
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
message = response["message"]
|
||||
content = (
|
||||
message.get("content", "").strip() if message.get("content") else None
|
||||
request_params = self._build_request_params(
|
||||
messages, tools, tool_choice, stream=False
|
||||
)
|
||||
|
||||
tool_calls = None
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
tool_calls = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
try:
|
||||
function_data = tool_call.get("function", {})
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
arguments = json.loads(arguments_str)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool call arguments: {e}, "
|
||||
f"tool: {function_data.get('name', 'unknown')}"
|
||||
)
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function_data.get("name", ""),
|
||||
"arguments": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
finish_reason = "error"
|
||||
if "done" in response and response["done"]:
|
||||
if tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
elif content:
|
||||
finish_reason = "stop"
|
||||
elif tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
elif content:
|
||||
finish_reason = "stop"
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
response = self.provider.chat(**request_params)
|
||||
return self._message_from_response(response)
|
||||
except (TimeoutException, ResponseError, ConnectionError) as e:
|
||||
logger.warning("Ollama returned an error: %s", str(e))
|
||||
return {
|
||||
@@ -204,3 +191,89 @@ class OllamaClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
async def chat_with_tools_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
"""Stream chat with tools; yields content deltas then final message."""
|
||||
if self.provider is None:
|
||||
logger.warning(
|
||||
"Ollama provider has not been initialized. Check your Ollama configuration."
|
||||
)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
return
|
||||
try:
|
||||
request_params = self._build_request_params(
|
||||
messages, tools, tool_choice, stream=True
|
||||
)
|
||||
async_client = OllamaAsyncClient(
|
||||
host=self.genai_config.base_url,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
content_parts: list[str] = []
|
||||
final_message: dict[str, Any] | None = None
|
||||
try:
|
||||
stream = await async_client.chat(**request_params)
|
||||
async for chunk in stream:
|
||||
if not chunk or "message" not in chunk:
|
||||
continue
|
||||
msg = chunk.get("message", {})
|
||||
delta = msg.get("content") or ""
|
||||
if delta:
|
||||
content_parts.append(delta)
|
||||
yield ("content_delta", delta)
|
||||
if chunk.get("done"):
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
tool_calls = parse_tool_calls_from_message(msg)
|
||||
final_message = {
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop",
|
||||
}
|
||||
break
|
||||
finally:
|
||||
await async_client.close()
|
||||
|
||||
if final_message is not None:
|
||||
yield ("message", final_message)
|
||||
else:
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": "".join(content_parts).strip() or None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
)
|
||||
except (TimeoutException, ResponseError, ConnectionError) as e:
|
||||
logger.warning("Ollama streaming error: %s", str(e))
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unexpected error in Ollama chat_with_tools_stream: %s", str(e)
|
||||
)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
|
||||
70
frigate/genai/utils.py
Normal file
70
frigate/genai/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Shared helpers for GenAI providers and chat (OpenAI-style messages, tool call parsing)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_tool_calls_from_message(
|
||||
message: dict[str, Any],
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
"""
|
||||
Parse tool_calls from an OpenAI-style message dict.
|
||||
|
||||
Message may have "tool_calls" as a list of:
|
||||
{"id": str, "function": {"name": str, "arguments": str}, ...}
|
||||
|
||||
Returns a list of {"id", "name", "arguments"} with arguments parsed as dict,
|
||||
or None if no tool_calls. Used by Ollama and LlamaCpp (non-stream) responses.
|
||||
"""
|
||||
raw = message.get("tool_calls")
|
||||
if not raw or not isinstance(raw, list):
|
||||
return None
|
||||
result = []
|
||||
for tool_call in raw:
|
||||
function_data = tool_call.get("function") or {}
|
||||
try:
|
||||
arguments_str = function_data.get("arguments") or "{}"
|
||||
arguments = json.loads(arguments_str)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments: %s, tool: %s",
|
||||
e,
|
||||
function_data.get("name", "unknown"),
|
||||
)
|
||||
arguments = {}
|
||||
result.append(
|
||||
{
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function_data.get("name", ""),
|
||||
"arguments": arguments,
|
||||
}
|
||||
)
|
||||
return result if result else None
|
||||
|
||||
|
||||
def build_assistant_message_for_conversation(
|
||||
content: Any,
|
||||
tool_calls_raw: Optional[List[dict[str, Any]]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the assistant message dict in OpenAI format for appending to a conversation.
|
||||
|
||||
tool_calls_raw: list of {"id", "name", "arguments"} (arguments as dict), or None.
|
||||
"""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls_raw:
|
||||
msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc.get("arguments") or {}),
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_raw
|
||||
]
|
||||
return msg
|
||||
1458
web/package-lock.json
generated
1458
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -71,6 +71,8 @@
|
||||
"react-icons": "^5.5.0",
|
||||
"react-konva": "^18.2.10",
|
||||
"react-router-dom": "^6.30.3",
|
||||
"react-markdown": "^9.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"react-swipeable": "^7.0.2",
|
||||
"react-tracked": "^2.0.1",
|
||||
"react-transition-group": "^4.4.5",
|
||||
|
||||
@@ -127,6 +127,7 @@
|
||||
"cancel": "Cancel",
|
||||
"close": "Close",
|
||||
"copy": "Copy",
|
||||
"copiedToClipboard": "Copied to clipboard",
|
||||
"back": "Back",
|
||||
"history": "History",
|
||||
"fullscreen": "Fullscreen",
|
||||
@@ -245,6 +246,7 @@
|
||||
"uiPlayground": "UI Playground",
|
||||
"faceLibrary": "Face Library",
|
||||
"classification": "Classification",
|
||||
"chat": "Chat",
|
||||
"user": {
|
||||
"title": "User",
|
||||
"account": "Account",
|
||||
|
||||
13
web/public/locales/en/views/chat.json
Normal file
13
web/public/locales/en/views/chat.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"placeholder": "Ask anything...",
|
||||
"error": "Something went wrong. Please try again.",
|
||||
"processing": "Processing...",
|
||||
"toolsUsed": "Used: {{tools}}",
|
||||
"showTools": "Show tools ({{count}})",
|
||||
"hideTools": "Hide tools",
|
||||
"call": "Call",
|
||||
"result": "Result",
|
||||
"arguments": "Arguments:",
|
||||
"response": "Response:",
|
||||
"send": "Send"
|
||||
}
|
||||
@@ -27,6 +27,7 @@ const Settings = lazy(() => import("@/pages/Settings"));
|
||||
const UIPlayground = lazy(() => import("@/pages/UIPlayground"));
|
||||
const FaceLibrary = lazy(() => import("@/pages/FaceLibrary"));
|
||||
const Classification = lazy(() => import("@/pages/ClassificationModel"));
|
||||
const Chat = lazy(() => import("@/pages/Chat"));
|
||||
const Logs = lazy(() => import("@/pages/Logs"));
|
||||
const AccessDenied = lazy(() => import("@/pages/AccessDenied"));
|
||||
|
||||
@@ -106,6 +107,7 @@ function DefaultAppView() {
|
||||
<Route path="/logs" element={<Logs />} />
|
||||
<Route path="/faces" element={<FaceLibrary />} />
|
||||
<Route path="/classification" element={<Classification />} />
|
||||
<Route path="/chat" element={<Chat />} />
|
||||
<Route path="/playground" element={<UIPlayground />} />
|
||||
</Route>
|
||||
<Route path="/unauthorized" element={<AccessDenied />} />
|
||||
|
||||
208
web/src/components/chat/ChatMessage.tsx
Normal file
208
web/src/components/chat/ChatMessage.tsx
Normal file
@@ -0,0 +1,208 @@
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import copy from "copy-to-clipboard";
|
||||
import { toast } from "sonner";
|
||||
import { FaCopy, FaPencilAlt } from "react-icons/fa";
|
||||
import { FaArrowUpLong } from "react-icons/fa6";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
type MessageBubbleProps = {
|
||||
role: "user" | "assistant";
|
||||
content: string;
|
||||
messageIndex?: number;
|
||||
onEditSubmit?: (messageIndex: number, newContent: string) => void;
|
||||
isComplete?: boolean;
|
||||
};
|
||||
|
||||
export function MessageBubble({
|
||||
role,
|
||||
content,
|
||||
messageIndex = 0,
|
||||
onEditSubmit,
|
||||
isComplete = true,
|
||||
}: MessageBubbleProps) {
|
||||
const { t } = useTranslation(["views/chat", "common"]);
|
||||
const isUser = role === "user";
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [draftContent, setDraftContent] = useState(content);
|
||||
const editInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
setDraftContent(content);
|
||||
}, [content]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isEditing) {
|
||||
editInputRef.current?.focus();
|
||||
editInputRef.current?.setSelectionRange(
|
||||
editInputRef.current.value.length,
|
||||
editInputRef.current.value.length,
|
||||
);
|
||||
}
|
||||
}, [isEditing]);
|
||||
|
||||
const handleCopy = () => {
|
||||
const text = content?.trim() || "";
|
||||
if (!text) return;
|
||||
if (copy(text)) {
|
||||
toast.success(t("button.copiedToClipboard", { ns: "common" }));
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditClick = () => {
|
||||
setDraftContent(content);
|
||||
setIsEditing(true);
|
||||
};
|
||||
|
||||
const handleEditSubmit = () => {
|
||||
const trimmed = draftContent.trim();
|
||||
if (!trimmed || onEditSubmit == null) return;
|
||||
onEditSubmit(messageIndex, trimmed);
|
||||
setIsEditing(false);
|
||||
};
|
||||
|
||||
const handleEditCancel = () => {
|
||||
setDraftContent(content);
|
||||
setIsEditing(false);
|
||||
};
|
||||
|
||||
const handleEditKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleEditSubmit();
|
||||
}
|
||||
if (e.key === "Escape") {
|
||||
handleEditCancel();
|
||||
}
|
||||
};
|
||||
|
||||
if (isUser && isEditing) {
|
||||
return (
|
||||
<div className="flex w-full max-w-full flex-col gap-2 self-end">
|
||||
<Textarea
|
||||
ref={editInputRef}
|
||||
value={draftContent}
|
||||
onChange={(e) => setDraftContent(e.target.value)}
|
||||
onKeyDown={handleEditKeyDown}
|
||||
className="min-h-[80px] w-full resize-y rounded-lg bg-primary px-3 py-2 text-primary-foreground placeholder:text-primary-foreground/60"
|
||||
placeholder={t("placeholder")}
|
||||
rows={3}
|
||||
/>
|
||||
<div className="flex items-center gap-2 self-end">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="text-muted-foreground hover:text-foreground"
|
||||
onClick={handleEditCancel}
|
||||
>
|
||||
{t("button.cancel", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="select"
|
||||
size="icon"
|
||||
className="size-9 rounded-full"
|
||||
disabled={!draftContent.trim()}
|
||||
onClick={handleEditSubmit}
|
||||
aria-label={t("send")}
|
||||
>
|
||||
<FaArrowUpLong size="16" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-1",
|
||||
isUser ? "items-end self-end" : "items-start self-start",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg px-3 py-2",
|
||||
isUser ? "bg-primary text-primary-foreground" : "bg-muted",
|
||||
)}
|
||||
>
|
||||
{isUser ? (
|
||||
content
|
||||
) : (
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
components={{
|
||||
table: ({ node: _n, ...props }) => (
|
||||
<table
|
||||
className="my-2 w-full border-collapse border border-border"
|
||||
{...props}
|
||||
/>
|
||||
),
|
||||
th: ({ node: _n, ...props }) => (
|
||||
<th
|
||||
className="border border-border bg-muted/50 px-2 py-1 text-left text-sm font-medium"
|
||||
{...props}
|
||||
/>
|
||||
),
|
||||
td: ({ node: _n, ...props }) => (
|
||||
<td
|
||||
className="border border-border px-2 py-1 text-sm"
|
||||
{...props}
|
||||
/>
|
||||
),
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-0.5">
|
||||
{isUser && onEditSubmit != null && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 text-muted-foreground hover:text-foreground"
|
||||
onClick={handleEditClick}
|
||||
aria-label={t("button.edit", { ns: "common" })}
|
||||
>
|
||||
<FaPencilAlt className="size-3" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{t("button.edit", { ns: "common" })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{isComplete && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 text-muted-foreground hover:text-foreground"
|
||||
onClick={handleCopy}
|
||||
disabled={!content?.trim()}
|
||||
aria-label={t("button.copy", { ns: "common" })}
|
||||
>
|
||||
<FaCopy className="size-3" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{t("button.copy", { ns: "common" })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
88
web/src/components/chat/ToolCallBubble.tsx
Normal file
88
web/src/components/chat/ToolCallBubble.tsx
Normal file
@@ -0,0 +1,88 @@
|
||||
import { useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "@/components/ui/collapsible";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
|
||||
type ToolCallBubbleProps = {
|
||||
name: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
response?: string;
|
||||
side: "left" | "right";
|
||||
};
|
||||
|
||||
export function ToolCallBubble({
|
||||
name,
|
||||
arguments: args,
|
||||
response,
|
||||
side,
|
||||
}: ToolCallBubbleProps) {
|
||||
const { t } = useTranslation(["views/chat"]);
|
||||
const [open, setOpen] = useState(false);
|
||||
const isLeft = side === "left";
|
||||
const normalizedName = name
|
||||
.replace(/_/g, " ")
|
||||
.split(" ")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
|
||||
.join(" ");
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg px-3 py-2",
|
||||
isLeft
|
||||
? "self-start bg-muted"
|
||||
: "self-end bg-primary text-primary-foreground",
|
||||
)}
|
||||
>
|
||||
<Collapsible open={open} onOpenChange={setOpen}>
|
||||
<CollapsibleTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className={cn(
|
||||
"h-auto w-full min-w-0 justify-start gap-2 whitespace-normal p-0 text-left text-xs hover:bg-transparent",
|
||||
!isLeft && "hover:text-primary-foreground",
|
||||
)}
|
||||
>
|
||||
{open ? (
|
||||
<ChevronDown size={12} className="shrink-0" />
|
||||
) : (
|
||||
<ChevronRight size={12} className="shrink-0" />
|
||||
)}
|
||||
<span className="break-words font-medium">
|
||||
{isLeft ? t("call") : t("result")} {normalizedName}
|
||||
</span>
|
||||
</Button>
|
||||
</CollapsibleTrigger>
|
||||
<CollapsibleContent>
|
||||
<div className="mt-2 space-y-2">
|
||||
{isLeft && args && Object.keys(args).length > 0 && (
|
||||
<div className="text-xs">
|
||||
<div className="font-medium text-muted-foreground">
|
||||
{t("arguments")}
|
||||
</div>
|
||||
<pre className="scrollbar-container mt-1 max-h-32 overflow-auto whitespace-pre-wrap break-words rounded bg-muted/50 p-2 text-[10px]">
|
||||
{JSON.stringify(args, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
{!isLeft && response && response !== "" && (
|
||||
<div className="text-xs">
|
||||
<div className="font-medium opacity-80">{t("response")}</div>
|
||||
<pre className="scrollbar-container mt-1 max-h-32 overflow-auto whitespace-pre-wrap break-words rounded bg-primary/20 p-2 text-[10px]">
|
||||
{response}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import { isDesktop } from "react-device-detect";
|
||||
import { FaCompactDisc, FaVideo } from "react-icons/fa";
|
||||
import { IoSearch } from "react-icons/io5";
|
||||
import { LuConstruction } from "react-icons/lu";
|
||||
import { MdCategory, MdVideoLibrary } from "react-icons/md";
|
||||
import { MdCategory, MdChat, MdVideoLibrary } from "react-icons/md";
|
||||
import { TbFaceId } from "react-icons/tb";
|
||||
import useSWR from "swr";
|
||||
import { useIsAdmin } from "./use-is-admin";
|
||||
@@ -18,6 +18,7 @@ export const ID_EXPORT = 4;
|
||||
export const ID_PLAYGROUND = 5;
|
||||
export const ID_FACE_LIBRARY = 6;
|
||||
export const ID_CLASSIFICATION = 7;
|
||||
export const ID_CHAT = 8;
|
||||
|
||||
export default function useNavigation(
|
||||
variant: "primary" | "secondary" = "primary",
|
||||
@@ -82,7 +83,15 @@ export default function useNavigation(
|
||||
url: "/classification",
|
||||
enabled: isDesktop && isAdmin,
|
||||
},
|
||||
{
|
||||
id: ID_CHAT,
|
||||
variant,
|
||||
icon: MdChat,
|
||||
title: "menu.chat",
|
||||
url: "/chat",
|
||||
enabled: isDesktop && isAdmin && config?.genai?.model !== "none",
|
||||
},
|
||||
] as NavData[],
|
||||
[config?.face_recognition?.enabled, variant, isAdmin],
|
||||
[config?.face_recognition?.enabled, config?.genai?.model, variant, isAdmin],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
/** ONNX embedding models that require local model downloads. GenAI providers are not in this list. */
|
||||
export const JINA_EMBEDDING_MODELS = ["jinav1", "jinav2"] as const;
|
||||
|
||||
export const supportedLanguageKeys = [
|
||||
"en",
|
||||
"es",
|
||||
|
||||
199
web/src/pages/Chat.tsx
Normal file
199
web/src/pages/Chat.tsx
Normal file
@@ -0,0 +1,199 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { FaArrowUpLong } from "react-icons/fa6";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useCallback } from "react";
|
||||
import axios from "axios";
|
||||
import { MessageBubble } from "@/components/chat/ChatMessage";
|
||||
import { ToolCallBubble } from "@/components/chat/ToolCallBubble";
|
||||
import type { ChatMessage } from "@/types/chat";
|
||||
import { streamChatCompletion } from "@/utils/chatUtil";
|
||||
|
||||
export default function ChatPage() {
|
||||
const { t } = useTranslation(["views/chat"]);
|
||||
const [input, setInput] = useState("");
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const submitConversation = useCallback(
|
||||
async (messagesToSend: ChatMessage[]) => {
|
||||
if (isLoading) return;
|
||||
const last = messagesToSend[messagesToSend.length - 1];
|
||||
if (!last || last.role !== "user" || !last.content.trim()) return;
|
||||
|
||||
setError(null);
|
||||
const assistantPlaceholder: ChatMessage = {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
toolCalls: undefined,
|
||||
};
|
||||
setMessages([...messagesToSend, assistantPlaceholder]);
|
||||
setIsLoading(true);
|
||||
|
||||
const apiMessages = messagesToSend.map((m) => ({
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
}));
|
||||
|
||||
const baseURL = axios.defaults.baseURL ?? "";
|
||||
const url = `${baseURL}chat/completion`;
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
...(axios.defaults.headers.common as Record<string, string>),
|
||||
};
|
||||
|
||||
await streamChatCompletion(url, headers, apiMessages, {
|
||||
updateMessages: (updater) => setMessages(updater),
|
||||
onError: (message) => setError(message),
|
||||
onDone: () => setIsLoading(false),
|
||||
defaultErrorMessage: t("error"),
|
||||
});
|
||||
},
|
||||
[isLoading, t],
|
||||
);
|
||||
|
||||
const sendMessage = useCallback(() => {
|
||||
const text = input.trim();
|
||||
if (!text || isLoading) return;
|
||||
setInput("");
|
||||
submitConversation([...messages, { role: "user", content: text }]);
|
||||
}, [input, isLoading, messages, submitConversation]);
|
||||
|
||||
const handleEditSubmit = useCallback(
|
||||
(messageIndex: number, newContent: string) => {
|
||||
const newList: ChatMessage[] = [
|
||||
...messages.slice(0, messageIndex),
|
||||
{ role: "user", content: newContent },
|
||||
];
|
||||
submitConversation(newList);
|
||||
},
|
||||
[messages, submitConversation],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex size-full justify-center p-2">
|
||||
<div className="flex size-full flex-col xl:w-[50%] 3xl:w-[35%]">
|
||||
<div className="scrollbar-container flex min-h-0 w-full flex-1 flex-col gap-2 overflow-y-auto">
|
||||
{messages.map((msg, i) => {
|
||||
const isStreamingPlaceholder =
|
||||
i === messages.length - 1 &&
|
||||
msg.role === "assistant" &&
|
||||
isLoading &&
|
||||
!msg.content?.trim() &&
|
||||
!(msg.toolCalls && msg.toolCalls.length > 0);
|
||||
if (isStreamingPlaceholder) {
|
||||
return <div key={i} />;
|
||||
}
|
||||
return (
|
||||
<div key={i} className="flex flex-col gap-2">
|
||||
{msg.role === "assistant" && msg.toolCalls && (
|
||||
<>
|
||||
{msg.toolCalls.map((tc, tcIdx) => (
|
||||
<div key={tcIdx} className="flex flex-col gap-2">
|
||||
<ToolCallBubble
|
||||
name={tc.name}
|
||||
arguments={tc.arguments}
|
||||
side="left"
|
||||
/>
|
||||
{tc.response && (
|
||||
<ToolCallBubble
|
||||
name={tc.name}
|
||||
response={tc.response}
|
||||
side="right"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
<MessageBubble
|
||||
role={msg.role}
|
||||
content={msg.content}
|
||||
messageIndex={i}
|
||||
onEditSubmit={
|
||||
msg.role === "user" ? handleEditSubmit : undefined
|
||||
}
|
||||
isComplete={
|
||||
msg.role === "user" || !isLoading || i < messages.length - 1
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
{(() => {
|
||||
const lastMsg = messages[messages.length - 1];
|
||||
const showProcessing =
|
||||
isLoading &&
|
||||
lastMsg?.role === "assistant" &&
|
||||
!lastMsg.content?.trim() &&
|
||||
!(lastMsg.toolCalls && lastMsg.toolCalls.length > 0);
|
||||
return showProcessing ? (
|
||||
<div className="self-start rounded-lg bg-muted px-3 py-2 text-muted-foreground">
|
||||
{t("processing")}
|
||||
</div>
|
||||
) : null;
|
||||
})()}
|
||||
{error && (
|
||||
<p className="self-start text-sm text-destructive" role="alert">
|
||||
{error}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<ChatEntry
|
||||
input={input}
|
||||
setInput={setInput}
|
||||
sendMessage={sendMessage}
|
||||
isLoading={isLoading}
|
||||
placeholder={t("placeholder")}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type ChatEntryProps = {
|
||||
input: string;
|
||||
setInput: (value: string) => void;
|
||||
sendMessage: () => void;
|
||||
isLoading: boolean;
|
||||
placeholder: string;
|
||||
};
|
||||
|
||||
function ChatEntry({
|
||||
input,
|
||||
setInput,
|
||||
sendMessage,
|
||||
isLoading,
|
||||
placeholder,
|
||||
}: ChatEntryProps) {
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendMessage();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex w-full flex-col items-center justify-center rounded-xl bg-secondary p-2">
|
||||
<div className="flex w-full flex-row items-center gap-2">
|
||||
<Input
|
||||
className="w-full flex-1 border-transparent bg-transparent shadow-none focus-visible:ring-0 dark:bg-transparent"
|
||||
placeholder={placeholder}
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
aria-busy={isLoading}
|
||||
/>
|
||||
<Button
|
||||
variant="select"
|
||||
className="size-10 shrink-0 rounded-full"
|
||||
disabled={!input.trim() || isLoading}
|
||||
onClick={sendMessage}
|
||||
>
|
||||
<FaArrowUpLong size="16" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import { toast } from "sonner";
|
||||
import useSWR from "swr";
|
||||
import useSWRInfinite from "swr/infinite";
|
||||
import { useDocDomain } from "@/hooks/use-doc-domain";
|
||||
import { JINA_EMBEDDING_MODELS } from "@/lib/const";
|
||||
|
||||
const API_LIMIT = 25;
|
||||
|
||||
@@ -294,12 +293,7 @@ export default function Explore() {
|
||||
const modelVersion = config?.semantic_search.model || "jinav1";
|
||||
const modelSize = config?.semantic_search.model_size || "small";
|
||||
|
||||
// GenAI providers have no local models to download
|
||||
const isGenaiEmbeddings =
|
||||
typeof modelVersion === "string" &&
|
||||
!(JINA_EMBEDDING_MODELS as readonly string[]).includes(modelVersion);
|
||||
|
||||
// Text model state (skipped for GenAI - no local models)
|
||||
// Text model state
|
||||
const { payload: textModelState } = useModelState(
|
||||
modelVersion === "jinav1"
|
||||
? "jinaai/jina-clip-v1-text_model_fp16.onnx"
|
||||
@@ -334,10 +328,6 @@ export default function Explore() {
|
||||
);
|
||||
|
||||
const allModelsLoaded = useMemo(() => {
|
||||
if (isGenaiEmbeddings) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return (
|
||||
textModelState === "downloaded" &&
|
||||
textTokenizerState === "downloaded" &&
|
||||
@@ -345,7 +335,6 @@ export default function Explore() {
|
||||
visionFeatureExtractorState === "downloaded"
|
||||
);
|
||||
}, [
|
||||
isGenaiEmbeddings,
|
||||
textModelState,
|
||||
textTokenizerState,
|
||||
visionModelState,
|
||||
@@ -369,11 +358,10 @@ export default function Explore() {
|
||||
!defaultViewLoaded ||
|
||||
(config?.semantic_search.enabled &&
|
||||
(!reindexState ||
|
||||
(!isGenaiEmbeddings &&
|
||||
(!textModelState ||
|
||||
!textTokenizerState ||
|
||||
!visionModelState ||
|
||||
!visionFeatureExtractorState))))
|
||||
!textModelState ||
|
||||
!textTokenizerState ||
|
||||
!visionModelState ||
|
||||
!visionFeatureExtractorState))
|
||||
) {
|
||||
return (
|
||||
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
|
||||
|
||||
11
web/src/types/chat.ts
Normal file
11
web/src/types/chat.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export type ToolCall = {
|
||||
name: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
response?: string;
|
||||
};
|
||||
|
||||
export type ChatMessage = {
|
||||
role: "user" | "assistant";
|
||||
content: string;
|
||||
toolCalls?: ToolCall[];
|
||||
};
|
||||
163
web/src/utils/chatUtil.ts
Normal file
163
web/src/utils/chatUtil.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import type { ChatMessage, ToolCall } from "@/types/chat";
|
||||
|
||||
export type StreamChatCallbacks = {
|
||||
/** Update the messages array (e.g. pass to setState). */
|
||||
updateMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void;
|
||||
/** Called when the stream sends an error or fetch fails. */
|
||||
onError: (message: string) => void;
|
||||
/** Called when the stream finishes (success or error). */
|
||||
onDone: () => void;
|
||||
/** Message used when fetch throws and no server error is available. */
|
||||
defaultErrorMessage?: string;
|
||||
};
|
||||
|
||||
type StreamChunk =
|
||||
| { type: "error"; error: string }
|
||||
| { type: "tool_calls"; tool_calls: ToolCall[] }
|
||||
| { type: "content"; delta: string };
|
||||
|
||||
/**
|
||||
* POST to chat/completion with stream: true, parse NDJSON stream, and invoke
|
||||
* callbacks so the caller can update UI (e.g. React state).
|
||||
*/
|
||||
export async function streamChatCompletion(
|
||||
url: string,
|
||||
headers: Record<string, string>,
|
||||
apiMessages: { role: string; content: string }[],
|
||||
callbacks: StreamChatCallbacks,
|
||||
): Promise<void> {
|
||||
const {
|
||||
updateMessages,
|
||||
onError,
|
||||
onDone,
|
||||
defaultErrorMessage = "Something went wrong. Please try again.",
|
||||
} = callbacks;
|
||||
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({ messages: apiMessages, stream: true }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const errBody = await res.json().catch(() => ({}));
|
||||
const message = (errBody as { error?: string }).error ?? res.statusText;
|
||||
onError(message);
|
||||
onDone();
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = res.body?.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
if (!reader) {
|
||||
onError("No response body");
|
||||
onDone();
|
||||
return;
|
||||
}
|
||||
|
||||
let buffer = "";
|
||||
let hadStreamError = false;
|
||||
|
||||
const applyChunk = (data: StreamChunk) => {
|
||||
if (data.type === "error") {
|
||||
onError(data.error);
|
||||
updateMessages((prev) =>
|
||||
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
|
||||
);
|
||||
return "break";
|
||||
}
|
||||
if (data.type === "tool_calls" && data.tool_calls?.length) {
|
||||
updateMessages((prev) => {
|
||||
const next = [...prev];
|
||||
const lastMsg = next[next.length - 1];
|
||||
if (lastMsg?.role === "assistant")
|
||||
next[next.length - 1] = {
|
||||
...lastMsg,
|
||||
toolCalls: data.tool_calls,
|
||||
};
|
||||
return next;
|
||||
});
|
||||
return "continue";
|
||||
}
|
||||
if (data.type === "content" && data.delta !== undefined) {
|
||||
updateMessages((prev) => {
|
||||
const next = [...prev];
|
||||
const lastMsg = next[next.length - 1];
|
||||
if (lastMsg?.role === "assistant")
|
||||
next[next.length - 1] = {
|
||||
...lastMsg,
|
||||
content: lastMsg.content + data.delta,
|
||||
};
|
||||
return next;
|
||||
});
|
||||
return "continue";
|
||||
}
|
||||
return "continue";
|
||||
};
|
||||
|
||||
for (;;) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
try {
|
||||
const data = JSON.parse(trimmed) as StreamChunk & { type: string };
|
||||
const result = applyChunk(data as StreamChunk);
|
||||
if (result === "break") {
|
||||
hadStreamError = true;
|
||||
break;
|
||||
}
|
||||
} catch {
|
||||
// skip malformed JSON lines
|
||||
}
|
||||
}
|
||||
if (hadStreamError) break;
|
||||
}
|
||||
|
||||
// Flush remaining buffer
|
||||
if (!hadStreamError && buffer.trim()) {
|
||||
try {
|
||||
const data = JSON.parse(buffer.trim()) as StreamChunk & {
|
||||
type: string;
|
||||
delta?: string;
|
||||
};
|
||||
if (data.type === "content" && data.delta !== undefined) {
|
||||
updateMessages((prev) => {
|
||||
const next = [...prev];
|
||||
const lastMsg = next[next.length - 1];
|
||||
if (lastMsg?.role === "assistant")
|
||||
next[next.length - 1] = {
|
||||
...lastMsg,
|
||||
content: lastMsg.content + data.delta!,
|
||||
};
|
||||
return next;
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
// ignore final malformed chunk
|
||||
}
|
||||
}
|
||||
|
||||
if (!hadStreamError) {
|
||||
updateMessages((prev) => {
|
||||
const next = [...prev];
|
||||
const lastMsg = next[next.length - 1];
|
||||
if (lastMsg?.role === "assistant" && lastMsg.content === "")
|
||||
next[next.length - 1] = { ...lastMsg, content: " " };
|
||||
return next;
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
onError(defaultErrorMessage);
|
||||
updateMessages((prev) =>
|
||||
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
|
||||
);
|
||||
} finally {
|
||||
onDone();
|
||||
}
|
||||
}
|
||||
@@ -46,6 +46,7 @@ i18n
|
||||
"components/icons",
|
||||
"components/player",
|
||||
"views/events",
|
||||
"views/chat",
|
||||
"views/explore",
|
||||
"views/live",
|
||||
"views/settings",
|
||||
|
||||
Reference in New Issue
Block a user