fix: Refactor update method to handle is_public cascading for related items

This commit is contained in:
Sean Morley
2026-01-11 13:01:00 -05:00
parent bc8bc4b487
commit fda1d039fd

View File

@@ -243,69 +243,6 @@ class CollectionViewSet(viewsets.ModelViewSet):
return Response(data)
# this make the is_public field of the collection cascade to the locations
@transaction.atomic
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
if 'collection' in serializer.validated_data:
new_collection = serializer.validated_data['collection']
# if the new collection is different from the old one and the user making the request is not the owner of the new collection return an error
if new_collection != instance.collection and new_collection.user != request.user:
return Response({"error": "User does not own the new collection"}, status=400)
# Check if the 'is_public' field is present in the update data
if 'is_public' in serializer.validated_data:
new_public_status = serializer.validated_data['is_public']
# if is_public has changed and the user is not the owner of the collection return an error
if new_public_status != instance.is_public and instance.user != request.user:
print(f"User {request.user.id} does not own the collection {instance.id} that is owned by {instance.user}")
return Response({"error": "User does not own the collection"}, status=400)
# Get all locations in this collection
locations_in_collection = Location.objects.filter(collections=instance)
if new_public_status:
# If collection becomes public, make all locations public
locations_in_collection.update(is_public=True)
else:
# If collection becomes private, check each location
# Only set a location to private if ALL of its collections are private
# Collect locations that do NOT belong to any other public collection (excluding the current one)
location_ids_to_set_private = []
for location in locations_in_collection:
has_public_collection = location.collections.filter(is_public=True).exclude(id=instance.id).exists()
if not has_public_collection:
location_ids_to_set_private.append(location.id)
# Bulk update those locations
Location.objects.filter(id__in=location_ids_to_set_private).update(is_public=False)
# Update transportations, notes, checklists, and lodgings related to this collection
# These still use direct ForeignKey relationships
Transportation.objects.filter(collection=instance).update(is_public=new_public_status)
Note.objects.filter(collection=instance).update(is_public=new_public_status)
Checklist.objects.filter(collection=instance).update(is_public=new_public_status)
Lodging.objects.filter(collection=instance).update(is_public=new_public_status)
# Log the action (optional)
action = "public" if new_public_status else "private"
print(f"Collection {instance.id} and its related objects were set to {action}")
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return Response(serializer.data)
# make an action to retreive all locations that are shared with the user
@action(detail=False, methods=['get'])
def shared(self, request):
@@ -882,8 +819,9 @@ class CollectionViewSet(viewsets.ModelViewSet):
@transaction.atomic
def update(self, request, *args, **kwargs):
"""Override update to clean up out-of-range itinerary items when dates change."""
"""Override update to handle is_public cascading and clean up out-of-range itinerary items when dates change."""
instance = self.get_object()
old_is_public = instance.is_public
old_start_date = instance.start_date
old_end_date = instance.end_date
@@ -893,6 +831,59 @@ class CollectionViewSet(viewsets.ModelViewSet):
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
# Check if is_public changed
new_is_public = serializer.instance.is_public
is_public_changed = old_is_public != new_is_public
# Handle is_public cascading
if is_public_changed:
if new_is_public:
# Collection is being made public, update all linked items to public
serializer.instance.locations.filter(is_public=False).update(is_public=True)
serializer.instance.transportation_set.filter(is_public=False).update(is_public=True)
serializer.instance.note_set.filter(is_public=False).update(is_public=True)
serializer.instance.checklist_set.filter(is_public=False).update(is_public=True)
serializer.instance.lodging_set.filter(is_public=False).update(is_public=True)
else:
# Collection is being made private, check each linked item
# Only set an item to private if it doesn't belong to any other public collection
# Handle locations (many-to-many relationship)
locations_in_collection = serializer.instance.locations.filter(is_public=True)
for location in locations_in_collection:
# Check if this location belongs to any other public collection
has_other_public_collection = location.collections.filter(
is_public=True
).exclude(id=serializer.instance.id).exists()
if not has_other_public_collection:
location.is_public = False
location.save(update_fields=['is_public'])
# Handle transportations, notes, checklists, lodging (foreign key relationships)
# Transportation
transportations_to_check = serializer.instance.transportation_set.filter(is_public=True)
for transportation in transportations_to_check:
transportation.is_public = False
transportation.save(update_fields=['is_public'])
# Notes
notes_to_check = serializer.instance.note_set.filter(is_public=True)
for note in notes_to_check:
note.is_public = False
note.save(update_fields=['is_public'])
# Checklists
checklists_to_check = serializer.instance.checklist_set.filter(is_public=True)
for checklist in checklists_to_check:
checklist.is_public = False
checklist.save(update_fields=['is_public'])
# Lodging
lodging_to_check = serializer.instance.lodging_set.filter(is_public=True)
for lodging in lodging_to_check:
lodging.is_public = False
lodging.save(update_fields=['is_public'])
# Check if dates changed
new_start_date = serializer.instance.start_date
new_end_date = serializer.instance.end_date