diff --git a/alembic/versions/20260304_174059_38e60731b8fc_add_generic_task_models.py b/alembic/versions/20260304_174059_38e60731b8fc_add_generic_task_models.py new file mode 100644 index 00000000..793e43d0 --- /dev/null +++ b/alembic/versions/20260304_174059_38e60731b8fc_add_generic_task_models.py @@ -0,0 +1,370 @@ +"""Add generic task models + +Revision ID: 38e60731b8fc +Revises: 523e523531a7 +Create Date: 2026-03-04 17:40:59.423728 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_postgresql_enum import TableReference +from sqlalchemy.dialects import postgresql + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "38e60731b8fc" +down_revision: Union[str, None] = "523e523531a7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum( + "circuit_simulation__config_generation", + "circuit_simulation__execution", + "circuit_extraction__config_generation", + "circuit_extraction__execution", + "ion_channel_modeling__config_generation", + "ion_channel_modeling__execution", + "skeletonization__config_generation", + "skeletonization__execution", + "ion_channel_simulation__config_generation", + "ion_channel_simulation__execution", + "em_synapse_mapping__config_generation", + "em_synapse_mapping__execution", + name="taskactivitytype", + ).create(op.get_bind()) + sa.Enum( + "circuit_simulation__campaign", + "circuit_simulation__config", + "circuit_extraction__campaign", + "circuit_extraction__config", + "ion_channel_modeling__campaign", + "ion_channel_modeling__config", + "skeletonization__campaign", + "skeletonization__config", + "ion_channel_simulation__campaign", + "ion_channel_simulation__config", + "em_synapse_mapping__campaign", + "em_synapse_mapping__config", + name="taskconfigtype", + ).create(op.get_bind()) + op.create_table( + "task_activity", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "task_activity_type", + postgresql.ENUM( + "circuit_simulation__config_generation", + "circuit_simulation__execution", + "circuit_extraction__config_generation", + "circuit_extraction__execution", + "ion_channel_modeling__config_generation", + "ion_channel_modeling__execution", + "skeletonization__config_generation", + "skeletonization__execution", + "ion_channel_simulation__config_generation", + "ion_channel_simulation__execution", + "em_synapse_mapping__config_generation", + "em_synapse_mapping__execution", + name="taskactivitytype", + create_type=False, + ), + nullable=False, + ), + sa.Column( + "executor", + postgresql.ENUM( + "single_node_job", + "distributed_job", + "jupyter_notebook", + name="executortype", + create_type=False, + ), + nullable=True, + ), + sa.Column("execution_id", sa.Uuid(), nullable=True), + sa.ForeignKeyConstraint(["id"], ["activity.id"], name=op.f("fk_task_activity_id_activity")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_task_activity")), + ) + op.create_index( + op.f("ix_task_activity_task_activity_type"), + "task_activity", + ["task_activity_type"], + unique=False, + ) + op.create_table( + "task_config", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "task_config_type", + postgresql.ENUM( + "circuit_simulation__campaign", + "circuit_simulation__config", + "circuit_extraction__campaign", + "circuit_extraction__config", + "ion_channel_modeling__campaign", + "ion_channel_modeling__config", + "skeletonization__campaign", + "skeletonization__config", + "ion_channel_simulation__campaign", + "ion_channel_simulation__config", + "em_synapse_mapping__campaign", + "em_synapse_mapping__config", + name="taskconfigtype", + create_type=False, + ), + nullable=False, + ), + sa.Column( + "meta", postgresql.JSONB(astext_type=sa.Text()), server_default="{}", nullable=False + ), + sa.Column("task_config_generator_id", sa.Uuid(), nullable=True), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("description_vector", postgresql.TSVECTOR(), nullable=True), + sa.ForeignKeyConstraint(["id"], ["entity.id"], name=op.f("fk_task_config_id_entity")), + sa.ForeignKeyConstraint( + ["task_config_generator_id"], + ["task_config.id"], + name=op.f("fk_task_config_task_config_generator_id_task_config"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_task_config")), + ) + op.create_index( + "ix_task_config_description_vector", + "task_config", + ["description_vector"], + unique=False, + postgresql_using="gin", + ) + op.create_index(op.f("ix_task_config_name"), "task_config", ["name"], unique=False) + op.create_index( + op.f("ix_task_config_task_config_generator_id"), + "task_config", + ["task_config_generator_id"], + unique=False, + ) + op.create_index( + op.f("ix_task_config_task_config_type"), "task_config", ["task_config_type"], unique=False + ) + op.create_table( + "task_config__entity", + sa.Column("task_config_id", sa.Uuid(), nullable=False), + sa.Column("entity_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["entity_id"], + ["entity.id"], + name=op.f("fk_task_config__entity_entity_id_entity"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["task_config_id"], + ["task_config.id"], + name=op.f("fk_task_config__entity_task_config_id_task_config"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("task_config_id", "entity_id", name=op.f("pk_task_config__entity")), + ) + op.sync_enum_values( + enum_schema="public", + enum_name="activitytype", + new_values=[ + "simulation_execution", + "simulation_generation", + "validation", + "calibration", + "analysis_notebook_execution", + "ion_channel_modeling_execution", + "ion_channel_modeling_config_generation", + "circuit_extraction_config_generation", + "circuit_extraction_execution", + "skeletonization_execution", + "skeletonization_config_generation", + "task_activity", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="activity", column_name="type") + ], + enum_values_to_rename=[], + ) + op.sync_enum_values( + enum_schema="public", + enum_name="entitytype", + new_values=[ + "analysis_software_source_code", + "brain_atlas", + "brain_atlas_region", + "cell_composition", + "cell_morphology", + "cell_morphology_protocol", + "electrical_cell_recording", + "electrical_recording", + "electrical_recording_stimulus", + "emodel", + "experimental_bouton_density", + "experimental_neuron_density", + "experimental_synapses_per_connection", + "external_url", + "ion_channel_model", + "ion_channel_modeling_campaign", + "ion_channel_modeling_config", + "ion_channel_recording", + "memodel", + "memodel_calibration_result", + "me_type_density", + "simulation", + "simulation_campaign", + "simulation_result", + "scientific_artifact", + "single_neuron_simulation", + "single_neuron_synaptome", + "single_neuron_synaptome_simulation", + "subject", + "validation_result", + "circuit", + "circuit_extraction_campaign", + "circuit_extraction_config", + "em_dense_reconstruction_dataset", + "em_cell_mesh", + "analysis_notebook_template", + "analysis_notebook_environment", + "analysis_notebook_result", + "skeletonization_config", + "skeletonization_campaign", + "task_config", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="entity", column_name="type"), + TableReference( + table_schema="public", table_name="measurement_label", column_name="entity_type" + ), + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="entitytype", + new_values=[ + "analysis_software_source_code", + "brain_atlas", + "brain_atlas_region", + "cell_composition", + "cell_morphology", + "cell_morphology_protocol", + "electrical_cell_recording", + "electrical_recording", + "electrical_recording_stimulus", + "emodel", + "experimental_bouton_density", + "experimental_neuron_density", + "experimental_synapses_per_connection", + "external_url", + "ion_channel_model", + "ion_channel_modeling_campaign", + "ion_channel_modeling_config", + "ion_channel_recording", + "memodel", + "memodel_calibration_result", + "me_type_density", + "simulation", + "simulation_campaign", + "simulation_result", + "scientific_artifact", + "single_neuron_simulation", + "single_neuron_synaptome", + "single_neuron_synaptome_simulation", + "subject", + "validation_result", + "circuit", + "circuit_extraction_campaign", + "circuit_extraction_config", + "em_dense_reconstruction_dataset", + "em_cell_mesh", + "analysis_notebook_template", + "analysis_notebook_environment", + "analysis_notebook_result", + "skeletonization_config", + "skeletonization_campaign", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="entity", column_name="type"), + TableReference( + table_schema="public", table_name="measurement_label", column_name="entity_type" + ), + ], + enum_values_to_rename=[], + ) + op.sync_enum_values( + enum_schema="public", + enum_name="activitytype", + new_values=[ + "simulation_execution", + "simulation_generation", + "validation", + "calibration", + "analysis_notebook_execution", + "ion_channel_modeling_execution", + "ion_channel_modeling_config_generation", + "circuit_extraction_config_generation", + "circuit_extraction_execution", + "skeletonization_execution", + "skeletonization_config_generation", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="activity", column_name="type") + ], + enum_values_to_rename=[], + ) + op.drop_table("task_config__entity") + op.drop_index(op.f("ix_task_config_task_config_type"), table_name="task_config") + op.drop_index(op.f("ix_task_config_task_config_generator_id"), table_name="task_config") + op.drop_index(op.f("ix_task_config_name"), table_name="task_config") + op.drop_index( + "ix_task_config_description_vector", table_name="task_config", postgresql_using="gin" + ) + op.drop_table("task_config") + op.drop_index(op.f("ix_task_activity_task_activity_type"), table_name="task_activity") + op.drop_table("task_activity") + sa.Enum( + "circuit_simulation__campaign", + "circuit_simulation__config", + "circuit_extraction__campaign", + "circuit_extraction__config", + "ion_channel_modeling__campaign", + "ion_channel_modeling__config", + "skeletonization__campaign", + "skeletonization__config", + "ion_channel_simulation__campaign", + "ion_channel_simulation__config", + "em_synapse_mapping__campaign", + "em_synapse_mapping__config", + name="taskconfigtype", + ).drop(op.get_bind()) + sa.Enum( + "circuit_simulation__config_generation", + "circuit_simulation__execution", + "circuit_extraction__config_generation", + "circuit_extraction__execution", + "ion_channel_modeling__config_generation", + "ion_channel_modeling__execution", + "skeletonization__config_generation", + "skeletonization__execution", + "ion_channel_simulation__config_generation", + "ion_channel_simulation__execution", + "em_synapse_mapping__config_generation", + "em_synapse_mapping__execution", + name="taskactivitytype", + ).drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/alembic/versions/20260304_174105_1f5cf23383af_update_triggers.py b/alembic/versions/20260304_174105_1f5cf23383af_update_triggers.py new file mode 100644 index 00000000..1640a267 --- /dev/null +++ b/alembic/versions/20260304_174105_1f5cf23383af_update_triggers.py @@ -0,0 +1,85 @@ +"""Update triggers + +Revision ID: 1f5cf23383af +Revises: 38e60731b8fc +Create Date: 2026-03-04 17:41:05.591001 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_utils.pg_function import PGFunction +from sqlalchemy import text as sql_text +from alembic_utils.pg_trigger import PGTrigger +from sqlalchemy import text as sql_text + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "1f5cf23383af" +down_revision: Union[str, None] = "38e60731b8fc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + public_task_config_task_config_description_vector = PGTrigger( + schema="public", + signature="task_config_description_vector", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.create_entity(public_task_config_task_config_description_vector) + + public_auth_fnc_task_config_task_config_generator_id = PGFunction( + schema="public", + signature="auth_fnc_task_config_task_config_generator_id()", + definition="RETURNS TRIGGER AS $$\n BEGIN\n IF NEW.task_config_generator_id IS NULL THEN RETURN NEW; END IF;\n IF NOT EXISTS (\n SELECT 1 FROM entity e1\n JOIN entity e2 ON e2.id = NEW.id\n WHERE e1.id = NEW.task_config_generator_id\n AND (e1.authorized_public = TRUE\n OR (e2.authorized_public = FALSE\n AND e1.authorized_project_id = e2.authorized_project_id\n )\n )\n ) THEN\n RAISE EXCEPTION 'unauthorized private reference\\: task_config.task_config_generator_id'\n USING ERRCODE = '42501'; -- Insufficient Privilege\n END IF;\n RETURN NEW;\n END;\n $$ LANGUAGE plpgsql", + ) + op.create_entity(public_auth_fnc_task_config_task_config_generator_id) + + public_task_config_auth_trg_task_config_task_config_generator_id = PGTrigger( + schema="public", + signature="auth_trg_task_config_task_config_generator_id", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION auth_fnc_task_config_task_config_generator_id()", + ) + op.create_entity(public_task_config_auth_trg_task_config_task_config_generator_id) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + public_task_config_auth_trg_task_config_task_config_generator_id = PGTrigger( + schema="public", + signature="auth_trg_task_config_task_config_generator_id", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION auth_fnc_task_config_task_config_generator_id()", + ) + op.drop_entity(public_task_config_auth_trg_task_config_task_config_generator_id) + + public_auth_fnc_task_config_task_config_generator_id = PGFunction( + schema="public", + signature="auth_fnc_task_config_task_config_generator_id()", + definition="RETURNS TRIGGER AS $$\n BEGIN\n IF NEW.task_config_generator_id IS NULL THEN RETURN NEW; END IF;\n IF NOT EXISTS (\n SELECT 1 FROM entity e1\n JOIN entity e2 ON e2.id = NEW.id\n WHERE e1.id = NEW.task_config_generator_id\n AND (e1.authorized_public = TRUE\n OR (e2.authorized_public = FALSE\n AND e1.authorized_project_id = e2.authorized_project_id\n )\n )\n ) THEN\n RAISE EXCEPTION 'unauthorized private reference\\: task_config.task_config_generator_id'\n USING ERRCODE = '42501'; -- Insufficient Privilege\n END IF;\n RETURN NEW;\n END;\n $$ LANGUAGE plpgsql", + ) + op.drop_entity(public_auth_fnc_task_config_task_config_generator_id) + + public_task_config_task_config_description_vector = PGTrigger( + schema="public", + signature="task_config_description_vector", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.drop_entity(public_task_config_task_config_description_vector) + + # ### end Alembic commands ### diff --git a/app/db/auth.py b/app/db/auth.py index 352903ec..6971d3c1 100644 --- a/app/db/auth.py +++ b/app/db/auth.py @@ -4,11 +4,9 @@ from pydantic import UUID4 from sqlalchemy import Delete, Select, and_, false, not_, or_, select, true -from sqlalchemy.orm import Query, Session +from sqlalchemy.orm import Query -from app.config import settings -from app.db.model import Entity, Identifiable -from app.queries.utils import get_user +from app.db.model import Entity from app.schemas.auth import UserContext @@ -93,34 +91,3 @@ def select_unauthorized_entities(ids: list[UUID4], project_id: UUID4 | None) -> ), ) ) - - -def is_user_authorized_for_deletion( # noqa: PLR0911 - db: Session, user_context: UserContext, obj: Identifiable -) -> bool: - if settings.APP_DISABLE_AUTH: - return True - - # if there is no authorized_project_id it is a global resource - if not (project_id := getattr(obj, "authorized_project_id", None)): - return False - - # Service maintainers may delete public/private entities within their projects. - if user_context.is_service_maintainer: - return project_id in user_context.user_project_ids - - # from here and below public entities cannot be deleted - if obj.authorized_public: # pyright: ignore [reportAttributeAccessIssue] - return False - - # Project admins may delete private entities within their projects - if project_id in user_context.admin_project_ids: - return True - - # Project members may delete only the private entities they themselves created - if project_id in user_context.member_project_ids and ( - db_user := get_user(db, user_context.profile.subject) - ): - return db_user.created_by_id == obj.created_by_id - - return False diff --git a/app/db/model.py b/app/db/model.py index a1d7a9f0..384c0548 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -72,6 +72,8 @@ StainingType, StorageType, StructuralDomain, + TaskActivityType, + TaskConfigType, ValidationStatus, ) from app.schemas.publication import Author @@ -2249,4 +2251,70 @@ class SkeletonizationExecution(Activity, ExecutionActivityMixin): __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 +class TaskConfigToEntity(Base): + """Represents the many-to-many associations between task configs and entities used as input.""" + + __tablename__ = "task_config__entity" + + task_config_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey(f"{EntityType.task_config}.id", ondelete="CASCADE"), + primary_key=True, + ) + entity_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("entity.id", ondelete="CASCADE"), + primary_key=True, + ) + + +class TaskConfig(NameDescriptionVectorMixin, Entity): + """Represents the configuration of a generic task. + + Assets: + - task configuration file. + + Attributes: + id (uuid.UUID): Primary key referencing the entity ID. + task_config_type: Type of task config. + meta (JSON_DICT): Meta parameters for the task. + task_config_generator_id: id of the task that generated this task config. + inputs: entities used as input for the task. + """ + + __tablename__ = EntityType.task_config.value + + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + task_config_type: Mapped[TaskConfigType] = mapped_column(index=True) + meta: Mapped[JSON_DICT] = mapped_column(default={}, server_default="{}") + task_config_generator_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey(f"{EntityType.task_config}.id"), index=True + ) + inputs: Mapped[list[Entity]] = relationship( + primaryjoin="TaskConfig.id == TaskConfigToEntity.task_config_id", + secondary="task_config__entity", + ) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + +class TaskActivity(Activity, ExecutionActivityMixin): + """Represents a generic task activity. + + Inputs (used): + - TaskConfig (one) + Outputs (generated): + - Entity (many) + + Attributes: + id (uuid.UUID): Primary key referencing the activity ID. + task_activity_type: Type of task config. + """ + + __tablename__ = ActivityType.task_activity.value + + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("activity.id"), primary_key=True) + task_activity_type: Mapped[TaskActivityType] = mapped_column(index=True) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + register_model_events() diff --git a/app/db/triggers.py b/app/db/triggers.py index 452a417e..383aae37 100644 --- a/app/db/triggers.py +++ b/app/db/triggers.py @@ -31,6 +31,7 @@ SingleNeuronSynaptome, SingleNeuronSynaptomeSimulation, SkeletonizationConfig, + TaskConfig, ValidationResult, ) @@ -188,6 +189,7 @@ def unauthorized_private_reference_trigger(model: type[Entity], field_name: str) (IonChannelModelingConfig, "ion_channel_modeling_campaign_id"), (SkeletonizationConfig, "skeletonization_campaign_id"), (SkeletonizationConfig, "em_cell_mesh_id"), + (TaskConfig, "task_config_generator_id"), ] entities = [ diff --git a/app/db/types.py b/app/db/types.py index 1cd7a3f2..0ccd7ffe 100644 --- a/app/db/types.py +++ b/app/db/types.py @@ -141,6 +141,41 @@ class EntityType(StrEnum): analysis_notebook_result = auto() skeletonization_config = auto() skeletonization_campaign = auto() + task_config = auto() + + +class TaskConfigType(StrEnum): + """Task config types.""" + + circuit_simulation__campaign = auto() + circuit_simulation__config = auto() + circuit_extraction__campaign = auto() + circuit_extraction__config = auto() + ion_channel_modeling__campaign = auto() + ion_channel_modeling__config = auto() + skeletonization__campaign = auto() + skeletonization__config = auto() + ion_channel_simulation__campaign = auto() + ion_channel_simulation__config = auto() + em_synapse_mapping__campaign = auto() + em_synapse_mapping__config = auto() + + +class TaskActivityType(StrEnum): + """Task activity types.""" + + circuit_simulation__config_generation = auto() + circuit_simulation__execution = auto() + circuit_extraction__config_generation = auto() + circuit_extraction__execution = auto() + ion_channel_modeling__config_generation = auto() + ion_channel_modeling__execution = auto() + skeletonization__config_generation = auto() + skeletonization__execution = auto() + ion_channel_simulation__config_generation = auto() + ion_channel_simulation__execution = auto() + em_synapse_mapping__config_generation = auto() + em_synapse_mapping__execution = auto() class AgentType(StrEnum): @@ -197,6 +232,7 @@ class ActivityType(StrEnum): circuit_extraction_execution = auto() skeletonization_execution = auto() skeletonization_config_generation = auto() + task_activity = auto() class DerivationType(StrEnum): @@ -404,6 +440,7 @@ class AssetLabel(StrEnum): ion_channel_model_thumbnail = auto() circuit_extraction_config = auto() skeletonization_config = auto() + task_config = auto() class LabelRequirements(BaseModel): @@ -922,6 +959,15 @@ class LabelRequirements(BaseModel): ) ], }, + EntityType.task_config: { + AssetLabel.task_config: [ + LabelRequirements( + content_type=ContentType.json, + is_directory=False, + description="Generic task configuration.", + ) + ], + }, } ALLOWED_ASSET_LABELS_PER_ENTITY |= { k: None for k in EntityType if k not in ALLOWED_ASSET_LABELS_PER_ENTITY diff --git a/app/filters/task_activity.py b/app/filters/task_activity.py new file mode 100644 index 00000000..7067320b --- /dev/null +++ b/app/filters/task_activity.py @@ -0,0 +1,19 @@ +from typing import Annotated + +from app.db.model import TaskActivity +from app.db.types import TaskActivityType +from app.dependencies.filter import FilterDepends +from app.filters.activity import ActivityFilterMixin, ExecutionActivityFilterMixin +from app.filters.base import CustomFilter + + +class TaskActivityFilter(CustomFilter, ActivityFilterMixin, ExecutionActivityFilterMixin): + task_activity_type: TaskActivityType | None = None + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = TaskActivity + ordering_model_fields = ["creation_date", "update_date"] # noqa: RUF012 + + +TaskActivityFilterDep = Annotated[TaskActivityFilter, FilterDepends(TaskActivityFilter)] diff --git a/app/filters/task_config.py b/app/filters/task_config.py new file mode 100644 index 00000000..8456a340 --- /dev/null +++ b/app/filters/task_config.py @@ -0,0 +1,35 @@ +import uuid +from typing import Annotated + +from fastapi_filter import with_prefix + +from app.db.model import TaskConfig +from app.db.types import TaskConfigType +from app.dependencies.filter import FilterDepends +from app.filters.base import CustomFilter +from app.filters.common import IdFilterMixin, ILikeSearchFilterMixin, NameFilterMixin +from app.filters.entity import EntityFilterMixin + + +class TaskConfigFilterBase(NameFilterMixin, IdFilterMixin, CustomFilter): + task_config_type: TaskConfigType | None = None + + +class NestedTaskConfigFilter(TaskConfigFilterBase): + class Constants(CustomFilter.Constants): + model = TaskConfig + + +class TaskConfigFilter(EntityFilterMixin, TaskConfigFilterBase, ILikeSearchFilterMixin): + task_config_generator_id: uuid.UUID | None = None + task_config_generator_id__in: list[uuid.UUID] | None = None + + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = TaskConfig + ordering_model_fields = ["creation_date", "update_date", "name"] # noqa: RUF012 + + +TaskConfigFilterDep = Annotated[TaskConfigFilter, FilterDepends(TaskConfigFilter)] +NestedTaskConfigFilterDep = FilterDepends(with_prefix("task_config", NestedTaskConfigFilter)) diff --git a/app/queries/common.py b/app/queries/common.py index 7114ef02..8c6e9b07 100644 --- a/app/queries/common.py +++ b/app/queries/common.py @@ -2,7 +2,6 @@ from http import HTTPStatus import sqlalchemy as sa -from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session from sqlalchemy.sql import operators @@ -10,10 +9,8 @@ from app.db.auth import ( constrain_to_readable_entities, constrain_to_writable_entities, - is_user_authorized_for_deletion, - select_unauthorized_entities, ) -from app.db.model import Activity, Generation, Identifiable, Usage +from app.db.model import Activity, Identifiable from app.db.utils import get_declaring_class, load_db_model_from_pydantic, update_model from app.dependencies.common import ( FacetQueryParams, @@ -32,14 +29,22 @@ ) from app.filters.base import Aliases, CustomFilter from app.queries import crud +from app.queries.constants import NESTED_ACTIVITY_RELATIONSHIPS from app.queries.filter import filter_from_db -from app.queries.types import ApplyOperations, SupportsModelValidate -from app.queries.utils import get_or_create_user_agent +from app.queries.types import ( + ApplyOperations, + NestedRelationships, + SupportsModelValidate, +) +from app.queries.utils import ( + create_associations_to_entities, + get_or_create_user_agent, + is_user_authorized_for_deletion, +) from app.schemas.activity import ActivityCreate, ActivityUpdate from app.schemas.auth import UserContext, UserContextWithProjectId from app.schemas.routers import DeleteResponse from app.schemas.types import ListResponse, PaginationResponse -from app.schemas.utils import NOT_SET def router_read_one[T: BaseModel, I: Identifiable]( @@ -103,7 +108,7 @@ def router_create_activity_one[T: BaseModel, I: Activity]( created_by_id=created_by_id, updated_by_id=updated_by_id, authorized_project_id=project_id, - ignore_attributes={"used_ids", "generated_ids"}, + ignore_attributes=set(NESTED_ACTIVITY_RELATIONSHIPS), ) with ( @@ -117,30 +122,14 @@ def router_create_activity_one[T: BaseModel, I: Activity]( db.add(db_model_instance) db.flush() - if associated_ids := json_model.used_ids + json_model.generated_ids: - if ( - unaccessible_entities := db.execute( - select_unauthorized_entities(associated_ids, user_context.project_id) - ) - .scalars() - .all() - ): - raise HTTPException( - status_code=404, - detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", - ) - - for entity_id in json_model.used_ids: - db.add(Usage(usage_entity_id=entity_id, usage_activity_id=db_model_instance.id)) - - for entity_id in json_model.generated_ids: - db.add( - Generation( - generation_entity_id=entity_id, generation_activity_id=db_model_instance.id - ) - ) - - db.flush() + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=NESTED_ACTIVITY_RELATIONSHIPS, + project_id=db_model_instance.authorized_project_id, + action="create", + ) if apply_operations: q = sa.select(db_model_class).where(db_model_class.id == db_model_instance.id) @@ -160,6 +149,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( response_schema_class: SupportsModelValidate[T], apply_operations: ApplyOperations | None = None, embedding: list[float] | None = None, + nested_relationships: NestedRelationships | None = None, ) -> T: """Create a model in the database. @@ -171,6 +161,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( response_schema_class: Pydantic schema class for the returned data. apply_operations: transformer function that modifies the select query. embedding: optional embedding vector to attach to the model. + nested_relationships: mapping of nested relationships that can be set automatically. Returns: the written model data as a Pydantic model. @@ -192,6 +183,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( created_by_id=created_by_id, updated_by_id=updated_by_id, authorized_project_id=project_id, + ignore_attributes=set(nested_relationships) if nested_relationships else None, ) if embedding is not None and hasattr(db_model_instance, "embedding"): @@ -208,6 +200,16 @@ def router_create_one[T: BaseModel, I: Identifiable]( db.add(db_model_instance) db.flush() + if nested_relationships: + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=nested_relationships, + project_id=getattr(db_model_instance, "authorized_project_id", None), + action="create", + ) + if apply_operations: q = sa.select(db_model_class).where(db_model_class.id == db_model_instance.id) q = apply_operations(q) @@ -395,6 +397,7 @@ def router_update_one[T: BaseModel, I: Identifiable]( json_model: BaseModel, response_schema_class: SupportsModelValidate[T], apply_operations: ApplyOperations | None = None, + nested_relationships: NestedRelationships | None = None, ): query = ( sa.select(db_model_class).where(db_model_class.id == id_).with_for_update(of=db_model_class) @@ -407,14 +410,29 @@ def router_update_one[T: BaseModel, I: Identifiable]( query = apply_operations(query) with ensure_result(error_message=f"{db_model_class.__name__} not found"): - obj = db.execute(query).unique().scalar_one() + db_model_instance = db.execute(query).unique().scalar_one() - obj = update_model(model=obj, data=json_model.model_dump()) + db_model_instance = update_model( + model=db_model_instance, + data=json_model.model_dump( + exclude=set(nested_relationships) if nested_relationships else None + ), + ) db.flush() - db.refresh(obj) + db.refresh(db_model_instance) + + if nested_relationships: + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=nested_relationships, + project_id=getattr(db_model_instance, "authorized_project_id", None), + action="update", + ) - return response_schema_class.model_validate(obj) + return response_schema_class.model_validate(db_model_instance) def router_user_delete_one[T: BaseModel, I: Identifiable]( @@ -485,44 +503,28 @@ def router_update_activity_one[T: BaseModel, I: Activity]( query = apply_operations(query) with ensure_result(error_message=f"{db_model_class.__name__} not found"): - obj = db.execute(query).unique().scalar_one() + db_model_instance = db.execute(query).unique().scalar_one() update_data = json_model.model_dump( exclude_unset=True, exclude_none=True, - exclude={"used_ids", "generated_ids"}, + exclude=set(NESTED_ACTIVITY_RELATIONSHIPS), exclude_defaults=True, # ignore NOT_SET default values ) for key, value in update_data.items(): - setattr(obj, key, value) - - # ignore NOT_SET values - generated_ids = json_model.generated_ids if json_model.generated_ids != NOT_SET else [] - - if generated_ids: - if obj.generated: - raise HTTPException( - status_code=404, - detail="It is forbidden to update generated_ids if they exist.", - ) - - if user_context and ( - unaccessible_entities := db.execute( - select_unauthorized_entities(generated_ids, user_context.project_id) - ) - .scalars() - .all() - ): - raise HTTPException( - status_code=404, - detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", - ) - - for entity_id in generated_ids: - db.add(Generation(generation_entity_id=entity_id, generation_activity_id=obj.id)) + setattr(db_model_instance, key, value) + + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=NESTED_ACTIVITY_RELATIONSHIPS, + project_id=db_model_instance.authorized_project_id, + action="update", + ) db.flush() - db.refresh(obj) + db.refresh(db_model_instance) - return response_schema_class.model_validate(obj) + return response_schema_class.model_validate(db_model_instance) diff --git a/app/queries/constants.py b/app/queries/constants.py new file mode 100644 index 00000000..a4fa7e59 --- /dev/null +++ b/app/queries/constants.py @@ -0,0 +1,30 @@ +from app.db.model import Generation, TaskConfigToEntity, Usage +from app.queries.types import NestedRelationships + +NESTED_ACTIVITY_RELATIONSHIPS: NestedRelationships = { + "used_ids": { + "relationship_name": "used", + "db_model_factory": lambda *, left_id, right_id: Usage( + usage_activity_id=left_id, usage_entity_id=right_id + ), + "nested_id_getter": lambda *, items: items, # used_ids is already the list of ids + }, + "generated_ids": { + "relationship_name": "generated", + "db_model_factory": lambda *, left_id, right_id: Generation( + generation_activity_id=left_id, generation_entity_id=right_id + ), + "nested_id_getter": lambda *, items: items, # generated_ids is already the list of ids + }, +} + +NESTED_TASK_CONFIG_RELATIONSHIPS: NestedRelationships = { + "inputs": { + "relationship_name": "inputs", + "db_model_factory": lambda *, left_id, right_id: TaskConfigToEntity( + task_config_id=left_id, entity_id=right_id + ), + # inputs contains a list of NestedEntityCreate validated from dicts {"id": } + "nested_id_getter": lambda *, items: [item.id for item in items], + }, +} diff --git a/app/queries/factory.py b/app/queries/factory.py index 923b6dac..680c56c6 100644 --- a/app/queries/factory.py +++ b/app/queries/factory.py @@ -33,6 +33,7 @@ Species, Strain, Subject, + TaskConfig, Usage, ) from app.dependencies.common import FacetQueryParams @@ -92,6 +93,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: ion_channel_model_alias = _get_alias(IonChannelModel, "ion_channel_model") ion_channel_modeling_config_alias = _get_alias(IonChannelModelingConfig) skeletonization_config_alias = _get_alias(SkeletonizationConfig) + task_config_alias = _get_alias(TaskConfig) name_to_facet_query_params: dict[str, FacetQueryParams] = { "agent": { @@ -153,6 +155,10 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: "id": skeletonization_config_alias.id, "label": skeletonization_config_alias.name, }, + "task_config": { + "id": task_config_alias.id, + "label": task_config_alias.name, + }, } filter_joins = { "species": lambda q: q.join(Species, db_model_class.species_id == Species.id), @@ -268,6 +274,10 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: skeletonization_config_alias, db_model_class.id == skeletonization_config_alias.skeletonization_campaign_id, ), + "task_config": lambda q: q.join( + task_config_alias, + db_model_class.id == task_config_alias.task_config_generator_id, + ), } name_to_facet_query_params = {k: name_to_facet_query_params[k] for k in facet_keys} filter_joins = {k: filter_joins[k] for k in filter_keys} diff --git a/app/queries/types.py b/app/queries/types.py index 902ed202..03ebb416 100644 --- a/app/queries/types.py +++ b/app/queries/types.py @@ -1,5 +1,6 @@ +import uuid from collections.abc import Callable -from typing import Any, Protocol +from typing import Any, Protocol, TypedDict import sqlalchemy as sa from pydantic import BaseModel @@ -11,3 +12,29 @@ class SupportsModelValidate[T: BaseModel](Protocol): @classmethod def model_validate(cls, obj: Any, *args, **kwargs) -> T: ... + + +class AssociationCallable(Protocol): + """Callable that should accept left_id and right_id and return a valid db model instance.""" + + def __call__(self, *, left_id: uuid.UUID, right_id: uuid.UUID) -> DeclarativeBase: ... + + +class NestedIdGetter(Protocol): + """Callable that should return the list of ids from the json model.""" + + def __call__(self, *, items: list) -> list[uuid.UUID]: ... + + +class NestedRelationship(TypedDict): + """Nested relationship dict, used for creating relationships in entities and activities.""" + + relationship_name: str # name of the relationship in the db model + db_model_factory: AssociationCallable # callable that should return a new db model instance + nested_id_getter: NestedIdGetter # callable that should return the list of ids from json + + +# mapping relationship_key -> relationship, where: +# - relationship_key is the key in the Create schema of the resource +# - relationship is a dict of type NestedRelationship +NestedRelationships = dict[str, NestedRelationship] diff --git a/app/queries/utils.py b/app/queries/utils.py index 842799e7..d43a0757 100644 --- a/app/queries/utils.py +++ b/app/queries/utils.py @@ -1,10 +1,18 @@ import uuid +from itertools import chain +from typing import Literal, cast import sqlalchemy as sa +from fastapi import HTTPException +from pydantic import BaseModel from sqlalchemy.orm import Session -from app.db.model import Person -from app.schemas.auth import UserProfile +from app.config import settings +from app.db.auth import select_unauthorized_entities +from app.db.model import Identifiable, Person +from app.queries.types import NestedRelationships +from app.schemas.auth import UserContext, UserProfile +from app.schemas.utils import NOT_SET from app.utils.uuid import create_uuid @@ -33,3 +41,103 @@ def get_or_create_user_agent(db: Session, user_profile: UserProfile) -> Person: def get_user(db: Session, subject_id: uuid.UUID) -> Person | None: query = sa.select(Person).where(Person.sub_id == subject_id) return db.execute(query).scalars().first() + + +def is_user_authorized_for_deletion( # noqa: PLR0911 + db: Session, user_context: UserContext, obj: Identifiable +) -> bool: + if settings.APP_DISABLE_AUTH: + return True + + # if there is no authorized_project_id it is a global resource + if not (project_id := getattr(obj, "authorized_project_id", None)): + return False + + # Service maintainers may delete public/private entities within their projects. + if user_context.is_service_maintainer: + return project_id in user_context.user_project_ids + + # from here and below public entities cannot be deleted + if obj.authorized_public: # pyright: ignore [reportAttributeAccessIssue] + return False + + # Project admins may delete private entities within their projects + if project_id in user_context.admin_project_ids: + return True + + # Project members may delete only the private entities they themselves created + if project_id in user_context.member_project_ids and ( + db_user := get_user(db, user_context.profile.subject) + ): + return db_user.created_by_id == obj.created_by_id + + return False + + +def create_associations_to_entities( + db: Session, + *, + left: Identifiable, + json_model: BaseModel, + nested_relationships: NestedRelationships, + project_id: uuid.UUID | None, + action: Literal["create", "update"], +) -> None: + """Create association records between the left Identifiable and each of the right_id passed. + + Args: + db: Database session. + left: Identifiable on the left side of the associations. + json_model: Pydantic model of the left resourece. + nested_relationships: Mapping of relationship keys to relationship dicts. + project_id: Optional project ID for authorization checks. + action: create or update the relationships. + + Raises: + HTTPException: If any of the associated entities are not public or not in the same project, + or if trying to update associations when it's not allowed. + """ + # map relationship keys to lists of entity IDs to associate, ignoring NOT_SET values + nested_relationship_ids: dict[str, list[uuid.UUID]] = cast( + "dict[str, list[uuid.UUID]]", + { + relationship_key: relationship["nested_id_getter"](items=items) # type: ignore[misc] + for relationship_key, relationship in nested_relationships.items() + if (items := getattr(json_model, relationship_key, NOT_SET)) != NOT_SET + }, + ) + + associated_ids = list(chain.from_iterable(nested_relationship_ids.values())) + + # skip if all the nested_relationship_ids are empty + if not associated_ids: + return + + # the associated entities should be public, or in the same given project + if ( + unaccessible_entities := db.execute( + select_unauthorized_entities(associated_ids, project_id) + ) + .scalars() + .all() + ): + raise HTTPException( + status_code=404, + detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", + ) + + for relationship_key, relationship in nested_relationships.items(): + # ignore empty ids + if not (right_ids := nested_relationship_ids.get(relationship_key)): + continue + if action == "update" and getattr(left, relationship["relationship_name"]): + raise HTTPException( + status_code=409, + detail=f"It is forbidden to update {relationship_key} if they exist.", + ) + factory = relationship["db_model_factory"] + for right_id in right_ids: + db_instance = factory(left_id=left.id, right_id=right_id) + db.add(db_instance) + + db.flush() diff --git a/app/routers/__init__.py b/app/routers/__init__.py index a62c2b03..fd8f696f 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -74,6 +74,8 @@ species, strain, subject, + task_activity, + task_config, validation, validation_result, ) @@ -157,6 +159,8 @@ species.router, strain.router, subject.router, + task_activity.router, + task_config.router, validation.router, validation_result.router, ] diff --git a/app/routers/task_activity.py b/app/routers/task_activity.py new file mode 100644 index 00000000..acbbdb9b --- /dev/null +++ b/app/routers/task_activity.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import task_activity as service + +ROUTE = "task-activity" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/routers/task_config.py b/app/routers/task_config.py new file mode 100644 index 00000000..e70f9239 --- /dev/null +++ b/app/routers/task_config.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import task_config as service + +ROUTE = "task-config" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/schemas/asset.py b/app/schemas/asset.py index 7954b762..c40f8ef1 100644 --- a/app/schemas/asset.py +++ b/app/schemas/asset.py @@ -14,6 +14,7 @@ AssetStatus, ContentType, EntityType, + LabelRequirements, StorageType, ) @@ -96,7 +97,7 @@ class AssetReadWithUploadMeta(AssetRead): upload_meta: UploadMetaRead | None = None -def _raise_on_label_requirement(asset, label_reqs): +def _raise_on_label_requirement(asset: AssetBase, label_reqs: list[LabelRequirements]) -> None: content_type_success = [ label_req.content_type == asset.content_type for label_req in label_reqs diff --git a/app/schemas/entity.py b/app/schemas/entity.py index 02a4383e..39997c1f 100644 --- a/app/schemas/entity.py +++ b/app/schemas/entity.py @@ -7,7 +7,16 @@ from app.schemas.agent import CreatedByUpdatedByMixin +class NestedEntityCreate(BaseModel): + """Entity model to be used for bare nested entities in create endpoints.""" + + model_config = ConfigDict(from_attributes=True) + id: uuid.UUID + + class NestedEntityRead(BaseModel): + """Entity model to be used for bare nested entities in read endpoints.""" + model_config = ConfigDict(from_attributes=True) id: uuid.UUID type: str @@ -16,8 +25,8 @@ class NestedEntityRead(BaseModel): class EntityRead(NestedEntityRead, CreatedByUpdatedByMixin): - pass + """Entity model that includes created_by and updated_by information.""" class EntityCountRead(RootModel[dict[str, int]]): - pass + """Entity count model that contains the number of entities by type.""" diff --git a/app/schemas/task_activity.py b/app/schemas/task_activity.py new file mode 100644 index 00000000..fdfe748d --- /dev/null +++ b/app/schemas/task_activity.py @@ -0,0 +1,31 @@ +from app.db.types import TaskActivityType +from app.schemas.activity import ( + ActivityCreate, + ActivityRead, + ActivityUpdate, + ExecutionActivityMixin, +) +from app.schemas.utils import make_update_schema + + +class TaskActivityBase(ExecutionActivityMixin): + task_activity_type: TaskActivityType | None = None + + +class TaskActivityCreate(ActivityCreate, TaskActivityBase): + pass + + +class TaskActivityRead(ActivityRead, TaskActivityBase): + pass + + +class TaskActivityUserUpdate(ActivityUpdate, TaskActivityBase): + pass + + +TaskActivityAdminUpdate = make_update_schema( + TaskActivityCreate, + "TaskActivityAdminUpdate", + excluded_fields={"task_activity_type"}, +) # pyright : ignore [reportInvalidTypeForm] diff --git a/app/schemas/task_config.py b/app/schemas/task_config.py new file mode 100644 index 00000000..8f904cd4 --- /dev/null +++ b/app/schemas/task_config.py @@ -0,0 +1,63 @@ +import uuid +from typing import Annotated, Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db.types import TaskConfigType +from app.schemas.agent import CreatedByUpdatedByMixin +from app.schemas.asset import AssetsMixin +from app.schemas.base import ( + AuthorizationMixin, + AuthorizationOptionalPublicMixin, + CreationMixin, + EntityTypeMixin, + IdentifiableMixin, + NameDescriptionMixin, +) +from app.schemas.contribution import ContributionReadWithoutEntityMixin +from app.schemas.entity import NestedEntityCreate, NestedEntityRead +from app.schemas.utils import make_update_schema + + +class TaskConfigBase(BaseModel, NameDescriptionMixin): + model_config = ConfigDict(from_attributes=True) + task_config_type: TaskConfigType + meta: dict[str, Any] + task_config_generator_id: uuid.UUID | None = None + + +class TaskConfigCreate(TaskConfigBase, AuthorizationOptionalPublicMixin): + inputs: Annotated[ + list[NestedEntityCreate], + Field(description="List of input entities (only ids)."), + ] = [] + + +TaskConfigUserUpdate = make_update_schema( + TaskConfigCreate, + "TaskConfigUserUpdate", +) # pyright: ignore [reportInvalidTypeForm] + +TaskConfigAdminUpdate = make_update_schema( + TaskConfigCreate, + "TaskConfigAdminUpdate", + excluded_fields=set(), +) # pyright : ignore [reportInvalidTypeForm] + + +class NestedTaskConfigRead(TaskConfigBase, EntityTypeMixin, IdentifiableMixin): + pass + + +class TaskConfigRead( + NestedTaskConfigRead, + AssetsMixin, + CreatedByUpdatedByMixin, + CreationMixin, + AuthorizationMixin, + ContributionReadWithoutEntityMixin, +): + inputs: Annotated[ + list[NestedEntityRead], + Field(description="List of input entities."), + ] = [] diff --git a/app/service/asset.py b/app/service/asset.py index 5515deab..13b48351 100644 --- a/app/service/asset.py +++ b/app/service/asset.py @@ -8,7 +8,6 @@ from types_boto3_s3 import S3Client from app.config import StorageUnion, storages -from app.db.auth import is_user_authorized_for_deletion from app.db.model import Asset, Entity from app.db.types import AssetLabel, AssetStatus, ContentType, EntityType, StorageType from app.dependencies.common import PaginationQuery @@ -16,6 +15,7 @@ from app.filters.asset import AssetFilterDep from app.queries import crud from app.queries.common import get_or_create_user_agent, router_read_many +from app.queries.utils import is_user_authorized_for_deletion from app.repository.group import RepositoryGroup from app.schemas.asset import ( AssetCreate, diff --git a/app/service/task_activity.py b/app/service/task_activity.py new file mode 100644 index 00000000..8a041ac2 --- /dev/null +++ b/app/service/task_activity.py @@ -0,0 +1,196 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload + +from app.db.model import Entity, Person, TaskActivity +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.task_activity import TaskActivityFilterDep +from app.queries.common import ( + router_create_activity_one, + router_read_many, + router_read_one, + router_update_activity_one, + router_user_delete_one, +) +from app.queries.factory import query_params_factory +from app.schemas.routers import DeleteResponse +from app.schemas.task_activity import ( + TaskActivityAdminUpdate, + TaskActivityCreate, + TaskActivityRead, + TaskActivityUserUpdate, +) +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + + +DBModel = TaskActivity +ReadSchema = TaskActivityRead +CreateSchema = TaskActivityCreate +UserUpdateSchema = TaskActivityUserUpdate +AdminUpdateSchema = TaskActivityAdminUpdate +FilterDep = TaskActivityFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.used), + joinedload(DBModel.generated), + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_activity_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + used_alias = aliased(Entity, flat=True) + generated_alias = aliased(Entity, flat=True) + + aliases: Aliases = { + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + Entity: { + "used": used_alias, + "generated": generated_alias, + }, + } + facet_keys = [] + filter_keys = [ + "created_by", + "updated_by", + "used", + "generated", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) + + +def update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] + user_context: UserContextDep, +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=None, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) diff --git a/app/service/task_config.py b/app/service/task_config.py new file mode 100644 index 00000000..76c45fbc --- /dev/null +++ b/app/service/task_config.py @@ -0,0 +1,200 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload, selectinload + +from app.db.model import ( + Agent, + Person, + TaskConfig, +) +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.task_config import TaskConfigFilterDep +from app.queries.common import ( + router_create_one, + router_read_many, + router_read_one, + router_update_one, + router_user_delete_one, +) +from app.queries.constants import NESTED_TASK_CONFIG_RELATIONSHIPS +from app.queries.factory import query_params_factory +from app.schemas.routers import DeleteResponse +from app.schemas.task_config import ( + TaskConfigAdminUpdate, + TaskConfigCreate, + TaskConfigRead, + TaskConfigUserUpdate, +) +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + +DBModel = TaskConfig +ReadSchema = TaskConfigRead +CreateSchema = TaskConfigCreate +UserUpdateSchema = TaskConfigUserUpdate +AdminUpdateSchema = TaskConfigAdminUpdate +FilterDep = TaskConfigFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + selectinload(DBModel.assets), + selectinload(DBModel.contributions), + selectinload(DBModel.inputs), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def update_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=None, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + agent_alias = aliased(Agent, flat=True) + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + + aliases: Aliases = { + Agent: { + "contribution": agent_alias, + }, + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + } + facet_keys = filter_keys = [ + "created_by", + "updated_by", + "contribution", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) diff --git a/docs/asset-labels.md b/docs/asset-labels.md index 5c0014af..5bcbd1be 100644 --- a/docs/asset-labels.md +++ b/docs/asset-labels.md @@ -61,3 +61,4 @@ | circuit_extraction_config | circuit_extraction_config | application/json | .json | Single circuit extraction configuration. | | skeletonization_campaign | campaign_generation_config | application/json | .json | Skeletonization campaign configuration. | | skeletonization_config | skeletonization_config | application/json | .json | Single skeletonization configuration. | +| task_config | task_config | application/json | .json | Generic task configuration. | diff --git a/docs/task-models-diagram.md b/docs/task-models-diagram.md new file mode 100644 index 00000000..64261463 --- /dev/null +++ b/docs/task-models-diagram.md @@ -0,0 +1,126 @@ +# Task Models Relationship Diagram + +```mermaid +flowchart TD + subgraph Inputs["Input Entities"] + EIn1[Entity - Campaign Inputs] + EIn2[Entity - Config Inputs] + end + + subgraph Task_Layer["Task Layer"] + TC1[TaskConfig - circuit_simulation__campaign] + TC2[TaskConfig - circuit_simulation__config] + end + + subgraph Activities + TA1[TaskActivity - circuit_simulation__config_generation] + TA2[TaskActivity - circuit_simulation__execution] + end + + subgraph Output["Output Entities"] + EOut[Entity - Output] + end + + EIn1 -->|TaskConfigToEntity| TC1 + EIn2 -->|TaskConfigToEntity| TC2 + TC1 -->|task_config_generator_id FK| TC2 + + TC1 -->|Usage| TA1 + TA1 -->|Generation| TC2 + + TC2 -->|Usage| TA2 + TA2 -->|Generation| EOut + + style TC1 fill:#e1f5ff + style TC2 fill:#e1f5ff + style EIn1 fill:#d4edda + style EIn2 fill:#d4edda + style EOut fill:#f8d7da + style TA1 fill:#fff4e1 + style TA2 fill:#fff4e1 +``` + +## ER Diagram + +```mermaid +erDiagram + Campaign_Entity_Input ||--o{ TaskConfigToEntity_Campaign : "inputs" + TaskConfig_Campaign ||--o{ TaskConfigToEntity_Campaign : "has" + TaskConfig_Campaign ||--o{ TaskConfig_Config : "task_config_generator_id" + + Config_Entity_Inputs ||--o{ TaskConfigToEntity_Config : "inputs" + TaskConfig_Config ||--o{ TaskConfigToEntity_Config : "has" + + TaskConfig_Campaign ||--o{ Usage_CG : "used by config_generation" + Usage_CG }o--|| TaskActivity_ConfigGeneration : "" + TaskActivity_ConfigGeneration ||--o{ Generation_CG : "" + Generation_CG }o--|| TaskConfig_Config : "generated by config_generation" + + TaskConfig_Config ||--o{ Usage_TE : "used by execution" + Usage_TE }o--|| TaskActivity_Execution : "" + TaskActivity_Execution ||--o{ Generation_TE : "" + Generation_TE }o--|| Entity_Output : "generated by execution" + + Entity { + uuid id PK + } + + TaskConfigToEntity { + uuid entity_id PK,FK + uuid task_config_id PK,FK + } + + TaskConfig { + uuid id PK,FK + JSON_DICT scan_parameters + TaskConfigType task_config_type + uuid task_config_generator_id FK + } + + Usage { + uuid usage_entity_id PK,FK + uuid usage_activity_id PK,FK + } + + Generation { + uuid generation_entity_id PK,FK + uuid generation_activity_id PK,FK + } + + TaskActivity { + uuid id PK,FK + TaskActivityType task_activity_type + } +``` + +## Relationships Explained + +### Flowchart (Primary) +Shows the workflow clearly: +- **Green boxes**: Input entities + - Campaign inputs (linked via TaskConfigToEntity) + - Config inputs (linked via TaskConfigToEntity) +- **Blue boxes**: TaskConfig with different types + - circuit_simulation__campaign + - circuit_simulation__config +- **Yellow boxes**: TaskActivity (processes) + - circuit_simulation__config_generation + - circuit_simulation__execution +- **Red box**: Output entities (generated by execution) +- Input and output entities are different instances, though all are Entity type + +### ER Diagram (Detailed) +Shows the same relationships with Usage and Generation tables split by activity: +- **TaskConfigToEntity**: Junction table linking input entities to TaskConfig (both campaign and config types) +- **Usage_CG**: Usage records for config_generation activity +- **Generation_CG**: Generation records for config_generation activity +- **Usage_TE**: Usage records for execution activity +- **Generation_TE**: Generation records for execution activity + +Notes: +- TaskConfig[circuit_simulation__campaign] has many input entities (via TaskConfigToEntity) +- TaskConfig[circuit_simulation__config] has many input entities (via TaskConfigToEntity) +- One TaskConfig[campaign] is used by TaskActivity[config_generation] to generate many TaskConfig[config] +- One TaskConfig[config] can be used by many TaskActivity[execution], each generating many Entity +- TaskActivity has task_activity_type field (enum: circuit_simulation__config_generation, circuit_simulation__execution, etc.) +- TaskConfig has task_config_type field (enum: circuit_simulation__campaign, circuit_simulation__config, etc.) diff --git a/scripts/export/build_database_archive.sh b/scripts/export/build_database_archive.sh index 5c92d897..97e7c805 100755 --- a/scripts/export/build_database_archive.sh +++ b/scripts/export/build_database_archive.sh @@ -2,7 +2,7 @@ # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="523e523531a7" +SCRIPT_DB_VERSION="1f5cf23383af" echo "DB dump (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" @@ -245,6 +245,12 @@ SET TRANSACTION READ ONLY; \copy (SELECT t0.* FROM strain AS t0 WHERE TRUE) TO '$DATA_DIR/strain.csv' WITH CSV HEADER; \echo Dumping table subject \copy (SELECT t0.* FROM subject AS t0 JOIN entity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/subject.csv' WITH CSV HEADER; +\echo Dumping table task_activity +\copy (SELECT t0.* FROM task_activity AS t0 JOIN activity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/task_activity.csv' WITH CSV HEADER; +\echo Dumping table task_config +\copy (SELECT t0.* FROM task_config AS t0 JOIN entity AS t1 ON t1.id=t0.id LEFT JOIN entity AS t2 ON t2.id=t0.task_config_generator_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/task_config.csv' WITH CSV HEADER; +\echo Dumping table task_config__entity +\copy (SELECT t0.* FROM task_config__entity AS t0 JOIN entity AS t1 ON t1.id=t0.task_config_id JOIN entity AS t2 ON t2.id=t0.entity_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/task_config__entity.csv' WITH CSV HEADER; \echo Dumping table usage \copy (SELECT t0.* FROM usage AS t0 JOIN entity AS t1 ON t1.id=t0.usage_entity_id JOIN activity AS t2 ON t2.id=t0.usage_activity_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/usage.csv' WITH CSV HEADER; \echo Dumping table validation @@ -263,7 +269,7 @@ install -m 755 /dev/stdin "$WORK_DIR/load.sh" <<'EOF_LOAD_SCRIPT' # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="523e523531a7" +SCRIPT_DB_VERSION="1f5cf23383af" echo "DB load (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" diff --git a/scripts/extract_traces/run.py b/scripts/extract_traces/run.py index fcc261b6..37ef8ee6 100644 --- a/scripts/extract_traces/run.py +++ b/scripts/extract_traces/run.py @@ -110,6 +110,7 @@ def extract(source: Path, component: str | None, output: Path) -> None: head == "admin" or "/assets" in tail or "/regions" in tail + or "/test-authenticated-endpoint" in tail or tail.endswith(("counts", "hierarchy", "derived-from")) ): # ignore undesired endpoints diff --git a/tests/conftest.py b/tests/conftest.py index b7b7eb4f..879995e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1919,3 +1919,74 @@ def skeletonization_config_id(client, skeletonization_config_json_data): json=skeletonization_config_json_data, ).json() return data["id"] + + +@pytest.fixture +def campaign_json_data(): + return { + "name": "campaign", + "description": "campaign-description", + "meta": {"foo": "bar"}, + "task_config_type": "skeletonization__campaign", + } + + +@pytest.fixture +def campaign_with_nested_relationships_json_data(campaign_json_data, em_cell_mesh): + return campaign_json_data | { + "inputs": [{"id": str(em_cell_mesh.id)}], + } + + +@pytest.fixture +def campaign_id(client, campaign_json_data): + data = assert_request( + client.post, + url="/task-config", + json=campaign_json_data | {"authorized_public": False}, + ).json() + return data["id"] + + +@pytest.fixture +def public_campaign_id(client, campaign_json_data): + data = assert_request( + client.post, + url="/task-config", + json=campaign_json_data | {"authorized_public": True}, + ).json() + return data["id"] + + +@pytest.fixture +def task_config_json_data(): + return { + "name": "task-config", + "description": "task-config-description", + "meta": {"foo": "bar"}, + "task_config_type": "skeletonization__config", + } + + +@pytest.fixture +def task_config_with_parent_json_data(task_config_json_data, public_campaign_id): + return task_config_json_data | { + "task_config_generator_id": public_campaign_id, + } + + +@pytest.fixture +def task_config_with_nested_relationships_json_data(task_config_json_data, em_cell_mesh): + return task_config_json_data | { + "inputs": [{"id": str(em_cell_mesh.id)}], + } + + +@pytest.fixture +def task_config_id(client, task_config_json_data): + data = assert_request( + client.post, + url="/task-config", + json=task_config_json_data, + ).json() + return data["id"] diff --git a/tests/test_task_activity.py b/tests/test_task_activity.py new file mode 100644 index 00000000..26f0691e --- /dev/null +++ b/tests/test_task_activity.py @@ -0,0 +1,350 @@ +from datetime import UTC, datetime + +import pytest +from pydantic import TypeAdapter + +from app.db.model import ( + Generation, + TaskActivity, + TaskConfig, + Usage, +) +from app.db.types import ActivityType, ExecutorType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + assert_request, + check_activity_create_one__unauthorized_entities, + check_activity_delete_one, + check_activity_update_one, + check_activity_update_one__fail_if_generated_ids_exists, + check_activity_update_one__fail_if_generated_ids_unauthorized, + check_creation_fields, + check_missing, + check_pagination, + create_cell_morphology_id, +) + +DateTimeAdapter = TypeAdapter(datetime) + +ROUTE = "task-activity" +ADMIN_ROUTE = "/admin/task-activity" + + +@pytest.fixture +def json_data(task_config_id, morphology_id): + return { + "task_activity_type": "skeletonization__execution", + "start_time": str(datetime.now(UTC)), + "end_time": str(datetime.now(UTC)), + "used_ids": [str(task_config_id)], + "generated_ids": [str(morphology_id)], + "status": "done", + "executor": str(ExecutorType.single_node_job), + "execution_id": "1739b817-26bb-4dad-93f4-0279a1b2cf6e", + } + + +@pytest.fixture +def create_id(client, json_data): + def _create_id(**kwargs): + return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json()["id"] + + return _create_id + + +@pytest.fixture +def model_id(create_id): + return create_id() + + +def _assert_read_response(data, json_data, *, empty_ids=False): + assert "id" in data + + if empty_ids: + assert len(data["used"]) == 0 + assert len(data["generated"]) == 0 + else: + assert data["used"] == [ + { + "id": json_data["used_ids"][0], + "type": "task_config", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + assert data["generated"] == [ + { + "id": json_data["generated_ids"][0], + "type": "cell_morphology", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + check_creation_fields(data) + assert DateTimeAdapter.validate_python(data["start_time"]) == DateTimeAdapter.validate_python( + json_data["start_time"] + ) + assert DateTimeAdapter.validate_python(data["end_time"]) == DateTimeAdapter.validate_python( + json_data["end_time"] + ) + assert data["type"] == ActivityType.task_activity + + +def test_create_one(clients, json_data): + data = assert_request(clients.user_1.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_read_one(clients, json_data, model_id): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_create_one__empty_ids(client, client_admin, json_data): + json_data = {k: v for k, v in json_data.items() if k not in {"used_ids", "generated_ids"}} + + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=f"{ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client_admin.get, url=f"{ADMIN_ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + +def test_create_one__unauthorized_entities( + db, + client_user_1, + client_user_2, + json_data, + subject_id, + brain_region_id, +): + """Do not allow associations with entities that are not authorized to the user.""" + + user1_private_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_private_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_public_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=True, + ) + check_activity_create_one__unauthorized_entities( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_private_generated_id, + u2_private_entity_id=user2_private_generated_id, + u2_public_entity_id=user2_public_generated_id, + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def morphology_id_1(client, subject_id, brain_region_id): + return create_cell_morphology_id( + client=client, + subject_id=subject_id, + brain_region_id=brain_region_id, + ) + + +@pytest.fixture +def morphology_id_2(client, subject_id, brain_region_id): + return create_cell_morphology_id( + client=client, + subject_id=subject_id, + brain_region_id=brain_region_id, + ) + + +@pytest.fixture +def models(create_id, task_config_id, morphology_id_1, morphology_id_2): + return [ + create_id( + used_ids=[task_config_id], + generated_ids=[], + ), + create_id( + used_ids=[task_config_id], + generated_ids=[morphology_id_1], + ), + create_id( + used_ids=[task_config_id], + generated_ids=[morphology_id_1, morphology_id_2], + ), + create_id( + used_ids=[], + generated_ids=[], + ), + ] + + +def test_filtering(client, models, task_config_id, morphology_id_1, morphology_id_2): + data = assert_request(client.get, url=ROUTE).json()["data"] + assert len(data) == len(models) + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id": task_config_id}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id": morphology_id_1}, + ).json()["data"] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={ + "used__id": task_config_id, + "generated__id": morphology_id_1, + }, + ).json()["data"] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id__in": [task_config_id]}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": [morphology_id_2]}, + ).json()["data"] + assert len(data) == 1 + + data = assert_request( + client.get, + url=ROUTE, + params={"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}, + ).json()["data"] + assert len(data) == len(models) + + for executor, count in ( + (ExecutorType.single_node_job, len(models)), + (ExecutorType.distributed_job, 0), + ): + data = assert_request( + client.get, + url=ROUTE, + params={"executor": str(executor)}, + ).json()["data"] + assert len(data) == count + + +def test_delete_one(db, clients, json_data): + check_activity_delete_one( + db=db, + clients=clients, + json_data=json_data, + route=ROUTE, + admin_route=ADMIN_ROUTE, + expected_counts_before={ + TaskConfig: 1, + TaskActivity: 1, + Usage: 1, + Generation: 1, + }, + expected_counts_after={ + TaskConfig: 1, + TaskActivity: 0, + Usage: 0, + Generation: 0, + }, + ) + + +def test_update_one( + client, + client_admin, + task_config_id, + morphology_id, + create_id, +): + check_activity_update_one( + client=client, + client_admin=client_admin, + route=ROUTE, + admin_route=ADMIN_ROUTE, + used_id=task_config_id, + generated_id=morphology_id, + constructor_func=create_id, + ) + + +def test_update_one__fail_if_generated_ids_unauthorized( + db, client_user_1, client_user_2, json_data, subject_id, brain_region_id +): + """Test that it is not allowed to update generated_ids with unauthorized entities.""" + + user1_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_activity_update_one__fail_if_generated_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_generated_id, + u2_private_entity_id=user2_generated_id, + ) + + +def test_update_one__fail_if_generated_ids_exists(client, morphology_id, task_config_id, create_id): + """Test activity Generation associations cannot be updated if they already exist.""" + check_activity_update_one__fail_if_generated_ids_exists( + client=client, + route=ROUTE, + entity_id_1=task_config_id, + entity_id_2=morphology_id, + constructor_func=create_id, + ) diff --git a/tests/test_task_config.py b/tests/test_task_config.py new file mode 100644 index 00000000..347523f6 --- /dev/null +++ b/tests/test_task_config.py @@ -0,0 +1,230 @@ +import pytest + +from app.db.model import TaskConfig +from app.db.types import EntityType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + assert_request, + check_authorization, + check_entity_delete_one, + check_entity_read_response, + check_entity_update_one, + check_entity_update_one__fail_if_nested_ids_exists, + check_entity_update_one__fail_if_nested_ids_unauthorized, + check_missing, + check_pagination, + create_cell_morphology_id, +) + +ROUTE = "task-config" +ADMIN_ROUTE = "/admin/task-config" +TASK_CONFIG_TYPE = "skeletonization__config" + + +@pytest.fixture +def json_data(task_config_json_data): + return task_config_json_data + + +@pytest.fixture +def public_json_data(json_data): + return json_data | {"authorized_public": True} + + +@pytest.fixture +def create_id(client, task_config_with_parent_json_data): + def _create_id(**kwargs): + return assert_request( + client.post, + url=ROUTE, + json=task_config_with_parent_json_data | kwargs, + ).json()["id"] + + return _create_id + + +@pytest.fixture +def model_id(task_config_id): + return task_config_id + + +def _assert_read_response(data, json_data): + check_entity_read_response(data, json_data, EntityType.task_config) + assert "inputs" in data + assert "task_config_generator_id" in data + assert data["meta"] == json_data["meta"] + + +def test_create_one(client, json_data): + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_create_one_with_parent(client, task_config_with_parent_json_data): + json_data = task_config_with_parent_json_data + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + assert ( + json_data["task_config_generator_id"] + == task_config_with_parent_json_data["task_config_generator_id"] + ) + + +def test_create_one_with_nested_relationships( + client, task_config_with_nested_relationships_json_data +): + json_data = task_config_with_nested_relationships_json_data + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + inputs = json_data["inputs"] + assert data["inputs"] == [ + { + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + "id": inputs[0]["id"], + "type": "em_cell_mesh", + }, + ] + + +def test_read_one(clients, model_id, json_data): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=f"{ROUTE}").json()["data"] + assert len(data) == 1 + _assert_read_response(data[0], json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_delete_one(db, clients, public_json_data): + check_entity_delete_one( + db=db, + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=public_json_data, + expected_counts_before={ + TaskConfig: 1, + }, + expected_counts_after={ + TaskConfig: 0, + }, + ) + + +def test_update_one(clients, json_data): + check_entity_update_one( + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=json_data, + patch_payload={ + "name": "name", + "description": "description", + }, + optional_payload=None, + ) + + +def test_update_one__fail_if_nested_ids_unauthorized( + db, client_user_1, client_user_2, json_data, subject_id, brain_region_id +): + """Test that it is not allowed to update the nested ids with unauthorized entities.""" + + user2_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u2_private_entity_id=user2_generated_id, + relationship_key="inputs", + ) + + +def test_update_one__fail_if_nested_ids_exists( + db, client_user_1, json_data, subject_id, brain_region_id +): + """Test that nested ids cannot be updated if they already exist.""" + user1_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_exists( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_generated_id, + relationship_key="inputs", + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_authorization(clients, public_json_data): + check_authorization(ROUTE, clients.user_1, clients.user_2, clients.no_project, public_json_data) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def models(create_id): + return [create_id(name=f"config-{i}") for i in range(3)] + + +def test_filtering_ordering(client, models, public_campaign_id): + def _req(query): + # always filter by task_config_type to exclude the parent campaign + return assert_request( + client.get, + url=ROUTE, + params=query | {"task_config_type": TASK_CONFIG_TYPE}, + ).json()["data"] + + data = _req({}) + assert len(data) == len(models) + + data = _req({"name__ilike": "config"}) + assert len(data) == len(models) + + data = _req({"name": "config-0"}) + assert len(data) == 1 + assert data[0]["name"] == "config-0" + + data = _req({"task_config_generator_id": public_campaign_id}) + assert len(data) == len(models) + + data = _req({"task_config_generator_id__in": [public_campaign_id]}) + assert len(data) == len(models) + + data = _req({"order_by": "-name"}) + assert [d["name"] for d in data] == ["config-2", "config-1", "config-0"] + + data = _req({"order_by": "-name", "name__in": ["config-1", "config-2"]}) + assert [d["name"] for d in data] == ["config-2", "config-1"] + + data = _req({"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}) + assert len(data) == len(models) + + data = _req({"ilike_search": "*description*"}) + assert len(data) == len(models) + + data = _req({"ilike_search": "config-1"}) + assert len(data) == 1 diff --git a/tests/utils.py b/tests/utils.py index f022964a..1e043946 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1646,11 +1646,68 @@ def check_activity_update_one__fail_if_generated_ids_exists( "generated_ids": [str(entity_id_1)], } data = assert_request( - client.patch, url=f"{route}/{gen1}", json=update_json, expected_status_code=404 + client.patch, url=f"{route}/{gen1}", json=update_json, expected_status_code=409 ).json() assert data["details"] == "It is forbidden to update generated_ids if they exist." +def check_entity_update_one__fail_if_nested_ids_unauthorized( + db, + route, + client_user_1, + json_data, + u2_private_entity_id, + relationship_key, +): + """Test that it is not allowed to update nested ids with unauthorized entities.""" + # sanity check to ensure that authorized_project_id and authorized_public are consistent + e2 = _get_entity(db, entity_id=u2_private_entity_id) + assert e2.authorized_public is False + + # create an entity without relationships + json_data |= { + relationship_key: [], + } + data = assert_request(client_user_1.post, url=route, json=json_data).json() + + # update the entity with invalid relationships + update_json = { + relationship_key: [{"id": str(u2_private_entity_id)}], + } + data = assert_request( + client_user_1.patch, url=f"{route}/{data['id']}", json=update_json, expected_status_code=404 + ).json() + assert data["details"] == f"Cannot access entities {u2_private_entity_id}" + + +def check_entity_update_one__fail_if_nested_ids_exists( + db, + route, + client_user_1, + json_data, + u1_private_entity_id, + relationship_key, +): + # sanity check to ensure that authorized_project_id and authorized_public are consistent + e1 = _get_entity(db, entity_id=u1_private_entity_id) + assert e1.authorized_public is False + + # create an entity with valid relationships + json_data |= { + relationship_key: [{"id": str(u1_private_entity_id)}], + } + data = assert_request(client_user_1.post, url=route, json=json_data).json() + + # update the entity when the nested relationships exist already + update_json = { + relationship_key: [{"id": str(u1_private_entity_id)}], + } + data = assert_request( + client_user_1.patch, url=f"{route}/{data['id']}", json=update_json, expected_status_code=409 + ).json() + assert data["details"] == f"It is forbidden to update {relationship_key} if they exist." + + def s3_key_exists(s3_client, key: str, storage_type=StorageType.aws_s3_internal) -> bool: bucket = storages[storage_type].bucket