Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ disable = "R1705"

[tool.pyright]
reportGeneralTypeIssues = false
reportCallIssue = false
reportOptionalMemberAccess = false
reportArgumentType = false
reportOptionalSubscript = false

[tool.mypy] # Static type checker
check_untyped_defs = true
Expand Down
46 changes: 27 additions & 19 deletions src/pathpyG/core/graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
Tuple,
List,
Union,
Any,
Optional,
Generator,
)

import numpy as np
Expand All @@ -19,10 +17,8 @@
import torch_geometric.utils
from torch_geometric import EdgeIndex
from torch_geometric.data import Data
from torch_geometric.transforms.to_undirected import ToUndirected
from torch_geometric.utils import scatter
from torch_geometric.utils import scatter, to_undirected

from pathpyG.utils.config import config
from pathpyG.core.index_map import IndexMap


Expand Down Expand Up @@ -180,13 +176,10 @@ def from_edge_list(
return Graph(Data(edge_index=edge_index, num_nodes=num_nodes), mapping=mapping)

def to_undirected(self) -> Graph:
"""
Returns an undirected version of a directed graph.
"""Return an undirected version of this directed graph.

This method transforms the current graph instance into an undirected graph by
adding all directed edges in opposite direction. It applies [`ToUndirected`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.ToUndirected.html#torch_geometric.transforms.ToUndirected)
transform to the underlying [`torch_geometric.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) object, which automatically
duplicates edge attributes for newly created directed edges.
This method creates a new undirected Graph from the current graph instance by
adding all directed edges in opposite direction.

Examples:
>>> import pathpyG as pp
Expand All @@ -195,15 +188,30 @@ def to_undirected(self) -> Graph:
>>> print(g_u)
Undirected graph with 3 nodes and 6 (directed) edges
"""
tf = ToUndirected()
d = tf(self.data)
# unfortunately, the application of a transform creates a new edge_index of type tensor
# so we have to recreate the EdgeIndex tensor and sort it again
# create undirected edge index by coalescing the directed edges and keep
# track of the original edge index for the edge attributes
attr_idx = torch.arange(self.data.num_edges, device=self.data.edge_index.device)
edge_index, attr_idx = to_undirected(
self.data.edge_index,
edge_attr=attr_idx,
num_nodes=self.data.num_nodes,
reduce="min",
)

e = EdgeIndex(data=d.edge_index, sparse_size=(self.data.num_nodes, self.data.num_nodes), is_undirected=True)
d.edge_index = e
d.num_nodes = self.data.num_nodes
return Graph(d, self.mapping)
data = Data(
edge_index=EdgeIndex(data=edge_index, sparse_size=(self.data.num_nodes, self.data.num_nodes), is_undirected=True),
num_nodes=self.data.num_nodes
)
# Note that while the torch_geometric.transforms.ToUndirected function would do this automatically,
# we do it manually since the transform cannot handle numpy arrays as edge attributes.
# make sure to copy all node and (undirected) edge attributes
for node_attr in self.node_attrs():
data[node_attr] = self.data[node_attr]
for edge_attr in self.edge_attrs():
if edge_attr != "edge_index":
data[edge_attr] = self.data[edge_attr][attr_idx]

return Graph(data, self.mapping)

def to_weighted_graph(self) -> Graph:
"""Coalesces multi-edges to single-edges with an additional weight attribute
Expand Down
30 changes: 28 additions & 2 deletions src/pathpyG/core/temporal_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Any, Optional, Generator
from typing import Tuple, Union, Any, Optional, Generator

import numpy as np

Expand Down Expand Up @@ -53,7 +53,10 @@ def __init__(self, data: Data, mapping: IndexMap | None = None) -> None:

# create mapping between edge index and edge tuples
self.edge_to_index = {
(e[0].item(), e[1].item()): i for i, e in enumerate([e for e in self.data.edge_index.t()])
(e[0].item(), e[1].item()): i for i, e in enumerate(self.data.edge_index.t())
}
self.tedge_to_index = {
(e[0].item(), e[1].item(), t.item()): i for i, (e, t) in enumerate(zip([e for e in self.data.edge_index.t()], self.data.time))
}

self.start_time = self.data.time[0].item()
Expand Down Expand Up @@ -163,6 +166,29 @@ def get_window(self, start_time: int, end_time: int) -> TemporalGraph:

return TemporalGraph(data=self.data.snapshot(start_time, end_time), mapping=self.mapping)

def __getitem__(self, key: Union[tuple, str]) -> Any:
"""Return node, edge, temporal edge, or graph attribute.

Args:
key: name of attribute to be returned
"""
if not isinstance(key, tuple):
if key in self.data.keys():
return self.data[key]
else:
raise KeyError(key + " is not a graph attribute")
elif key[0] in self.node_attrs():
return self.data[key[0]][self.mapping.to_idx(key[1])]
elif key[0] in self.edge_attrs():
# TODO: Get item for non-temporal edges will only return the last occurence of the edge
# This is a limitation and should be fixed in the future.
if len(key) == 3:
return self.data[key[0]][self.edge_to_index[self.mapping.to_idx(key[1]), self.mapping.to_idx(key[2])]]
else:
return self.data[key[0]][self.tedge_to_index[self.mapping.to_idx(key[1]), self.mapping.to_idx(key[2]), key[3]]]
else:
raise KeyError(key[0] + " is not a node or edge attribute")

def __str__(self) -> str:
"""
Return a string representation of the graph
Expand Down
2 changes: 1 addition & 1 deletion src/pathpyG/io/netzschleuder.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def read_netzschleuder_graph(

# construct graph and assign edge attributes
if timestamps:
g = df_to_temporal_graph(df=edges, is_undirected=not is_directed, multiedges=multiedges, num_nodes=num_nodes)
g = df_to_temporal_graph(df=edges, multiedges=multiedges, num_nodes=num_nodes)
else:
g = df_to_graph(df=edges, multiedges=multiedges, is_undirected=not is_directed, num_nodes=num_nodes)

Expand Down
Loading
Loading