From 6a0e31dcf9e6f87b7a305f621d9b0808330acc98 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Thu, 18 Dec 2025 08:35:47 -0600 Subject: [PATCH] Add object classification attributes to Tracked Object Details (#21348) * attributes endpoint * event endpoints * add attributes to more filters * add to suggestions and query in explore * support attributes in search input * i18n * add object type filter to endpoint * add attributes to tracked object details pane * add generic multi select dialog * save object attributes endpoint * add group by param to fetch attributes endpoint * add attribute editing to tracked object details * docs * fix docs * update openapi spec to match python --- .../object_classification.md | 2 +- .../state_classification.md | 2 +- docs/static/frigate-api.yaml | 74 +++++++ frigate/api/classification.py | 54 ++++++ .../api/defs/query/events_query_parameters.py | 3 + frigate/api/defs/request/events_body.py | 7 + frigate/api/event.py | 150 +++++++++++++++ web/public/locales/en/components/filter.json | 4 + web/public/locales/en/views/explore.json | 7 + web/public/locales/en/views/search.json | 1 + web/src/components/input/InputWithTags.tsx | 2 +- .../overlay/detail/SearchDetailDialog.tsx | 182 +++++++++++++++++- .../overlay/dialog/AttributeSelectDialog.tsx | 123 ++++++++++++ .../overlay/dialog/MultiSelectDialog.tsx | 96 +++++++++ .../overlay/dialog/SearchFilterDialog.tsx | 90 ++++++++- web/src/pages/Explore.tsx | 3 + web/src/types/search.ts | 4 + web/src/views/search/SearchView.tsx | 10 + 18 files changed, 808 insertions(+), 6 deletions(-) create mode 100644 web/src/components/overlay/dialog/AttributeSelectDialog.tsx create mode 100644 web/src/components/overlay/dialog/MultiSelectDialog.tsx diff --git a/docs/docs/configuration/custom_classification/object_classification.md b/docs/docs/configuration/custom_classification/object_classification.md index 37d908285..52056a007 100644 --- a/docs/docs/configuration/custom_classification/object_classification.md +++ b/docs/docs/configuration/custom_classification/object_classification.md @@ -3,7 +3,7 @@ id: object_classification title: Object Classification --- -Object classification allows you to train a custom MobileNetV2 classification model to run on tracked objects (persons, cars, animals, etc.) to identify a finer category or attribute for that object. +Object classification allows you to train a custom MobileNetV2 classification model to run on tracked objects (persons, cars, animals, etc.) to identify a finer category or attribute for that object. Classification results are visible in the Tracked Object Details pane in Explore, through the `frigate/tracked_object_details` MQTT topic, in Home Assistant sensors via the official Frigate integration, or through the event endpoints in the HTTP API. ## Minimum System Requirements diff --git a/docs/docs/configuration/custom_classification/state_classification.md b/docs/docs/configuration/custom_classification/state_classification.md index 6b95a2567..196ec78de 100644 --- a/docs/docs/configuration/custom_classification/state_classification.md +++ b/docs/docs/configuration/custom_classification/state_classification.md @@ -3,7 +3,7 @@ id: state_classification title: State Classification --- -State classification allows you to train a custom MobileNetV2 classification model on a fixed region of your camera frame(s) to determine a current state. The model can be configured to run on a schedule and/or when motion is detected in that region. +State classification allows you to train a custom MobileNetV2 classification model on a fixed region of your camera frame(s) to determine a current state. The model can be configured to run on a schedule and/or when motion is detected in that region. Classification results are available through the `frigate//classification/` MQTT topic and in Home Assistant sensors via the official Frigate integration. ## Minimum System Requirements diff --git a/docs/static/frigate-api.yaml b/docs/static/frigate-api.yaml index 624688965..1cfe1b91f 100644 --- a/docs/static/frigate-api.yaml +++ b/docs/static/frigate-api.yaml @@ -616,6 +616,32 @@ paths: application/json: schema: $ref: "#/components/schemas/HTTPValidationError" + /classification/attributes: + get: + tags: + - Classification + summary: Get custom classification attributes + description: |- + Returns custom classification attributes for a given object type. + Only includes models with classification_type set to 'attribute'. + By default returns a flat sorted list of all attribute labels. + If group_by_model is true, returns attributes grouped by model name. + operationId: get_custom_attributes_classification_attributes_get + parameters: + - name: object_type + in: query + schema: + type: string + - name: group_by_model + in: query + schema: + type: boolean + default: false + responses: + "200": + description: Successful Response + "422": + description: Validation Error /classification/{name}/dataset: get: tags: @@ -2912,6 +2938,42 @@ paths: application/json: schema: $ref: "#/components/schemas/HTTPValidationError" + /events/{event_id}/attributes: + post: + tags: + - Events + summary: Set custom classification attributes + description: |- + Sets an event's custom classification attributes for all attribute-type + models that apply to the event's object type. + Returns a success message or an error if the event is not found. + operationId: set_attributes_events__event_id__attributes_post + parameters: + - name: event_id + in: path + required: true + schema: + type: string + title: Event Id + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/EventsAttributesBody" + responses: + "200": + description: Successful Response + content: + application/json: + schema: + $ref: "#/components/schemas/GenericResponse" + "422": + description: Validation Error + content: + application/json: + schema: + $ref: "#/components/schemas/HTTPValidationError" /events/{event_id}/description: post: tags: @@ -4959,6 +5021,18 @@ components: required: - subLabel title: EventsSubLabelBody + EventsAttributesBody: + properties: + attributes: + type: object + title: Attributes + description: Object with model names as keys and attribute values + additionalProperties: + type: string + type: object + required: + - attributes + title: EventsAttributesBody ExportModel: properties: id: diff --git a/frigate/api/classification.py b/frigate/api/classification.py index deafaf956..18e590ce1 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -31,6 +31,7 @@ from frigate.api.defs.response.generic_response import GenericResponse from frigate.api.defs.tags import Tags from frigate.config import FrigateConfig from frigate.config.camera import DetectConfig +from frigate.config.classification import ObjectClassificationType from frigate.const import CLIPS_DIR, FACE_DIR, MODEL_CACHE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event @@ -622,6 +623,59 @@ def get_classification_dataset(name: str): ) +@router.get( + "/classification/attributes", + summary="Get custom classification attributes", + description="""Returns custom classification attributes for a given object type. + Only includes models with classification_type set to 'attribute'. + By default returns a flat sorted list of all attribute labels. + If group_by_model is true, returns attributes grouped by model name.""", +) +def get_custom_attributes( + request: Request, object_type: str = None, group_by_model: bool = False +): + models_with_attributes = {} + + for ( + model_key, + model_config, + ) in request.app.frigate_config.classification.custom.items(): + if ( + not model_config.enabled + or not model_config.object_config + or model_config.object_config.classification_type + != ObjectClassificationType.attribute + ): + continue + + model_objects = getattr(model_config.object_config, "objects", []) or [] + if object_type is not None and object_type not in model_objects: + continue + + dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(model_key), "dataset") + if not os.path.exists(dataset_dir): + continue + + attributes = [] + for category_name in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category_name) + if os.path.isdir(category_dir) and category_name != "none": + attributes.append(category_name) + + if attributes: + model_name = model_config.name or model_key + models_with_attributes[model_name] = sorted(attributes) + + if group_by_model: + return JSONResponse(content=models_with_attributes) + else: + # Flatten to a unique sorted list + all_attributes = set() + for attributes in models_with_attributes.values(): + all_attributes.update(attributes) + return JSONResponse(content=sorted(list(all_attributes))) + + @router.get( "/classification/{name}/train", summary="Get classification train images", diff --git a/frigate/api/defs/query/events_query_parameters.py b/frigate/api/defs/query/events_query_parameters.py index 187dd3f91..8e5a5391a 100644 --- a/frigate/api/defs/query/events_query_parameters.py +++ b/frigate/api/defs/query/events_query_parameters.py @@ -12,6 +12,7 @@ class EventsQueryParams(BaseModel): labels: Optional[str] = "all" sub_label: Optional[str] = "all" sub_labels: Optional[str] = "all" + attributes: Optional[str] = "all" zone: Optional[str] = "all" zones: Optional[str] = "all" limit: Optional[int] = 100 @@ -58,6 +59,8 @@ class EventsSearchQueryParams(BaseModel): limit: Optional[int] = 50 cameras: Optional[str] = "all" labels: Optional[str] = "all" + sub_labels: Optional[str] = "all" + attributes: Optional[str] = "all" zones: Optional[str] = "all" after: Optional[float] = None before: Optional[float] = None diff --git a/frigate/api/defs/request/events_body.py b/frigate/api/defs/request/events_body.py index 6110e34f5..50754e92a 100644 --- a/frigate/api/defs/request/events_body.py +++ b/frigate/api/defs/request/events_body.py @@ -24,6 +24,13 @@ class EventsLPRBody(BaseModel): ) +class EventsAttributesBody(BaseModel): + attributes: List[str] = Field( + title="Selected classification attributes for the event", + default_factory=list, + ) + + class EventsDescriptionBody(BaseModel): description: Union[str, None] = Field(title="The description of the event") diff --git a/frigate/api/event.py b/frigate/api/event.py index fc78ac0e5..ea5cfb29c 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -37,6 +37,7 @@ from frigate.api.defs.query.regenerate_query_parameters import ( RegenerateQueryParameters, ) from frigate.api.defs.request.events_body import ( + EventsAttributesBody, EventsCreateBody, EventsDeleteBody, EventsDescriptionBody, @@ -55,6 +56,7 @@ from frigate.api.defs.response.event_response import ( from frigate.api.defs.response.generic_response import GenericResponse from frigate.api.defs.tags import Tags from frigate.comms.event_metadata_updater import EventMetadataTypeEnum +from frigate.config.classification import ObjectClassificationType from frigate.const import CLIPS_DIR, TRIGGER_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event, ReviewSegment, Timeline, Trigger @@ -99,6 +101,8 @@ def events( if sub_labels == "all" and sub_label != "all": sub_labels = sub_label + attributes = unquote(params.attributes) + zone = params.zone zones = params.zones @@ -187,6 +191,17 @@ def events( sub_label_clause = reduce(operator.or_, sub_label_clauses) clauses.append((sub_label_clause)) + if attributes != "all": + # Custom classification results are stored as data[model_name] = result_value + filtered_attributes = attributes.split(",") + attribute_clauses = [] + + for attr in filtered_attributes: + attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*') + + attribute_clause = reduce(operator.or_, attribute_clauses) + clauses.append(attribute_clause) + if recognized_license_plate != "all": filtered_recognized_license_plates = recognized_license_plate.split(",") @@ -492,6 +507,8 @@ def events_search( # Filters cameras = params.cameras labels = params.labels + sub_labels = params.sub_labels + attributes = params.attributes zones = params.zones after = params.after before = params.before @@ -566,6 +583,38 @@ def events_search( if labels != "all": event_filters.append((Event.label << labels.split(","))) + if sub_labels != "all": + # use matching so joined sub labels are included + # for example a sub label 'bob' would get events + # with sub labels 'bob' and 'bob, john' + sub_label_clauses = [] + filtered_sub_labels = sub_labels.split(",") + + if "None" in filtered_sub_labels: + filtered_sub_labels.remove("None") + sub_label_clauses.append((Event.sub_label.is_null())) + + for label in filtered_sub_labels: + sub_label_clauses.append( + (Event.sub_label.cast("text") == label) + ) # include exact matches + + # include this label when part of a list + sub_label_clauses.append((Event.sub_label.cast("text") % f"*{label},*")) + sub_label_clauses.append((Event.sub_label.cast("text") % f"*, {label}*")) + + event_filters.append((reduce(operator.or_, sub_label_clauses))) + + if attributes != "all": + # Custom classification results are stored as data[model_name] = result_value + filtered_attributes = attributes.split(",") + attribute_clauses = [] + + for attr in filtered_attributes: + attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*') + + event_filters.append(reduce(operator.or_, attribute_clauses)) + if zones != "all": zone_clauses = [] filtered_zones = zones.split(",") @@ -1351,6 +1400,107 @@ async def set_plate( ) +@router.post( + "/events/{event_id}/attributes", + response_model=GenericResponse, + dependencies=[Depends(require_role(["admin"]))], + summary="Set custom classification attributes", + description=( + "Sets an event's custom classification attributes for all attribute-type " + "models that apply to the event's object type." + ), +) +async def set_attributes( + request: Request, + event_id: str, + body: EventsAttributesBody, +): + try: + event: Event = Event.get(Event.id == event_id) + await require_camera_access(event.camera, request=request) + except DoesNotExist: + return JSONResponse( + content=({"success": False, "message": f"Event {event_id} not found."}), + status_code=404, + ) + + object_type = event.label + selected_attributes = set(body.attributes or []) + applied_updates: list[dict[str, str | float | None]] = [] + + for ( + model_key, + model_config, + ) in request.app.frigate_config.classification.custom.items(): + # Only apply to enabled attribute classifiers that target this object type + if ( + not model_config.enabled + or not model_config.object_config + or model_config.object_config.classification_type + != ObjectClassificationType.attribute + or object_type not in (model_config.object_config.objects or []) + ): + continue + + # Get available labels from dataset directory + dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(model_key), "dataset") + available_labels = set() + + if os.path.exists(dataset_dir): + for category_name in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category_name) + if os.path.isdir(category_dir): + available_labels.add(category_name) + + if not available_labels: + logger.warning( + "No dataset found for custom attribute model %s at %s", + model_key, + dataset_dir, + ) + continue + + # Find all selected attributes that apply to this model + model_name = model_config.name or model_key + matching_attrs = selected_attributes & available_labels + + if matching_attrs: + # Publish updates for each selected attribute + for attr in matching_attrs: + request.app.event_metadata_updater.publish( + (event_id, model_name, attr, 1.0), + EventMetadataTypeEnum.attribute.value, + ) + applied_updates.append( + {"model": model_name, "label": attr, "score": 1.0} + ) + else: + # Clear this model's attribute + request.app.event_metadata_updater.publish( + (event_id, model_name, None, None), + EventMetadataTypeEnum.attribute.value, + ) + applied_updates.append({"model": model_name, "label": None, "score": None}) + + if len(applied_updates) == 0: + return JSONResponse( + content={ + "success": False, + "message": "No matching attributes found for this object type.", + }, + status_code=400, + ) + + return JSONResponse( + content={ + "success": True, + "message": f"Updated {len(applied_updates)} attribute(s)", + "applied": applied_updates, + }, + status_code=200, + ) + + @router.post( "/events/{event_id}/description", response_model=GenericResponse, diff --git a/web/public/locales/en/components/filter.json b/web/public/locales/en/components/filter.json index 177234bed..e9ae5c769 100644 --- a/web/public/locales/en/components/filter.json +++ b/web/public/locales/en/components/filter.json @@ -38,6 +38,10 @@ "label": "Sub Labels", "all": "All Sub Labels" }, + "attributes": { + "label": "Classification Attributes", + "all": "All Attributes" + }, "score": "Score", "estimatedSpeed": "Estimated Speed ({{unit}})", "features": { diff --git a/web/public/locales/en/views/explore.json b/web/public/locales/en/views/explore.json index 6c938c109..ff95e2fc6 100644 --- a/web/public/locales/en/views/explore.json +++ b/web/public/locales/en/views/explore.json @@ -104,12 +104,14 @@ "regenerate": "A new description has been requested from {{provider}}. Depending on the speed of your provider, the new description may take some time to regenerate.", "updatedSublabel": "Successfully updated sub label.", "updatedLPR": "Successfully updated license plate.", + "updatedAttributes": "Successfully updated attributes.", "audioTranscription": "Successfully requested audio transcription. Depending on the speed of your Frigate server, the transcription may take some time to complete." }, "error": { "regenerate": "Failed to call {{provider}} for a new description: {{errorMessage}}", "updatedSublabelFailed": "Failed to update sub label: {{errorMessage}}", "updatedLPRFailed": "Failed to update license plate: {{errorMessage}}", + "updatedAttributesFailed": "Failed to update attributes: {{errorMessage}}", "audioTranscription": "Failed to request audio transcription: {{errorMessage}}" } } @@ -125,6 +127,10 @@ "desc": "Enter a new license plate value for this {{label}}", "descNoLabel": "Enter a new license plate value for this tracked object" }, + "editAttributes": { + "title": "Edit attributes", + "desc": "Select classification attributes for this {{label}}" + }, "snapshotScore": { "label": "Snapshot Score" }, @@ -136,6 +142,7 @@ "label": "Score" }, "recognizedLicensePlate": "Recognized License Plate", + "attributes": "Classification Attributes", "estimatedSpeed": "Estimated Speed", "objects": "Objects", "camera": "Camera", diff --git a/web/public/locales/en/views/search.json b/web/public/locales/en/views/search.json index 22da7721f..dae622c70 100644 --- a/web/public/locales/en/views/search.json +++ b/web/public/locales/en/views/search.json @@ -16,6 +16,7 @@ "labels": "Labels", "zones": "Zones", "sub_labels": "Sub Labels", + "attributes": "Attributes", "search_type": "Search Type", "time_range": "Time Range", "before": "Before", diff --git a/web/src/components/input/InputWithTags.tsx b/web/src/components/input/InputWithTags.tsx index 298537136..70f1cd0c9 100755 --- a/web/src/components/input/InputWithTags.tsx +++ b/web/src/components/input/InputWithTags.tsx @@ -399,7 +399,7 @@ export default function InputWithTags({ newFilters.sort = value as SearchSortType; break; default: - // Handle array types (cameras, labels, subLabels, zones) + // Handle array types (cameras, labels, sub_labels, attributes, zones) if (!newFilters[type]) newFilters[type] = []; if (Array.isArray(newFilters[type])) { if (!(newFilters[type] as string[]).includes(value)) { diff --git a/web/src/components/overlay/detail/SearchDetailDialog.tsx b/web/src/components/overlay/detail/SearchDetailDialog.tsx index 1c46213df..392e929eb 100644 --- a/web/src/components/overlay/detail/SearchDetailDialog.tsx +++ b/web/src/components/overlay/detail/SearchDetailDialog.tsx @@ -84,6 +84,7 @@ import { LuInfo } from "react-icons/lu"; import { TooltipPortal } from "@radix-ui/react-tooltip"; import { FaPencilAlt } from "react-icons/fa"; import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog"; +import AttributeSelectDialog from "@/components/overlay/dialog/AttributeSelectDialog"; import { Trans, useTranslation } from "react-i18next"; import { useIsAdmin } from "@/hooks/use-is-admin"; import { getTranslatedLabel } from "@/utils/i18n"; @@ -297,6 +298,7 @@ type DialogContentComponentProps = { isPopoverOpen: boolean; setIsPopoverOpen: (open: boolean) => void; dialogContainer: HTMLDivElement | null; + setShowNavigationButtons: React.Dispatch>; }; function DialogContentComponent({ @@ -314,6 +316,7 @@ function DialogContentComponent({ isPopoverOpen, setIsPopoverOpen, dialogContainer, + setShowNavigationButtons, }: DialogContentComponentProps) { if (page === "tracking_details") { return ( @@ -399,6 +402,7 @@ function DialogContentComponent({ config={config} setSearch={setSearch} setInputFocused={setInputFocused} + setShowNavigationButtons={setShowNavigationButtons} /> @@ -415,6 +419,7 @@ function DialogContentComponent({ config={config} setSearch={setSearch} setInputFocused={setInputFocused} + setShowNavigationButtons={setShowNavigationButtons} /> ); @@ -459,6 +464,7 @@ export default function SearchDetailDialog({ const [isOpen, setIsOpen] = useState(search != undefined); const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [showNavigationButtons, setShowNavigationButtons] = useState(false); const dialogContentRef = useRef(null); const [dialogContainer, setDialogContainer] = useState( null, @@ -540,9 +546,9 @@ export default function SearchDetailDialog({ onOpenChange={handleOpenChange} enableHistoryBack={true} > - {isDesktop && onPrevious && onNext && ( + {isDesktop && onPrevious && onNext && showNavigationButtons && ( -
+
@@ -664,12 +671,14 @@ type ObjectDetailsTabProps = { config?: FrigateConfig; setSearch: (search: SearchResult | undefined) => void; setInputFocused: React.Dispatch>; + setShowNavigationButtons?: React.Dispatch>; }; function ObjectDetailsTab({ search, config, setSearch, setInputFocused, + setShowNavigationButtons, }: ObjectDetailsTabProps) { const { t, i18n } = useTranslation([ "views/explore", @@ -678,6 +687,15 @@ function ObjectDetailsTab({ ]); const apiHost = useApiHost(); + const hasCustomClassificationModels = useMemo( + () => Object.keys(config?.classification?.custom ?? {}).length > 0, + [config], + ); + const { data: modelAttributes } = useSWR>( + hasCustomClassificationModels && search + ? `classification/attributes?object_type=${encodeURIComponent(search.label)}&group_by_model=true` + : null, + ); // mutation / revalidation @@ -708,6 +726,7 @@ function ObjectDetailsTab({ const [desc, setDesc] = useState(search?.data.description); const [isSubLabelDialogOpen, setIsSubLabelDialogOpen] = useState(false); const [isLPRDialogOpen, setIsLPRDialogOpen] = useState(false); + const [isAttributesDialogOpen, setIsAttributesDialogOpen] = useState(false); const [isEditingDesc, setIsEditingDesc] = useState(false); const originalDescRef = useRef(null); @@ -722,6 +741,19 @@ function ObjectDetailsTab({ // we have to make sure the current selected search item stays in sync useEffect(() => setDesc(search?.data.description ?? ""), [search]); + useEffect(() => setIsAttributesDialogOpen(false), [search?.id]); + + useEffect(() => { + const anyDialogOpen = + isSubLabelDialogOpen || isLPRDialogOpen || isAttributesDialogOpen; + setShowNavigationButtons?.(!anyDialogOpen); + }, [ + isSubLabelDialogOpen, + isLPRDialogOpen, + isAttributesDialogOpen, + setShowNavigationButtons, + ]); + const formattedDate = useFormattedTimestamp( search?.start_time ?? 0, config?.ui.time_format == "24hour" @@ -807,6 +839,41 @@ function ObjectDetailsTab({ } }, [search]); + // Extract current attribute selections grouped by model + const selectedAttributesByModel = useMemo(() => { + if (!search || !modelAttributes) { + return {}; + } + + const dataAny = search.data as Record; + const selections: Record = {}; + + // Initialize all models with null + Object.keys(modelAttributes).forEach((modelName) => { + selections[modelName] = null; + }); + + // Find which attribute is selected for each model + Object.keys(modelAttributes).forEach((modelName) => { + const value = dataAny[modelName]; + if ( + typeof value === "string" && + modelAttributes[modelName].includes(value) + ) { + selections[modelName] = value; + } + }); + + return selections; + }, [search, modelAttributes]); + + // Get flat list of selected attributes for display + const eventAttributes = useMemo(() => { + return Object.values(selectedAttributesByModel) + .filter((attr): attr is string => attr !== null) + .sort((a, b) => a.localeCompare(b)); + }, [selectedAttributesByModel]); + const isEventsKey = useCallback((key: unknown): boolean => { const candidate = Array.isArray(key) ? key[0] : key; const EVENTS_KEY_PATTERNS = ["events", "events/search", "events/explore"]; @@ -1048,6 +1115,74 @@ function ObjectDetailsTab({ [search, apiHost, mutate, setSearch, t, mapSearchResults, isEventsKey], ); + const handleAttributesSave = useCallback( + (selectedAttributes: string[]) => { + if (!search) return; + + axios + .post(`${apiHost}api/events/${search.id}/attributes`, { + attributes: selectedAttributes, + }) + .then((response) => { + const applied = Array.isArray(response.data?.applied) + ? (response.data.applied as { + model?: string; + label?: string | null; + score?: number | null; + }[]) + : []; + + toast.success(t("details.item.toast.success.updatedAttributes"), { + position: "top-center", + }); + + const applyUpdatedAttributes = (event: SearchResult) => { + if (event.id !== search.id) return event; + + const updatedData: Record = { ...event.data }; + + applied.forEach(({ model, label, score }) => { + if (!model) return; + updatedData[model] = label ?? null; + updatedData[`${model}_score`] = score ?? null; + }); + + return { ...event, data: updatedData } as SearchResult; + }; + + mutate( + (key) => isEventsKey(key), + (currentData: SearchResult[][] | SearchResult[] | undefined) => + mapSearchResults(currentData, applyUpdatedAttributes), + { + optimisticData: true, + rollbackOnError: true, + revalidate: false, + }, + ); + + setSearch(applyUpdatedAttributes(search)); + setIsAttributesDialogOpen(false); + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + + toast.error( + t("details.item.toast.error.updatedAttributesFailed", { + errorMessage, + }), + { + position: "top-center", + }, + ); + }); + }, + [search, apiHost, mutate, t, mapSearchResults, isEventsKey, setSearch], + ); + // speech transcription const onTranscribe = useCallback(() => { @@ -1295,6 +1430,38 @@ function ObjectDetailsTab({
)} + + {hasCustomClassificationModels && + modelAttributes && + Object.keys(modelAttributes).length > 0 && ( +
+
+ {t("details.attributes")} + {isAdmin && ( + + + + setIsAttributesDialogOpen(true)} + /> + + + + + {t("button.edit", { ns: "common" })} + + + + )} +
+
+ {eventAttributes.length > 0 + ? eventAttributes.join(", ") + : t("label.none", { ns: "common" })} +
+
+ )}
@@ -1595,6 +1762,17 @@ function ObjectDetailsTab({ defaultValue={search?.data.recognized_license_plate || ""} allowEmpty={true} /> + ); diff --git a/web/src/components/overlay/dialog/AttributeSelectDialog.tsx b/web/src/components/overlay/dialog/AttributeSelectDialog.tsx new file mode 100644 index 000000000..b2ddc48ea --- /dev/null +++ b/web/src/components/overlay/dialog/AttributeSelectDialog.tsx @@ -0,0 +1,123 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Label } from "@/components/ui/label"; +import { Switch } from "@/components/ui/switch"; +import { cn } from "@/lib/utils"; +import { useCallback, useEffect, useState } from "react"; +import { isDesktop } from "react-device-detect"; +import { useTranslation } from "react-i18next"; + +type AttributeSelectDialogProps = { + open: boolean; + setOpen: (open: boolean) => void; + title: string; + description: string; + onSave: (selectedAttributes: string[]) => void; + selectedAttributes: Record; // model -> selected attribute + modelAttributes: Record; // model -> available attributes + className?: string; +}; + +export default function AttributeSelectDialog({ + open, + setOpen, + title, + description, + onSave, + selectedAttributes, + modelAttributes, + className, +}: AttributeSelectDialogProps) { + const { t } = useTranslation(); + const [internalSelection, setInternalSelection] = useState< + Record + >({}); + + useEffect(() => { + if (open) { + setInternalSelection({ ...selectedAttributes }); + } + }, [open, selectedAttributes]); + + const handleSave = useCallback(() => { + // Convert from model->attribute map to flat list of attributes + const attributes = Object.values(internalSelection).filter( + (attr): attr is string => attr !== null, + ); + onSave(attributes); + }, [internalSelection, onSave]); + + const handleToggle = useCallback((modelName: string, attribute: string) => { + setInternalSelection((prev) => { + const currentSelection = prev[modelName]; + // If clicking the currently selected attribute, deselect it + if (currentSelection === attribute) { + return { ...prev, [modelName]: null }; + } + // Otherwise, select this attribute for this model + return { ...prev, [modelName]: attribute }; + }); + }, []); + + return ( + + e.preventDefault()} + > + + {title} + {description} + +
+
+ {Object.entries(modelAttributes).map(([modelName, attributes]) => ( +
+
+ {modelName} +
+
+ {attributes.map((attribute) => ( +
+ + + handleToggle(modelName, attribute) + } + /> +
+ ))} +
+
+ ))} +
+
+ + + + +
+
+ ); +} diff --git a/web/src/components/overlay/dialog/MultiSelectDialog.tsx b/web/src/components/overlay/dialog/MultiSelectDialog.tsx new file mode 100644 index 000000000..b30ac9cf5 --- /dev/null +++ b/web/src/components/overlay/dialog/MultiSelectDialog.tsx @@ -0,0 +1,96 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { cn } from "@/lib/utils"; +import { useState } from "react"; +import { isMobile } from "react-device-detect"; +import { useTranslation } from "react-i18next"; +import FilterSwitch from "@/components/filter/FilterSwitch"; + +type MultiSelectDialogProps = { + open: boolean; + title: string; + description?: string; + setOpen: (open: boolean) => void; + onSave: (selectedItems: string[]) => void; + selectedItems: string[]; + availableItems: string[]; + allowEmpty?: boolean; +}; + +export default function MultiSelectDialog({ + open, + title, + description, + setOpen, + onSave, + selectedItems = [], + availableItems = [], + allowEmpty = false, +}: MultiSelectDialogProps) { + const { t } = useTranslation("common"); + const [internalSelection, setInternalSelection] = + useState(selectedItems); + + // Reset internal selection when dialog opens + const handleOpenChange = (isOpen: boolean) => { + if (isOpen) { + setInternalSelection(selectedItems); + } + setOpen(isOpen); + }; + + const toggleItem = (item: string) => { + setInternalSelection((prev) => + prev.includes(item) ? prev.filter((i) => i !== item) : [...prev, item], + ); + }; + + const handleSave = () => { + if (!allowEmpty && internalSelection.length === 0) { + return; + } + onSave(internalSelection); + setOpen(false); + }; + + return ( + + + + {title} + {description && {description}} + +
+ {availableItems.map((item) => ( + toggleItem(item)} + /> + ))} +
+ + + + +
+
+ ); +} diff --git a/web/src/components/overlay/dialog/SearchFilterDialog.tsx b/web/src/components/overlay/dialog/SearchFilterDialog.tsx index 3ee2052d0..eb1188257 100644 --- a/web/src/components/overlay/dialog/SearchFilterDialog.tsx +++ b/web/src/components/overlay/dialog/SearchFilterDialog.tsx @@ -65,6 +65,13 @@ export default function SearchFilterDialog({ const { t } = useTranslation(["components/filter"]); const [currentFilter, setCurrentFilter] = useState(filter ?? {}); const { data: allSubLabels } = useSWR(["sub_labels", { split_joined: 1 }]); + const hasCustomClassificationModels = useMemo( + () => Object.keys(config?.classification?.custom ?? {}).length > 0, + [config], + ); + const { data: allAttributes } = useSWR( + hasCustomClassificationModels ? "classification/attributes" : null, + ); const { data: allRecognizedLicensePlates } = useSWR( "recognized_license_plates", ); @@ -91,8 +98,10 @@ export default function SearchFilterDialog({ (currentFilter.max_speed ?? 150) < 150 || (currentFilter.zones?.length ?? 0) > 0 || (currentFilter.sub_labels?.length ?? 0) > 0 || + (hasCustomClassificationModels && + (currentFilter.attributes?.length ?? 0) > 0) || (currentFilter.recognized_license_plate?.length ?? 0) > 0), - [currentFilter], + [currentFilter, hasCustomClassificationModels], ); const trigger = ( @@ -133,6 +142,15 @@ export default function SearchFilterDialog({ setCurrentFilter({ ...currentFilter, sub_labels: newSubLabels }) } /> + {hasCustomClassificationModels && ( + + setCurrentFilter({ ...currentFilter, attributes: newAttributes }) + } + /> + )} ); } + +type AttributeFilterContentProps = { + allAttributes?: string[]; + attributes: string[] | undefined; + setAttributes: (labels: string[] | undefined) => void; +}; +export function AttributeFilterContent({ + allAttributes, + attributes, + setAttributes, +}: AttributeFilterContentProps) { + const { t } = useTranslation(["components/filter"]); + const sortedAttributes = useMemo( + () => + [...(allAttributes || [])].sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()), + ), + [allAttributes], + ); + return ( +
+ +
{t("attributes.label")}
+
+ + { + if (isChecked) { + setAttributes(undefined); + } + }} + /> +
+
+ {sortedAttributes.map((item) => ( + { + if (isChecked) { + const updatedAttributes = attributes ? [...attributes] : []; + + updatedAttributes.push(item); + setAttributes(updatedAttributes); + } else { + const updatedAttributes = attributes ? [...attributes] : []; + + // can not deselect the last item + if (updatedAttributes.length > 1) { + updatedAttributes.splice(updatedAttributes.indexOf(item), 1); + setAttributes(updatedAttributes); + } + } + }} + /> + ))} +
+
+ ); +} diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 53ebd0401..8f50e982e 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -31,6 +31,7 @@ const SEARCH_FILTER_ARRAY_KEYS = [ "cameras", "labels", "sub_labels", + "attributes", "recognized_license_plate", "zones", ]; @@ -122,6 +123,7 @@ export default function Explore() { cameras: searchSearchParams["cameras"], labels: searchSearchParams["labels"], sub_labels: searchSearchParams["sub_labels"], + attributes: searchSearchParams["attributes"], recognized_license_plate: searchSearchParams["recognized_license_plate"], zones: searchSearchParams["zones"], @@ -158,6 +160,7 @@ export default function Explore() { cameras: searchSearchParams["cameras"], labels: searchSearchParams["labels"], sub_labels: searchSearchParams["sub_labels"], + attributes: searchSearchParams["attributes"], recognized_license_plate: searchSearchParams["recognized_license_plate"], zones: searchSearchParams["zones"], diff --git a/web/src/types/search.ts b/web/src/types/search.ts index 8fb81dfc6..d47e95584 100644 --- a/web/src/types/search.ts +++ b/web/src/types/search.ts @@ -5,6 +5,7 @@ const SEARCH_FILTERS = [ "general", "zone", "sub", + "attribute", "source", "sort", ] as const; @@ -16,6 +17,7 @@ export const DEFAULT_SEARCH_FILTERS: SearchFilters[] = [ "general", "zone", "sub", + "attribute", "source", "sort", ]; @@ -71,6 +73,7 @@ export type SearchFilter = { cameras?: string[]; labels?: string[]; sub_labels?: string[]; + attributes?: string[]; recognized_license_plate?: string[]; zones?: string[]; before?: number; @@ -95,6 +98,7 @@ export type SearchQueryParams = { cameras?: string[]; labels?: string[]; sub_labels?: string[]; + attributes?: string[]; recognized_license_plate?: string[]; zones?: string[]; before?: string; diff --git a/web/src/views/search/SearchView.tsx b/web/src/views/search/SearchView.tsx index 426b7e209..a373acc82 100644 --- a/web/src/views/search/SearchView.tsx +++ b/web/src/views/search/SearchView.tsx @@ -143,6 +143,13 @@ export default function SearchView({ }, [config, searchFilter, allowedCameras]); const { data: allSubLabels } = useSWR("sub_labels"); + const hasCustomClassificationModels = useMemo( + () => Object.keys(config?.classification?.custom ?? {}).length > 0, + [config], + ); + const { data: allAttributes } = useSWR( + hasCustomClassificationModels ? "classification/attributes" : null, + ); const { data: allRecognizedLicensePlates } = useSWR( "recognized_license_plates", ); @@ -182,6 +189,7 @@ export default function SearchView({ labels: Object.values(allLabels || {}), zones: Object.values(allZones || {}), sub_labels: allSubLabels, + ...(hasCustomClassificationModels && { attributes: allAttributes }), search_type: ["thumbnail", "description"] as SearchSource[], time_range: config?.ui.time_format == "24hour" @@ -204,9 +212,11 @@ export default function SearchView({ allLabels, allZones, allSubLabels, + allAttributes, allRecognizedLicensePlates, searchFilter, allowedCameras, + hasCustomClassificationModels, ], );