diff --git a/alembic/versions/20260227_153548_a873e3bb7ed3_add_generic_campaign_models.py b/alembic/versions/20260227_153548_a873e3bb7ed3_add_generic_campaign_models.py new file mode 100644 index 00000000..29e3f6f3 --- /dev/null +++ b/alembic/versions/20260227_153548_a873e3bb7ed3_add_generic_campaign_models.py @@ -0,0 +1,365 @@ +"""Add generic campaign models + +Revision ID: a873e3bb7ed3 +Revises: 523e523531a7 +Create Date: 2026-02-27 15:35:48.007178 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_postgresql_enum import TableReference +from sqlalchemy.dialects import postgresql + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "a873e3bb7ed3" +down_revision: Union[str, None] = "523e523531a7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum( + "circuit_simulation", + "circuit_extraction", + "ion_channel_modeling", + "skeletonization", + "ion_channel_simulation", + "em_synapse_mapping", + name="tasktype", + ).create(op.get_bind()) + op.create_table( + "campaign", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "task_type", + postgresql.ENUM( + "circuit_simulation", + "circuit_extraction", + "ion_channel_modeling", + "skeletonization", + "ion_channel_simulation", + "em_synapse_mapping", + name="tasktype", + create_type=False, + ), + nullable=False, + ), + sa.Column( + "scan_parameters", + postgresql.JSONB(astext_type=sa.Text()), + server_default="{}", + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("description_vector", postgresql.TSVECTOR(), nullable=True), + sa.ForeignKeyConstraint(["id"], ["entity.id"], name=op.f("fk_campaign_id_entity")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_campaign")), + ) + op.create_index( + "ix_campaign_description_vector", + "campaign", + ["description_vector"], + unique=False, + postgresql_using="gin", + ) + op.create_index(op.f("ix_campaign_name"), "campaign", ["name"], unique=False) + op.create_index(op.f("ix_campaign_task_type"), "campaign", ["task_type"], unique=False) + op.create_table( + "task_config_generation", + sa.Column("id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["id"], ["activity.id"], name=op.f("fk_config_generation_id_activity") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_config_generation")), + ) + op.create_table( + "task_execution", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "executor", + postgresql.ENUM( + "single_node_job", + "distributed_job", + "jupyter_notebook", + name="executortype", + create_type=False, + ), + nullable=True, + ), + sa.Column("execution_id", sa.Uuid(), nullable=True), + sa.ForeignKeyConstraint( + ["id"], ["activity.id"], name=op.f("fk_task_execution_id_activity") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_task_execution")), + ) + op.create_table( + "campaign__entity", + sa.Column("campaign_id", sa.Uuid(), nullable=False), + sa.Column("entity_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["campaign_id"], + ["campaign.id"], + name=op.f("fk_campaign__entity_campaign_id_campaign"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["entity_id"], + ["entity.id"], + name=op.f("fk_campaign__entity_entity_id_entity"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("campaign_id", "entity_id", name=op.f("pk_campaign__entity")), + ) + op.create_table( + "task_config", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "task_type", + postgresql.ENUM( + "circuit_simulation", + "circuit_extraction", + "ion_channel_modeling", + "skeletonization", + "ion_channel_simulation", + "em_synapse_mapping", + name="tasktype", + create_type=False, + ), + nullable=False, + ), + sa.Column( + "scan_parameters", + postgresql.JSONB(astext_type=sa.Text()), + server_default="{}", + nullable=False, + ), + sa.Column("campaign_id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("description_vector", postgresql.TSVECTOR(), nullable=True), + sa.ForeignKeyConstraint( + ["campaign_id"], ["campaign.id"], name=op.f("fk_task_config_campaign_id_campaign") + ), + sa.ForeignKeyConstraint(["id"], ["entity.id"], name=op.f("fk_task_config_id_entity")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_task_config")), + ) + op.create_index( + op.f("ix_task_config_campaign_id"), "task_config", ["campaign_id"], unique=False + ) + op.create_index( + "ix_task_config_description_vector", + "task_config", + ["description_vector"], + unique=False, + postgresql_using="gin", + ) + op.create_index(op.f("ix_task_config_name"), "task_config", ["name"], unique=False) + op.create_index(op.f("ix_task_config_task_type"), "task_config", ["task_type"], unique=False) + op.create_table( + "task_config__entity", + sa.Column("task_config_id", sa.Uuid(), nullable=False), + sa.Column("entity_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["entity_id"], + ["entity.id"], + name=op.f("fk_task_config__entity_entity_id_entity"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["task_config_id"], + ["task_config.id"], + name=op.f("fk_task_config__entity_task_config_id_task_config"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("task_config_id", "entity_id", name=op.f("pk_task_config__entity")), + ) + op.sync_enum_values( + enum_schema="public", + enum_name="activitytype", + new_values=[ + "simulation_execution", + "simulation_generation", + "validation", + "calibration", + "analysis_notebook_execution", + "ion_channel_modeling_execution", + "ion_channel_modeling_config_generation", + "circuit_extraction_config_generation", + "circuit_extraction_execution", + "skeletonization_execution", + "skeletonization_config_generation", + "task_config_generation", + "task_execution", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="activity", column_name="type") + ], + enum_values_to_rename=[], + ) + op.sync_enum_values( + enum_schema="public", + enum_name="entitytype", + new_values=[ + "analysis_software_source_code", + "brain_atlas", + "brain_atlas_region", + "cell_composition", + "cell_morphology", + "cell_morphology_protocol", + "electrical_cell_recording", + "electrical_recording", + "electrical_recording_stimulus", + "emodel", + "experimental_bouton_density", + "experimental_neuron_density", + "experimental_synapses_per_connection", + "external_url", + "ion_channel_model", + "ion_channel_modeling_campaign", + "ion_channel_modeling_config", + "ion_channel_recording", + "memodel", + "memodel_calibration_result", + "me_type_density", + "simulation", + "simulation_campaign", + "simulation_result", + "scientific_artifact", + "single_neuron_simulation", + "single_neuron_synaptome", + "single_neuron_synaptome_simulation", + "subject", + "validation_result", + "circuit", + "circuit_extraction_campaign", + "circuit_extraction_config", + "em_dense_reconstruction_dataset", + "em_cell_mesh", + "analysis_notebook_template", + "analysis_notebook_environment", + "analysis_notebook_result", + "skeletonization_config", + "skeletonization_campaign", + "campaign", + "task_config", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="entity", column_name="type"), + TableReference( + table_schema="public", table_name="measurement_label", column_name="entity_type" + ), + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="entitytype", + new_values=[ + "analysis_software_source_code", + "brain_atlas", + "brain_atlas_region", + "cell_composition", + "cell_morphology", + "cell_morphology_protocol", + "electrical_cell_recording", + "electrical_recording", + "electrical_recording_stimulus", + "emodel", + "experimental_bouton_density", + "experimental_neuron_density", + "experimental_synapses_per_connection", + "external_url", + "ion_channel_model", + "ion_channel_modeling_campaign", + "ion_channel_modeling_config", + "ion_channel_recording", + "memodel", + "memodel_calibration_result", + "me_type_density", + "simulation", + "simulation_campaign", + "simulation_result", + "scientific_artifact", + "single_neuron_simulation", + "single_neuron_synaptome", + "single_neuron_synaptome_simulation", + "subject", + "validation_result", + "circuit", + "circuit_extraction_campaign", + "circuit_extraction_config", + "em_dense_reconstruction_dataset", + "em_cell_mesh", + "analysis_notebook_template", + "analysis_notebook_environment", + "analysis_notebook_result", + "skeletonization_config", + "skeletonization_campaign", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="entity", column_name="type"), + TableReference( + table_schema="public", table_name="measurement_label", column_name="entity_type" + ), + ], + enum_values_to_rename=[], + ) + op.sync_enum_values( + enum_schema="public", + enum_name="activitytype", + new_values=[ + "simulation_execution", + "simulation_generation", + "validation", + "calibration", + "analysis_notebook_execution", + "ion_channel_modeling_execution", + "ion_channel_modeling_config_generation", + "circuit_extraction_config_generation", + "circuit_extraction_execution", + "skeletonization_execution", + "skeletonization_config_generation", + ], + affected_columns=[ + TableReference(table_schema="public", table_name="activity", column_name="type") + ], + enum_values_to_rename=[], + ) + op.drop_table("task_config__entity") + op.drop_index(op.f("ix_task_config_task_type"), table_name="task_config") + op.drop_index(op.f("ix_task_config_name"), table_name="task_config") + op.drop_index( + "ix_task_config_description_vector", table_name="task_config", postgresql_using="gin" + ) + op.drop_index(op.f("ix_task_config_campaign_id"), table_name="task_config") + op.drop_table("task_config") + op.drop_table("campaign__entity") + op.drop_table("task_execution") + op.drop_table("task_config_generation") + op.drop_index(op.f("ix_campaign_task_type"), table_name="campaign") + op.drop_index(op.f("ix_campaign_name"), table_name="campaign") + op.drop_index("ix_campaign_description_vector", table_name="campaign", postgresql_using="gin") + op.drop_table("campaign") + sa.Enum( + "circuit_simulation", + "circuit_extraction", + "ion_channel_modeling", + "skeletonization", + "ion_channel_simulation", + "em_synapse_mapping", + name="tasktype", + ).drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/alembic/versions/20260227_153555_b21b11d836a8_update_triggers.py b/alembic/versions/20260227_153555_b21b11d836a8_update_triggers.py new file mode 100644 index 00000000..5918e58c --- /dev/null +++ b/alembic/versions/20260227_153555_b21b11d836a8_update_triggers.py @@ -0,0 +1,103 @@ +"""Update triggers + +Revision ID: b21b11d836a8 +Revises: a873e3bb7ed3 +Create Date: 2026-02-27 15:35:55.189216 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_utils.pg_function import PGFunction +from sqlalchemy import text as sql_text +from alembic_utils.pg_trigger import PGTrigger +from sqlalchemy import text as sql_text + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "b21b11d836a8" +down_revision: Union[str, None] = "a873e3bb7ed3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + public_task_config_task_config_description_vector = PGTrigger( + schema="public", + signature="task_config_description_vector", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.create_entity(public_task_config_task_config_description_vector) + + public_campaign_campaign_description_vector = PGTrigger( + schema="public", + signature="campaign_description_vector", + on_entity="public.campaign", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON campaign\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.create_entity(public_campaign_campaign_description_vector) + + public_auth_fnc_task_config_campaign_id = PGFunction( + schema="public", + signature="auth_fnc_task_config_campaign_id()", + definition="RETURNS TRIGGER AS $$\n BEGIN\n \n IF NOT EXISTS (\n SELECT 1 FROM entity e1\n JOIN entity e2 ON e2.id = NEW.id\n WHERE e1.id = NEW.campaign_id\n AND (e1.authorized_public = TRUE\n OR (e2.authorized_public = FALSE\n AND e1.authorized_project_id = e2.authorized_project_id\n )\n )\n ) THEN\n RAISE EXCEPTION 'unauthorized private reference\\: task_config.campaign_id'\n USING ERRCODE = '42501'; -- Insufficient Privilege\n END IF;\n RETURN NEW;\n END;\n $$ LANGUAGE plpgsql", + ) + op.create_entity(public_auth_fnc_task_config_campaign_id) + + public_task_config_auth_trg_task_config_campaign_id = PGTrigger( + schema="public", + signature="auth_trg_task_config_campaign_id", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION auth_fnc_task_config_campaign_id()", + ) + op.create_entity(public_task_config_auth_trg_task_config_campaign_id) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + public_task_config_auth_trg_task_config_campaign_id = PGTrigger( + schema="public", + signature="auth_trg_task_config_campaign_id", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION auth_fnc_task_config_campaign_id()", + ) + op.drop_entity(public_task_config_auth_trg_task_config_campaign_id) + + public_auth_fnc_task_config_campaign_id = PGFunction( + schema="public", + signature="auth_fnc_task_config_campaign_id()", + definition="RETURNS TRIGGER AS $$\n BEGIN\n \n IF NOT EXISTS (\n SELECT 1 FROM entity e1\n JOIN entity e2 ON e2.id = NEW.id\n WHERE e1.id = NEW.campaign_id\n AND (e1.authorized_public = TRUE\n OR (e2.authorized_public = FALSE\n AND e1.authorized_project_id = e2.authorized_project_id\n )\n )\n ) THEN\n RAISE EXCEPTION 'unauthorized private reference\\: task_config.campaign_id'\n USING ERRCODE = '42501'; -- Insufficient Privilege\n END IF;\n RETURN NEW;\n END;\n $$ LANGUAGE plpgsql", + ) + op.drop_entity(public_auth_fnc_task_config_campaign_id) + + public_campaign_campaign_description_vector = PGTrigger( + schema="public", + signature="campaign_description_vector", + on_entity="public.campaign", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON campaign\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.drop_entity(public_campaign_campaign_description_vector) + + public_task_config_task_config_description_vector = PGTrigger( + schema="public", + signature="task_config_description_vector", + on_entity="public.task_config", + is_constraint=False, + definition="BEFORE INSERT OR UPDATE ON task_config\n FOR EACH ROW EXECUTE FUNCTION\n tsvector_update_trigger(description_vector, 'pg_catalog.english', description, name)", + ) + op.drop_entity(public_task_config_task_config_description_vector) + + # ### end Alembic commands ### diff --git a/app/db/auth.py b/app/db/auth.py index 352903ec..6971d3c1 100644 --- a/app/db/auth.py +++ b/app/db/auth.py @@ -4,11 +4,9 @@ from pydantic import UUID4 from sqlalchemy import Delete, Select, and_, false, not_, or_, select, true -from sqlalchemy.orm import Query, Session +from sqlalchemy.orm import Query -from app.config import settings -from app.db.model import Entity, Identifiable -from app.queries.utils import get_user +from app.db.model import Entity from app.schemas.auth import UserContext @@ -93,34 +91,3 @@ def select_unauthorized_entities(ids: list[UUID4], project_id: UUID4 | None) -> ), ) ) - - -def is_user_authorized_for_deletion( # noqa: PLR0911 - db: Session, user_context: UserContext, obj: Identifiable -) -> bool: - if settings.APP_DISABLE_AUTH: - return True - - # if there is no authorized_project_id it is a global resource - if not (project_id := getattr(obj, "authorized_project_id", None)): - return False - - # Service maintainers may delete public/private entities within their projects. - if user_context.is_service_maintainer: - return project_id in user_context.user_project_ids - - # from here and below public entities cannot be deleted - if obj.authorized_public: # pyright: ignore [reportAttributeAccessIssue] - return False - - # Project admins may delete private entities within their projects - if project_id in user_context.admin_project_ids: - return True - - # Project members may delete only the private entities they themselves created - if project_id in user_context.member_project_ids and ( - db_user := get_user(db, user_context.profile.subject) - ): - return db_user.created_by_id == obj.created_by_id - - return False diff --git a/app/db/model.py b/app/db/model.py index a1d7a9f0..e62df16d 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -72,6 +72,7 @@ StainingType, StorageType, StructuralDomain, + TaskType, ValidationStatus, ) from app.schemas.publication import Author @@ -2249,4 +2250,155 @@ class SkeletonizationExecution(Activity, ExecutionActivityMixin): __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 +class CampaignToEntity(Base): + """Represents the many-to-many associations between campaigns and entities used as input.""" + + __tablename__ = "campaign__entity" + + campaign_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey(f"{EntityType.campaign}.id", ondelete="CASCADE"), + primary_key=True, + ) + entity_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("entity.id", ondelete="CASCADE"), + primary_key=True, + ) + + +class TaskConfigToEntity(Base): + """Represents the many-to-many associations between task configs and entities used as input.""" + + __tablename__ = "task_config__entity" + + task_config_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey(f"{EntityType.task_config}.id", ondelete="CASCADE"), + primary_key=True, + ) + entity_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("entity.id", ondelete="CASCADE"), + primary_key=True, + ) + + +class Campaign(NameDescriptionVectorMixin, Entity): + """Represents a generic campaign in the database. + + Assets: + - campaign configuration file. + + Attributes: + id (uuid.UUID): Primary key referencing the entity ID. + task_type: Type of task. + scan_parameters (JSON_DICT): Scan parameters for the campaign. + input: entities used as input for the campaign. + + Potential mappings for existing campaigns: + SimulationCampaign: + entity_id/entity -> input[0] + simulations -> task_config_generation.used[] + SkeletonizationCampaign: + input_meshes -> input + skeletonization_configs -> task_config_generation.used[] + CircuitExtractionCampaign: + IonChannelModelingCampaign: + input_recordings -> input + ion_channel_modeling_configs -> task_config_generation.used[] + """ + + __tablename__ = EntityType.campaign.value + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + task_type: Mapped[TaskType] = mapped_column(index=True) + scan_parameters: Mapped[JSON_DICT] = mapped_column(default={}, server_default="{}") + + input: Mapped[list[Entity]] = relationship( + primaryjoin="Campaign.id == CampaignToEntity.campaign_id", + secondary="campaign__entity", + ) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + +class TaskConfig(NameDescriptionVectorMixin, Entity): + """Represents the configuration of a task of the campaign in the database. + + Assets: + - task configuration file. + + Attributes: + id (uuid.UUID): Primary key referencing the entity ID. + task_type: Type of task. + scan_parameters (JSON_DICT): Scan parameters for the task. + campaign_id: id of the campaign that generated the config. + input: entities used as input for the task. + + Potential mappings for existing campaigns: + Simulation: + simulation_campaign_id -> campaign_id + entity/entity_id -> input[0] + number_neurons -> MISSING + SkeletonizationConfig: + skeletonization_campaign_id -> campaign_id + em_cell_mesh_id -> input[0] + CircuitExtractionConfig: + circuit_id -> input[0] + NOTHING -> campaign_id + IonChannelModelingConfig: + ion_channel_modeling_campaign_id -> campaign_id + """ + + __tablename__ = EntityType.task_config.value + + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + task_type: Mapped[TaskType] = mapped_column(index=True) + scan_parameters: Mapped[JSON_DICT] = mapped_column(default={}, server_default="{}") + campaign_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey(f"{EntityType.campaign}.id"), index=True + ) + input: Mapped[list[Entity]] = relationship( + primaryjoin="TaskConfig.id == TaskConfigToEntity.task_config_id", + secondary="task_config__entity", + ) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + +class TaskConfigGeneration(Activity): + """Represents an activity generating the task configurations in a campaign. + + Inputs (used): + - Campaign (one) + Outputs (generated): + - TaskConfig (many) + + Attributes: + id (uuid.UUID): Primary key referencing the activity ID. + """ + + __tablename__ = ActivityType.task_config_generation.value + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("activity.id"), primary_key=True) + # task_type: Mapped[TaskType] = mapped_column(index=True) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + +class TaskExecution(Activity, ExecutionActivityMixin): + """Represents an activity executing a task in a campaign. + + Inputs (used): + - TaskConfig (one) + Outputs (generated): + - Entity (many) + + Attributes: + id (uuid.UUID): Primary key referencing the activity ID. + """ + + __tablename__ = ActivityType.task_execution.value + + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("activity.id"), primary_key=True) + # task_type: Mapped[TaskType] = mapped_column(index=True) + + __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + register_model_events() diff --git a/app/db/triggers.py b/app/db/triggers.py index 452a417e..c4e6435b 100644 --- a/app/db/triggers.py +++ b/app/db/triggers.py @@ -31,6 +31,7 @@ SingleNeuronSynaptome, SingleNeuronSynaptomeSimulation, SkeletonizationConfig, + TaskConfig, ValidationResult, ) @@ -188,6 +189,7 @@ def unauthorized_private_reference_trigger(model: type[Entity], field_name: str) (IonChannelModelingConfig, "ion_channel_modeling_campaign_id"), (SkeletonizationConfig, "skeletonization_campaign_id"), (SkeletonizationConfig, "em_cell_mesh_id"), + (TaskConfig, "campaign_id"), ] entities = [ diff --git a/app/db/types.py b/app/db/types.py index 1cd7a3f2..6b6efb38 100644 --- a/app/db/types.py +++ b/app/db/types.py @@ -141,6 +141,19 @@ class EntityType(StrEnum): analysis_notebook_result = auto() skeletonization_config = auto() skeletonization_campaign = auto() + campaign = auto() + task_config = auto() + + +class TaskType(StrEnum): + """Task types for campaigns.""" + + circuit_simulation = auto() + circuit_extraction = auto() + ion_channel_modeling = auto() + skeletonization = auto() + ion_channel_simulation = auto() + em_synapse_mapping = auto() class AgentType(StrEnum): @@ -197,6 +210,8 @@ class ActivityType(StrEnum): circuit_extraction_execution = auto() skeletonization_execution = auto() skeletonization_config_generation = auto() + task_config_generation = auto() + task_execution = auto() class DerivationType(StrEnum): diff --git a/app/filters/campaign.py b/app/filters/campaign.py new file mode 100644 index 00000000..b23dad6c --- /dev/null +++ b/app/filters/campaign.py @@ -0,0 +1,26 @@ +from typing import Annotated + +from app.db.model import Campaign +from app.db.types import TaskType +from app.dependencies.filter import FilterDepends +from app.filters.base import CustomFilter +from app.filters.common import ILikeSearchFilterMixin, NameFilterMixin +from app.filters.entity import EntityFilterMixin +from app.filters.task_config import NestedTaskConfigFilter, NestedTaskConfigFilterDep + + +class CampaignFilter(CustomFilter, EntityFilterMixin, NameFilterMixin, ILikeSearchFilterMixin): + task_type: TaskType | None = None + task_config: Annotated[ + NestedTaskConfigFilter | None, + NestedTaskConfigFilterDep, + ] = None + + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = Campaign + ordering_model_fields = ["creation_date", "update_date", "name"] # noqa: RUF012 + + +CampaignFilterDep = Annotated[CampaignFilter, FilterDepends(CampaignFilter)] diff --git a/app/filters/task_config.py b/app/filters/task_config.py new file mode 100644 index 00000000..cfe6bea0 --- /dev/null +++ b/app/filters/task_config.py @@ -0,0 +1,35 @@ +import uuid +from typing import Annotated + +from fastapi_filter import with_prefix + +from app.db.model import TaskConfig +from app.db.types import TaskType +from app.dependencies.filter import FilterDepends +from app.filters.base import CustomFilter +from app.filters.common import IdFilterMixin, ILikeSearchFilterMixin, NameFilterMixin +from app.filters.entity import EntityFilterMixin + + +class TaskConfigFilterBase(NameFilterMixin, IdFilterMixin, CustomFilter): + task_type: TaskType | None = None + + +class NestedTaskConfigFilter(TaskConfigFilterBase): + class Constants(CustomFilter.Constants): + model = TaskConfig + + +class TaskConfigFilter(EntityFilterMixin, TaskConfigFilterBase, ILikeSearchFilterMixin): + campaign_id: uuid.UUID | None = None + campaign_id__in: list[uuid.UUID] | None = None + + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = TaskConfig + ordering_model_fields = ["creation_date", "update_date", "name"] # noqa: RUF012 + + +TaskConfigFilterDep = Annotated[TaskConfigFilter, FilterDepends(TaskConfigFilter)] +NestedTaskConfigFilterDep = FilterDepends(with_prefix("task_config", NestedTaskConfigFilter)) diff --git a/app/filters/task_config_generation.py b/app/filters/task_config_generation.py new file mode 100644 index 00000000..52c5056d --- /dev/null +++ b/app/filters/task_config_generation.py @@ -0,0 +1,19 @@ +from typing import Annotated + +from app.db.model import TaskConfigGeneration +from app.dependencies.filter import FilterDepends +from app.filters.activity import ActivityFilterMixin +from app.filters.base import CustomFilter + + +class TaskConfigGenerationFilter(CustomFilter, ActivityFilterMixin): + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = TaskConfigGeneration + ordering_model_fields = ["creation_date", "update_date"] # noqa: RUF012 + + +TaskConfigGenerationFilterDep = Annotated[ + TaskConfigGenerationFilter, FilterDepends(TaskConfigGenerationFilter) +] diff --git a/app/filters/task_execution.py b/app/filters/task_execution.py new file mode 100644 index 00000000..00f3e968 --- /dev/null +++ b/app/filters/task_execution.py @@ -0,0 +1,17 @@ +from typing import Annotated + +from app.db.model import TaskExecution +from app.dependencies.filter import FilterDepends +from app.filters.activity import ActivityFilterMixin, ExecutionActivityFilterMixin +from app.filters.base import CustomFilter + + +class TaskExecutionFilter(CustomFilter, ActivityFilterMixin, ExecutionActivityFilterMixin): + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = TaskExecution + ordering_model_fields = ["creation_date", "update_date"] # noqa: RUF012 + + +TaskExecutionFilterDep = Annotated[TaskExecutionFilter, FilterDepends(TaskExecutionFilter)] diff --git a/app/queries/common.py b/app/queries/common.py index 7114ef02..8c6e9b07 100644 --- a/app/queries/common.py +++ b/app/queries/common.py @@ -2,7 +2,6 @@ from http import HTTPStatus import sqlalchemy as sa -from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session from sqlalchemy.sql import operators @@ -10,10 +9,8 @@ from app.db.auth import ( constrain_to_readable_entities, constrain_to_writable_entities, - is_user_authorized_for_deletion, - select_unauthorized_entities, ) -from app.db.model import Activity, Generation, Identifiable, Usage +from app.db.model import Activity, Identifiable from app.db.utils import get_declaring_class, load_db_model_from_pydantic, update_model from app.dependencies.common import ( FacetQueryParams, @@ -32,14 +29,22 @@ ) from app.filters.base import Aliases, CustomFilter from app.queries import crud +from app.queries.constants import NESTED_ACTIVITY_RELATIONSHIPS from app.queries.filter import filter_from_db -from app.queries.types import ApplyOperations, SupportsModelValidate -from app.queries.utils import get_or_create_user_agent +from app.queries.types import ( + ApplyOperations, + NestedRelationships, + SupportsModelValidate, +) +from app.queries.utils import ( + create_associations_to_entities, + get_or_create_user_agent, + is_user_authorized_for_deletion, +) from app.schemas.activity import ActivityCreate, ActivityUpdate from app.schemas.auth import UserContext, UserContextWithProjectId from app.schemas.routers import DeleteResponse from app.schemas.types import ListResponse, PaginationResponse -from app.schemas.utils import NOT_SET def router_read_one[T: BaseModel, I: Identifiable]( @@ -103,7 +108,7 @@ def router_create_activity_one[T: BaseModel, I: Activity]( created_by_id=created_by_id, updated_by_id=updated_by_id, authorized_project_id=project_id, - ignore_attributes={"used_ids", "generated_ids"}, + ignore_attributes=set(NESTED_ACTIVITY_RELATIONSHIPS), ) with ( @@ -117,30 +122,14 @@ def router_create_activity_one[T: BaseModel, I: Activity]( db.add(db_model_instance) db.flush() - if associated_ids := json_model.used_ids + json_model.generated_ids: - if ( - unaccessible_entities := db.execute( - select_unauthorized_entities(associated_ids, user_context.project_id) - ) - .scalars() - .all() - ): - raise HTTPException( - status_code=404, - detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", - ) - - for entity_id in json_model.used_ids: - db.add(Usage(usage_entity_id=entity_id, usage_activity_id=db_model_instance.id)) - - for entity_id in json_model.generated_ids: - db.add( - Generation( - generation_entity_id=entity_id, generation_activity_id=db_model_instance.id - ) - ) - - db.flush() + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=NESTED_ACTIVITY_RELATIONSHIPS, + project_id=db_model_instance.authorized_project_id, + action="create", + ) if apply_operations: q = sa.select(db_model_class).where(db_model_class.id == db_model_instance.id) @@ -160,6 +149,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( response_schema_class: SupportsModelValidate[T], apply_operations: ApplyOperations | None = None, embedding: list[float] | None = None, + nested_relationships: NestedRelationships | None = None, ) -> T: """Create a model in the database. @@ -171,6 +161,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( response_schema_class: Pydantic schema class for the returned data. apply_operations: transformer function that modifies the select query. embedding: optional embedding vector to attach to the model. + nested_relationships: mapping of nested relationships that can be set automatically. Returns: the written model data as a Pydantic model. @@ -192,6 +183,7 @@ def router_create_one[T: BaseModel, I: Identifiable]( created_by_id=created_by_id, updated_by_id=updated_by_id, authorized_project_id=project_id, + ignore_attributes=set(nested_relationships) if nested_relationships else None, ) if embedding is not None and hasattr(db_model_instance, "embedding"): @@ -208,6 +200,16 @@ def router_create_one[T: BaseModel, I: Identifiable]( db.add(db_model_instance) db.flush() + if nested_relationships: + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=nested_relationships, + project_id=getattr(db_model_instance, "authorized_project_id", None), + action="create", + ) + if apply_operations: q = sa.select(db_model_class).where(db_model_class.id == db_model_instance.id) q = apply_operations(q) @@ -395,6 +397,7 @@ def router_update_one[T: BaseModel, I: Identifiable]( json_model: BaseModel, response_schema_class: SupportsModelValidate[T], apply_operations: ApplyOperations | None = None, + nested_relationships: NestedRelationships | None = None, ): query = ( sa.select(db_model_class).where(db_model_class.id == id_).with_for_update(of=db_model_class) @@ -407,14 +410,29 @@ def router_update_one[T: BaseModel, I: Identifiable]( query = apply_operations(query) with ensure_result(error_message=f"{db_model_class.__name__} not found"): - obj = db.execute(query).unique().scalar_one() + db_model_instance = db.execute(query).unique().scalar_one() - obj = update_model(model=obj, data=json_model.model_dump()) + db_model_instance = update_model( + model=db_model_instance, + data=json_model.model_dump( + exclude=set(nested_relationships) if nested_relationships else None + ), + ) db.flush() - db.refresh(obj) + db.refresh(db_model_instance) + + if nested_relationships: + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=nested_relationships, + project_id=getattr(db_model_instance, "authorized_project_id", None), + action="update", + ) - return response_schema_class.model_validate(obj) + return response_schema_class.model_validate(db_model_instance) def router_user_delete_one[T: BaseModel, I: Identifiable]( @@ -485,44 +503,28 @@ def router_update_activity_one[T: BaseModel, I: Activity]( query = apply_operations(query) with ensure_result(error_message=f"{db_model_class.__name__} not found"): - obj = db.execute(query).unique().scalar_one() + db_model_instance = db.execute(query).unique().scalar_one() update_data = json_model.model_dump( exclude_unset=True, exclude_none=True, - exclude={"used_ids", "generated_ids"}, + exclude=set(NESTED_ACTIVITY_RELATIONSHIPS), exclude_defaults=True, # ignore NOT_SET default values ) for key, value in update_data.items(): - setattr(obj, key, value) - - # ignore NOT_SET values - generated_ids = json_model.generated_ids if json_model.generated_ids != NOT_SET else [] - - if generated_ids: - if obj.generated: - raise HTTPException( - status_code=404, - detail="It is forbidden to update generated_ids if they exist.", - ) - - if user_context and ( - unaccessible_entities := db.execute( - select_unauthorized_entities(generated_ids, user_context.project_id) - ) - .scalars() - .all() - ): - raise HTTPException( - status_code=404, - detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", - ) - - for entity_id in generated_ids: - db.add(Generation(generation_entity_id=entity_id, generation_activity_id=obj.id)) + setattr(db_model_instance, key, value) + + create_associations_to_entities( + db=db, + left=db_model_instance, + json_model=json_model, + nested_relationships=NESTED_ACTIVITY_RELATIONSHIPS, + project_id=db_model_instance.authorized_project_id, + action="update", + ) db.flush() - db.refresh(obj) + db.refresh(db_model_instance) - return response_schema_class.model_validate(obj) + return response_schema_class.model_validate(db_model_instance) diff --git a/app/queries/constants.py b/app/queries/constants.py new file mode 100644 index 00000000..4429267d --- /dev/null +++ b/app/queries/constants.py @@ -0,0 +1,35 @@ +from app.db.model import CampaignToEntity, Generation, TaskConfigToEntity, Usage +from app.queries.types import NestedRelationships + +NESTED_ACTIVITY_RELATIONSHIPS: NestedRelationships = { + "used_ids": { + "relationship_name": "used", + "db_model_factory": lambda *, left_id, right_id: Usage( + usage_activity_id=left_id, usage_entity_id=right_id + ), + }, + "generated_ids": { + "relationship_name": "generated", + "db_model_factory": lambda *, left_id, right_id: Generation( + generation_activity_id=left_id, generation_entity_id=right_id + ), + }, +} + +NESTED_CAMPAIGN_RELATIONSHIPS: NestedRelationships = { + "input_ids": { + "relationship_name": "input", + "db_model_factory": lambda *, left_id, right_id: CampaignToEntity( + campaign_id=left_id, entity_id=right_id + ), + }, +} + +NESTED_TASK_CONFIG_RELATIONSHIPS: NestedRelationships = { + "input_ids": { + "relationship_name": "input", + "db_model_factory": lambda *, left_id, right_id: TaskConfigToEntity( + task_config_id=left_id, entity_id=right_id + ), + }, +} diff --git a/app/queries/factory.py b/app/queries/factory.py index 923b6dac..1bce3762 100644 --- a/app/queries/factory.py +++ b/app/queries/factory.py @@ -33,6 +33,7 @@ Species, Strain, Subject, + TaskConfig, Usage, ) from app.dependencies.common import FacetQueryParams @@ -92,6 +93,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: ion_channel_model_alias = _get_alias(IonChannelModel, "ion_channel_model") ion_channel_modeling_config_alias = _get_alias(IonChannelModelingConfig) skeletonization_config_alias = _get_alias(SkeletonizationConfig) + task_config_alias = _get_alias(TaskConfig) name_to_facet_query_params: dict[str, FacetQueryParams] = { "agent": { @@ -153,6 +155,10 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: "id": skeletonization_config_alias.id, "label": skeletonization_config_alias.name, }, + "task_config": { + "id": task_config_alias.id, + "label": task_config_alias.name, + }, } filter_joins = { "species": lambda q: q.join(Species, db_model_class.species_id == Species.id), @@ -268,6 +274,10 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T: skeletonization_config_alias, db_model_class.id == skeletonization_config_alias.skeletonization_campaign_id, ), + "task_config": lambda q: q.join( + task_config_alias, + db_model_class.id == task_config_alias.campaign_id, + ), } name_to_facet_query_params = {k: name_to_facet_query_params[k] for k in facet_keys} filter_joins = {k: filter_joins[k] for k in filter_keys} diff --git a/app/queries/types.py b/app/queries/types.py index 902ed202..8e97f31d 100644 --- a/app/queries/types.py +++ b/app/queries/types.py @@ -1,5 +1,6 @@ +import uuid from collections.abc import Callable -from typing import Any, Protocol +from typing import Any, Protocol, TypedDict import sqlalchemy as sa from pydantic import BaseModel @@ -11,3 +12,22 @@ class SupportsModelValidate[T: BaseModel](Protocol): @classmethod def model_validate(cls, obj: Any, *args, **kwargs) -> T: ... + + +class AssociationCallable(Protocol): + """Callable that should accept left_id and right_id and return a valid db model instance.""" + + def __call__(self, *, left_id: uuid.UUID, right_id: uuid.UUID) -> DeclarativeBase: ... + + +class NestedRelationship(TypedDict): + """Nested relationship dict, used for creating relationships in entities and activities.""" + + relationship_name: str # name of the relationship in the db model + db_model_factory: AssociationCallable # callable that should return a new db model instance + + +# mapping relationship_key -> relationship, where: +# - relationship_key is the key used to pass ids in the Create schema of the resource +# - relationship is a dict of type NestedRelationship +NestedRelationships = dict[str, NestedRelationship] diff --git a/app/queries/utils.py b/app/queries/utils.py index 842799e7..8c863313 100644 --- a/app/queries/utils.py +++ b/app/queries/utils.py @@ -1,10 +1,18 @@ import uuid +from itertools import chain +from typing import Literal, cast import sqlalchemy as sa +from fastapi import HTTPException +from pydantic import BaseModel from sqlalchemy.orm import Session -from app.db.model import Person -from app.schemas.auth import UserProfile +from app.config import settings +from app.db.auth import select_unauthorized_entities +from app.db.model import Identifiable, Person +from app.queries.types import NestedRelationships +from app.schemas.auth import UserContext, UserProfile +from app.schemas.utils import NOT_SET from app.utils.uuid import create_uuid @@ -33,3 +41,103 @@ def get_or_create_user_agent(db: Session, user_profile: UserProfile) -> Person: def get_user(db: Session, subject_id: uuid.UUID) -> Person | None: query = sa.select(Person).where(Person.sub_id == subject_id) return db.execute(query).scalars().first() + + +def is_user_authorized_for_deletion( # noqa: PLR0911 + db: Session, user_context: UserContext, obj: Identifiable +) -> bool: + if settings.APP_DISABLE_AUTH: + return True + + # if there is no authorized_project_id it is a global resource + if not (project_id := getattr(obj, "authorized_project_id", None)): + return False + + # Service maintainers may delete public/private entities within their projects. + if user_context.is_service_maintainer: + return project_id in user_context.user_project_ids + + # from here and below public entities cannot be deleted + if obj.authorized_public: # pyright: ignore [reportAttributeAccessIssue] + return False + + # Project admins may delete private entities within their projects + if project_id in user_context.admin_project_ids: + return True + + # Project members may delete only the private entities they themselves created + if project_id in user_context.member_project_ids and ( + db_user := get_user(db, user_context.profile.subject) + ): + return db_user.created_by_id == obj.created_by_id + + return False + + +def create_associations_to_entities( + db: Session, + *, + left: Identifiable, + json_model: BaseModel, + nested_relationships: NestedRelationships, + project_id: uuid.UUID | None, + action: Literal["create", "update"], +) -> None: + """Create association records between the left Identifiable and each of the right_id passed. + + Args: + db: Database session. + left: Identifiable on the left side of the associations. + json_model: Pydantic model of the left resourece. + nested_relationships: Mapping of relationship keys to relationship dicts. + project_id: Optional project ID for authorization checks. + action: create or update the relationships. + + Raises: + HTTPException: If any of the associated entities are not public or not in the same project, + or if trying to update associations when it's not allowed. + """ + # map relationship keys to lists of entity IDs to associate, ignoring NOT_SET values + nested_relationship_ids: dict[str, list[uuid.UUID]] = cast( + "dict[str, list[uuid.UUID]]", + { + relationship_key: ids + for relationship_key in nested_relationships + if (ids := getattr(json_model, relationship_key, NOT_SET)) != NOT_SET + }, + ) + + associated_ids = list(chain.from_iterable(nested_relationship_ids.values())) + + # skip if all the nested_relationship_ids are empty + if not associated_ids: + return + + # the associated entities should be public, or in the same given project + if ( + unaccessible_entities := db.execute( + select_unauthorized_entities(associated_ids, project_id) + ) + .scalars() + .all() + ): + raise HTTPException( + status_code=404, + detail=f"Cannot access entities {', '.join(str(e) for e in unaccessible_entities)}", + ) + + for relationship_key, relationship in nested_relationships.items(): + # ignore empty ids + if not (right_ids := nested_relationship_ids.get(relationship_key)): + continue + if action == "update" and getattr(left, relationship["relationship_name"]): + raise HTTPException( + status_code=409, + detail=f"It is forbidden to update {relationship_key} if they exist.", + ) + factory = relationship["db_model_factory"] + for right_id in right_ids: + db_instance = factory(left_id=left.id, right_id=right_id) + db.add(db_instance) + + db.flush() diff --git a/app/routers/__init__.py b/app/routers/__init__.py index a62c2b03..ea407259 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -15,6 +15,7 @@ brain_region, brain_region_hierarchy, calibration, + campaign, cell_composition, cell_morphology, cell_morphology_protocol, @@ -74,6 +75,9 @@ species, strain, subject, + task_config, + task_config_generation, + task_execution, validation, validation_result, ) @@ -99,6 +103,7 @@ brain_region.router, brain_region_hierarchy.router, calibration.router, + campaign.router, cell_composition.router, cell_morphology.router, cell_morphology_protocol.router, @@ -157,6 +162,9 @@ species.router, strain.router, subject.router, + task_config.router, + task_config_generation.router, + task_execution.router, validation.router, validation_result.router, ] diff --git a/app/routers/campaign.py b/app/routers/campaign.py new file mode 100644 index 00000000..5cf9ffeb --- /dev/null +++ b/app/routers/campaign.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import campaign as service + +ROUTE = "campaign" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/routers/task_config.py b/app/routers/task_config.py new file mode 100644 index 00000000..e70f9239 --- /dev/null +++ b/app/routers/task_config.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import task_config as service + +ROUTE = "task-config" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/routers/task_config_generation.py b/app/routers/task_config_generation.py new file mode 100644 index 00000000..eecf33ac --- /dev/null +++ b/app/routers/task_config_generation.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import task_config_generation as service + +ROUTE = "task-config-generation" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/routers/task_execution.py b/app/routers/task_execution.py new file mode 100644 index 00000000..beb45474 --- /dev/null +++ b/app/routers/task_execution.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter + +from app.routers.admin import router as admin_router +from app.service import task_execution as service + +ROUTE = "task-execution" +router = APIRouter(prefix=f"/{ROUTE}", tags=[ROUTE]) + +read_many = router.get("")(service.read_many) +read_one = router.get("/{id_}")(service.read_one) +create_one = router.post("")(service.create_one) +delete_one = router.delete("/{id_}")(service.delete_one) +update_one = router.patch("/{id_}")(service.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(service.admin_read_one) +admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(service.admin_update_one) diff --git a/app/schemas/campaign.py b/app/schemas/campaign.py new file mode 100644 index 00000000..10229111 --- /dev/null +++ b/app/schemas/campaign.py @@ -0,0 +1,65 @@ +import uuid + +from pydantic import BaseModel, ConfigDict + +from app.db.types import JSON_DICT, TaskType +from app.schemas.agent import CreatedByUpdatedByMixin +from app.schemas.asset import AssetsMixin +from app.schemas.base import ( + AuthorizationMixin, + AuthorizationOptionalPublicMixin, + CreationMixin, + EntityTypeMixin, + IdentifiableMixin, + NameDescriptionMixin, +) +from app.schemas.contribution import ContributionReadWithoutEntityMixin +from app.schemas.entity import NestedEntityRead +from app.schemas.utils import make_update_schema + + +class CampaignBase( + BaseModel, + NameDescriptionMixin, +): + model_config = ConfigDict(from_attributes=True) + task_type: TaskType + scan_parameters: JSON_DICT + + +class CampaignCreate( + CampaignBase, + AuthorizationOptionalPublicMixin, +): + input_ids: list[uuid.UUID] = [] + + +CampaignUserUpdate = make_update_schema( + CampaignCreate, + "CampaignUserUpdate", +) # pyright: ignore [reportInvalidTypeForm] + +CampaignAdminUpdate = make_update_schema( + CampaignCreate, + "CampaignAdminUpdate", + excluded_fields=set(), +) # pyright : ignore [reportInvalidTypeForm] + + +class NestedCampaignRead( + CampaignBase, + EntityTypeMixin, + IdentifiableMixin, +): + pass + + +class CampaignRead( + NestedCampaignRead, + AssetsMixin, + CreatedByUpdatedByMixin, + CreationMixin, + AuthorizationMixin, + ContributionReadWithoutEntityMixin, +): + input: list[NestedEntityRead] diff --git a/app/schemas/task_config.py b/app/schemas/task_config.py new file mode 100644 index 00000000..1871ddd1 --- /dev/null +++ b/app/schemas/task_config.py @@ -0,0 +1,56 @@ +import uuid + +from pydantic import BaseModel, ConfigDict + +from app.db.types import JSON_DICT, TaskType +from app.schemas.agent import CreatedByUpdatedByMixin +from app.schemas.asset import AssetsMixin +from app.schemas.base import ( + AuthorizationMixin, + AuthorizationOptionalPublicMixin, + CreationMixin, + EntityTypeMixin, + IdentifiableMixin, + NameDescriptionMixin, +) +from app.schemas.contribution import ContributionReadWithoutEntityMixin +from app.schemas.entity import NestedEntityRead +from app.schemas.utils import make_update_schema + + +class TaskConfigBase(BaseModel, NameDescriptionMixin): + model_config = ConfigDict(from_attributes=True) + task_type: TaskType + scan_parameters: JSON_DICT + campaign_id: uuid.UUID + + +class TaskConfigCreate(TaskConfigBase, AuthorizationOptionalPublicMixin): + input_ids: list[uuid.UUID] = [] + + +TaskConfigUserUpdate = make_update_schema( + TaskConfigCreate, + "TaskConfigUserUpdate", +) # pyright: ignore [reportInvalidTypeForm] + +TaskConfigAdminUpdate = make_update_schema( + TaskConfigCreate, + "TaskConfigAdminUpdate", + excluded_fields=set(), +) # pyright : ignore [reportInvalidTypeForm] + + +class NestedTaskConfigRead(TaskConfigBase, EntityTypeMixin, IdentifiableMixin): + pass + + +class TaskConfigRead( + NestedTaskConfigRead, + AssetsMixin, + CreatedByUpdatedByMixin, + CreationMixin, + AuthorizationMixin, + ContributionReadWithoutEntityMixin, +): + input: list[NestedEntityRead] diff --git a/app/schemas/task_config_generation.py b/app/schemas/task_config_generation.py new file mode 100644 index 00000000..292bf2ce --- /dev/null +++ b/app/schemas/task_config_generation.py @@ -0,0 +1,21 @@ +from app.schemas.activity import ActivityCreate, ActivityRead, ActivityUpdate +from app.schemas.utils import make_update_schema + + +class TaskConfigGenerationCreate(ActivityCreate): + pass + + +class TaskConfigGenerationRead(ActivityRead): + pass + + +class TaskConfigGenerationUserUpdate(ActivityUpdate): + pass + + +TaskConfigGenerationAdminUpdate = make_update_schema( + TaskConfigGenerationCreate, + "TaskConfigGenerationAdminUpdate", + excluded_fields=set(), +) # pyright : ignore [reportInvalidTypeForm] diff --git a/app/schemas/task_execution.py b/app/schemas/task_execution.py new file mode 100644 index 00000000..91385abd --- /dev/null +++ b/app/schemas/task_execution.py @@ -0,0 +1,26 @@ +from app.schemas.activity import ( + ActivityCreate, + ActivityRead, + ActivityUpdate, + ExecutionActivityMixin, +) +from app.schemas.utils import make_update_schema + + +class TaskExecutionCreate(ActivityCreate, ExecutionActivityMixin): + pass + + +class TaskExecutionRead(ActivityRead, ExecutionActivityMixin): + pass + + +class TaskExecutionUserUpdate(ActivityUpdate, ExecutionActivityMixin): + pass + + +TaskExecutionAdminUpdate = make_update_schema( + TaskExecutionCreate, + "TaskExecutionAdminUpdate", + excluded_fields=set(), +) # pyright : ignore [reportInvalidTypeForm] diff --git a/app/service/asset.py b/app/service/asset.py index 5515deab..13b48351 100644 --- a/app/service/asset.py +++ b/app/service/asset.py @@ -8,7 +8,6 @@ from types_boto3_s3 import S3Client from app.config import StorageUnion, storages -from app.db.auth import is_user_authorized_for_deletion from app.db.model import Asset, Entity from app.db.types import AssetLabel, AssetStatus, ContentType, EntityType, StorageType from app.dependencies.common import PaginationQuery @@ -16,6 +15,7 @@ from app.filters.asset import AssetFilterDep from app.queries import crud from app.queries.common import get_or_create_user_agent, router_read_many +from app.queries.utils import is_user_authorized_for_deletion from app.repository.group import RepositoryGroup from app.schemas.asset import ( AssetCreate, diff --git a/app/service/campaign.py b/app/service/campaign.py new file mode 100644 index 00000000..6839962e --- /dev/null +++ b/app/service/campaign.py @@ -0,0 +1,204 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload, selectinload + +from app.db.model import ( + Agent, + Campaign, + Person, + TaskConfig, +) +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.campaign import CampaignFilterDep +from app.queries.common import ( + router_create_one, + router_read_many, + router_read_one, + router_update_one, + router_user_delete_one, +) +from app.queries.constants import NESTED_CAMPAIGN_RELATIONSHIPS +from app.queries.factory import query_params_factory +from app.schemas.campaign import ( + CampaignAdminUpdate, + CampaignCreate, + CampaignRead, + CampaignUserUpdate, +) +from app.schemas.routers import DeleteResponse +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + + +DBModel = Campaign +ReadSchema = CampaignRead +CreateSchema = CampaignCreate +UserUpdateSchema = CampaignUserUpdate +AdminUpdateSchema = CampaignAdminUpdate +FilterDep = CampaignFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + selectinload(DBModel.assets), + selectinload(DBModel.contributions), + selectinload(DBModel.input), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_CAMPAIGN_RELATIONSHIPS, + ) + + +def update_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_CAMPAIGN_RELATIONSHIPS, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=None, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_CAMPAIGN_RELATIONSHIPS, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + agent_alias = aliased(Agent, flat=True) + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + task_config_alias = aliased(TaskConfig, flat=True) + aliases: Aliases = { + Agent: { + "contribution": agent_alias, + }, + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + TaskConfig: task_config_alias, + } + facet_keys = filter_keys = [ + "created_by", + "updated_by", + "contribution", + "task_config", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) diff --git a/app/service/task_config.py b/app/service/task_config.py new file mode 100644 index 00000000..ac3cfd4f --- /dev/null +++ b/app/service/task_config.py @@ -0,0 +1,200 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload, selectinload + +from app.db.model import ( + Agent, + Person, + TaskConfig, +) +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.task_config import TaskConfigFilterDep +from app.queries.common import ( + router_create_one, + router_read_many, + router_read_one, + router_update_one, + router_user_delete_one, +) +from app.queries.constants import NESTED_TASK_CONFIG_RELATIONSHIPS +from app.queries.factory import query_params_factory +from app.schemas.routers import DeleteResponse +from app.schemas.task_config import ( + TaskConfigAdminUpdate, + TaskConfigCreate, + TaskConfigRead, + TaskConfigUserUpdate, +) +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + +DBModel = TaskConfig +ReadSchema = TaskConfigRead +CreateSchema = TaskConfigCreate +UserUpdateSchema = TaskConfigUserUpdate +AdminUpdateSchema = TaskConfigAdminUpdate +FilterDep = TaskConfigFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + selectinload(DBModel.assets), + selectinload(DBModel.contributions), + selectinload(DBModel.input), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def update_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=None, + json_model=json_model, + response_schema_class=ReadSchema, + apply_operations=_load, + nested_relationships=NESTED_TASK_CONFIG_RELATIONSHIPS, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + agent_alias = aliased(Agent, flat=True) + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + + aliases: Aliases = { + Agent: { + "contribution": agent_alias, + }, + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + } + facet_keys = filter_keys = [ + "created_by", + "updated_by", + "contribution", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) diff --git a/app/service/task_config_generation.py b/app/service/task_config_generation.py new file mode 100644 index 00000000..5116551e --- /dev/null +++ b/app/service/task_config_generation.py @@ -0,0 +1,197 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload + +from app.db.model import Entity, Person, TaskConfigGeneration +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.task_config_generation import ( + TaskConfigGenerationFilterDep, +) +from app.queries.common import ( + router_create_activity_one, + router_read_many, + router_read_one, + router_update_activity_one, + router_user_delete_one, +) +from app.queries.factory import query_params_factory +from app.schemas.routers import DeleteResponse +from app.schemas.task_config_generation import ( + TaskConfigGenerationAdminUpdate, + TaskConfigGenerationCreate, + TaskConfigGenerationRead, + TaskConfigGenerationUserUpdate, +) +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + +DBModel = TaskConfigGeneration +ReadSchema = TaskConfigGenerationRead +CreateSchema = TaskConfigGenerationCreate +UserUpdateSchema = TaskConfigGenerationUserUpdate +AdminUpdateSchema = TaskConfigGenerationAdminUpdate +FilterDep = TaskConfigGenerationFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.used), + joinedload(DBModel.generated), + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_activity_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + used_alias = aliased(Entity, flat=True) + generated_alias = aliased(Entity, flat=True) + + aliases: Aliases = { + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + Entity: { + "used": used_alias, + "generated": generated_alias, + }, + } + facet_keys = [] + filter_keys = [ + "created_by", + "updated_by", + "used", + "generated", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) + + +def update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] + user_context: UserContextDep, +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=None, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) diff --git a/app/service/task_execution.py b/app/service/task_execution.py new file mode 100644 index 00000000..410d7ff3 --- /dev/null +++ b/app/service/task_execution.py @@ -0,0 +1,196 @@ +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import aliased, joinedload, raiseload + +from app.db.model import Entity, Person, TaskExecution +from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep +from app.dependencies.common import ( + FacetsDep, + PaginationQuery, + SearchDep, +) +from app.dependencies.db import SessionDep +from app.filters.task_execution import TaskExecutionFilterDep +from app.queries.common import ( + router_create_activity_one, + router_read_many, + router_read_one, + router_update_activity_one, + router_user_delete_one, +) +from app.queries.factory import query_params_factory +from app.schemas.routers import DeleteResponse +from app.schemas.task_execution import ( + TaskExecutionAdminUpdate, + TaskExecutionCreate, + TaskExecutionRead, + TaskExecutionUserUpdate, +) +from app.schemas.types import ListResponse + +if TYPE_CHECKING: + from app.filters.base import Aliases + + +DBModel = TaskExecution +ReadSchema = TaskExecutionRead +CreateSchema = TaskExecutionCreate +UserUpdateSchema = TaskExecutionUserUpdate +AdminUpdateSchema = TaskExecutionAdminUpdate +FilterDep = TaskExecutionFilterDep + + +def _load(query: sa.Select): + return query.options( + joinedload(DBModel.used), + joinedload(DBModel.generated), + joinedload(DBModel.created_by), + joinedload(DBModel.updated_by), + raiseload("*"), + ) + + +def read_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=user_context, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> ReadSchema: + return router_read_one( + db=db, + id_=id_, + db_model_class=DBModel, + user_context=None, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def create_one( + db: SessionDep, + json_model: CreateSchema, + user_context: UserContextWithProjectIdDep, +) -> ReadSchema: + return router_create_activity_one( + db=db, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def read_many( + user_context: UserContextDep, + db: SessionDep, + pagination_request: PaginationQuery, + filter_model: FilterDep, + with_search: SearchDep, + facets: FacetsDep, +) -> ListResponse[ReadSchema]: + created_by_alias = aliased(Person, flat=True) + updated_by_alias = aliased(Person, flat=True) + used_alias = aliased(Entity, flat=True) + generated_alias = aliased(Entity, flat=True) + + aliases: Aliases = { + Person: { + "created_by": created_by_alias, + "updated_by": updated_by_alias, + }, + Entity: { + "used": used_alias, + "generated": generated_alias, + }, + } + facet_keys = [] + filter_keys = [ + "created_by", + "updated_by", + "used", + "generated", + ] + name_to_facet_query_params, filter_joins = query_params_factory( + db_model_class=DBModel, + facet_keys=facet_keys, + filter_keys=filter_keys, + aliases=aliases, + ) + return router_read_many( + db=db, + filter_model=filter_model, + db_model_class=DBModel, + with_search=with_search, + with_in_brain_region=None, + facets=facets, + name_to_facet_query_params=name_to_facet_query_params, + apply_filter_query_operations=None, + apply_data_query_operations=_load, + aliases=aliases, + pagination_request=pagination_request, + response_schema_class=ReadSchema, + authorized_project_id=user_context.project_id, + filter_joins=filter_joins, + ) + + +def delete_one( + user_context: UserContextDep, + db: SessionDep, + id_: uuid.UUID, +) -> DeleteResponse: + return router_user_delete_one( + id_=id_, + db=db, + db_model_class=DBModel, + user_context=user_context, + ) + + +def update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: UserUpdateSchema, # pyright: ignore [reportInvalidTypeForm] + user_context: UserContextDep, +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=user_context, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) + + +def admin_update_one( + db: SessionDep, + id_: uuid.UUID, + json_model: AdminUpdateSchema, # pyright: ignore [reportInvalidTypeForm] +) -> ReadSchema: + return router_update_activity_one( + db=db, + id_=id_, + json_model=json_model, + user_context=None, + db_model_class=DBModel, + response_schema_class=ReadSchema, + apply_operations=_load, + ) diff --git a/docs/campaign-models-diagram.md b/docs/campaign-models-diagram.md new file mode 100644 index 00000000..d6fc2fb9 --- /dev/null +++ b/docs/campaign-models-diagram.md @@ -0,0 +1,136 @@ +# Campaign Models Relationship Diagram + +```mermaid +flowchart TD + subgraph Input["Input Entities"] + EIn1[Entity - Campaign Input] + EIn2[Entity - TaskConfig Input] + end + + subgraph Campaign_Layer["Campaign Layer"] + C[Campaign] + TC[TaskConfig] + end + + subgraph Activities + CG[TaskConfigGeneration] + TE[TaskExecution] + end + + subgraph Output["Output Entities"] + EOut[Entity - Output] + end + + EIn1 -->|CampaignToEntity| C + EIn2 -->|TaskConfigToEntity| TC + C -->|campaign_id FK| TC + + C -->|Usage| CG + CG -->|Generation| TC + + TC -->|Usage| TE + TE -->|Generation| EOut + + style C fill:#e1f5ff + style TC fill:#e1f5ff + style EIn1 fill:#d4edda + style EIn2 fill:#d4edda + style EOut fill:#f8d7da + style CG fill:#fff4e1 + style TE fill:#fff4e1 +``` + +## ER Diagram + +```mermaid +erDiagram + Campaign_Entity_Input ||--o{ CampaignToEntity : "input" + Campaign ||--o{ CampaignToEntity : "has" + Campaign ||--o{ TaskConfig : "campaign_id" + + TaskConfig_Entity_Input ||--o{ TaskConfigToEntity : "input" + TaskConfig ||--o{ TaskConfigToEntity : "has" + + Campaign ||--o{ Usage_CG : "used by TaskConfigGeneration" + Usage_CG }o--|| TaskConfigGeneration : "" + TaskConfigGeneration ||--o{ Generation_CG : "" + Generation_CG }o--|| TaskConfig : "generated by TaskConfigGeneration" + + TaskConfig ||--o{ Usage_TE : "used by TaskExecution" + Usage_TE }o--|| TaskExecution : "" + TaskExecution ||--o{ Generation_TE : "" + Generation_TE }o--|| Entity_Output : "generated by TaskExecution" + + Entity { + uuid id PK + } + + CampaignToEntity { + uuid entity_id PK,FK + uuid campaign_id PK,FK + } + + TaskConfigToEntity { + uuid entity_id PK,FK + uuid task_config_id PK,FK + } + + Campaign { + uuid id PK,FK + JSON_DICT scan_parameters + TaskType task_type + } + + Usage { + uuid usage_entity_id PK,FK + uuid usage_activity_id PK,FK + } + + Generation { + uuid generation_entity_id PK,FK + uuid generation_activity_id PK,FK + } + + TaskConfig { + uuid id PK,FK + JSON_DICT scan_parameters + TaskType task_type + uuid campaign_id FK + } + + TaskConfigGeneration { + uuid id PK,FK + } + + TaskExecution { + uuid id PK,FK + } +``` + +## Relationships Explained + +### Flowchart (Primary) +Shows the workflow clearly: +- **Green boxes**: Input entities + - Campaign input (linked via CampaignToEntity) + - TaskConfig input (linked via TaskConfigToEntity) +- **Blue boxes**: Campaign entities (Campaign and TaskConfig) +- **Yellow boxes**: Activities (processes) +- **Red box**: Output entities (generated by TaskExecution) +- Input and output entities are different instances, though all are Entity type + +### ER Diagram (Detailed) +Shows the same relationships with Usage and Generation tables split by activity: +- **CampaignToEntity**: Junction table linking input entities to Campaign +- **TaskConfigToEntity**: Junction table linking input entities to TaskConfig +- **Usage_CG**: Usage records for TaskConfigGeneration activity +- **Generation_CG**: Generation records for TaskConfigGeneration activity +- **Usage_TE**: Usage records for TaskExecution activity +- **Generation_TE**: Generation records for TaskExecution activity + +Notes: +- Campaign has many input entities (via CampaignToEntity) +- TaskConfig has many input entities (via TaskConfigToEntity) +- One Campaign is used by TaskConfigGeneration to generate many TaskConfig +- One TaskConfig can be used by many TaskExecution, each generating many Entity +- Activities (TaskConfigGeneration, TaskExecution) do not have task_type field diff --git a/scripts/export/build_database_archive.sh b/scripts/export/build_database_archive.sh index 5c92d897..1d3af03f 100755 --- a/scripts/export/build_database_archive.sh +++ b/scripts/export/build_database_archive.sh @@ -2,7 +2,7 @@ # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="523e523531a7" +SCRIPT_DB_VERSION="b21b11d836a8" echo "DB dump (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" @@ -105,6 +105,10 @@ SET TRANSACTION READ ONLY; \copy (SELECT t0.* FROM brain_region_hierarchy AS t0 WHERE TRUE) TO '$DATA_DIR/brain_region_hierarchy.csv' WITH CSV HEADER; \echo Dumping table calibration \copy (SELECT t0.* FROM calibration AS t0 JOIN activity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/calibration.csv' WITH CSV HEADER; +\echo Dumping table campaign +\copy (SELECT t0.* FROM campaign AS t0 JOIN entity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/campaign.csv' WITH CSV HEADER; +\echo Dumping table campaign__entity +\copy (SELECT t0.* FROM campaign__entity AS t0 JOIN entity AS t1 ON t1.id=t0.campaign_id JOIN entity AS t2 ON t2.id=t0.entity_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/campaign__entity.csv' WITH CSV HEADER; \echo Dumping table cell_composition \copy (SELECT t0.* FROM cell_composition AS t0 JOIN entity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/cell_composition.csv' WITH CSV HEADER; \echo Dumping table cell_morphology @@ -245,6 +249,14 @@ SET TRANSACTION READ ONLY; \copy (SELECT t0.* FROM strain AS t0 WHERE TRUE) TO '$DATA_DIR/strain.csv' WITH CSV HEADER; \echo Dumping table subject \copy (SELECT t0.* FROM subject AS t0 JOIN entity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/subject.csv' WITH CSV HEADER; +\echo Dumping table task_config +\copy (SELECT t0.* FROM task_config AS t0 JOIN entity AS t1 ON t1.id=t0.id JOIN entity AS t2 ON t2.id=t0.campaign_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/task_config.csv' WITH CSV HEADER; +\echo Dumping table task_config__entity +\copy (SELECT t0.* FROM task_config__entity AS t0 JOIN entity AS t1 ON t1.id=t0.task_config_id JOIN entity AS t2 ON t2.id=t0.entity_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/task_config__entity.csv' WITH CSV HEADER; +\echo Dumping table task_config_generation +\copy (SELECT t0.* FROM task_config_generation AS t0 JOIN activity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/task_config_generation.csv' WITH CSV HEADER; +\echo Dumping table task_execution +\copy (SELECT t0.* FROM task_execution AS t0 JOIN activity AS t1 ON t1.id=t0.id WHERE t1.authorized_public IS NOT false) TO '$DATA_DIR/task_execution.csv' WITH CSV HEADER; \echo Dumping table usage \copy (SELECT t0.* FROM usage AS t0 JOIN entity AS t1 ON t1.id=t0.usage_entity_id JOIN activity AS t2 ON t2.id=t0.usage_activity_id WHERE t1.authorized_public IS NOT false AND t2.authorized_public IS NOT false) TO '$DATA_DIR/usage.csv' WITH CSV HEADER; \echo Dumping table validation @@ -263,7 +275,7 @@ install -m 755 /dev/stdin "$WORK_DIR/load.sh" <<'EOF_LOAD_SCRIPT' # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="523e523531a7" +SCRIPT_DB_VERSION="b21b11d836a8" echo "DB load (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" diff --git a/tests/conftest.py b/tests/conftest.py index b7b7eb4f..af698216 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1919,3 +1919,68 @@ def skeletonization_config_id(client, skeletonization_config_json_data): json=skeletonization_config_json_data, ).json() return data["id"] + + +@pytest.fixture +def campaign_json_data(): + return { + "name": "campaign", + "description": "campaign-description", + "scan_parameters": {"foo": "bar"}, + "task_type": "skeletonization", + } + + +@pytest.fixture +def campaign_with_nested_relationships_json_data(campaign_json_data, em_cell_mesh): + return campaign_json_data | { + "input_ids": [str(em_cell_mesh.id)], + } + + +@pytest.fixture +def campaign_id(client, campaign_json_data): + data = assert_request( + client.post, + url="/campaign", + json=campaign_json_data | {"authorized_public": False}, + ).json() + return data["id"] + + +@pytest.fixture +def public_campaign_id(client, campaign_json_data): + data = assert_request( + client.post, + url="/campaign", + json=campaign_json_data | {"authorized_public": True}, + ).json() + return data["id"] + + +@pytest.fixture +def task_config_json_data(public_campaign_id): + return { + "name": "task-config", + "description": "task-config-description", + "campaign_id": public_campaign_id, + "scan_parameters": {"foo": "bar"}, + "task_type": "skeletonization", + } + + +@pytest.fixture +def task_config_with_nested_relationships_json_data(task_config_json_data, em_cell_mesh): + return task_config_json_data | { + "input_ids": [str(em_cell_mesh.id)], + } + + +@pytest.fixture +def task_config_id(client, task_config_json_data): + data = assert_request( + client.post, + url="/task-config", + json=task_config_json_data, + ).json() + return data["id"] diff --git a/tests/test_campaign.py b/tests/test_campaign.py new file mode 100644 index 00000000..0190bf13 --- /dev/null +++ b/tests/test_campaign.py @@ -0,0 +1,249 @@ +import pytest + +from app.db.model import Campaign, TaskConfig +from app.db.types import EntityType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + add_all_db, + add_db, + assert_request, + check_authorization, + check_entity_delete_one, + check_entity_read_response, + check_entity_update_one, + check_entity_update_one__fail_if_nested_ids_exists, + check_entity_update_one__fail_if_nested_ids_unauthorized, + check_missing, + check_pagination, + create_cell_morphology_id, +) + +ROUTE = "campaign" +ADMIN_ROUTE = "/admin/campaign" + + +@pytest.fixture +def json_data(campaign_json_data): + return campaign_json_data + + +@pytest.fixture +def public_json_data(json_data): + return json_data | {"authorized_public": True} + + +@pytest.fixture +def model_id(campaign_id): + return campaign_id + + +@pytest.fixture +def create_id(client, json_data): + def _create_id(**kwargs): + return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json()["id"] + + return _create_id + + +def _assert_read_response(data, json_data): + check_entity_read_response(data, json_data, EntityType.campaign) + assert "input" in data + assert data["scan_parameters"] == json_data["scan_parameters"] + + +def test_create_one(client, json_data): + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_create_one_with_nested_relationships(client, campaign_with_nested_relationships_json_data): + json_data = campaign_with_nested_relationships_json_data + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + input_ids = json_data["input_ids"] + assert data["input"] == [ + { + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + "id": input_ids[0], + "type": "em_cell_mesh", + }, + ] + + +def test_read_one(clients, model_id, json_data): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=f"{ROUTE}").json()["data"] + assert len(data) == 1 + _assert_read_response(data[0], json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_delete_one(db, clients, public_json_data): + check_entity_delete_one( + db=db, + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=public_json_data, + expected_counts_before={ + Campaign: 1, + }, + expected_counts_after={ + Campaign: 0, + }, + ) + + +def test_update_one(clients, json_data): + check_entity_update_one( + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=json_data, + patch_payload={ + "name": "name", + "description": "description", + }, + optional_payload=None, + ) + + +def test_update_one__fail_if_nested_ids_unauthorized( + db, client_user_1, client_user_2, json_data, subject_id, brain_region_id +): + """Test that it is not allowed to update the nested ids with unauthorized entities.""" + + user2_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u2_private_entity_id=user2_generated_id, + relationship_key="input_ids", + ) + + +def test_update_one__fail_if_nested_ids_exists( + db, client_user_1, json_data, subject_id, brain_region_id +): + """Test that nested ids cannot be updated if they already exist.""" + user1_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_exists( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_generated_id, + relationship_key="input_ids", + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_authorization(clients, public_json_data): + check_authorization(ROUTE, clients.user_1, clients.user_2, clients.no_project, public_json_data) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def models(db, json_data, person_id): + db_campaigns = add_all_db( + db, + [ + Campaign( + **( + json_data + | { + "name": f"campaign-{i}", + "description": f"campaign-description-{i}", + "scan_parameters": {"foo": "bar"}, + "created_by_id": person_id, + "updated_by_id": person_id, + "authorized_project_id": PROJECT_ID, + } + ) + ) + for i in range(4) + ], + ) + + mapping = [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + ] + + for i, campaign in enumerate(db_campaigns): + for j in mapping[i]: + add_db( + db, + TaskConfig( + task_type="skeletonization", + name=f"config-{j}", + description=f"config-{j}", + campaign_id=campaign.id, + scan_parameters=campaign.scan_parameters, + created_by_id=person_id, + updated_by_id=person_id, + authorized_project_id=PROJECT_ID, + ), + ) + + return db_campaigns + + +def test_filtering_ordering(client, models): + def _req(query): + return assert_request(client.get, url=ROUTE, params=query).json()["data"] + + data = _req({}) + assert len(data) == len(models) + + data = _req({"name__ilike": "campaign"}) + assert len(data) == len(models) + + data = _req({"name": "campaign-1"}) + assert len(data) == 1 + assert data[0]["name"] == "campaign-1" + + data = _req({"task_config__name": "config-2"}) + assert {d["name"] for d in data} == {"campaign-1", "campaign-2"} + + data = _req({"order_by": "-name"}) + assert [d["name"] for d in data] == [f"campaign-{i}" for i in range(4)][::-1] + + data = _req({"task_config__name": "config-2", "order_by": "name"}) + assert [d["name"] for d in data] == ["campaign-1", "campaign-2"] + + data = _req({"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}) + assert len(data) == len(models) + + data = _req({"ilike_search": "*description*"}) + assert len(data) == len(models) + + data = _req({"ilike_search": "campaign-1"}) + assert len(data) == 1 diff --git a/tests/test_task_config.py b/tests/test_task_config.py new file mode 100644 index 00000000..0a5199a4 --- /dev/null +++ b/tests/test_task_config.py @@ -0,0 +1,210 @@ +import pytest + +from app.db.model import TaskConfig +from app.db.types import EntityType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + assert_request, + check_authorization, + check_entity_delete_one, + check_entity_read_response, + check_entity_update_one, + check_entity_update_one__fail_if_nested_ids_exists, + check_entity_update_one__fail_if_nested_ids_unauthorized, + check_missing, + check_pagination, + create_cell_morphology_id, +) + +ROUTE = "task-config" +ADMIN_ROUTE = "/admin/task-config" + + +@pytest.fixture +def json_data(task_config_json_data): + return task_config_json_data + + +@pytest.fixture +def public_json_data(json_data): + return json_data | {"authorized_public": True} + + +@pytest.fixture +def create_id(client, json_data): + def _create_id(**kwargs): + return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json()["id"] + + return _create_id + + +@pytest.fixture +def model_id(task_config_id): + return task_config_id + + +def _assert_read_response(data, json_data): + check_entity_read_response(data, json_data, EntityType.task_config) + assert "input" in data + assert data["campaign_id"] + assert data["scan_parameters"] == json_data["scan_parameters"] + + +def test_create_one(client, json_data): + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_create_one_with_nested_relationships( + client, task_config_with_nested_relationships_json_data +): + json_data = task_config_with_nested_relationships_json_data + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + input_ids = json_data["input_ids"] + assert data["input"] == [ + { + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + "id": input_ids[0], + "type": "em_cell_mesh", + }, + ] + + +def test_read_one(clients, model_id, json_data): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=f"{ROUTE}").json()["data"] + assert len(data) == 1 + _assert_read_response(data[0], json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_delete_one(db, clients, public_json_data): + check_entity_delete_one( + db=db, + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=public_json_data, + expected_counts_before={ + TaskConfig: 1, + }, + expected_counts_after={ + TaskConfig: 0, + }, + ) + + +def test_update_one(clients, json_data): + check_entity_update_one( + route=ROUTE, + admin_route=ADMIN_ROUTE, + clients=clients, + json_data=json_data, + patch_payload={ + "name": "name", + "description": "description", + }, + optional_payload=None, + ) + + +def test_update_one__fail_if_nested_ids_unauthorized( + db, client_user_1, client_user_2, json_data, subject_id, brain_region_id +): + """Test that it is not allowed to update the nested ids with unauthorized entities.""" + + user2_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u2_private_entity_id=user2_generated_id, + relationship_key="input_ids", + ) + + +def test_update_one__fail_if_nested_ids_exists( + db, client_user_1, json_data, subject_id, brain_region_id +): + """Test that nested ids cannot be updated if they already exist.""" + user1_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_entity_update_one__fail_if_nested_ids_exists( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_generated_id, + relationship_key="input_ids", + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_authorization(clients, public_json_data): + check_authorization(ROUTE, clients.user_1, clients.user_2, clients.no_project, public_json_data) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def models(create_id): + return [create_id(name=f"config-{i}") for i in range(3)] + + +def test_filtering_ordering(client, models, public_campaign_id): + def _req(query): + return assert_request(client.get, url=ROUTE, params=query).json()["data"] + + data = _req({}) + assert len(data) == len(models) + + data = _req({"name__ilike": "config"}) + assert len(data) == len(models) + + data = _req({"name": "config-0"}) + assert len(data) == 1 + assert data[0]["name"] == "config-0" + + data = _req({"campaign_id": public_campaign_id}) + assert len(data) == len(models) + + data = _req({"campaign_id__in": [public_campaign_id]}) + assert len(data) == len(models) + + data = _req({"order_by": "-name"}) + assert [d["name"] for d in data] == ["config-2", "config-1", "config-0"] + + data = _req({"order_by": "-name", "name__in": ["config-1", "config-2"]}) + assert [d["name"] for d in data] == ["config-2", "config-1"] + + data = _req({"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}) + assert len(data) == len(models) + + data = _req({"ilike_search": "*description*"}) + assert len(data) == len(models) + + data = _req({"ilike_search": "config-1"}) + assert len(data) == 1 diff --git a/tests/test_task_config_generation.py b/tests/test_task_config_generation.py new file mode 100644 index 00000000..344ed08f --- /dev/null +++ b/tests/test_task_config_generation.py @@ -0,0 +1,342 @@ +from datetime import UTC, datetime + +import pytest +from pydantic import TypeAdapter + +from app.db.model import ( + Campaign, + Generation, + TaskConfig, + TaskConfigGeneration, + Usage, +) +from app.db.types import ActivityType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + assert_request, + check_activity_create_one__unauthorized_entities, + check_activity_delete_one, + check_activity_update_one, + check_activity_update_one__fail_if_generated_ids_exists, + check_activity_update_one__fail_if_generated_ids_unauthorized, + check_creation_fields, + check_missing, + check_pagination, + create_campaign_id, + create_task_config_id, +) + +DateTimeAdapter = TypeAdapter(datetime) + +ROUTE = "task-config-generation" +ADMIN_ROUTE = "/admin/task-config-generation" +TASK_TYPE = "skeletonization" + + +@pytest.fixture +def json_data(campaign_id, task_config_id): + return { + "start_time": str(datetime.now(UTC)), + "end_time": str(datetime.now(UTC)), + "used_ids": [campaign_id], + "generated_ids": [task_config_id], + } + + +@pytest.fixture +def create_id(client, json_data): + def _create_id(**kwargs): + return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json()["id"] + + return _create_id + + +@pytest.fixture +def model_id(create_id): + return create_id() + + +def _assert_read_response(data, json_data, *, empty_ids=False): + assert "id" in data + + if empty_ids: + assert len(data["used"]) == 0 + assert len(data["generated"]) == 0 + else: + assert data["used"] == [ + { + "id": json_data["used_ids"][0], + "type": "campaign", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + assert data["generated"] == [ + { + "id": json_data["generated_ids"][0], + "type": "task_config", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + check_creation_fields(data) + assert DateTimeAdapter.validate_python(data["start_time"]) == DateTimeAdapter.validate_python( + json_data["start_time"] + ) + assert DateTimeAdapter.validate_python(data["end_time"]) == DateTimeAdapter.validate_python( + json_data["end_time"] + ) + assert data["type"] == ActivityType.task_config_generation.value + + +def test_create_one(clients, json_data): + data = assert_request(clients.user_1.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_read_one(clients, json_data, model_id): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_create_one__empty_ids(client, client_admin, json_data): + json_data = {k: v for k, v in json_data.items() if k not in {"used_ids", "generated_ids"}} + + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=f"{ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client_admin.get, url=f"{ADMIN_ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + +def test_create_one__unauthorized_entities(db, client_user_1, client_user_2, json_data): + """Do not allow associations with entities that are not authorized to the user.""" + + user1_private_used_id = create_campaign_id( + client_user_1, + authorized_public=False, + task_type=TASK_TYPE, + ) + user2_private_used_id = create_campaign_id( + client_user_2, + authorized_public=False, + task_type=TASK_TYPE, + ) + user2_public_used_id = create_campaign_id( + client_user_2, + authorized_public=True, + task_type=TASK_TYPE, + ) + check_activity_create_one__unauthorized_entities( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_private_used_id, + u2_private_entity_id=user2_private_used_id, + u2_public_entity_id=user2_public_used_id, + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def task_config_id_1(client, campaign_id): + return create_task_config_id( + client=client, + campaign_id=campaign_id, + task_type=TASK_TYPE, + ) + + +@pytest.fixture +def task_config_id_2(client, campaign_id): + return create_task_config_id( + client=client, + campaign_id=campaign_id, + task_type=TASK_TYPE, + ) + + +@pytest.fixture +def models( + create_id, + campaign_id, + task_config_id_1, + task_config_id_2, +): + return [ + create_id( + used_ids=[campaign_id], + generated_ids=[], + ), + create_id( + used_ids=[campaign_id], + generated_ids=[task_config_id_1], + ), + create_id( + used_ids=[campaign_id], + generated_ids=[ + task_config_id_1, + task_config_id_2, + ], + ), + create_id( + used_ids=[], + generated_ids=[], + ), + ] + + +def test_filtering( + client, + models, + campaign_id, + task_config_id_1, + task_config_id_2, +): + data = assert_request(client.get, url=ROUTE).json()["data"] + assert len(data) == len(models) + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id": campaign_id}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id": task_config_id_1}, + ).json()["data"] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={ + "used__id": campaign_id, + "generated__id": task_config_id_2, + }, + ).json()["data"] + assert len(data) == 1 + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id__in": [campaign_id]}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": [task_config_id_2]}, + ).json()["data"] + assert len(data) == 1 + + data = assert_request( + client.get, + url=ROUTE, + params={"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}, + ).json()["data"] + assert len(data) == len(models) + + +def test_delete_one(db, clients, json_data): + check_activity_delete_one( + db=db, + clients=clients, + json_data=json_data, + route=ROUTE, + admin_route=ADMIN_ROUTE, + expected_counts_before={ + Campaign: 2, + TaskConfig: 1, + Usage: 1, + Generation: 1, + TaskConfigGeneration: 1, + }, + expected_counts_after={ + Campaign: 2, + TaskConfig: 1, + Usage: 0, + Generation: 0, + TaskConfigGeneration: 0, + }, + ) + + +def test_update_one( + client, + client_admin, + campaign_id, + task_config_id, + create_id, +): + check_activity_update_one( + client=client, + client_admin=client_admin, + route=ROUTE, + admin_route=ADMIN_ROUTE, + used_id=campaign_id, + generated_id=task_config_id, + constructor_func=create_id, + ) + + +def test_update_one__fail_if_generated_ids_unauthorized( + db, client_user_1, client_user_2, json_data +): + """Test that it is not allowed to update generated_ids with unauthorized entities.""" + + user1_private_used_id = create_campaign_id( + client_user_1, + authorized_public=False, + task_type=TASK_TYPE, + ) + user2_private_used_id = create_campaign_id( + client_user_2, + authorized_public=False, + task_type=TASK_TYPE, + ) + check_activity_update_one__fail_if_generated_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_private_used_id, + u2_private_entity_id=user2_private_used_id, + ) + + +def test_update_one__fail_if_generated_ids_exists(client, campaign_id, task_config_id, create_id): + """Test activity Generation associations cannot be updated if they already exist.""" + check_activity_update_one__fail_if_generated_ids_exists( + client=client, + route=ROUTE, + entity_id_1=campaign_id, + entity_id_2=task_config_id, + constructor_func=create_id, + ) diff --git a/tests/test_task_execution.py b/tests/test_task_execution.py new file mode 100644 index 00000000..a673b17f --- /dev/null +++ b/tests/test_task_execution.py @@ -0,0 +1,349 @@ +from datetime import UTC, datetime + +import pytest +from pydantic import TypeAdapter + +from app.db.model import ( + Generation, + TaskConfig, + TaskExecution, + Usage, +) +from app.db.types import ActivityType, ExecutorType + +from .utils import ( + PROJECT_ID, + USER_SUB_ID_1, + assert_request, + check_activity_create_one__unauthorized_entities, + check_activity_delete_one, + check_activity_update_one, + check_activity_update_one__fail_if_generated_ids_exists, + check_activity_update_one__fail_if_generated_ids_unauthorized, + check_creation_fields, + check_missing, + check_pagination, + create_cell_morphology_id, +) + +DateTimeAdapter = TypeAdapter(datetime) + +ROUTE = "task-execution" +ADMIN_ROUTE = "/admin/task-execution" + + +@pytest.fixture +def json_data(task_config_id, morphology_id): + return { + "start_time": str(datetime.now(UTC)), + "end_time": str(datetime.now(UTC)), + "used_ids": [str(task_config_id)], + "generated_ids": [str(morphology_id)], + "status": "done", + "executor": str(ExecutorType.single_node_job), + "execution_id": "1739b817-26bb-4dad-93f4-0279a1b2cf6e", + } + + +@pytest.fixture +def create_id(client, json_data): + def _create_id(**kwargs): + return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json()["id"] + + return _create_id + + +@pytest.fixture +def model_id(create_id): + return create_id() + + +def _assert_read_response(data, json_data, *, empty_ids=False): + assert "id" in data + + if empty_ids: + assert len(data["used"]) == 0 + assert len(data["generated"]) == 0 + else: + assert data["used"] == [ + { + "id": json_data["used_ids"][0], + "type": "task_config", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + assert data["generated"] == [ + { + "id": json_data["generated_ids"][0], + "type": "cell_morphology", + "authorized_project_id": PROJECT_ID, + "authorized_public": False, + } + ] + check_creation_fields(data) + assert DateTimeAdapter.validate_python(data["start_time"]) == DateTimeAdapter.validate_python( + json_data["start_time"] + ) + assert DateTimeAdapter.validate_python(data["end_time"]) == DateTimeAdapter.validate_python( + json_data["end_time"] + ) + assert data["type"] == ActivityType.task_execution + + +def test_create_one(clients, json_data): + data = assert_request(clients.user_1.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + +def test_read_one(clients, json_data, model_id): + data = assert_request(clients.user_1.get, url=f"{ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + data = assert_request(clients.user_1.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data) + + data = assert_request(clients.admin.get, url=f"{ADMIN_ROUTE}/{model_id}").json() + _assert_read_response(data, json_data) + + +def test_create_one__empty_ids(client, client_admin, json_data): + json_data = {k: v for k, v in json_data.items() if k not in {"used_ids", "generated_ids"}} + + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=f"{ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data, empty_ids=True) + + data = assert_request(client_admin.get, url=f"{ADMIN_ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data, empty_ids=True) + + +def test_create_one__unauthorized_entities( + db, + client_user_1, + client_user_2, + json_data, + subject_id, + brain_region_id, +): + """Do not allow associations with entities that are not authorized to the user.""" + + user1_private_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_private_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_public_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=True, + ) + check_activity_create_one__unauthorized_entities( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_private_generated_id, + u2_private_entity_id=user2_private_generated_id, + u2_public_entity_id=user2_public_generated_id, + ) + + +def test_missing(client): + check_missing(ROUTE, client) + + +def test_pagination(client, create_id): + check_pagination(ROUTE, client, create_id) + + +@pytest.fixture +def morphology_id_1(client, subject_id, brain_region_id): + return create_cell_morphology_id( + client=client, + subject_id=subject_id, + brain_region_id=brain_region_id, + ) + + +@pytest.fixture +def morphology_id_2(client, subject_id, brain_region_id): + return create_cell_morphology_id( + client=client, + subject_id=subject_id, + brain_region_id=brain_region_id, + ) + + +@pytest.fixture +def models(create_id, task_config_id, morphology_id_1, morphology_id_2): + return [ + create_id( + used_ids=[task_config_id], + generated_ids=[], + ), + create_id( + used_ids=[task_config_id], + generated_ids=[morphology_id_1], + ), + create_id( + used_ids=[task_config_id], + generated_ids=[morphology_id_1, morphology_id_2], + ), + create_id( + used_ids=[], + generated_ids=[], + ), + ] + + +def test_filtering(client, models, task_config_id, morphology_id_1, morphology_id_2): + data = assert_request(client.get, url=ROUTE).json()["data"] + assert len(data) == len(models) + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id": task_config_id}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id": morphology_id_1}, + ).json()["data"] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={ + "used__id": task_config_id, + "generated__id": morphology_id_1, + }, + ).json()["data"] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={"used__id__in": [task_config_id]}, + ).json()["data"] + assert len(data) == 3 + + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": [morphology_id_2]}, + ).json()["data"] + assert len(data) == 1 + + data = assert_request( + client.get, + url=ROUTE, + params={"created_by__sub_id": USER_SUB_ID_1, "updated_by__sub_id": USER_SUB_ID_1}, + ).json()["data"] + assert len(data) == len(models) + + for executor, count in ( + (ExecutorType.single_node_job, len(models)), + (ExecutorType.distributed_job, 0), + ): + data = assert_request( + client.get, + url=ROUTE, + params={"executor": str(executor)}, + ).json()["data"] + assert len(data) == count + + +def test_delete_one(db, clients, json_data): + check_activity_delete_one( + db=db, + clients=clients, + json_data=json_data, + route=ROUTE, + admin_route=ADMIN_ROUTE, + expected_counts_before={ + TaskConfig: 1, + TaskExecution: 1, + Usage: 1, + Generation: 1, + }, + expected_counts_after={ + TaskConfig: 1, + TaskExecution: 0, + Usage: 0, + Generation: 0, + }, + ) + + +def test_update_one( + client, + client_admin, + task_config_id, + morphology_id, + create_id, +): + check_activity_update_one( + client=client, + client_admin=client_admin, + route=ROUTE, + admin_route=ADMIN_ROUTE, + used_id=task_config_id, + generated_id=morphology_id, + constructor_func=create_id, + ) + + +def test_update_one__fail_if_generated_ids_unauthorized( + db, client_user_1, client_user_2, json_data, subject_id, brain_region_id +): + """Test that it is not allowed to update generated_ids with unauthorized entities.""" + + user1_generated_id = create_cell_morphology_id( + client_user_1, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + user2_generated_id = create_cell_morphology_id( + client_user_2, + subject_id=subject_id, + brain_region_id=brain_region_id, + authorized_public=False, + ) + check_activity_update_one__fail_if_generated_ids_unauthorized( + db=db, + route=ROUTE, + client_user_1=client_user_1, + json_data=json_data, + u1_private_entity_id=user1_generated_id, + u2_private_entity_id=user2_generated_id, + ) + + +def test_update_one__fail_if_generated_ids_exists(client, morphology_id, task_config_id, create_id): + """Test activity Generation associations cannot be updated if they already exist.""" + check_activity_update_one__fail_if_generated_ids_exists( + client=client, + route=ROUTE, + entity_id_1=task_config_id, + entity_id_2=morphology_id, + constructor_func=create_id, + ) diff --git a/tests/utils.py b/tests/utils.py index f022964a..0ebba045 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -315,6 +315,54 @@ def create_circuit_extraction_campaign_id( ).json()["id"] +def create_campaign_id( + client, + task_type, + name="Test Campaign Name", + description="Test Campaign Description", + *, + authorized_public: bool = False, +): + response = client.post( + "/campaign", + json={ + "task_type": task_type, + "name": name, + "description": description, + "authorized_public": authorized_public, + "scan_parameters": {"foo": "bar"}, + }, + ) + + assert response.status_code == 200 + return response.json()["id"] + + +def create_task_config_id( + client, + task_type, + campaign_id, + name="Test Task Config Name", + description="Test Task Config Description", + *, + authorized_public: bool = False, +): + response = client.post( + "/task-config", + json={ + "task_type": task_type, + "name": name, + "description": description, + "campaign_id": str(campaign_id), + "authorized_public": authorized_public, + "scan_parameters": {"foo": "bar"}, + }, + ) + + assert response.status_code == 200 + return response.json()["id"] + + def add_db(db, row): """Add one row to the db and commit the transaction.""" db.add(row) @@ -1646,11 +1694,68 @@ def check_activity_update_one__fail_if_generated_ids_exists( "generated_ids": [str(entity_id_1)], } data = assert_request( - client.patch, url=f"{route}/{gen1}", json=update_json, expected_status_code=404 + client.patch, url=f"{route}/{gen1}", json=update_json, expected_status_code=409 ).json() assert data["details"] == "It is forbidden to update generated_ids if they exist." +def check_entity_update_one__fail_if_nested_ids_unauthorized( + db, + route, + client_user_1, + json_data, + u2_private_entity_id, + relationship_key, +): + """Test that it is not allowed to update nested ids with unauthorized entities.""" + # sanity check to ensure that authorized_project_id and authorized_public are consistent + e2 = _get_entity(db, entity_id=u2_private_entity_id) + assert e2.authorized_public is False + + # create an entity without relationships + json_data |= { + relationship_key: [], + } + data = assert_request(client_user_1.post, url=route, json=json_data).json() + + # update the entity with invalid relationships + update_json = { + relationship_key: [str(u2_private_entity_id)], + } + data = assert_request( + client_user_1.patch, url=f"{route}/{data['id']}", json=update_json, expected_status_code=404 + ).json() + assert data["details"] == f"Cannot access entities {u2_private_entity_id}" + + +def check_entity_update_one__fail_if_nested_ids_exists( + db, + route, + client_user_1, + json_data, + u1_private_entity_id, + relationship_key, +): + # sanity check to ensure that authorized_project_id and authorized_public are consistent + e1 = _get_entity(db, entity_id=u1_private_entity_id) + assert e1.authorized_public is False + + # create an entity with valid relationships + json_data |= { + relationship_key: [str(u1_private_entity_id)], + } + data = assert_request(client_user_1.post, url=route, json=json_data).json() + + # update the entity when the nested relationships exist already + update_json = { + relationship_key: [str(u1_private_entity_id)], + } + data = assert_request( + client_user_1.patch, url=f"{route}/{data['id']}", json=update_json, expected_status_code=409 + ).json() + assert data["details"] == f"It is forbidden to update {relationship_key} if they exist." + + def s3_key_exists(s3_client, key: str, storage_type=StorageType.aws_s3_internal) -> bool: bucket = storages[storage_type].bucket