Add object classification attributes to Tracked Object Details (#21348)

* attributes endpoint

* event endpoints

* add attributes to more filters

* add to suggestions and query in explore

* support attributes in search input

* i18n

* add object type filter to endpoint

* add attributes to tracked object details pane

* add generic multi select dialog

* save object attributes endpoint

* add group by param to fetch attributes endpoint

* add attribute editing to tracked object details

* docs

* fix docs

* update openapi spec to match python
This commit is contained in:
Josh Hawkins
2025-12-18 08:35:47 -06:00
committed by GitHub
parent 074b060e9c
commit 6a0e31dcf9
18 changed files with 808 additions and 6 deletions

View File

@@ -3,7 +3,7 @@ id: object_classification
title: Object Classification
---
Object classification allows you to train a custom MobileNetV2 classification model to run on tracked objects (persons, cars, animals, etc.) to identify a finer category or attribute for that object.
Object classification allows you to train a custom MobileNetV2 classification model to run on tracked objects (persons, cars, animals, etc.) to identify a finer category or attribute for that object. Classification results are visible in the Tracked Object Details pane in Explore, through the `frigate/tracked_object_details` MQTT topic, in Home Assistant sensors via the official Frigate integration, or through the event endpoints in the HTTP API.
## Minimum System Requirements

View File

@@ -3,7 +3,7 @@ id: state_classification
title: State Classification
---
State classification allows you to train a custom MobileNetV2 classification model on a fixed region of your camera frame(s) to determine a current state. The model can be configured to run on a schedule and/or when motion is detected in that region.
State classification allows you to train a custom MobileNetV2 classification model on a fixed region of your camera frame(s) to determine a current state. The model can be configured to run on a schedule and/or when motion is detected in that region. Classification results are available through the `frigate/<camera_name>/classification/<model_name>` MQTT topic and in Home Assistant sensors via the official Frigate integration.
## Minimum System Requirements

View File

@@ -616,6 +616,32 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/HTTPValidationError"
/classification/attributes:
get:
tags:
- Classification
summary: Get custom classification attributes
description: |-
Returns custom classification attributes for a given object type.
Only includes models with classification_type set to 'attribute'.
By default returns a flat sorted list of all attribute labels.
If group_by_model is true, returns attributes grouped by model name.
operationId: get_custom_attributes_classification_attributes_get
parameters:
- name: object_type
in: query
schema:
type: string
- name: group_by_model
in: query
schema:
type: boolean
default: false
responses:
"200":
description: Successful Response
"422":
description: Validation Error
/classification/{name}/dataset:
get:
tags:
@@ -2912,6 +2938,42 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/HTTPValidationError"
/events/{event_id}/attributes:
post:
tags:
- Events
summary: Set custom classification attributes
description: |-
Sets an event's custom classification attributes for all attribute-type
models that apply to the event's object type.
Returns a success message or an error if the event is not found.
operationId: set_attributes_events__event_id__attributes_post
parameters:
- name: event_id
in: path
required: true
schema:
type: string
title: Event Id
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/EventsAttributesBody"
responses:
"200":
description: Successful Response
content:
application/json:
schema:
$ref: "#/components/schemas/GenericResponse"
"422":
description: Validation Error
content:
application/json:
schema:
$ref: "#/components/schemas/HTTPValidationError"
/events/{event_id}/description:
post:
tags:
@@ -4959,6 +5021,18 @@ components:
required:
- subLabel
title: EventsSubLabelBody
EventsAttributesBody:
properties:
attributes:
type: object
title: Attributes
description: Object with model names as keys and attribute values
additionalProperties:
type: string
type: object
required:
- attributes
title: EventsAttributesBody
ExportModel:
properties:
id:

View File

@@ -31,6 +31,7 @@ from frigate.api.defs.response.generic_response import GenericResponse
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.config.classification import ObjectClassificationType
from frigate.const import CLIPS_DIR, FACE_DIR, MODEL_CACHE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
@@ -622,6 +623,59 @@ def get_classification_dataset(name: str):
)
@router.get(
"/classification/attributes",
summary="Get custom classification attributes",
description="""Returns custom classification attributes for a given object type.
Only includes models with classification_type set to 'attribute'.
By default returns a flat sorted list of all attribute labels.
If group_by_model is true, returns attributes grouped by model name.""",
)
def get_custom_attributes(
request: Request, object_type: str = None, group_by_model: bool = False
):
models_with_attributes = {}
for (
model_key,
model_config,
) in request.app.frigate_config.classification.custom.items():
if (
not model_config.enabled
or not model_config.object_config
or model_config.object_config.classification_type
!= ObjectClassificationType.attribute
):
continue
model_objects = getattr(model_config.object_config, "objects", []) or []
if object_type is not None and object_type not in model_objects:
continue
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(model_key), "dataset")
if not os.path.exists(dataset_dir):
continue
attributes = []
for category_name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category_name)
if os.path.isdir(category_dir) and category_name != "none":
attributes.append(category_name)
if attributes:
model_name = model_config.name or model_key
models_with_attributes[model_name] = sorted(attributes)
if group_by_model:
return JSONResponse(content=models_with_attributes)
else:
# Flatten to a unique sorted list
all_attributes = set()
for attributes in models_with_attributes.values():
all_attributes.update(attributes)
return JSONResponse(content=sorted(list(all_attributes)))
@router.get(
"/classification/{name}/train",
summary="Get classification train images",

View File

@@ -12,6 +12,7 @@ class EventsQueryParams(BaseModel):
labels: Optional[str] = "all"
sub_label: Optional[str] = "all"
sub_labels: Optional[str] = "all"
attributes: Optional[str] = "all"
zone: Optional[str] = "all"
zones: Optional[str] = "all"
limit: Optional[int] = 100
@@ -58,6 +59,8 @@ class EventsSearchQueryParams(BaseModel):
limit: Optional[int] = 50
cameras: Optional[str] = "all"
labels: Optional[str] = "all"
sub_labels: Optional[str] = "all"
attributes: Optional[str] = "all"
zones: Optional[str] = "all"
after: Optional[float] = None
before: Optional[float] = None

View File

@@ -24,6 +24,13 @@ class EventsLPRBody(BaseModel):
)
class EventsAttributesBody(BaseModel):
attributes: List[str] = Field(
title="Selected classification attributes for the event",
default_factory=list,
)
class EventsDescriptionBody(BaseModel):
description: Union[str, None] = Field(title="The description of the event")

