diff --git a/.gitignore b/.gitignore index c6300704..f8226ab6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ coverage.xml .vscode *.egg-info tmp +temp/ .DS_Store export.tar files.txt diff --git a/alembic/versions/20260508_095241_b8ebf9a4e3da_add_label_to_derivation_and_circuit_.py b/alembic/versions/20260508_095241_b8ebf9a4e3da_add_label_to_derivation_and_circuit_.py new file mode 100644 index 00000000..d18ad48c --- /dev/null +++ b/alembic/versions/20260508_095241_b8ebf9a4e3da_add_label_to_derivation_and_circuit_.py @@ -0,0 +1,62 @@ +"""Add label to derivation and circuit_customization/emodel_circuit types + +Revision ID: b8ebf9a4e3da +Revises: c8cdf20bbb0d +Create Date: 2026-05-08 09:52:41.175792 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from alembic_postgresql_enum import TableReference + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "b8ebf9a4e3da" +down_revision: Union[str, None] = "c8cdf20bbb0d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("derivation", sa.Column("label", sa.String(), nullable=True)) + op.sync_enum_values( + enum_schema="public", + enum_name="derivationtype", + new_values=[ + "circuit_customization", + "circuit_extraction", + "circuit_rewiring", + "emodel_circuit", + "unspecified", + ], + affected_columns=[ + TableReference( + table_schema="public", table_name="derivation", column_name="derivation_type" + ) + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="derivationtype", + new_values=["circuit_extraction", "circuit_rewiring", "unspecified"], + affected_columns=[ + TableReference( + table_schema="public", table_name="derivation", column_name="derivation_type" + ) + ], + enum_values_to_rename=[], + ) + op.drop_column("derivation", "label") + # ### end Alembic commands ### diff --git a/app/db/model.py b/app/db/model.py index d1a73dd9..9069a4d6 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -1671,6 +1671,7 @@ class Derivation(Base): used: Mapped["Entity"] = relationship(foreign_keys=[used_id]) generated: Mapped["Entity"] = relationship(foreign_keys=[generated_id], passive_deletes=True) derivation_type: Mapped[DerivationType] + label: Mapped[str | None] class ScientificArtifactPublicationLink(Identifiable): diff --git a/app/db/types.py b/app/db/types.py index 29926251..117f049e 100644 --- a/app/db/types.py +++ b/app/db/types.py @@ -248,15 +248,25 @@ class DerivationType(StrEnum): """Represents the type of derivation relationship between two entities. Attributes: + circuit_customization: Indicates that a circuit entity was derived from another circuit by + customizing certain components. The optional ``label`` field specifies the type of + customization, such as ``synaptic_modification``, ``emodel_addition``, + ``emodel_modification``, ``population_modification``. circuit_extraction: Indicates that the entity was derived by extracting a set of nodes from a circuit. circuit_rewiring: Indicates that the entity was derived by rewiring the connectivity of a circuit. + emodel_circuit: Indicates that an emodel (used) was assigned to neurons of a circuit + (generated). The optional ``label`` field on the derivation carries the SONATA + ``model_template`` entry, by convention ``hoc:`` + (e.g. ``hoc:cADpyr_L5TPC``, ``hoc:bAC_L23BC``). unspecified: Indicates a derivation that does not require a specific type. """ + circuit_customization = auto() circuit_extraction = auto() circuit_rewiring = auto() + emodel_circuit = auto() unspecified = auto() diff --git a/app/schemas/derivation.py b/app/schemas/derivation.py index 87f2d7f9..dabfc4d9 100644 --- a/app/schemas/derivation.py +++ b/app/schemas/derivation.py @@ -1,10 +1,24 @@ +import re import uuid -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from app.db.types import DerivationType from app.schemas.base import BasicEntityRead +# Allowed label values per derivation type. +# - emodel_circuit: SONATA ``model_template`` entry, by convention ``hoc:``. +# - circuit_customization: type of customization applied to the source circuit. +_HOC_TEMPLATE_RE = re.compile(r"^hoc:[A-Za-z0-9_]+$") +_CIRCUIT_CUSTOMIZATION_LABELS = frozenset( + { + "synaptic_modification", + "emodel_addition", + "emodel_modification", + "population_modification", + } +) + class DerivationBase(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -14,9 +28,35 @@ class DerivationCreate(DerivationBase): used_id: uuid.UUID generated_id: uuid.UUID derivation_type: DerivationType + label: str | None = None + + @model_validator(mode="after") + def label_matches_derivation_type(self): + """Validate the label against the derivation type when provided.""" + if self.label is None: + return self + if self.derivation_type == DerivationType.emodel_circuit and not _HOC_TEMPLATE_RE.fullmatch( + self.label + ): + msg = ( + "label for derivation_type 'emodel_circuit' must match " + "'hoc:' (e.g. 'hoc:cADpyr_L5TPC')" + ) + raise ValueError(msg) + if ( + self.derivation_type == DerivationType.circuit_customization + and self.label not in _CIRCUIT_CUSTOMIZATION_LABELS + ): + msg = ( + "label for derivation_type 'circuit_customization' must be one of " + f"{sorted(_CIRCUIT_CUSTOMIZATION_LABELS)}" + ) + raise ValueError(msg) + return self class DerivationRead(DerivationBase): used: BasicEntityRead generated: BasicEntityRead derivation_type: DerivationType + label: str | None = None diff --git a/app/service/derivation.py b/app/service/derivation.py index a6db6c46..c7d90445 100644 --- a/app/service/derivation.py +++ b/app/service/derivation.py @@ -1,17 +1,21 @@ """Generic derivation service.""" import uuid +from http import HTTPStatus import sqlalchemy as sa from sqlalchemy import and_ from sqlalchemy.orm import aliased, joinedload, raiseload from app.db.model import Derivation, DerivationType, Entity +from app.db.types import EntityType from app.db.utils import ENTITY_TYPE_TO_CLASS, load_db_model_from_pydantic from app.dependencies.auth import AdminContextDep, UserContextDep, UserContextWithProjectIdDep from app.dependencies.common import DerivationQueryDep, PaginationQuery from app.dependencies.db import SessionDep from app.errors import ( + ApiError, + ApiErrorCode, ensure_authorized_references, ensure_foreign_keys_integrity, ensure_result, @@ -27,6 +31,66 @@ from app.schemas.types import ListResponse from app.utils.routers import entity_route_to_type +# Allowed entity types per derivation_type for (used, generated). +# A value of ``None`` means "no type-specific check" (any entity allowed). +# ``unspecified`` is intentionally unconstrained: it is currently used as a +# placeholder for emodel/memodel derivations and will be replaced by dedicated +# derivation types later. +_ALLOWED_ENTITY_TYPES: dict[ + DerivationType, tuple[frozenset[EntityType] | None, frozenset[EntityType] | None] +] = { + DerivationType.circuit_extraction: ( + frozenset({EntityType.circuit}), + frozenset({EntityType.circuit}), + ), + DerivationType.circuit_rewiring: ( + frozenset({EntityType.circuit}), + frozenset({EntityType.circuit}), + ), + DerivationType.circuit_customization: ( + frozenset({EntityType.circuit}), + frozenset({EntityType.circuit}), + ), + DerivationType.emodel_circuit: ( + frozenset({EntityType.emodel}), + frozenset({EntityType.circuit}), + ), + DerivationType.unspecified: (None, None), +} + + +def _validate_entity_types( + derivation_type: DerivationType, + used_entity: Entity, + generated_entity: Entity, +) -> None: + """Validate that the used/generated entity types match the derivation_type. + + Raises: + ApiError: with ``INVALID_REQUEST`` (HTTP 422) if the types do not match. + """ + allowed_used, allowed_generated = _ALLOWED_ENTITY_TYPES[derivation_type] + if allowed_used is not None and used_entity.type not in allowed_used: + raise ApiError( + message=( + f"derivation_type '{derivation_type.value}' requires used entity type " + f"to be one of {sorted(t.value for t in allowed_used)}, " + f"got '{used_entity.type.value}'" + ), + error_code=ApiErrorCode.INVALID_REQUEST, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + if allowed_generated is not None and generated_entity.type not in allowed_generated: + raise ApiError( + message=( + f"derivation_type '{derivation_type.value}' requires generated entity type " + f"to be one of {sorted(t.value for t in allowed_generated)}, " + f"got '{generated_entity.type.value}'" + ), + error_code=ApiErrorCode.INVALID_REQUEST, + http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + def _read_many( *, @@ -158,6 +222,7 @@ def create_one( json_model.generated_id, user_context.project_id, ) + _validate_entity_types(json_model.derivation_type, used_entity, generated_entity) db_model_class = Derivation db_model_instance = load_db_model_from_pydantic( json_model=json_model, diff --git a/scripts/export/build_database_archive.sh b/scripts/export/build_database_archive.sh index 8c784eac..b52ecd9b 100755 --- a/scripts/export/build_database_archive.sh +++ b/scripts/export/build_database_archive.sh @@ -2,7 +2,7 @@ # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="c8cdf20bbb0d" +SCRIPT_DB_VERSION="b8ebf9a4e3da" echo "DB dump (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" @@ -271,7 +271,7 @@ install -m 755 /dev/stdin "$WORK_DIR/load.sh" <<'EOF_LOAD_SCRIPT' # Automatically generated, do not edit! set -euo pipefail SCRIPT_VERSION="1" -SCRIPT_DB_VERSION="c8cdf20bbb0d" +SCRIPT_DB_VERSION="b8ebf9a4e3da" echo "DB load (version $SCRIPT_VERSION for db version $SCRIPT_DB_VERSION)" diff --git a/tests/conftest.py b/tests/conftest.py index 85284c10..f5c81c7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1289,7 +1289,7 @@ def root_circuit_json_data(brain_atlas_id, subject_id, brain_region_id, license_ "subject_id": str(subject_id), "build_category": "em_reconstruction", "authorized_project_id": PROJECT_ID, - "authorized_public": True, + "authorized_public": False, "created_by_id": str(person_id), "updated_by_id": str(person_id), "brain_region_id": str(brain_region_id), diff --git a/tests/test_derivation.py b/tests/test_derivation.py index 5f6ab596..fa014f1e 100644 --- a/tests/test_derivation.py +++ b/tests/test_derivation.py @@ -1,6 +1,6 @@ import pytest -from app.db.model import Derivation +from app.db.model import Circuit, Derivation from app.errors import ApiErrorCode from app.schemas.api import ErrorResponse @@ -8,43 +8,66 @@ PROJECT_ID, UNRELATED_PROJECT_ID, add_all_db, + add_db, assert_request, assert_response, - create_electrical_cell_recording_id, ) +def _add_source_circuit(db, root_circuit_json_data, person_id, name): + return add_db( + db, + Circuit( + **root_circuit_json_data + | { + "name": name, + "created_by_id": person_id, + "updated_by_id": person_id, + "authorized_project_id": PROJECT_ID, + "authorized_public": True, + } + ), + ) + + def test_get_derived_from( - db, clients, emodel_id, public_emodel_id, electrical_cell_recording_json_data + db, + clients, + person_id, + public_root_circuit, + root_circuit, + root_circuit_json_data, ): - # create two emodels, one with derivations and one without - trace_ids = [ - create_electrical_cell_recording_id( - clients.user_1, json_data=electrical_cell_recording_json_data | {"name": f"name-{i}"} - ) + # Create source circuits (used) for the typed derivations. + # Source/target circuits use circuit_extraction / circuit_rewiring + # (circuit -> circuit), and a single unspecified derivation goes to a + # private target. Direct DB inserts via add_all_db bypass the create-time + # type validator, but the data still satisfies it for consistency. + source_ids = [ + _add_source_circuit(db, root_circuit_json_data, person_id, f"source-{i}").id for i in range(6) ] derivations = ( [ Derivation( - used_id=ecr_id, - generated_id=public_emodel_id, + used_id=src_id, + generated_id=public_root_circuit.id, derivation_type="circuit_extraction", ) - for ecr_id in trace_ids[:3] + for src_id in source_ids[:3] ] + [ Derivation( - used_id=ecr_id, - generated_id=public_emodel_id, + used_id=src_id, + generated_id=public_root_circuit.id, derivation_type="circuit_rewiring", ) - for ecr_id in trace_ids[3:5] + for src_id in source_ids[3:5] ] + [ Derivation( - used_id=trace_ids[5], - generated_id=emodel_id, # private + used_id=source_ids[5], + generated_id=root_circuit.id, # private derivation_type="unspecified", ) ] @@ -52,44 +75,44 @@ def test_get_derived_from( add_all_db(db, derivations) response = clients.user_1.get( - url=f"/emodel/{public_emodel_id}/derived-from", + url=f"/circuit/{public_root_circuit.id}/derived-from", params={"derivation_type": "circuit_extraction"}, ) assert_response(response, 200) data = response.json()["data"] assert len(data) == 3 - assert [d["id"] for d in data] == [str(id_) for id_ in reversed(trace_ids[:3])] - assert all(d["type"] == "electrical_cell_recording" for d in data) + assert [d["id"] for d in data] == [str(id_) for id_ in reversed(source_ids[:3])] + assert all(d["type"] == "circuit" for d in data) response = clients.user_1.get( - url=f"/emodel/{public_emodel_id}/derived-from", + url=f"/circuit/{public_root_circuit.id}/derived-from", params={"derivation_type": "circuit_rewiring"}, ) assert_response(response, 200) data = response.json()["data"] assert len(data) == 2 - assert [d["id"] for d in data] == [str(id_) for id_ in reversed(trace_ids[3:5])] - assert all(d["type"] == "electrical_cell_recording" for d in data) + assert [d["id"] for d in data] == [str(id_) for id_ in reversed(source_ids[3:5])] + assert all(d["type"] == "circuit" for d in data) response = clients.user_1.get( - url=f"/emodel/{emodel_id}/derived-from", + url=f"/circuit/{root_circuit.id}/derived-from", params={"derivation_type": "unspecified"}, ) assert_response(response, 200) data = response.json()["data"] assert len(data) == 1 - assert data[0]["id"] == str(trace_ids[5]) - assert data[0]["type"] == "electrical_cell_recording" + assert data[0]["id"] == str(source_ids[5]) + assert data[0]["type"] == "circuit" # Test error not derivation_type param - response = clients.user_1.get(url=f"/emodel/{public_emodel_id}/derived-from") + response = clients.user_1.get(url=f"/circuit/{public_root_circuit.id}/derived-from") assert_response(response, 422) error = ErrorResponse.model_validate(response.json()) assert error.error_code == ApiErrorCode.INVALID_REQUEST # Test error invalid derivation_type param response = clients.user_1.get( - url=f"/emodel/{public_emodel_id}/derived-from", + url=f"/circuit/{public_root_circuit.id}/derived-from", params={"derivation_type": "invalid_type"}, ) assert_response(response, 422) @@ -98,7 +121,7 @@ def test_get_derived_from( # Test empty result response = clients.user_1.get( - url=f"/emodel/{public_emodel_id}/derived-from", + url=f"/circuit/{public_root_circuit.id}/derived-from", params={"derivation_type": "unspecified"}, ) assert_response(response, 200) @@ -107,7 +130,7 @@ def test_get_derived_from( # Test private unreadable entity response = clients.user_2.get( - url=f"/emodel/{emodel_id}/derived-from", + url=f"/circuit/{root_circuit.id}/derived-from", params={"derivation_type": "unspecified"}, ) assert_response(response, 404) @@ -115,7 +138,7 @@ def test_get_derived_from( # Test non existing entity response = clients.user_2.get( - url="/emodel/00000000-0000-0000-0000-000000000000/derived-from", + url="/circuit/00000000-0000-0000-0000-000000000000/derived-from", params={"derivation_type": "unspecified"}, ) assert_response(response, 404) @@ -123,14 +146,14 @@ def test_get_derived_from( data = assert_request( clients.admin.get, - url=f"/admin/emodel/{public_emodel_id}/derived-from", + url=f"/admin/circuit/{public_root_circuit.id}/derived-from", params={"derivation_type": "circuit_extraction"}, ).json()["data"] assert len(data) == 3 data = assert_request( clients.admin.get, - url=f"/admin/emodel/{emodel_id}/derived-from", + url=f"/admin/circuit/{root_circuit.id}/derived-from", params={"derivation_type": "unspecified"}, ).json()["data"] assert len(data) == 1 @@ -139,12 +162,18 @@ def test_get_derived_from( @pytest.mark.parametrize( "derivation_type", [ + "circuit_customization", "circuit_extraction", "circuit_rewiring", "unspecified", ], ) def test_create_one(client, derivation_type, root_circuit, circuit): + """Create derivations between two circuits (covered by validation rules). + + ``emodel_circuit`` is excluded here because it requires used=emodel, + generated=circuit; it is exercised in ``test_create_emodel_circuit_with_label``. + """ data = assert_request( client.post, url="/derivation", @@ -158,9 +187,116 @@ def test_create_one(client, derivation_type, root_circuit, circuit): "used": {"type": "circuit", "id": str(root_circuit.id)}, "generated": {"type": "circuit", "id": str(circuit.id)}, "derivation_type": derivation_type, + "label": None, + } + + +def test_create_emodel_circuit_with_label(client, emodel_id, circuit): + """Link an emodel (used) to a circuit (generated) with the SONATA model_template label.""" + data = assert_request( + client.post, + url="/derivation", + json={ + "used_id": str(emodel_id), + "generated_id": str(circuit.id), + "derivation_type": "emodel_circuit", + "label": "hoc:cADpyr_L5TPC", + }, + ).json() + assert data == { + "used": {"type": "emodel", "id": str(emodel_id)}, + "generated": {"type": "circuit", "id": str(circuit.id)}, + "derivation_type": "emodel_circuit", + "label": "hoc:cADpyr_L5TPC", + } + + +def test_create_circuit_customization_with_label(client, root_circuit, circuit): + """Derive a circuit from another by customizing components, with a label for the type.""" + data = assert_request( + client.post, + url="/derivation", + json={ + "used_id": str(root_circuit.id), + "generated_id": str(circuit.id), + "derivation_type": "circuit_customization", + "label": "synaptic_modification", + }, + ).json() + assert data == { + "used": {"type": "circuit", "id": str(root_circuit.id)}, + "generated": {"type": "circuit", "id": str(circuit.id)}, + "derivation_type": "circuit_customization", + "label": "synaptic_modification", } +@pytest.mark.parametrize( + ("derivation_type", "label"), + [ + ("emodel_circuit", "cADpyr_L5TPC"), # missing hoc: prefix + ("emodel_circuit", "hoc:"), # empty template name + ("emodel_circuit", "nml:cADpyr_L5TPC"), # wrong prefix + ("circuit_customization", "unknown_modification"), + ("circuit_customization", ""), + ], +) +def test_create_invalid_label_for_derivation_type( + client, emodel_id, circuit, derivation_type, label +): + used_id = emodel_id if derivation_type == "emodel_circuit" else str(circuit.id) + data = assert_request( + client.post, + url="/derivation", + json={ + "used_id": str(used_id), + "generated_id": str(circuit.id), + "derivation_type": derivation_type, + "label": label, + }, + expected_status_code=422, + ).json() + assert data["error_code"] == "INVALID_REQUEST" + + +@pytest.mark.parametrize( + ("derivation_type", "use_emodel_as_used", "use_emodel_as_generated"), + [ + # circuit-only types reject emodel as the used entity + ("circuit_extraction", True, False), + ("circuit_rewiring", True, False), + ("circuit_customization", True, False), + # emodel_circuit rejects circuit as used + ("emodel_circuit", False, False), + # emodel_circuit rejects emodel as generated + ("emodel_circuit", True, True), + ], +) +def test_create_invalid_entity_types_for_derivation_type( + client, + emodel_id, + public_emodel_id, + root_circuit, + circuit, + derivation_type, + use_emodel_as_used, + use_emodel_as_generated, +): + used_id = str(emodel_id) if use_emodel_as_used else str(root_circuit.id) + generated_id = str(public_emodel_id) if use_emodel_as_generated else str(circuit.id) + data = assert_request( + client.post, + url="/derivation", + json={ + "used_id": used_id, + "generated_id": generated_id, + "derivation_type": derivation_type, + }, + expected_status_code=422, + ).json() + assert data["error_code"] == "INVALID_REQUEST" + + def test_create_invalid_data(client, root_circuit, circuit): # test that the derivation type is mandatory data = assert_request(