diff --git a/.gitignore b/.gitignore index 15d7809..cec5f55 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ build .tox .idea .vscode +.kiro venv dist brainbuilder/version.py diff --git a/brainbuilder/app/sonata.py b/brainbuilder/app/sonata.py index 5efe35e..f091d6c 100644 --- a/brainbuilder/app/sonata.py +++ b/brainbuilder/app/sonata.py @@ -449,3 +449,22 @@ def resize_datatypes(file, population_name, population_type, attributes): click.secho(f"The following updates were performed:\n{updates}") click.secho(f"One should run `h5repack {file} output.h5` to repack the file", fg="green") + + +@app.command() +@click.argument("circuit-config", type=REQUIRED_PATH) +@click.option("-o", "--output", help="Save to file (e.g., circuit.png). Otherwise opens viewer.") +@click.option("-t", "--title", help="Title for the graph (default: parent directory name)") +def visualize(circuit_config, output, title): + """Display a graph of circuit connectivity grouped by population. + + Requires: pip install brainbuilder[viz] and system graphviz. + """ + from brainbuilder.utils.sonata.visualize import draw_circuit + + if not title: + title = Path(circuit_config).parent.name + + draw_circuit(circuit_config, output_path=output, title=title) + if output: + click.echo(f"Saved to {output}") diff --git a/brainbuilder/utils/sonata/split_population.py b/brainbuilder/utils/sonata/split_population.py index a3e3aa7..54de57a 100755 --- a/brainbuilder/utils/sonata/split_population.py +++ b/brainbuilder/utils/sonata/split_population.py @@ -48,25 +48,36 @@ @dataclass class WriteEdgeConfig: - input_path: str | Path + input_path: str | Path | list[str | Path] output_path: str | Path src_node_name: str dst_node_name: str - src_edge_name: str + src_edge_name: str | list[str] dst_edge_name: str - src_mapping: pd.DataFrame + src_mapping: pd.DataFrame | list[pd.DataFrame] dst_mapping: pd.DataFrame h5_read_chunk_size: int | None = None edge_type: type[bytes] | None = None def __post_init__(self): - self.input_path = ( - Path(self.input_path) if isinstance(self.input_path, str) else self.input_path - ) + # Normalize to lists for uniform processing + if not isinstance(self.input_path, list): + self.input_path = [self.input_path] + self.src_edge_name = [self.src_edge_name] + self.src_mapping = [self.src_mapping] + self.input_path = [Path(p) if isinstance(p, str) else p for p in self.input_path] self.output_path = ( Path(self.output_path) if isinstance(self.output_path, str) else self.output_path ) + @property + def source_node_count(self) -> int: + return int(max(np.max(m) for m in self.src_mapping)) + 1 + + @property + def target_node_count(self) -> int: + return int(np.max(self.dst_mapping)) + 1 + def _create_chunked_slices(length, chunk_size): """return `slices` each of size `chunk_size`, that cover `length`""" @@ -237,124 +248,127 @@ def _h5_get_read_chunk_size(): def _copy_filtered_edges( - h5in: h5py.File, - h5out: h5py.File, write_edge_config: WriteEdgeConfig, edge_mappings: dict[str, tuple[pd.DataFrame, str]] = None, -): - """ - Copy and filter edge datasets from an input HDF5 file to an output HDF5 file. +) -> int: + """Copy and filter edge datasets from input HDF5 file(s) to an output HDF5 file. - This function: - - Reads the source edge population in chunks. - - Filters edges based on source and target node mappings. - - Copies source/target node IDs and all associated datasets. - - Initializes and populates new edge groups, preserving attributes. - - Raises errors if invalid IDs or unsupported groups (e.g., @library) are encountered. + Opens the output file, writes filtered edges, finalizes, and closes. + Supports multiple input sources (input_path/src_edge_name/src_mapping lists) + which are processed sequentially into a single output edge population. Args: - h5in (h5py.File): Input HDF5 file containing original edges. - h5out (h5py.File): Output HDF5 file to store filtered edges. - write_edge_config (WriteEdgeConfig): Configuration specifying - source/target populations, edge names, mappings, and read chunk size. - edge_mappings (dict[str, tuple[pd.DataFrame, str]]): Optional dict - updated with old→new edge ID mappings. The key is the old edge file name, - pd.DataFrame is the id remapping and the last str is the new edge file name. - - Notes: - - Only the "dynamics_params" group is currently supported in edge groups. - - Supports appendable datasets; existing datasets will raise an error if re-created. - - Processing is chunked to handle large edge populations efficiently. + write_edge_config: Configuration specifying source/target populations, + edge names, mappings, output path, and read chunk size. + edge_mappings: Optional dict updated with old->new edge ID mappings. + + Returns: + The number of edges written. """ - # extract values h5_read_chunk_size = ( write_edge_config.h5_read_chunk_size if write_edge_config.h5_read_chunk_size is not None else _h5_get_read_chunk_size() ) is_neuroglial = write_edge_config.edge_type == "synapse_astrocyte" - # get groups - orig_edges = h5in["edges"][write_edge_config.src_edge_name] - orig_group = _get_unique_group(orig_edges) - new_edges = h5out.create_group("edges/" + write_edge_config.dst_edge_name) - new_group = new_edges.create_group(GROUP_NAME) - - hdf5.create_appendable_dataset(new_edges, "source_node_id", np.uint64) - hdf5.create_appendable_dataset(new_edges, "target_node_id", np.uint64) - - # since create_appendable_dataset already fails if we append to an existing - # dataset, the following attribute assignments are safe - new_edges["source_node_id"].attrs["node_population"] = write_edge_config.src_node_name - new_edges["target_node_id"].attrs["node_population"] = write_edge_config.dst_node_name - - additional_attrs = {} - if is_neuroglial: - # find new name of synapse_edge_pop - src_syn_edge_pop = orig_edges[GROUP_NAME]["synapse_id"].attrs["edge_population"] - _, dst_syn_edge_pop = edge_mappings[src_syn_edge_pop] - # create dataset and add correct attr - additional_attrs["synapse_id"] = {"edge_population": dst_syn_edge_pop} - - _init_edge_group(orig_group, new_group, additional_attrs) - - sgids_new = write_edge_config.src_mapping.index.to_numpy() - tgids_new = write_edge_config.dst_mapping.index.to_numpy() - assert (sgids_new >= 0).all(), "Source population ids must be positive." - assert (tgids_new >= 0).all(), "Target population ids must be positive." - - total_chunks = math.ceil(len(orig_edges["source_node_id"]) / h5_read_chunk_size) - L.debug( - "Processing %s edges in %s chunks of size %s [src_edge_name=%s]", - len(orig_edges["source_node_id"]), - total_chunks, - h5_read_chunk_size, - write_edge_config.src_edge_name, - ) - sl_and_masks = _compute_chunks_and_masks( - orig_edges=orig_edges, - sgids_new=sgids_new, - tgids_new=tgids_new, - h5_read_chunk_size=h5_read_chunk_size, - edge_mappings=edge_mappings, - is_neuroglial=is_neuroglial, - ) + write_edge_config.output_path.parent.mkdir(parents=True, exist_ok=True) - if edge_mappings is not None: - offset = new_edges["source_node_id"].shape[0] - if offset != 0: - raise RuntimeError( - "Cannot append edges when edge_mappings is enabled and the destination already " - f"contains {offset} edges. The current implementation only supports edges created " - "in a single pass and does not capture cross-generation connections " - "(old astrocyte -> new neuron or new astrocyte -> old neuron). " - "Only new->new connections would be handled correctly. " - "Appending is therefore blocked as a safety safeguard." - ) + with h5py.File(write_edge_config.output_path, "a") as h5out: + # Create output edge group (once, using first input as template) + with h5py.File(write_edge_config.input_path[0], "r") as h5_first: + first_orig_edges = h5_first["edges"][write_edge_config.src_edge_name[0]] + orig_group = _get_unique_group(first_orig_edges) - assert write_edge_config.src_edge_name not in edge_mappings, ( - f"Source edge population '{write_edge_config.src_edge_name}' " - "already exists in edge_mappings. " - "Cannot overwrite an existing mapping; check your inputs or " - "ensure edge populations are unique." - ) - edge_mappings[write_edge_config.src_edge_name] = ( - _compute_edge_mapping(sl_and_masks=sl_and_masks, offset=offset), - write_edge_config.dst_edge_name, - ) + new_edges = h5out.create_group("edges/" + write_edge_config.dst_edge_name) + new_group = new_edges.create_group(GROUP_NAME) - _write_masked_edges( - sl_and_masks=sl_and_masks, - new_edges=new_edges, - orig_edges=orig_edges, - src_mapping=write_edge_config.src_mapping, - dst_mapping=write_edge_config.dst_mapping, - edge_mappings=edge_mappings, - is_neuroglial=is_neuroglial, - ) + hdf5.create_appendable_dataset(new_edges, "source_node_id", np.uint64) + hdf5.create_appendable_dataset(new_edges, "target_node_id", np.uint64) + + new_edges["source_node_id"].attrs["node_population"] = write_edge_config.src_node_name + new_edges["target_node_id"].attrs["node_population"] = write_edge_config.dst_node_name + + additional_attrs = {} + if is_neuroglial: + src_syn_edge_pop = orig_group["synapse_id"].attrs["edge_population"] + _, dst_syn_edge_pop = edge_mappings[src_syn_edge_pop] + additional_attrs["synapse_id"] = {"edge_population": dst_syn_edge_pop} - L.debug("Finalize edges") - _finalize_edges(new_edges) + _init_edge_group(orig_group, new_group, additional_attrs) + + # Process each input source + tgids_new = write_edge_config.dst_mapping.index.to_numpy() + assert (tgids_new >= 0).all(), "Target population ids must be positive." + + for input_path, src_edge_name, src_mapping in zip( + write_edge_config.input_path, + write_edge_config.src_edge_name, + write_edge_config.src_mapping, + ): + sgids_new = src_mapping.index.to_numpy() + assert (sgids_new >= 0).all(), "Source population ids must be positive." + + with h5py.File(input_path, "r") as h5in: + orig_edges = h5in["edges"][src_edge_name] + + total_chunks = math.ceil(len(orig_edges["source_node_id"]) / h5_read_chunk_size) + L.debug( + "Processing %s edges in %s chunks of size %s [src_edge_name=%s]", + len(orig_edges["source_node_id"]), + total_chunks, + h5_read_chunk_size, + src_edge_name, + ) + + sl_and_masks = _compute_chunks_and_masks( + orig_edges=orig_edges, + sgids_new=sgids_new, + tgids_new=tgids_new, + h5_read_chunk_size=h5_read_chunk_size, + edge_mappings=edge_mappings, + is_neuroglial=is_neuroglial, + ) + + if edge_mappings is not None: + offset = new_edges["source_node_id"].shape[0] + if offset != 0 and is_neuroglial: + raise RuntimeError( + "Cannot append neuroglial edges when edge_mappings is enabled and the " + f"destination already contains {offset} edges." + ) + + assert src_edge_name not in edge_mappings, ( + f"Source edge population '{src_edge_name}' already exists in edge_mappings." + ) + edge_mappings[src_edge_name] = ( + _compute_edge_mapping(sl_and_masks=sl_and_masks, offset=offset), + write_edge_config.dst_edge_name, + ) + + _write_masked_edges( + sl_and_masks=sl_and_masks, + new_edges=new_edges, + orig_edges=orig_edges, + src_mapping=src_mapping, + dst_mapping=write_edge_config.dst_mapping, + edge_mappings=edge_mappings, + is_neuroglial=is_neuroglial, + ) + + edge_count = len(new_edges["source_node_id"]) + + # Verify consistency + if edge_count > 0: + assert write_edge_config.source_node_count >= int(np.max(new_edges["source_node_id"])) + assert write_edge_config.target_node_count >= int(np.max(new_edges["target_node_id"])) + L.debug("Finalize edges") + _finalize_edges(new_edges) + else: + # Remove empty population from the file + del h5out[f"edges/{write_edge_config.dst_edge_name}"] + + return edge_count def _compute_edge_mapping(sl_and_masks, offset=0): @@ -438,33 +452,6 @@ def _write_masked_edges( _populate_edge_group(orig_edges[GROUP_NAME], new_edges[GROUP_NAME], sl, mask, overrides) -def _get_node_counts( - h5out: h5py.File, new_edge_pop_name: str, src_mapping: pd.DataFrame, dst_mapping: pd.DataFrame -): - """for `h5out`, return the `new_edge_pop_name`, `source_node_count`, and `target_node_count`""" - - source_node_count = int(np.max(src_mapping)) + 1 - target_node_count = int(np.max(dst_mapping)) + 1 - - new_edges = h5out["edges"][new_edge_pop_name] - edge_count = len(new_edges["source_node_id"]) - - if edge_count > 0: - assert source_node_count >= int(np.max(new_edges["source_node_id"])) - assert target_node_count >= int(np.max(new_edges["target_node_id"])) - - return edge_count, source_node_count, target_node_count - - -def _write_indexes( - edge_file_name: str | Path, new_pop_name: str, source_node_count: int, target_node_count: int -): - """ibid""" - libsonata.EdgePopulation.write_indices( - str(edge_file_name), new_pop_name, source_node_count, target_node_count - ) - - def _check_all_edges_used(h5in, written_edges): """Verify that the number of written edges matches the number of initial edges.""" orig_edges = h5in["edges"][_get_unique_population(h5in["edges"])] @@ -501,15 +488,16 @@ def _write_edges( ) L.debug("Writing to %s", write_edge_config.output_path) - with h5py.File(write_edge_config.output_path, "w") as h5out: - _copy_filtered_edges(h5in=h5in, h5out=h5out, write_edge_config=write_edge_config) - edge_count, sgid_count, tgid_count = _get_node_counts( - h5out, edge_pop_name, id_mapping[src_node_pop], id_mapping[dst_node_pop] - ) + edge_count = _copy_filtered_edges(write_edge_config=write_edge_config) - # after the h5 file is closed, it's indexed if valid, or it's removed if empty + # after the file is closed, it's indexed if valid, or it's removed if empty if edge_count > 0: - _write_indexes(write_edge_config.output_path, edge_pop_name, sgid_count, tgid_count) + libsonata.EdgePopulation.write_indices( + str(write_edge_config.output_path), + edge_pop_name, + write_edge_config.source_node_count, + write_edge_config.target_node_count, + ) L.debug("Wrote %s edges to %s", edge_count, write_edge_config.output_path) written_edges += edge_count else: @@ -661,11 +649,15 @@ def simple_split_subcircuit(output, node_set_name, node_set_path, nodes_path, ed def _write_subcircuit_edges( write_edge_config: WriteEdgeConfig, edge_mappings: dict[str, tuple[pd.DataFrame, str]] ): - """copy a population to an edge file + """Write a filtered edge population to an HDF5 file. + + If the population has no edges after filtering, it is removed from the file. + If the file becomes empty as a result, the file itself is deleted. - If DELETED_EMPTY_EDGES_FILE is returned, the file was removed since no - populations existed in it any more - If DELETED_EMPTY_EDGES_POPULATION is returned, the population was removed + Returns: + (output_path, edge_count): the path written to, and the number of edges. + output_path may be DELETED_EMPTY_EDGES_FILE (file removed) or + DELETED_EMPTY_EDGES_POPULATION (population removed, file kept). """ output_path = write_edge_config.output_path @@ -677,47 +669,24 @@ def _write_subcircuit_edges( str(output_path), ) - output_path.parent.mkdir(parents=True, exist_ok=True) - - with h5py.File(write_edge_config.input_path, "r") as h5in: - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - is_file_empty = False - - with h5py.File(output_path, "a") as h5out: - _copy_filtered_edges( - h5in=h5in, - h5out=h5out, - write_edge_config=write_edge_config, - edge_mappings=edge_mappings, - ) - - edge_count, sgid_count, tgid_count = _get_node_counts( - h5out=h5out, - new_edge_pop_name=write_edge_config.dst_edge_name, - src_mapping=write_edge_config.src_mapping, - dst_mapping=write_edge_config.dst_mapping, - ) - - if edge_count == 0: - del h5out[f"/edges/{write_edge_config.dst_edge_name}"] - is_file_empty = len(h5out["/edges"]) == 0 + edge_count = _copy_filtered_edges( + write_edge_config=write_edge_config, + edge_mappings=edge_mappings, + ) - # after the h5 file is closed, it's indexed if valid, or it's removed if empty - if edge_count > 0: - _write_indexes( - edge_file_name=write_edge_config.output_path, - new_pop_name=write_edge_config.dst_edge_name, - source_node_count=sgid_count, - target_node_count=tgid_count, - ) - L.debug("Wrote %s edges to %s", edge_count, output_path) - elif is_file_empty: - Path(output_path).unlink(missing_ok=True) + if edge_count > 0: + L.debug("Wrote %s edges to %s", edge_count, output_path) + elif output_path.exists(): + # Check if the file has any other populations; if not, remove it + with h5py.File(output_path, "r") as h5: + is_file_empty = len(h5["edges"]) == 0 + if is_file_empty: + output_path.unlink(missing_ok=True) output_path = DELETED_EMPTY_EDGES_FILE - else: # population empty, but not file + else: output_path = DELETED_EMPTY_EDGES_POPULATION - return output_path + return output_path, edge_count def _get_storage_path(edge): @@ -789,10 +758,20 @@ def _orchestrate_write_subcircuit_edges(write_edge_configs: list[WriteEdgeConfig ) for write_edge_config in write_edge_configs_sorted: - new_edges_files[write_edge_config.dst_edge_name] = _write_subcircuit_edges( + output_path, edge_count = _write_subcircuit_edges( write_edge_config=write_edge_config, edge_mappings=edge_mappings ) + if edge_count > 0: + libsonata.EdgePopulation.write_indices( + str(output_path), + write_edge_config.dst_edge_name, + write_edge_config.source_node_count, + write_edge_config.target_node_count, + ) + + new_edges_files[write_edge_config.dst_edge_name] = output_path + return new_edges_files @@ -832,20 +811,32 @@ def _get_subcircuit_external_ids(all_sgids, all_tgids, wanted_src_ids, wanted_ds return ret.sort_index() -def _write_subcircuit_external( +def _gather_new_external_subcircuits( output, circuit, id_mapping, node_pop_name_mapping, existing_node_pop_names, existing_edge_pop_names, + external_id_offset=None, ): - """Write external connectivity. + """Gather external connectivity: non-virtual sources projecting into the subcircuit. + + Identifies source nodes that are NOT in the subcircuit but have edges targeting + nodes that ARE in the subcircuit. Builds WriteEdgeConfigs and a nodes_to_write + dict suitable for _write_subcircuit. - returns: (new_node_files, new_edges_files); with, respectively, - dictionaries with node and edge population_name -> path + Updates `id_mapping` and `node_pop_name_mapping` in place. - Warning: this writes `id_mapping` in place + Args: + external_id_offset: Dict of population_name -> starting new_id offset. + Used when merging with pre-existing externals from a parent circuit. + If None or population not in dict, offset is 0. + + Returns: + (write_edge_configs, nodes_to_write): + write_edge_configs: list of WriteEdgeConfig + nodes_to_write: dict of output_pop_name -> list of (source_pop_name, node_ids) """ assert all(edge.type != "neuroglial" for edge in circuit.edges.values()), ( @@ -856,7 +847,7 @@ def _write_subcircuit_external( "that multiplies the amount of additional files required." ) - new_nodes = {} + nodes_to_write = {} write_edge_configs = [] for name, edge in circuit.edges.items(): @@ -870,14 +861,10 @@ def _write_subcircuit_external( ) ] - # only keep ids that are used; this is duplicating work in _copy_edge_attributes - # but the alternative is that it keeps track of the new id_mapping; which - # seemed less ideal with h5py.File(_get_storage_path(edge)) as h5: all_sgids = h5[f"edges/{name}/source_node_id"] all_tgids = h5[f"edges/{name}/target_node_id"] - # overwrite wanted_src_ids with a DataFrame; the numpy array is not needed wanted_src_ids = _get_subcircuit_external_ids( all_sgids, all_tgids, @@ -896,12 +883,12 @@ def _write_subcircuit_external( while new_source_pop_name in existing_node_pop_names: L.debug("%s already exists as an node population", new_source_pop_name) new_source_pop_name = "external_" + new_source_pop_name - node_pop_name_mapping[new_source_pop_name] = edge.source.name + if new_source_pop_name not in node_pop_name_mapping: + node_pop_name_mapping[new_source_pop_name] = [edge.source.name] output_path = output / (new_name + ".h5") - output_path.parent.mkdir(parents=True, exist_ok=True) L.debug( - "Writing edges %s for %s -> %s [%s]", + "Gathering external edges %s for %s -> %s [%s]", name, edge.source.name, edge.target.name, @@ -909,25 +896,25 @@ def _write_subcircuit_external( ) if new_source_pop_name in id_mapping: - # If mapping already exists, only add new IDs w/o changing existing! - # (May happen if different target populations have same external source population) + # Same source, multiple target populations — deduplicate and extend existing_mapping = id_mapping[new_source_pop_name] is_existing = _isin(wanted_src_ids.index, existing_mapping.index) wanted_src_ids.loc[is_existing] = existing_mapping.loc[ wanted_src_ids.loc[is_existing].index ] - new_ids = ( - np.arange(np.sum(~is_existing)) - + wanted_src_ids[NEW_IDS].loc[is_existing].max() - + 1 - ) # New node IDs begin at the lowest unused value (max + 1) + if is_existing.any(): + max_existing_id = wanted_src_ids[NEW_IDS].loc[is_existing].max() + else: + max_existing_id = existing_mapping[NEW_IDS].max() + new_ids = np.arange(np.sum(~is_existing)) + int(max_existing_id) + 1 wanted_src_ids.loc[~is_existing, NEW_IDS] = new_ids - # And merge new into existing id_mapping[new_source_pop_name] = pd.concat( [existing_mapping, wanted_src_ids.loc[~is_existing]], axis=0 ) else: + pop_offset = (external_id_offset or {}).get(new_source_pop_name, 0) + wanted_src_ids[NEW_IDS] = wanted_src_ids[NEW_IDS] + pop_offset id_mapping[new_source_pop_name] = wanted_src_ids write_edge_config = WriteEdgeConfig( @@ -943,41 +930,40 @@ def _write_subcircuit_external( ) write_edge_configs.append(write_edge_config) - new_nodes[new_source_pop_name] = ( - edge.source.name, - wanted_src_ids.index.to_numpy(), - ) - - new_edges_files = _orchestrate_write_subcircuit_edges(write_edge_configs=write_edge_configs) - - new_node_files = {} - # write new virtual nodes from originally non-virtual populations - for population_name, id_tuple in new_nodes.items(): - # Get all properties of the subset of the node population that is relevant - orig_population_name, ids = id_tuple - df = circuit.nodes[orig_population_name].get(ids).reset_index(drop=True) - nodes_path = Path(output) / population_name / "nodes.h5" - nodes_path.parent.mkdir(parents=True, exist_ok=True) - new_node_files[population_name] = _save_sonata_nodes(nodes_path, df, population_name) + nodes_to_write[new_source_pop_name] = [ + ( + edge.source.name, + id_mapping[new_source_pop_name].index.to_numpy(), + ) + ] - return new_node_files, new_edges_files + return write_edge_configs, nodes_to_write -def _write_subcircuit_virtual( +def _filter_virtual_typed_subcircuit( output, circuit, edge_populations_to_paths, id_mapping, node_pop_name_mapping, + do_externals, list_of_sources_to_ignore=(), ): - """write all node/edge populations that have virtual nodes as source + """Filter and gather virtual-typed source populations for subcircuit extraction. - Note: the id_mapping dictionary is updated with the used virtual nodes - """ - # pylint: disable=too-many-locals - new_node_files = {} + Selects edge populations whose source is typed as virtual and whose target + is in `id_mapping`. The `do_externals` flag controls which flavor: + - False: only genuine virtuals (excluding external_* populations) + - True: only external_* populations + Updates `id_mapping` and `node_pop_name_mapping` in place with the selected + source populations. + + Returns: + (write_edge_configs, nodes_to_write): + write_edge_configs: list of WriteEdgeConfig + nodes_to_write: dict of output_pop_name -> (source_pop_name, node_ids) + """ virtual_populations = { name: edge for name, edge in circuit.edges.items() @@ -985,6 +971,7 @@ def _write_subcircuit_virtual( edge.source.type == "virtual" and edge.target.name in id_mapping and edge.source.name not in list_of_sources_to_ignore + and (edge.source.name.startswith("external_") == do_externals) ) } @@ -1014,19 +1001,15 @@ def _write_subcircuit_virtual( if edge.source.name in pop_used_source_node_ids } - # update the mappings with the virtual nodes + # update the mappings with the selected source nodes for name, ids in pop_used_source_node_ids.items(): id_mapping[name] = pd.DataFrame({NEW_IDS: range(len(ids))}, index=ids) - # Virtual input sources retain their name unchanged - node_pop_name_mapping[name] = name - - # write the edges that have the virtual populations as source + node_pop_name_mapping[name] = [name] - write_edge_configs = [] - for edge_pop_name, edge in virtual_populations.items(): - write_edge_config = WriteEdgeConfig( + write_edge_configs = [ + WriteEdgeConfig( output_path=Path(output) / edge_populations_to_paths[edge_pop_name], - input_path=_get_storage_path(edge), # Where to read from + input_path=_get_storage_path(edge), src_node_name=edge.source.name, dst_node_name=edge.target.name, src_edge_name=edge_pop_name, @@ -1035,14 +1018,39 @@ def _write_subcircuit_virtual( dst_mapping=id_mapping[edge.target.name], edge_type=edge.type, ) - write_edge_configs.append(write_edge_config) + for edge_pop_name, edge in virtual_populations.items() + ] + nodes_to_write = {name: [(name, ids)] for name, ids in pop_used_source_node_ids.items()} + + return write_edge_configs, nodes_to_write + +def _write_subcircuit( + output, + circuit, + write_edge_configs, + nodes_to_write, +): + """Write edges and nodes for a subcircuit extraction step. + + Args: + output: Path where files will be written. + circuit: bluepysnap Circuit object. + write_edge_configs: list of WriteEdgeConfig (fully constructed). + nodes_to_write: dict of output_pop_name -> list of (source_pop_name, node_ids). + Each entry is a list of sources to read and concatenate for that + output population. source_pop_name is the population to read from + in the circuit, node_ids are the IDs to extract. + + Returns: + (new_node_files, new_edges_files): dicts of population_name -> path. + """ new_edges_files = _orchestrate_write_subcircuit_edges(write_edge_configs=write_edge_configs) - # write virtual nodes based on virtual populations - for population_name, ids in pop_used_source_node_ids.items(): - # Get all properties of the subset of the node population that is relevant - df = circuit.nodes[population_name].get(ids).reset_index(drop=True) + new_node_files = {} + for population_name, sources in nodes_to_write.items(): + dfs = [circuit.nodes[src_pop].get(ids) for src_pop, ids in sources] + df = pd.concat(dfs, ignore_index=True) nodes_path = Path(output) / population_name / "nodes.h5" nodes_path.parent.mkdir(parents=True, exist_ok=True) new_node_files[population_name] = _save_sonata_nodes(nodes_path, df, population_name) @@ -1167,50 +1175,127 @@ def _should_keep(name): return ret -def _mapping_to_parent_dict(id_mapping, node_pop_name_mapping): +def _mapping_to_parent_dict(id_mapping, node_pop_name_mapping, id_mapping_secondary=None): + """Build the serializable mapping dict from id_mapping and parent names. + + Handles multiple parents per population: first parent uses parent_name/parent_id, + subsequent parents use parent2_name/parent2_id, parent3_name/parent3_id, etc. + """ + if id_mapping_secondary is None: + id_mapping_secondary = {} + mapping = {} for population, df in id_mapping.items(): - mapping[population] = { + parent_names = node_pop_name_mapping[population] + + entry = { PARENT_IDS: df.index.to_list(), NEW_IDS: df[NEW_IDS].to_list(), - PARENT_NAME: node_pop_name_mapping[population], + PARENT_NAME: parent_names[0], } + + # Add secondary parent(s) if present + if population in id_mapping_secondary: + sec_df = id_mapping_secondary[population] + entry[NEW_IDS] = entry[NEW_IDS] + sec_df[NEW_IDS].to_list() + for i, parent_name in enumerate(parent_names[1:], start=2): + suffix = str(i) + entry[f"parent{suffix}_id"] = sec_df.index.to_list() + entry[f"parent{suffix}_name"] = parent_name + + mapping[population] = entry return mapping -def _make_parent_the_original_mapping(this_mapping): - for this_pop in this_mapping.keys(): - this_mapping[this_pop][ORIG_IDS] = this_mapping[this_pop][PARENT_IDS] - this_mapping[this_pop][ORIG_NAME] = this_mapping[this_pop][PARENT_NAME] +def _set_original_ids(this_mapping: dict, parent_mapping: dict | None) -> None: + """Set original_id and original_name for each population in the mapping. + If parent_mapping is None (parent is the root circuit), original_id is + copied from parent_id. Otherwise, original_id is traced back through the + parent's mapping to the root circuit. -def _add_mapping_to_original(this_mapping, parent_mapping): - for this_pop in this_mapping.keys(): - parent_pop = this_mapping[this_pop][PARENT_NAME] + Handles numbered parent fields (parent2_id/parent2_name, etc.) for + populations with multiple parent sources. + """ + for entry in this_mapping.values(): + if parent_mapping is None: + # Collect all parent_ids across numbered fields + all_parent_ids = list(entry[PARENT_IDS]) + for i in range(2, 100): + key = f"parent{i}_id" + if key not in entry: + break + all_parent_ids.extend(entry[key]) + entry[ORIG_IDS] = all_parent_ids + entry[ORIG_NAME] = entry[PARENT_NAME] + else: + # Resolve primary parent + parent_pop = entry[PARENT_NAME] + backwards_mapped = pd.Series( + parent_mapping[parent_pop][ORIG_IDS], + index=parent_mapping[parent_pop][NEW_IDS], + ) + orig_ids = backwards_mapped[entry[PARENT_IDS]].to_list() + orig_name = parent_mapping[parent_pop][ORIG_NAME] + + # Resolve numbered parents + for i in range(2, 100): + id_key = f"parent{i}_id" + name_key = f"parent{i}_name" + if id_key not in entry: + break + parent_pop_i = entry[name_key] + backwards_mapped_i = pd.Series( + parent_mapping[parent_pop_i][ORIG_IDS], + index=parent_mapping[parent_pop_i][NEW_IDS], + ) + orig_ids.extend(backwards_mapped_i[entry[id_key]].to_list()) + # Assert all parents trace back to the same original + assert parent_mapping[parent_pop_i][ORIG_NAME] == orig_name, ( + f"Cannot merge: original_name mismatch " + f"{orig_name} vs {parent_mapping[parent_pop_i][ORIG_NAME]}" + ) - backwards_mapped = pd.Series( - parent_mapping[parent_pop][ORIG_IDS], index=parent_mapping[parent_pop][NEW_IDS] - ) - orig_ids = backwards_mapped[this_mapping[this_pop][PARENT_IDS]] - orig_name = parent_mapping[parent_pop][ORIG_NAME] + entry[ORIG_IDS] = orig_ids + entry[ORIG_NAME] = orig_name - this_mapping[this_pop][ORIG_IDS] = orig_ids.to_list() - this_mapping[this_pop][ORIG_NAME] = orig_name +def _write_mapping( + output, + parent_circ, + id_mapping, + node_pop_name_mapping, + id_mapping_secondary=None, +): + """Write the id mappings between the old and new populations for future analysis.""" + if id_mapping_secondary is None: + id_mapping_secondary = {} -def _write_mapping(output, parent_circ, id_mapping, node_pop_name_mapping): - """write the id mappings between the old and new populations for future analysis""" - this_mapping = _mapping_to_parent_dict(id_mapping, node_pop_name_mapping) + this_mapping = _mapping_to_parent_dict(id_mapping, node_pop_name_mapping, id_mapping_secondary) provenance = parent_circ.config.get("components", {}).get("provenance", {}) - if "id_mapping" in provenance: - # Currently, bluepysnap does not seem to resolve $BASE_DIR for entries in "provenance". - # Therefore I decided to not prepend it and just assume the file exists near the circuit config. + # Currently, bluepysnap does not resolve $BASE_DIR for provenance entries, + # so assume the mapping file is relative to the circuit config. + parent_mapping = None + if mapping_path := provenance.get("id_mapping"): parent_root = Path(parent_circ._circuit_config_path).parent - parent_mapping = utils.load_json(parent_root / provenance["id_mapping"]) - _add_mapping_to_original(this_mapping, parent_mapping) - else: - _make_parent_the_original_mapping(this_mapping) + parent_mapping = utils.load_json(parent_root / mapping_path) + + _set_original_ids(this_mapping, parent_mapping) + + # Validate length invariants for merged populations + for pop_name, entry in this_mapping.items(): + n_new = len(entry[NEW_IDS]) + n_parent_ids = len(entry[PARENT_IDS]) + for i in range(2, 100): + key = f"parent{i}_id" + if key not in entry: + break + n_parent_ids += len(entry[key]) + assert n_parent_ids == n_new == len(entry[ORIG_IDS]), ( + f"Length mismatch for population '{pop_name}': " + f"sum(parent*_id)={n_parent_ids}, new_id={n_new}, original_id={len(entry[ORIG_IDS])}" + ) mapping_fn = "id_mapping.json" utils.dump_json(output / mapping_fn, this_mapping) @@ -1270,10 +1355,10 @@ def split_subcircuit( id_mapping = _get_node_id_mapping(split_populations) # Intrinsic input sources retain their name unchanged - node_pop_name_mapping = {pop_name: pop_name for pop_name in split_populations.keys()} + node_pop_name_mapping = {pop_name: [pop_name] for pop_name in split_populations.keys()} # TODO: should function `_write_subcircuit_biological`, - # `_write_subcircuit_external`, `_write_subcircuit_virtual` + # `_gather_new_external_subcircuits`, `_filter_virtual_typed_subcircuit`/`_write_subcircuit` # handle node updates and config updates? new_node_files, new_edge_files = _write_subcircuit_biological( @@ -1281,14 +1366,21 @@ def split_subcircuit( ) if do_virtual: - new_virtual_node_files, new_virtual_edge_files = _write_subcircuit_virtual( + write_edge_configs, nodes_to_write = _filter_virtual_typed_subcircuit( output, circuit, edge_pop_to_paths, id_mapping, node_pop_name_mapping, + False, list_of_virtual_sources_to_ignore, ) + new_virtual_node_files, new_virtual_edge_files = _write_subcircuit( + output, + circuit, + write_edge_configs, + nodes_to_write, + ) new_node_files.update(new_virtual_node_files) new_edge_files.update(new_virtual_edge_files) @@ -1296,20 +1388,97 @@ def split_subcircuit( existing_node_pop_names = list(new_node_files.keys()) existing_edge_pop_names = list(new_edge_files.keys()) if create_external: - new_virtual_node_files, new_virtual_edge_files = _write_subcircuit_external( + # Phase A: carry over existing external_ populations from parent circuit + write_edge_configs_a, nodes_to_write_a = _filter_virtual_typed_subcircuit( output, circuit, + edge_pop_to_paths, id_mapping, node_pop_name_mapping, + True, + list_of_virtual_sources_to_ignore, + ) + + # Phase B: create new externals from biophysical populations now outside subcircuit + # Give _gather a clean id_mapping without filter's external entries. + # Pass offset so new_ids continue after filter's max for overlapping populations. + gather_id_mapping = {k: v for k, v in id_mapping.items() if not k.startswith("external_")} + gather_node_pop_name_mapping = { + k: v for k, v in node_pop_name_mapping.items() if not k.startswith("external_") + } + # Compute per-population offset: new_ids continue after filter's max + external_offsets = { + k: int(v[NEW_IDS].max()) + 1 for k, v in id_mapping.items() if k.startswith("external_") + } + write_edge_configs_b, nodes_to_write_b = _gather_new_external_subcircuits( + output, + circuit, + gather_id_mapping, + gather_node_pop_name_mapping, existing_node_pop_names, existing_edge_pop_names, + external_id_offset=external_offsets, + ) + + # Merge: gather's external populations are already offset. + # Move them to id_mapping_secondary (for serialization) or main id_mapping. + id_mapping_secondary = {} + for pop_name in list(gather_id_mapping.keys()): + if not pop_name.startswith("external_"): + continue + gather_df = gather_id_mapping[pop_name] + if pop_name in id_mapping: + # Overlapping: store as secondary (already offset) + id_mapping_secondary[pop_name] = gather_df + # Append parent name + if pop_name in gather_node_pop_name_mapping: + for parent in gather_node_pop_name_mapping[pop_name]: + if parent not in node_pop_name_mapping.get(pop_name, []): + node_pop_name_mapping[pop_name].append(parent) + else: + # Non-overlapping: just move to main id_mapping + id_mapping[pop_name] = gather_df + node_pop_name_mapping[pop_name] = gather_node_pop_name_mapping.get( + pop_name, [pop_name] + ) + + # Merge edge configs: combine configs with same dst_edge_name into + # a single multi-input config; leave others as-is. + filter_configs_by_name = {cfg.dst_edge_name: cfg for cfg in write_edge_configs_a} + merged_edge_configs = list(write_edge_configs_a) + for cfg_b in write_edge_configs_b: + if cfg_b.dst_edge_name in filter_configs_by_name: + cfg_a = filter_configs_by_name[cfg_b.dst_edge_name] + cfg_a.input_path = cfg_a.input_path + cfg_b.input_path + cfg_a.src_edge_name = cfg_a.src_edge_name + cfg_b.src_edge_name + cfg_a.src_mapping = cfg_a.src_mapping + cfg_b.src_mapping + else: + merged_edge_configs.append(cfg_b) + + merged_nodes_to_write = dict(nodes_to_write_a) + for pop_name, sources in nodes_to_write_b.items(): + if pop_name in merged_nodes_to_write: + merged_nodes_to_write[pop_name] = merged_nodes_to_write[pop_name] + sources + else: + merged_nodes_to_write[pop_name] = sources + + new_virtual_node_files, new_virtual_edge_files = _write_subcircuit( + output, + circuit, + merged_edge_configs, + merged_nodes_to_write, ) new_node_files.update(new_virtual_node_files) new_edge_files.update(new_virtual_edge_files) - mapping_fn = _write_mapping(output, circuit, id_mapping, node_pop_name_mapping) - + mapping_fn = _write_mapping( + output, + circuit, + id_mapping, + node_pop_name_mapping, + id_mapping_secondary if create_external else None, + ) config = copy.deepcopy(circuit.config) node_sets = _update_node_sets(utils.load_json(config["node_sets_file"]), id_mapping) diff --git a/brainbuilder/utils/sonata/visualize.py b/brainbuilder/utils/sonata/visualize.py new file mode 100644 index 0000000..c86be29 --- /dev/null +++ b/brainbuilder/utils/sonata/visualize.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Visualize a SONATA circuit as a graph with populations as clusters.""" + +import json +from collections import Counter +from enum import StrEnum +from pathlib import Path + +import bluepysnap +import h5py + +# Populations with more nodes than this threshold are shown as a single +# cluster node with population-level edges instead of individual nodes. +_MAX_NODES_DETAILED = 10 + + +class PopulationType(StrEnum): + """Classification of a node population for visualization purposes.""" + + BIOPHYSICAL = "biophysical" + VIRTUAL = "virtual" + EXTERNAL = "external" + + @property + def color(self) -> str: + return { + PopulationType.BIOPHYSICAL: "lightyellow", + PopulationType.VIRTUAL: "lightblue", + PopulationType.EXTERNAL: "lightsalmon", + }.get(self, "white") + + @classmethod + def from_population(cls, pop_name: str, pop_type: str | None) -> "PopulationType": + if pop_name.startswith("external_"): + return cls.EXTERNAL + if pop_type == "virtual": + return cls.VIRTUAL + return cls.BIOPHYSICAL + + +def _load_id_mapping(circuit_config_path): + """Load id_mapping from circuit provenance if available. + + Returns: + dict: pop_name -> list of parent_ids, or None if no mapping exists. + """ + config = json.loads(Path(circuit_config_path).read_text()) + mapping_path = config.get("components", {}).get("provenance", {}).get("id_mapping") + if not mapping_path: + return None + + mapping_file = Path(circuit_config_path).parent / mapping_path + if not mapping_file.exists(): + return None + + mapping = json.loads(mapping_file.read_text()) + return mapping + + +def draw_circuit( + circuit_config_path, output_path=None, max_nodes_detailed=_MAX_NODES_DETAILED, title=None +): + """Draw a SONATA circuit using graphviz with populations as clusters. + + For small populations (<=max_nodes_detailed), individual nodes and edges + are shown. For large populations, a single summary node is shown with + population-level edges. Duplicate edges are collapsed with a count label. + + If an id_mapping exists in provenance, node labels show the parent (original) + IDs instead of the local IDs. External populations get a distinct color. + + Args: + circuit_config_path: Path to circuit_config.json. + output_path: If provided, save the rendered image to this path. + Otherwise, render to a temp file and open it. + max_nodes_detailed: Populations with more nodes than this are shown + as a single summary node. + + Requires: + pip install brainbuilder[viz] + System graphviz: brew install graphviz (macOS) / apt install graphviz (Linux) + """ + try: + import graphviz + except ImportError as e: + raise ImportError( + "graphviz Python package is required for visualization. " + "Install with: pip install brainbuilder[viz]\n" + "Also requires system graphviz: brew install graphviz" + ) from e + + circuit = bluepysnap.Circuit(str(circuit_config_path)) + id_mapping = _load_id_mapping(circuit_config_path) + + dot = graphviz.Digraph("circuit", format="png") + dot.attr(rankdir="LR") + dot.attr("node", shape="circle", fontsize="9", width="0.3", height="0.3") + if title: + dot.attr(label=title, labelloc="t", fontsize="14") + + detailed_pops = set() + + for pop_name, pop in circuit.nodes.items(): + pop_type = PopulationType.from_population(pop_name, pop.type) + + # Get original IDs for labels if mapping exists + parent_ids = None + if id_mapping and pop_name in id_mapping: + entry = id_mapping[pop_name] + parent_ids = entry.get("original_id", entry.get("parent_id")) + + if pop.size <= max_nodes_detailed: + detailed_pops.add(pop_name) + with dot.subgraph(name=f"cluster_{pop_name}") as sub: + sub.attr( + label=f"{pop_name} ({pop_type}, {pop.size})", + style="filled", + color=pop_type.color, + ) + prev = None + for i in range(pop.size): + label = str(parent_ids[i]) if parent_ids else str(i) + sub.node(f"{pop_name}__{i}", label=f"<{label}>") + if prev is not None: + sub.edge(prev, f"{pop_name}__{i}", style="invis", weight="10") + prev = f"{pop_name}__{i}" + else: + dot.node( + f"{pop_name}__summary", + label=f"{pop_name}\n({pop_type}, {pop.size})", + shape="box", + style="filled", + fillcolor=pop_type.color, + ) + + # Edges — group duplicates and show count + for edge_name, edge in circuit.edges.items(): + src_name = edge.source.name + tgt_name = edge.target.name + src_detailed = src_name in detailed_pops + tgt_detailed = tgt_name in detailed_pops + + with h5py.File(edge.h5_filepath, "r") as h5: + sgids = h5[f"edges/{edge_name}/source_node_id"][:] + tgids = h5[f"edges/{edge_name}/target_node_id"][:] + + if src_detailed and tgt_detailed: + edge_counts = Counter(zip(sgids.tolist(), tgids.tolist())) + for (s, t), count in edge_counts.items(): + attrs = {} + if count > 1: + attrs["label"] = str(count) + attrs["fontsize"] = "8" + dot.edge(f"{src_name}__{s}", f"{tgt_name}__{t}", **attrs) + else: + src_node = f"{src_name}__summary" if not src_detailed else f"{src_name}__{sgids[0]}" + tgt_node = f"{tgt_name}__summary" if not tgt_detailed else f"{tgt_name}__{tgids[0]}" + dot.edge(src_node, tgt_node, label=str(len(sgids)), fontsize="8") + + if output_path: + dot.render(outfile=output_path, cleanup=True) + else: + import tempfile + + filename = title.replace(" ", "_") if title else "circuit" + filepath = Path(tempfile.gettempdir()) / filename + dot.render(filename=str(filepath), view=True, cleanup=True) diff --git a/pyproject.toml b/pyproject.toml index f099150..6f0b5d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ include = ["brainbuilder*"] [project.optional-dependencies] all = [] # for compatibility reindex = [] # for compatibility +viz = ["graphviz"] [project.urls] Homepage = "https://openbraininstitute/brainbuilder" diff --git a/tests/unit/test_sonata/test_split_population.py b/tests/unit/test_sonata/test_split_population.py index 6e3bab2..41b2bd6 100644 --- a/tests/unit/test_sonata/test_split_population.py +++ b/tests/unit/test_sonata/test_split_population.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import shutil from pathlib import Path import bluepysnap @@ -9,7 +10,7 @@ import utils from numpy.testing import assert_array_equal -from brainbuilder.utils import load_json +from brainbuilder.utils import load_json, dump_json from brainbuilder.utils.sonata import split_population from brainbuilder.utils.sonata import utils as sonata_utils @@ -882,13 +883,10 @@ def test_copy_filtered_edges_advanced(tmp_path): h5_read_chunk_size=3, # FORCE chunking edge_type="synapse_astrocyte" ) - with h5py.File(write_edge_config.input_path, "r") as h5in, h5py.File(write_edge_config.output_path, "w") as h5out: - split_population._copy_filtered_edges( - h5in=h5in, - h5out=h5out, - write_edge_config=write_edge_config, - edge_mappings=edge_mappings - ) + split_population._copy_filtered_edges( + write_edge_config=write_edge_config, + edge_mappings=edge_mappings + ) # Verification keep, new_name = edge_mappings["orig_src_edge_pop"] @@ -919,3 +917,313 @@ def test_copy_filtered_edges_advanced(tmp_path): expected_syn_ids = [1, 2] np.testing.assert_array_equal(syn_ids, expected_syn_ids) assert out_edges["0"]["synapse_id"].attrs["edge_population"] == "new_biophysical_edge_pop" + + +def _assert_circuits_equal(path_a, path_b, strict_order=False): + """Assert two extracted circuits are equal using original IDs. + + Checks that they have the same populations with the same original nodes, + and the same edges (translated to original IDs). + + Args: + path_a: Path to first circuit directory. + path_b: Path to second circuit directory. + strict_order: If True, require same node and edge ordering. + If False (default), allow reordering. + """ + circ_a = bluepysnap.Circuit(str(Path(path_a) / "circuit_config.json")) + circ_b = bluepysnap.Circuit(str(Path(path_b) / "circuit_config.json")) + mapping_a = load_json(Path(path_a) / "id_mapping.json") + mapping_b = load_json(Path(path_b) / "id_mapping.json") + + # Same node populations + assert set(circ_a.nodes.keys()) == set(circ_b.nodes.keys()), ( + f"Node populations differ: {set(circ_a.nodes.keys())} vs {set(circ_b.nodes.keys())}" + ) + + # Same original IDs per population + for pop_name in circ_a.nodes.keys(): + orig_a = mapping_a[pop_name]["original_id"] + orig_b = mapping_b[pop_name]["original_id"] + if strict_order: + assert orig_a == orig_b, ( + f"Population '{pop_name}' original_ids differ: {orig_a} vs {orig_b}" + ) + else: + assert sorted(orig_a) == sorted(orig_b), ( + f"Population '{pop_name}' original_ids differ: " + f"{sorted(orig_a)} vs {sorted(orig_b)}" + ) + assert mapping_a[pop_name]["original_name"] == mapping_b[pop_name]["original_name"], ( + f"Population '{pop_name}' original_name differs" + ) + + # Same edge populations + assert set(circ_a.edges.keys()) == set(circ_b.edges.keys()), ( + f"Edge populations differ: {set(circ_a.edges.keys())} vs {set(circ_b.edges.keys())}" + ) + + # Same edges (translated to original IDs) + for edge_name in circ_a.edges.keys(): + edge_a = circ_a.edges[edge_name] + edge_b = circ_b.edges[edge_name] + src_pop = edge_a.source.name + tgt_pop = edge_a.target.name + + # Build new_id -> original_id lookup for each circuit + orig_src_a = dict(zip(mapping_a[src_pop]["new_id"], mapping_a[src_pop]["original_id"])) + orig_tgt_a = dict(zip(mapping_a[tgt_pop]["new_id"], mapping_a[tgt_pop]["original_id"])) + orig_src_b = dict(zip(mapping_b[src_pop]["new_id"], mapping_b[src_pop]["original_id"])) + orig_tgt_b = dict(zip(mapping_b[tgt_pop]["new_id"], mapping_b[tgt_pop]["original_id"])) + + with h5py.File(edge_a.h5_filepath, "r") as h5: + sgids_a = h5[f"edges/{edge_name}/source_node_id"][:] + tgids_a = h5[f"edges/{edge_name}/target_node_id"][:] + with h5py.File(edge_b.h5_filepath, "r") as h5: + sgids_b = h5[f"edges/{edge_name}/source_node_id"][:] + tgids_b = h5[f"edges/{edge_name}/target_node_id"][:] + + edges_a = [(orig_src_a[int(s)], orig_tgt_a[int(t)]) for s, t in zip(sgids_a, tgids_a)] + edges_b = [(orig_src_b[int(s)], orig_tgt_b[int(t)]) for s, t in zip(sgids_b, tgids_b)] + + if strict_order: + assert edges_a == edges_b, ( + f"Edge population '{edge_name}' differs:\n {edges_a}\n vs\n {edges_b}" + ) + else: + assert sorted(edges_a) == sorted(edges_b), ( + f"Edge population '{edge_name}' differs:\n " + f"{sorted(edges_a)}\n vs\n {sorted(edges_b)}" + ) + + +def _split_custom_subcircuit(output, circuit_config, node_set_name, node_set_def, + do_virtual=False, create_external=False): + """Run split_subcircuit after injecting a custom node_set into a copy of the circuit. + + Copies the circuit directory to a sibling of `output`, injects `node_set_def` + into the node_sets.json, and runs the extraction. + + Returns: + Path to the output directory (same as `output`). + """ + fixture = output.parent / (output.name + "_fixture") + shutil.copytree(Path(circuit_config).parent, fixture) + + node_sets = load_json(fixture / "node_sets.json") + node_sets.update(node_set_def) + dump_json(fixture / "node_sets.json", node_sets) + + split_population.split_subcircuit( + output, node_set_name, str(fixture / "circuit_config.json"), + do_virtual=do_virtual, create_external=create_external + ) + return output + + +def test_subsubcircuit_virtual_operates_on_virtuals_only(tmp_path): + """Test that do_virtual skips external populations in nested extraction. + + Extracts c2_c1 from c1 with do_virtual=True, create_external=True (so c2_c1 has external_A). + Then extracts c3_c2_c1 from c2_c1 with do_virtual=True, create_external=False. + Verifies that do_virtual does NOT process external_A (only genuine virtuals like V1). + Compares c3_c2_c1 against c3_c1 (direct extraction from c1) for equivalence. + """ + subset_c2_c1 = { + "subset_c2_c1": ["subset_c2_c1_popA", "subset_c2_c1_popB", "subset_c2_c1_popC"], + "subset_c2_c1_popA": {"population": "A", "node_id": [1, 5]}, + "subset_c2_c1_popB": {"population": "B", "node_id": [0, 1, 2, 3, 4, 5]}, + "subset_c2_c1_popC": {"population": "C", "node_id": [0, 1, 2, 3, 4, 5]}, + } + subset_c3_c2_c1 = { + "subset_c3_c2_c1": ["subset_c3_c2_c1_popA", "subset_c3_c2_c1_popB", "subset_c3_c2_c1_popC"], + "subset_c3_c2_c1_popA": {"population": "A", "node_id": [0, 1]}, + "subset_c3_c2_c1_popB": {"population": "B", "node_id": [1, 2, 3, 4, 5]}, + "subset_c3_c2_c1_popC": {"population": "C", "node_id": [0, 1, 3, 4, 5]}, + } + subset_c3_c1 = { + "subset_c3_c1": ["subset_c3_c1_popA", "subset_c3_c1_popB", "subset_c3_c1_popC"], + "subset_c3_c1_popA": {"population": "A", "node_id": [1, 5]}, + "subset_c3_c1_popB": {"population": "B", "node_id": [1, 2, 3, 4, 5]}, + "subset_c3_c1_popC": {"population": "C", "node_id": [0, 1, 3, 4, 5]}, + } + + circuit_config = str(SPLIT_SUBCIRCUIT_DATA_PATH / "circuit_config.json") + + path_c2_c1 = _split_custom_subcircuit( + tmp_path / "c2_c1", circuit_config, "subset_c2_c1", subset_c2_c1, + do_virtual=True, create_external=True + ) + + path_c3_c2_c1 = _split_custom_subcircuit( + tmp_path / "c3_c2_c1", str(path_c2_c1 / "circuit_config.json"), + "subset_c3_c2_c1", subset_c3_c2_c1, + do_virtual=True, create_external=False + ) + + path_c3_c1 = _split_custom_subcircuit( + tmp_path / "c3_c1", circuit_config, + "subset_c3_c1", subset_c3_c1, + do_virtual=True, create_external=False + ) + + # --- Assertions for c3_c2_c1 --- + circ_c_b_a = bluepysnap.Circuit(str(path_c3_c2_c1 / "circuit_config.json")) + node_pop_names = set(circ_c_b_a.nodes.keys()) + + # do_virtual should NOT have processed external_A (it's not a genuine virtual) + assert "external_A" not in node_pop_names, ( + "external_A should not be extracted by do_virtual" + ) + + # Genuine virtuals should be processed correctly + assert "V1" in node_pop_names, "V1 (genuine virtual) should be extracted" + assert "V2" not in node_pop_names, "V2 should be dropped (its target is outside c3_c2_c1)" + + # c3_c1 and c3_c2_c1 should be equivalent + _assert_circuits_equal(path_c3_c1, path_c3_c2_c1, True) + + +def test_subsubcircuit_externals_merge(tmp_path): + """Nested extraction with create_external merges external populations correctly. + + Verifies that extracting c3 from c2 (from c1) produces the same result as + extracting c3 directly from c1, up to 3 levels of nesting (c4). + """ + ### Nomenclature: + # X_Y: circuit X derived from parent Y. + # Nodes removed from a biophysical population either become external + # (if they have at least one edge targeting a kept node) or are discarded. + # Virtual nodes are kept only if they have at least one edge to a kept + # biophysical node; otherwise they are dropped. + + # c2_c1 from A: keep A:{1,2,3,5}, B:all, C:all + # - external_A: {0, 4} + # - discarded: A:{} (none) + # - V1 kept: {1,2}. V1 dropped: {0,3} + # - V2 kept: {0} + subset_c2_c1 = { + "subset_c2_c1": ["subset_c2_c1_popA", "subset_c2_c1_popB", "subset_c2_c1_popC"], + "subset_c2_c1_popA": {"population": "A", "node_id": [1, 2, 3, 5]}, + "subset_c2_c1_popB": {"population": "B", "node_id": [0, 1, 2, 3, 4, 5]}, + "subset_c2_c1_popC": {"population": "C", "node_id": [0, 1, 2, 3, 4, 5]}, + } + # c3_c1 from A: keep A:{1,2,3}, B:{1,2,3,4,5}, C:{0,1,2,5} + # - external_A: {0, 5}. Discarded A: {4} + # - external_C: {3}. Discarded C: {4} + # - discarded B: {0} + # - V1 kept: {1}. V1 dropped: {2} + # - V2 kept: {0} + subset_c3_c1 = { + "subset_c3_c1": ["subset_c3_c1_popA", "subset_c3_c1_popB", "subset_c3_c1_popC"], + "subset_c3_c1_popA": {"population": "A", "node_id": [1, 2, 3]}, + "subset_c3_c1_popB": {"population": "B", "node_id": [1, 2, 3, 4, 5]}, + "subset_c3_c1_popC": {"population": "C", "node_id": [0, 1, 2, 5]}, + } + # c4_c1 from A: keep A:{1,2}, B:{1,3,4}, C:{0,1,2,5} + # - external_A: {3, 5}. Discarded A: {0, 4} + # - external_B: {2}. Discarded B: {0, 5} + # - external_C: {3}. Discarded C: {4} + # - V1 kept: {1}. V1 dropped: {2} + # - V2 kept: {0} + subset_c4_c1 = { + "subset_c4_c1": ["subset_c4_c1_popA", "subset_c4_c1_popB", "subset_c4_c1_popC"], + "subset_c4_c1_popA": {"population": "A", "node_id": [1, 2]}, + "subset_c4_c1_popB": {"population": "B", "node_id": [1, 3, 4]}, + "subset_c4_c1_popC": {"population": "C", "node_id": [0, 1, 2, 5]}, + } + # c3_c2_c1 from c2_c1 (parent has A:local{0,1,2,3}=orig{1,2,3,5}, external_A:orig{0,4}): + # - Remove A:local{2} (orig 3) -> merges into external_A + # - Remove B:{0} -> discarded + # - Remove C:{3,4} -> C:3 external_C, C:4 discarded + # Result should equal c3_c1 + subset_c3_c2_c1 = { + "subset_c3_c2_c1": ["subset_c3_c2_c1_popA", "subset_c3_c2_c1_popB", "subset_c3_c2_c1_popC"], + "subset_c3_c2_c1_popA": {"population": "A", "node_id": [0, 1, 2]}, + "subset_c3_c2_c1_popB": {"population": "B", "node_id": [1, 2, 3, 4, 5]}, + "subset_c3_c2_c1_popC": {"population": "C", "node_id": [0, 1, 2, 5]}, + } + # c4_c2_c1 from c2_c1 (parent has A:local{0,1,2,3}=orig{1,2,3,5}, external_A:orig{0,4}): + # - Remove A:local{1,2} (orig 2,3) -> A:2(orig 3) merges into external_A, A:1(orig 2) discarded + # - Remove B:{0,2,5} -> B:2 external_B, rest discarded + # - Remove C:{3,4} -> C:3 external_C, C:4 discarded + # Result should equal c4_c1 + subset_c4_c2_c1 = { + "subset_c4_c2_c1": ["subset_c4_c2_c1_popA", "subset_c4_c2_c1_popB", "subset_c4_c2_c1_popC"], + "subset_c4_c2_c1_popA": {"population": "A", "node_id": [0, 1]}, + "subset_c4_c2_c1_popB": {"population": "B", "node_id": [1, 3, 4]}, + "subset_c4_c2_c1_popC": {"population": "C", "node_id": [0, 1, 2, 5]}, + } + # c4_c3_c1 from c3_c1 (parent has A:local{0,1,2}=orig{1,2,3}, external_A:orig{0,5}): + # - Remove A:local{2} (orig 3) -> merges into external_A + # - Remove B:local{1,4} (orig 2,5) -> B:1(orig 2) external_B, B:4(orig 5) discarded + # Result should equal c4_c1 + subset_c4_c3_c1 = { + "subset_c4_c3_c1": ["subset_c4_c3_c1_popA", "subset_c4_c3_c1_popB", "subset_c4_c3_c1_popC"], + "subset_c4_c3_c1_popA": {"population": "A", "node_id": [0, 1]}, + "subset_c4_c3_c1_popB": {"population": "B", "node_id": [0, 2, 3]}, + "subset_c4_c3_c1_popC": {"population": "C", "node_id": [0, 1, 2, 3]}, + } + # c4_c3_c2_c1 from c3_c2_c1 (same original content as c3_c1): + # - Same removals as c4_c3_c1. Result should equal c4_c1 + subset_c4_c3_c2_c1 = { + "subset_c4_c3_c2_c1": ["subset_c4_c3_c2_c1_popA", "subset_c4_c3_c2_c1_popB", "subset_c4_c3_c2_c1_popC"], + "subset_c4_c3_c2_c1_popA": {"population": "A", "node_id": [0, 1]}, + "subset_c4_c3_c2_c1_popB": {"population": "B", "node_id": [0, 2, 3]}, + "subset_c4_c3_c2_c1_popC": {"population": "C", "node_id": [0, 1, 2, 3]}, + } + + + + + circuit_config = str(SPLIT_SUBCIRCUIT_DATA_PATH / "circuit_config.json") + + path_c2_c1 = _split_custom_subcircuit( + tmp_path / "c2_c1", circuit_config, "subset_c2_c1", subset_c2_c1, + do_virtual=True, create_external=True + ) + + path_c3_c2_c1 = _split_custom_subcircuit( + tmp_path / "c3_c2_c1", str(path_c2_c1 / "circuit_config.json"), + "subset_c3_c2_c1", subset_c3_c2_c1, + do_virtual=True, create_external=True + ) + + path_c3_c1 = _split_custom_subcircuit( + tmp_path / "c3_c1", circuit_config, + "subset_c3_c1", subset_c3_c1, + do_virtual=True, create_external=True + ) + + path_c4_c1 = _split_custom_subcircuit( + tmp_path / "c4_c1", circuit_config, + "subset_c4_c1", subset_c4_c1, + do_virtual=True, create_external=True + ) + + path_c4_c2_c1 = _split_custom_subcircuit( + tmp_path / "c4_c2_c1", str(path_c2_c1 / "circuit_config.json"), + "subset_c4_c2_c1", subset_c4_c2_c1, + do_virtual=True, create_external=True + ) + + path_c4_c3_c1 = _split_custom_subcircuit( + tmp_path / "c4_c3_c1", str(path_c3_c1 / "circuit_config.json"), + "subset_c4_c3_c1", subset_c4_c3_c1, + do_virtual=True, create_external=True + ) + + path_c4_c3_c2_c1 = _split_custom_subcircuit( + tmp_path / "c4_c3_c2_c1", str(path_c3_c2_c1 / "circuit_config.json"), + "subset_c4_c3_c2_c1", subset_c4_c3_c2_c1, + do_virtual=True, create_external=True + ) + + + # All C circuits should be equal + _assert_circuits_equal(path_c3_c1, path_c3_c2_c1) + + # All D circuits should be equal + _assert_circuits_equal(path_c4_c1, path_c4_c2_c1) + _assert_circuits_equal(path_c4_c1, path_c4_c3_c1) + _assert_circuits_equal(path_c4_c1, path_c4_c3_c2_c1) diff --git a/tox.ini b/tox.ini index 244cbb6..7c67e2e 100644 --- a/tox.ini +++ b/tox.ini @@ -85,6 +85,6 @@ convention = google [gh-actions] python = - 3.10: py310, lint, coverage + 3.10: py310, coverage 3.11: py311, check-packaging - 3.12: py312, docs + 3.12: py312, lint, docs diff --git a/viz_circuits.sh b/viz_circuits.sh new file mode 100755 index 0000000..7eb7d47 --- /dev/null +++ b/viz_circuits.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Viz all circuits from test_subsubcircuit_externals_merge +# Usage: ./viz_circuits.sh + +set -e + +BASE="removeme/test_subsubcircuit_externals_m0" +BB="venv/bin/brainbuilder" + +rm -rf removeme +venv/bin/tox -e py312 -- tests/unit/test_sonata/test_split_population.py::test_subsubcircuit_externals_merge --basetemp=removeme -vv + +$BB sonata visualize tests/unit/data/sonata/split_subcircuit/circuit_config.json --title "Original (A)" + +for dir in "$BASE"/*/; do + name=$(basename "$dir") + [[ "$name" == *_fixture ]] && continue + config="$dir/circuit_config.json" + if [ -f "$config" ]; then + $BB sonata visualize "$config" --title "$name" + fi +done