diff --git a/src/pathpyG/core/temporal_graph.py b/src/pathpyG/core/temporal_graph.py index 415d9426..73030ad5 100644 --- a/src/pathpyG/core/temporal_graph.py +++ b/src/pathpyG/core/temporal_graph.py @@ -63,15 +63,15 @@ def __init__(self, data: Data, mapping: IndexMap | None = None) -> None: self.end_time = self.data.time[-1].item() @staticmethod - def from_edge_list(edge_list, num_nodes: Optional[int] = None) -> TemporalGraph: + def from_edge_list(edge_list, num_nodes: Optional[int] = None) -> TemporalGraph: # type: ignore + """Create a temporal graph from a list of tuples containing edges with timestamps.""" edge_array = np.array(edge_list) - ts = edge_array[:, 2].astype(np.number) # Convert timestamps to tensor - if np.issubdtype(ts.dtype, np.integer): - ts = torch.tensor(ts, dtype=torch.long) + if isinstance(edge_list[0][2], int): + ts = torch.tensor(edge_array[:, 2].astype(np.long)) else: - ts = torch.tensor(ts, dtype=torch.float32) + ts = torch.tensor(edge_array[:, 2].astype(np.double)) index_map = IndexMap(np.unique(edge_array[:, :2])) edge_index = index_map.to_idxs(edge_array[:, :2].T)