View File

@@ -37,6 +37,7 @@ from frigate.api.defs.query.regenerate_query_parameters import (
RegenerateQueryParameters,
)
from frigate.api.defs.request.events_body import (
EventsAttributesBody,
EventsCreateBody,
EventsDeleteBody,
EventsDescriptionBody,
@@ -55,6 +56,7 @@ from frigate.api.defs.response.event_response import (
from frigate.api.defs.response.generic_response import GenericResponse
from frigate.api.defs.tags import Tags
from frigate.comms.event_metadata_updater import EventMetadataTypeEnum
from frigate.config.classification import ObjectClassificationType
from frigate.const import CLIPS_DIR, TRIGGER_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event, ReviewSegment, Timeline, Trigger
@@ -99,6 +101,8 @@ def events(
if sub_labels == "all" and sub_label != "all":
sub_labels = sub_label
attributes = unquote(params.attributes)
zone = params.zone
zones = params.zones
@@ -187,6 +191,17 @@ def events(
sub_label_clause = reduce(operator.or_, sub_label_clauses)
clauses.append((sub_label_clause))
if attributes != "all":
# Custom classification results are stored as data[model_name] = result_value
filtered_attributes = attributes.split(",")
attribute_clauses = []
for attr in filtered_attributes:
attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*')
attribute_clause = reduce(operator.or_, attribute_clauses)
clauses.append(attribute_clause)
if recognized_license_plate != "all":
filtered_recognized_license_plates = recognized_license_plate.split(",")
@@ -492,6 +507,8 @@ def events_search(
# Filters
cameras = params.cameras
labels = params.labels
sub_labels = params.sub_labels
attributes = params.attributes
zones = params.zones
after = params.after
before = params.before
@@ -566,6 +583,38 @@ def events_search(
if labels != "all":
event_filters.append((Event.label << labels.split(",")))
if sub_labels != "all":
# use matching so joined sub labels are included
# for example a sub label 'bob' would get events
# with sub labels 'bob' and 'bob, john'
sub_label_clauses = []
filtered_sub_labels = sub_labels.split(",")
if "None" in filtered_sub_labels:
filtered_sub_labels.remove("None")
sub_label_clauses.append((Event.sub_label.is_null()))
for label in filtered_sub_labels:
sub_label_clauses.append(
(Event.sub_label.cast("text") == label)
) # include exact matches
# include this label when part of a list
sub_label_clauses.append((Event.sub_label.cast("text") % f"*{label},*"))
sub_label_clauses.append((Event.sub_label.cast("text") % f"*, {label}*"))
event_filters.append((reduce(operator.or_, sub_label_clauses)))
if attributes != "all":
# Custom classification results are stored as data[model_name] = result_value
filtered_attributes = attributes.split(",")
attribute_clauses = []
for attr in filtered_attributes:
attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*')
event_filters.append(reduce(operator.or_, attribute_clauses))
if zones != "all":
zone_clauses = []
filtered_zones = zones.split(",")
@@ -1351,6 +1400,107 @@ async def set_plate(
)
@router.post(
"/events/{event_id}/attributes",
response_model=GenericResponse,
dependencies=[Depends(require_role(["admin"]))],
summary="Set custom classification attributes",
description=(
"Sets an event's custom classification attributes for all attribute-type "
"models that apply to the event's object type."
),
)
async def set_attributes(
request: Request,
event_id: str,
body: EventsAttributesBody,
):
try:
event: Event = Event.get(Event.id == event_id)
await require_camera_access(event.camera, request=request)
except DoesNotExist:
return JSONResponse(
content=({"success": False, "message": f"Event {event_id} not found."}),
status_code=404,
)
object_type = event.label
selected_attributes = set(body.attributes or [])
applied_updates: list[dict[str, str | float | None]] = []
for (
model_key,
model_config,
) in request.app.frigate_config.classification.custom.items():
# Only apply to enabled attribute classifiers that target this object type
if (
not model_config.enabled
or not model_config.object_config
or model_config.object_config.classification_type
!= ObjectClassificationType.attribute
or object_type not in (model_config.object_config.objects or [])
):
continue
# Get available labels from dataset directory
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(model_key), "dataset")
available_labels = set()
if os.path.exists(dataset_dir):
for category_name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category_name)
if os.path.isdir(category_dir):
available_labels.add(category_name)
if not available_labels:
logger.warning(
"No dataset found for custom attribute model %s at %s",
model_key,
dataset_dir,
)
continue
# Find all selected attributes that apply to this model
model_name = model_config.name or model_key
matching_attrs = selected_attributes & available_labels
if matching_attrs:
# Publish updates for each selected attribute
for attr in matching_attrs:
request.app.event_metadata_updater.publish(
(event_id, model_name, attr, 1.0),
EventMetadataTypeEnum.attribute.value,
)
applied_updates.append(
{"model": model_name, "label": attr, "score": 1.0}
)
else:
# Clear this model's attribute
request.app.event_metadata_updater.publish(
(event_id, model_name, None, None),
EventMetadataTypeEnum.attribute.value,
)
applied_updates.append({"model": model_name, "label": None, "score": None})
if len(applied_updates) == 0:
return JSONResponse(
content={
"success": False,
"message": "No matching attributes found for this object type.",
},
status_code=400,
)
return JSONResponse(
content={
"success": True,
"message": f"Updated {len(applied_updates)} attribute(s)",
"applied": applied_updates,
},
status_code=200,
)
@router.post(
"/events/{event_id}/description",
response_model=GenericResponse,

View File

@@ -38,6 +38,10 @@
"label": "Sub Labels",
"all": "All Sub Labels"
},
"attributes": {
"label": "Classification Attributes",
"all": "All Attributes"
},
"score": "Score",
"estimatedSpeed": "Estimated Speed ({{unit}})",
"features": {

View File

@@ -104,12 +104,14 @@
"regenerate": "A new description has been requested from {{provider}}. Depending on the speed of your provider, the new description may take some time to regenerate.",
"updatedSublabel": "Successfully updated sub label.",
"updatedLPR": "Successfully updated license plate.",
"updatedAttributes": "Successfully updated attributes.",
"audioTranscription": "Successfully requested audio transcription. Depending on the speed of your Frigate server, the transcription may take some time to complete."
},
"error": {
"regenerate": "Failed to call {{provider}} for a new description: {{errorMessage}}",
"updatedSublabelFailed": "Failed to update sub label: {{errorMessage}}",
"updatedLPRFailed": "Failed to update license plate: {{errorMessage}}",
"updatedAttributesFailed": "Failed to update attributes: {{errorMessage}}",
"audioTranscription": "Failed to request audio transcription: {{errorMessage}}"
}
}
@@ -125,6 +127,10 @@
"desc": "Enter a new license plate value for this {{label}}",
"descNoLabel": "Enter a new license plate value for this tracked object"
},
"editAttributes": {
"title": "Edit attributes",
"desc": "Select classification attributes for this {{label}}"
},
"snapshotScore": {
"label": "Snapshot Score"
},
@@ -136,6 +142,7 @@
"label": "Score"
},
"recognizedLicensePlate": "Recognized License Plate",
"attributes": "Classification Attributes",
"estimatedSpeed": "Estimated Speed",
"objects": "Objects",
"camera": "Camera",

View File

@@ -16,6 +16,7 @@
"labels": "Labels",
"zones": "Zones",
"sub_labels": "Sub Labels",
"attributes": "Attributes",
"search_type": "Search Type",
"time_range": "Time Range",
"before": "Before",

View File

@@ -399,7 +399,7 @@ export default function InputWithTags({
newFilters.sort = value as SearchSortType;
break;
default:
// Handle array types (cameras, labels, subLabels, zones)
// Handle array types (cameras, labels, sub_labels, attributes, zones)
if (!newFilters[type]) newFilters[type] = [];
if (Array.isArray(newFilters[type])) {
if (!(newFilters[type] as string[]).includes(value)) {

View File

@@ -84,6 +84,7 @@ import { LuInfo } from "react-icons/lu";
import { TooltipPortal } from "@radix-ui/react-tooltip";
import { FaPencilAlt } from "react-icons/fa";
import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog";
import AttributeSelectDialog from "@/components/overlay/dialog/AttributeSelectDialog";
import { Trans, useTranslation } from "react-i18next";
import { useIsAdmin } from "@/hooks/use-is-admin";
import { getTranslatedLabel } from "@/utils/i18n";
@@ -297,6 +298,7 @@ type DialogContentComponentProps = {
isPopoverOpen: boolean;
setIsPopoverOpen: (open: boolean) => void;
dialogContainer: HTMLDivElement | null;
setShowNavigationButtons: React.Dispatch<React.SetStateAction<boolean>>;
};
function DialogContentComponent({
@@ -314,6 +316,7 @@ function DialogContentComponent({
isPopoverOpen,
setIsPopoverOpen,
dialogContainer,
setShowNavigationButtons,
}: DialogContentComponentProps) {
if (page === "tracking_details") {
return (
@@ -399,6 +402,7 @@ function DialogContentComponent({
config={config}
setSearch={setSearch}
setInputFocused={setInputFocused}
setShowNavigationButtons={setShowNavigationButtons}
/>
</div>
</div>
@@ -415,6 +419,7 @@ function DialogContentComponent({
config={config}
setSearch={setSearch}
setInputFocused={setInputFocused}
setShowNavigationButtons={setShowNavigationButtons}
/>
</>
);
@@ -459,6 +464,7 @@ export default function SearchDetailDialog({
const [isOpen, setIsOpen] = useState(search != undefined);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const [showNavigationButtons, setShowNavigationButtons] = useState(false);
const dialogContentRef = useRef<HTMLDivElement | null>(null);
const [dialogContainer, setDialogContainer] = useState<HTMLDivElement | null>(
null,
@@ -540,9 +546,9 @@ export default function SearchDetailDialog({
onOpenChange={handleOpenChange}
enableHistoryBack={true}
>
{isDesktop && onPrevious && onNext && (
{isDesktop && onPrevious && onNext && showNavigationButtons && (
<DialogPortal>
<div className="pointer-events-none fixed inset-0 z-[200] flex items-center justify-center">
<div className="pointer-events-none fixed inset-0 z-[51] flex items-center justify-center">
<div
className={cn(
"relative flex items-center justify-between",
@@ -652,6 +658,7 @@ export default function SearchDetailDialog({
isPopoverOpen={isPopoverOpen}
setIsPopoverOpen={setIsPopoverOpen}
dialogContainer={dialogContainer}
setShowNavigationButtons={setShowNavigationButtons}
/>
</Content>
</Overlay>
@@ -664,12 +671,14 @@ type ObjectDetailsTabProps = {
config?: FrigateConfig;
setSearch: (search: SearchResult | undefined) => void;
setInputFocused: React.Dispatch<React.SetStateAction<boolean>>;
setShowNavigationButtons?: React.Dispatch<React.SetStateAction<boolean>>;
};
function ObjectDetailsTab({
search,
config,
setSearch,
setInputFocused,
setShowNavigationButtons,
}: ObjectDetailsTabProps) {
const { t, i18n } = useTranslation([
"views/explore",
@@ -678,6 +687,15 @@ function ObjectDetailsTab({
]);
const apiHost = useApiHost();
const hasCustomClassificationModels = useMemo(
() => Object.keys(config?.classification?.custom ?? {}).length > 0,
[config],
);
const { data: modelAttributes } = useSWR<Record<string, string[]>>(
hasCustomClassificationModels && search
? `classification/attributes?object_type=${encodeURIComponent(search.label)}&group_by_model=true`
: null,
);
// mutation / revalidation
@@ -708,6 +726,7 @@ function ObjectDetailsTab({
const [desc, setDesc] = useState(search?.data.description);
const [isSubLabelDialogOpen, setIsSubLabelDialogOpen] = useState(false);
const [isLPRDialogOpen, setIsLPRDialogOpen] = useState(false);
const [isAttributesDialogOpen, setIsAttributesDialogOpen] = useState(false);
const [isEditingDesc, setIsEditingDesc] = useState(false);
const originalDescRef = useRef<string | null>(null);
@@ -722,6 +741,19 @@ function ObjectDetailsTab({
// we have to make sure the current selected search item stays in sync
useEffect(() => setDesc(search?.data.description ?? ""), [search]);
useEffect(() => setIsAttributesDialogOpen(false), [search?.id]);
useEffect(() => {
const anyDialogOpen =
isSubLabelDialogOpen || isLPRDialogOpen || isAttributesDialogOpen;
setShowNavigationButtons?.(!anyDialogOpen);
}, [
isSubLabelDialogOpen,
isLPRDialogOpen,
isAttributesDialogOpen,
setShowNavigationButtons,
]);
const formattedDate = useFormattedTimestamp(
search?.start_time ?? 0,
config?.ui.time_format == "24hour"
@@ -807,6 +839,41 @@ function ObjectDetailsTab({
}
}, [search]);
// Extract current attribute selections grouped by model
const selectedAttributesByModel = useMemo(() => {
if (!search || !modelAttributes) {
return {};
}
const dataAny = search.data as Record<string, unknown>;
const selections: Record<string, string | null> = {};
// Initialize all models with null
Object.keys(modelAttributes).forEach((modelName) => {
selections[modelName] = null;
});
// Find which attribute is selected for each model
Object.keys(modelAttributes).forEach((modelName) => {
const value = dataAny[modelName];
if (
typeof value === "string" &&
modelAttributes[modelName].includes(value)
) {
selections[modelName] = value;
}
});
return selections;
}, [search, modelAttributes]);
// Get flat list of selected attributes for display
const eventAttributes = useMemo(() => {
return Object.values(selectedAttributesByModel)
.filter((attr): attr is string => attr !== null)
.sort((a, b) => a.localeCompare(b));
}, [selectedAttributesByModel]);
const isEventsKey = useCallback((key: unknown): boolean => {
const candidate = Array.isArray(key) ? key[0] : key;
const EVENTS_KEY_PATTERNS = ["events", "events/search", "events/explore"];
@@ -1048,6 +1115,74 @@ function ObjectDetailsTab({
[search, apiHost, mutate, setSearch, t, mapSearchResults, isEventsKey],
);
const handleAttributesSave = useCallback(
(selectedAttributes: string[]) => {
if (!search) return;
axios
.post(`${apiHost}api/events/${search.id}/attributes`, {
attributes: selectedAttributes,
})
.then((response) => {
const applied = Array.isArray(response.data?.applied)
? (response.data.applied as {
model?: string;
label?: string | null;
score?: number | null;
}[])
: [];
toast.success(t("details.item.toast.success.updatedAttributes"), {
position: "top-center",
});
const applyUpdatedAttributes = (event: SearchResult) => {
if (event.id !== search.id) return event;
const updatedData: Record<string, unknown> = { ...event.data };
applied.forEach(({ model, label, score }) => {
if (!model) return;
updatedData[model] = label ?? null;
updatedData[`${model}_score`] = score ?? null;
});
return { ...event, data: updatedData } as SearchResult;
};
mutate(
(key) => isEventsKey(key),
(currentData: SearchResult[][] | SearchResult[] | undefined) =>
mapSearchResults(currentData, applyUpdatedAttributes),
{
optimisticData: true,
rollbackOnError: true,
revalidate: false,
},
);
setSearch(applyUpdatedAttributes(search));
setIsAttributesDialogOpen(false);
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(
t("details.item.toast.error.updatedAttributesFailed", {
errorMessage,
}),
{
position: "top-center",
},
);
});
},
[search, apiHost, mutate, t, mapSearchResults, isEventsKey, setSearch],
);
// speech transcription
const onTranscribe = useCallback(() => {
@@ -1295,6 +1430,38 @@ function ObjectDetailsTab({
</div>
</div>
)}
{hasCustomClassificationModels &&
modelAttributes &&
Object.keys(modelAttributes).length > 0 && (
<div className="flex flex-col gap-1.5">
<div className="flex items-center gap-2 text-sm text-primary/40">
{t("details.attributes")}
{isAdmin && (
<Tooltip>
<TooltipTrigger asChild>
<span>
<FaPencilAlt
className="size-4 cursor-pointer text-primary/40 hover:text-primary/80"
onClick={() => setIsAttributesDialogOpen(true)}
/>
</span>
</TooltipTrigger>
<TooltipPortal>
<TooltipContent>
{t("button.edit", { ns: "common" })}
</TooltipContent>
</TooltipPortal>
</Tooltip>
)}
</div>
<div className="text-sm">
{eventAttributes.length > 0
? eventAttributes.join(", ")
: t("label.none", { ns: "common" })}
</div>
</div>
)}
</div>
</div>
@@ -1595,6 +1762,17 @@ function ObjectDetailsTab({
defaultValue={search?.data.recognized_license_plate || ""}
allowEmpty={true}
/>
<AttributeSelectDialog
open={isAttributesDialogOpen}
setOpen={setIsAttributesDialogOpen}
title={t("details.editAttributes.title")}
description={t("details.editAttributes.desc", {
label: search.label,
})}
onSave={handleAttributesSave}
selectedAttributes={selectedAttributesByModel}
modelAttributes={modelAttributes ?? {}}
/>
</div>
</div>
);

View File

@@ -0,0 +1,123 @@
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Label } from "@/components/ui/label";
import { Switch } from "@/components/ui/switch";
import { cn } from "@/lib/utils";
import { useCallback, useEffect, useState } from "react";
import { isDesktop } from "react-device-detect";
import { useTranslation } from "react-i18next";
type AttributeSelectDialogProps = {
open: boolean;
setOpen: (open: boolean) => void;
title: string;
description: string;
onSave: (selectedAttributes: string[]) => void;
selectedAttributes: Record<string, string | null>; // model -> selected attribute
modelAttributes: Record<string, string[]>; // model -> available attributes
className?: string;
};
export default function AttributeSelectDialog({
open,
setOpen,
title,
description,
onSave,
selectedAttributes,
modelAttributes,
className,
}: AttributeSelectDialogProps) {
const { t } = useTranslation();
const [internalSelection, setInternalSelection] = useState<
Record<string, string | null>
>({});
useEffect(() => {
if (open) {
setInternalSelection({ ...selectedAttributes });
}
}, [open, selectedAttributes]);
const handleSave = useCallback(() => {
// Convert from model->attribute map to flat list of attributes
const attributes = Object.values(internalSelection).filter(
(attr): attr is string => attr !== null,
);
onSave(attributes);
}, [internalSelection, onSave]);
const handleToggle = useCallback((modelName: string, attribute: string) => {
setInternalSelection((prev) => {
const currentSelection = prev[modelName];
// If clicking the currently selected attribute, deselect it
if (currentSelection === attribute) {
return { ...prev, [modelName]: null };
}
// Otherwise, select this attribute for this model
return { ...prev, [modelName]: attribute };
});
}, []);
return (
<Dialog open={open} onOpenChange={setOpen}>
<DialogContent
className={cn(className, isDesktop ? "max-w-md" : "max-w-[90%]")}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<DialogHeader>
<DialogTitle>{title}</DialogTitle>
<DialogDescription>{description}</DialogDescription>
</DialogHeader>
<div className="scrollbar-container overflow-y-auto">
<div className="max-h-[80dvh] space-y-6 py-2">
{Object.entries(modelAttributes).map(([modelName, attributes]) => (
<div key={modelName} className="space-y-3">
<div className="text-sm font-semibold text-primary-variant">
{modelName}
</div>
<div className="space-y-2 pl-2">
{attributes.map((attribute) => (
<div
key={attribute}
className="flex items-center justify-between gap-2"
>
<Label
htmlFor={`${modelName}-${attribute}`}
className="cursor-pointer text-sm text-primary"
>
{attribute}
</Label>
<Switch
id={`${modelName}-${attribute}`}
checked={internalSelection[modelName] === attribute}
onCheckedChange={() =>
handleToggle(modelName, attribute)
}
/>
</div>
))}
</div>
</div>
))}
</div>
</div>
<DialogFooter>
<Button type="button" onClick={() => setOpen(false)}>
{t("button.cancel")}
</Button>
<Button variant="select" onClick={handleSave}>
{t("button.save", { ns: "common" })}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
}

View File

@@ -0,0 +1,96 @@
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { cn } from "@/lib/utils";
import { useState } from "react";
import { isMobile } from "react-device-detect";
import { useTranslation } from "react-i18next";
import FilterSwitch from "@/components/filter/FilterSwitch";
type MultiSelectDialogProps = {
open: boolean;
title: string;
description?: string;
setOpen: (open: boolean) => void;
onSave: (selectedItems: string[]) => void;
selectedItems: string[];
availableItems: string[];
allowEmpty?: boolean;
};
export default function MultiSelectDialog({
open,
title,
description,
setOpen,
onSave,
selectedItems = [],
availableItems = [],
allowEmpty = false,
}: MultiSelectDialogProps) {
const { t } = useTranslation("common");
const [internalSelection, setInternalSelection] =
useState<string[]>(selectedItems);
// Reset internal selection when dialog opens
const handleOpenChange = (isOpen: boolean) => {
if (isOpen) {
setInternalSelection(selectedItems);
}
setOpen(isOpen);
};
const toggleItem = (item: string) => {
setInternalSelection((prev) =>
prev.includes(item) ? prev.filter((i) => i !== item) : [...prev, item],
);
};
const handleSave = () => {
if (!allowEmpty && internalSelection.length === 0) {
return;
}
onSave(internalSelection);
setOpen(false);
};
return (
<Dialog open={open} defaultOpen={false} onOpenChange={handleOpenChange}>
<DialogContent>
<DialogHeader>
<DialogTitle>{title}</DialogTitle>
{description && <DialogDescription>{description}</DialogDescription>}
</DialogHeader>
<div className="max-h-[80dvh] space-y-3 overflow-y-auto py-4">
{availableItems.map((item) => (
<FilterSwitch
key={item}
label={item}
isChecked={internalSelection.includes(item)}
onCheckedChange={() => toggleItem(item)}
/>
))}
</div>
<DialogFooter className={cn("pt-4", isMobile && "gap-2")}>
<Button type="button" onClick={() => setOpen(false)}>
{t("button.cancel")}
</Button>
<Button
variant="select"
type="button"
onClick={handleSave}
disabled={!allowEmpty && internalSelection.length === 0}
>
{t("button.save")}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
}

View File

@@ -65,6 +65,13 @@ export default function SearchFilterDialog({
const { t } = useTranslation(["components/filter"]);
const [currentFilter, setCurrentFilter] = useState(filter ?? {});
const { data: allSubLabels } = useSWR(["sub_labels", { split_joined: 1 }]);
const hasCustomClassificationModels = useMemo(
() => Object.keys(config?.classification?.custom ?? {}).length > 0,
[config],
);
const { data: allAttributes } = useSWR(
hasCustomClassificationModels ? "classification/attributes" : null,
);
const { data: allRecognizedLicensePlates } = useSWR<string[]>(
"recognized_license_plates",
);
@@ -91,8 +98,10 @@ export default function SearchFilterDialog({
(currentFilter.max_speed ?? 150) < 150 ||
(currentFilter.zones?.length ?? 0) > 0 ||
(currentFilter.sub_labels?.length ?? 0) > 0 ||
(hasCustomClassificationModels &&
(currentFilter.attributes?.length ?? 0) > 0) ||
(currentFilter.recognized_license_plate?.length ?? 0) > 0),
[currentFilter],
[currentFilter, hasCustomClassificationModels],
);
const trigger = (
@@ -133,6 +142,15 @@ export default function SearchFilterDialog({
setCurrentFilter({ ...currentFilter, sub_labels: newSubLabels })
}
/>
{hasCustomClassificationModels && (
<AttributeFilterContent
allAttributes={allAttributes}
attributes={currentFilter.attributes}
setAttributes={(newAttributes) =>
setCurrentFilter({ ...currentFilter, attributes: newAttributes })
}
/>
)}
<RecognizedLicensePlatesFilterContent
allRecognizedLicensePlates={allRecognizedLicensePlates}
recognizedLicensePlates={currentFilter.recognized_license_plate}
@@ -216,6 +234,7 @@ export default function SearchFilterDialog({
max_speed: undefined,
has_snapshot: undefined,
has_clip: undefined,
...(hasCustomClassificationModels && { attributes: undefined }),
recognized_license_plate: undefined,
}));
}}
@@ -1087,3 +1106,72 @@ export function RecognizedLicensePlatesFilterContent({
</div>
);
}
type AttributeFilterContentProps = {
allAttributes?: string[];
attributes: string[] | undefined;
setAttributes: (labels: string[] | undefined) => void;
};
export function AttributeFilterContent({
allAttributes,
attributes,
setAttributes,
}: AttributeFilterContentProps) {
const { t } = useTranslation(["components/filter"]);
const sortedAttributes = useMemo(
() =>
[...(allAttributes || [])].sort((a, b) =>
a.toLowerCase().localeCompare(b.toLowerCase()),
),
[allAttributes],
);
return (
<div className="overflow-x-hidden">
<DropdownMenuSeparator className="mb-3" />
<div className="text-lg">{t("attributes.label")}</div>
<div className="mb-5 mt-2.5 flex items-center justify-between">
<Label
className="mx-2 cursor-pointer text-primary"
htmlFor="allAttributes"
>
{t("attributes.all")}
</Label>
<Switch
className="ml-1"
id="allAttributes"
checked={attributes == undefined}
onCheckedChange={(isChecked) => {
if (isChecked) {
setAttributes(undefined);
}
}}
/>
</div>
<div className="mt-2.5 flex flex-col gap-2.5">
{sortedAttributes.map((item) => (
<FilterSwitch
key={item}
label={item.replaceAll("_", " ")}
isChecked={attributes?.includes(item) ?? false}
onCheckedChange={(isChecked) => {
if (isChecked) {
const updatedAttributes = attributes ? [...attributes] : [];
updatedAttributes.push(item);
setAttributes(updatedAttributes);
} else {
const updatedAttributes = attributes ? [...attributes] : [];
// can not deselect the last item
if (updatedAttributes.length > 1) {
updatedAttributes.splice(updatedAttributes.indexOf(item), 1);
setAttributes(updatedAttributes);
}
}
}}
/>
))}
</div>
</div>
);
}

View File

@@ -31,6 +31,7 @@ const SEARCH_FILTER_ARRAY_KEYS = [
"cameras",
"labels",
"sub_labels",
"attributes",
"recognized_license_plate",
"zones",
];
@@ -122,6 +123,7 @@ export default function Explore() {
cameras: searchSearchParams["cameras"],
labels: searchSearchParams["labels"],
sub_labels: searchSearchParams["sub_labels"],
attributes: searchSearchParams["attributes"],
recognized_license_plate:
searchSearchParams["recognized_license_plate"],
zones: searchSearchParams["zones"],
@@ -158,6 +160,7 @@ export default function Explore() {
cameras: searchSearchParams["cameras"],
labels: searchSearchParams["labels"],
sub_labels: searchSearchParams["sub_labels"],
attributes: searchSearchParams["attributes"],
recognized_license_plate:
searchSearchParams["recognized_license_plate"],
zones: searchSearchParams["zones"],

View File

@@ -5,6 +5,7 @@ const SEARCH_FILTERS = [
"general",
"zone",
"sub",
"attribute",
"source",
"sort",
] as const;
@@ -16,6 +17,7 @@ export const DEFAULT_SEARCH_FILTERS: SearchFilters[] = [
"general",
"zone",
"sub",
"attribute",
"source",
"sort",
];
@@ -71,6 +73,7 @@ export type SearchFilter = {
cameras?: string[];
labels?: string[];
sub_labels?: string[];
attributes?: string[];
recognized_license_plate?: string[];
zones?: string[];
before?: number;
@@ -95,6 +98,7 @@ export type SearchQueryParams = {
cameras?: string[];
labels?: string[];
sub_labels?: string[];
attributes?: string[];
recognized_license_plate?: string[];
zones?: string[];
before?: string;

View File

@@ -143,6 +143,13 @@ export default function SearchView({
}, [config, searchFilter, allowedCameras]);
const { data: allSubLabels } = useSWR("sub_labels");
const hasCustomClassificationModels = useMemo(
() => Object.keys(config?.classification?.custom ?? {}).length > 0,
[config],
);
const { data: allAttributes } = useSWR(
hasCustomClassificationModels ? "classification/attributes" : null,
);
const { data: allRecognizedLicensePlates } = useSWR(
"recognized_license_plates",
);
@@ -182,6 +189,7 @@ export default function SearchView({
labels: Object.values(allLabels || {}),
zones: Object.values(allZones || {}),
sub_labels: allSubLabels,
...(hasCustomClassificationModels && { attributes: allAttributes }),
search_type: ["thumbnail", "description"] as SearchSource[],
time_range:
config?.ui.time_format == "24hour"
@@ -204,9 +212,11 @@ export default function SearchView({
allLabels,
allZones,
allSubLabels,
allAttributes,
allRecognizedLicensePlates,
searchFilter,
allowedCameras,
hasCustomClassificationModels,
],
);