Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,11 @@ def get_queryset(self) -> QuerySet:
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)
if self.action == "retrieve":
qs = self.get_taxa_observed(
qs, project, include_unobserved=include_unobserved, apply_default_filters=False
qs,
project,
include_unobserved=include_unobserved,
apply_default_score_filter=True,
apply_default_taxa_filter=False,
)
qs = qs.prefetch_related(
Prefetch(
Expand All @@ -1519,7 +1523,12 @@ def get_queryset(self) -> QuerySet:
return qs

def get_taxa_observed(
self, qs: QuerySet, project: Project, include_unobserved=False, apply_default_filters=True
self,
qs: QuerySet,
project: Project,
include_unobserved=False,
apply_default_score_filter=True,
apply_default_taxa_filter=True,
) -> QuerySet:
"""
If a project is passed, only return taxa that have been observed.
Expand All @@ -1537,15 +1546,21 @@ def get_taxa_observed(
# Respects apply_defaults flag: build_occurrence_default_filters_q checks it internally
from ami.main.models_future.filters import build_occurrence_default_filters_q

default_filters_q = build_occurrence_default_filters_q(project, self.request, occurrence_accessor="")
default_filters_q = build_occurrence_default_filters_q(
project,
self.request,
occurrence_accessor="",
apply_default_score_filter=apply_default_score_filter,
apply_default_taxa_filter=apply_default_taxa_filter,
)

# Combine base occurrence filters with default filters
base_filter = models.Q(
occurrence_filters,
determination_id=models.OuterRef("id"),
)
if apply_default_filters:
base_filter = base_filter & default_filters_q

base_filter = base_filter & default_filters_q

# Count occurrences - uses composite index (determination_id, project_id, event_id, determination_score)
occurrences_count_subquery = models.Subquery(
Expand Down
119 changes: 102 additions & 17 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,41 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti
return collection


def get_project_default_filters():
"""
Read default taxa names from Django settings (read from environment variables)
and return corresponding Taxon objects.
"""
include_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_INCLUDE_TAXA))
exclude_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_EXCLUDE_TAXA))

return {"default_include_taxa": include_taxa, "default_exclude_taxa": exclude_taxa}


def get_or_create_default_project(user: User) -> "Project":
"""
Create a default project for a user.

Default related objects like devices and research sites will be created
when the project is saved for the first time.
If the project already exists, it will be returned without modification.
When a new project is created, default related objects (device, site,
deployment, collection, processing service) and default taxa filters are
initialized explicitly. ``get_or_create`` bypasses ``ProjectManager.create``,
so we call ``create_related_defaults`` here instead of relying on the manager.
"""
project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True)
logger.info(f"Created default project for user {user}")
project, created = Project.objects.get_or_create(name="Scratch Project", owner=user)
if created:
logger.info(f"Created default project for user {user}")
Project.objects.create_related_defaults(project)
defaults = get_project_default_filters()

if defaults["default_include_taxa"]:
project.default_filters_include_taxa.set(defaults["default_include_taxa"])
logger.info(f"Set {len(defaults['default_include_taxa'])} default include taxa for project {project}")
if defaults["default_exclude_taxa"]:
project.default_filters_exclude_taxa.set(defaults["default_exclude_taxa"])
logger.info(f"Set {len(defaults['default_exclude_taxa'])} default exclude taxa for project {project}")
project.save()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
logger.info(f"Loaded existing default project for user {user}")
return project


Expand Down Expand Up @@ -317,7 +342,7 @@ def summary_data(self):

def update_related_calculated_fields(self):
"""
Update calculated fields for all related events and deployments.
Update calculated fields for all related events, deployments, and source images.
"""
# Update events
for event in self.events.all():
Expand All @@ -327,6 +352,10 @@ def update_related_calculated_fields(self):
for deployment in self.deployments.all():
deployment.update_calculated_fields(save=True)

# Update source image cached detection counts using the project's default filters
# so SourceImage.detections_count stays consistent with get_detections_count().
update_detection_counts(qs=SourceImage.objects.filter(project=self), project=self)

def save(self, *args, **kwargs):
super().save(*args, **kwargs)
# Add owner to members
Expand Down Expand Up @@ -767,6 +796,23 @@ def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.dat
)
return (first, last)

def get_detections_count(self) -> int | None:
"""
Return detections count filtered by project default filters.

Excludes null-bbox placeholder detections (records indicating an image
was processed and no detections were found) to stay consistent with
``SourceImage.get_detections_count`` and ``Event.get_detections_count``.
"""
qs = Detection.objects.filter(source_image__deployment=self).exclude(NULL_DETECTIONS_FILTER)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def first_date(self) -> datetime.date | None:
return self.first_capture_timestamp.date() if self.first_capture_timestamp else None

Expand Down Expand Up @@ -999,7 +1045,7 @@ def update_calculated_fields(self, save=False):

self.events_count = self.events.count()
self.captures_count = self.data_source_total_files or self.captures.count()
self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count()
self.detections_count = self.get_detections_count()
occ_qs = self.occurrences.filter(event__isnull=False).apply_default_filters( # type: ignore
project=self.project,
request=None,
Expand Down Expand Up @@ -1164,7 +1210,20 @@ def get_captures_count(self) -> int:
return self.captures.distinct().count()

def get_detections_count(self) -> int | None:
return Detection.objects.filter(Q(source_image__event=self)).count()
"""
Return detections count filtered by project default filters.

Excludes null-bbox placeholder detections to stay consistent with
``SourceImage.get_detections_count`` and ``Deployment.get_detections_count``.
"""
qs = Detection.objects.filter(source_image__event=self).exclude(NULL_DETECTIONS_FILTER)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def get_occurrences_count(self, classification_threshold: float = 0) -> int:
"""
Expand Down Expand Up @@ -1889,9 +1948,23 @@ def size_display(self) -> str:
return filesizeformat(self.size)

def get_detections_count(self) -> int:
# Detections count excludes detections without bounding boxes
# Detections with null bounding boxes are valid and indicates the image was successfully processed
return self.detections.exclude(NULL_DETECTIONS_FILTER).count()
"""
Return detections count filtered by project default filters.

Excludes detections without bounding boxes — those are placeholder records
indicating the image was successfully processed and no detections were found.
"""
qs = self.detections.exclude(NULL_DETECTIONS_FILTER)
project = self.project
if not project:
return qs.distinct().count()

q = build_occurrence_default_filters_q(
project=project,
request=None,
occurrence_accessor="occurrence",
)
return qs.filter(q).distinct().count()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def get_was_processed(self, algorithm_key: str | None = None) -> bool:
"""
Expand Down Expand Up @@ -2069,22 +2142,34 @@ class Meta:
]


def update_detection_counts(qs: models.QuerySet[SourceImage] | None = None, null_only=False) -> int:
def update_detection_counts(
qs: models.QuerySet[SourceImage] | None = None,
null_only=False,
project: "Project | None" = None,
) -> int:
"""
Update the detection count for all source images using a bulk update query.

When ``project`` is provided, the count is filtered by that project's default
filters so the cached ``SourceImage.detections_count`` stays consistent with
``SourceImage.get_detections_count()``.

@TODO Needs testing.
"""
qs = qs or SourceImage.objects.all()
if null_only:
qs = qs.filter(detections_count__isnull=True)

detection_qs = Detection.objects.filter(source_image_id=models.OuterRef("pk")).exclude(NULL_DETECTIONS_FILTER)
if project is not None:
filter_q = build_occurrence_default_filters_q(
project=project,
request=None,
occurrence_accessor="occurrence",
)
detection_qs = detection_qs.filter(filter_q)
subquery = models.Subquery(
Detection.objects.filter(source_image_id=models.OuterRef("pk"))
.exclude(NULL_DETECTIONS_FILTER)
.values("source_image_id")
.annotate(count=models.Count("id"))
.values("count"),
detection_qs.values("source_image_id").annotate(count=models.Count("id")).values("count"),
output_field=models.IntegerField(),
)
start_time = time.time()
Expand Down
30 changes: 16 additions & 14 deletions ami/main/models_future/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def build_occurrence_default_filters_q(
project: "Project | None" = None,
request: "Request | None" = None,
occurrence_accessor: str = "",
apply_default_score_filter: bool = True,
apply_default_taxa_filter: bool = True,
) -> Q:
"""
Build a Q filter that applies default filters (score threshold + taxa) for Occurrence relationships.
Expand Down Expand Up @@ -194,19 +196,19 @@ def build_occurrence_default_filters_q(
return Q()

filter_q = Q()

# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)

# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q
if apply_default_score_filter:
# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)
if apply_default_taxa_filter:
# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q

return filter_q
89 changes: 88 additions & 1 deletion ami/main/signals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging

from django.contrib.auth.models import Group
from django.db import transaction
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver
from guardian.shortcuts import assign_perm

from ami.main.models import Project
from ami.main.tasks import refresh_project_cached_counts
from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project

from .models import Project, User
from .models import User

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,3 +113,87 @@ def delete_project_groups(sender, instance, **kwargs):
prefix = f"{instance.pk}_"
# Find and delete all groups that start with {project_id}_
Group.objects.filter(name__startswith=prefix).delete()


# ============================================================================
# Project Default Filters Update Signals
# ============================================================================
# These signals handle efficient updates to calculated fields for project-related
# objects (such as Deployments and Events) whenever a project's default filter
# values change.
#
# Specifically, they trigger recalculation of cached counts when:
# - The project's default score threshold is updated
# - The project's default include taxa are modified
# - The project's default exclude taxa are modified
#
# This ensures that cached counts (e.g., occurrences_count, taxa_count) remain
# accurate and consistent with the active filter configuration for each project.
# ============================================================================


def refresh_cached_counts_for_project(project: Project):
"""
Enqueue a Celery task to refresh cached counts for a project's Deployments
and Events after the surrounding transaction commits.

This fan-out can iterate hundreds of events and dozens of deployments, so
running it inline in the request/save path would block the caller. The
``transaction.on_commit`` wrapper guarantees the task only runs if the
triggering save succeeds.
"""
logger.info(f"Scheduling cached-count refresh for project {project.pk} ({project.name})")
transaction.on_commit(lambda: refresh_project_cached_counts.delay(project.pk))


@receiver(pre_save, sender=Project)
def cache_old_threshold(sender, instance, **kwargs):
"""
Cache the previous default score threshold before saving the Project.

We do this because:
- In post_save, the instance already contains the NEW value.
- To detect whether the threshold actually changed, we must read the OLD
value from the database before the update happens.
- This allows us to accurately detect threshold changes and then trigger
recalculation of cached filtered counts (Events, Deployments, etc.).

The cached value is stored on the instance as `_old_threshold` so it can be
safely accessed in the post_save handler.
"""
if instance.pk:
instance._old_threshold = Project.objects.get(pk=instance.pk).default_filters_score_threshold
else:
instance._old_threshold = None


@receiver(post_save, sender=Project)
def threshold_updated(sender, instance, **kwargs):
"""
After saving the Project, compare the previously cached threshold with the new value.
If the default score threshold changed, we refresh all cached counts using the new filters.

This two-step (pre_save + post_save) pattern is required because:
- post_save instances already contain the updated value
- so the old threshold would be lost without caching it in pre_save
"""
old_threshold = instance._old_threshold
new_threshold = instance.default_filters_score_threshold
if old_threshold is not None and old_threshold != new_threshold:
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_include_taxa.through)
def include_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when include taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Include taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_exclude_taxa.through)
def exclude_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when exclude taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)
Loading
Loading