diff --git a/src/pathpyG/io/netzschleuder.py b/src/pathpyG/io/netzschleuder.py index d4d911e8..64a2c718 100644 --- a/src/pathpyG/io/netzschleuder.py +++ b/src/pathpyG/io/netzschleuder.py @@ -172,7 +172,7 @@ def read_netzschleuder_graph( # construct graph and assign edge attributes if timestamps: - g = df_to_temporal_graph(df=edges, is_undirected=not is_directed, num_nodes=num_nodes) + g = df_to_temporal_graph(df=edges, is_undirected=not is_directed, multiedges=multiedges, num_nodes=num_nodes) else: g = df_to_graph(df=edges, multiedges=multiedges, is_undirected=not is_directed, num_nodes=num_nodes) diff --git a/src/pathpyG/io/pandas.py b/src/pathpyG/io/pandas.py index 03d357bf..51c4f941 100644 --- a/src/pathpyG/io/pandas.py +++ b/src/pathpyG/io/pandas.py @@ -281,7 +281,7 @@ def add_edge_attributes(df: pd.DataFrame, g: Graph, time_attr: str | None = None def df_to_temporal_graph( - df: pd.DataFrame, is_undirected: bool = False, timestamp_format="%Y-%m-%d %H:%M:%S", time_rescale=1, num_nodes: int | None = None + df: pd.DataFrame, is_undirected: bool = False, multiedges: bool = False, timestamp_format="%Y-%m-%d %H:%M:%S", time_rescale=1, num_nodes: int | None = None ) -> TemporalGraph: """Reads a temporal graph from a pandas data frame. @@ -352,6 +352,9 @@ def df_to_temporal_graph( f"Found {df['t'].dtype} instead." ) + if not multiedges: + df = df.drop_duplicates(subset=["v", "w", "t"]) + mapping = IndexMap(node_ids=np.unique(df[["v", "w"]].values)) data = Data( edge_index=mapping.to_idxs(df[["v", "w"]].values.T), diff --git a/tests/io/test_netzschleuder.py b/tests/io/test_netzschleuder.py index ea080158..bfa96982 100644 --- a/tests/io/test_netzschleuder.py +++ b/tests/io/test_netzschleuder.py @@ -116,7 +116,7 @@ def test_read_netzschleuder_graph(): def test_read_netzschleuder_graph_temporal(): """Test the read_netzschleuder_graph() function for timestamped data.""" - g = read_netzschleuder_graph(name="email_company", time_attr="time") + g = read_netzschleuder_graph(name="email_company", time_attr="time", multiedges=True) assert isinstance(g, TemporalGraph) assert g.n == 167 assert g.m == 82927