diff --git a/partridge/gtfs.py b/partridge/gtfs.py index 62b91c5..16ed0f5 100644 --- a/partridge/gtfs.py +++ b/partridge/gtfs.py @@ -1,6 +1,7 @@ import os from threading import RLock from typing import Dict, Optional, Union +from warnings import warn import networkx as nx import numpy as np @@ -34,9 +35,13 @@ def __init__( self._locks: Dict[str, RLock] = {} if isinstance(source, self.__class__): self._read = source.get + self._proxy_feed = bool(self._view) elif isinstance(source, str) and os.path.isdir(source): self._read = self._read_csv self._bootstrap(source) + self._proxy_feed = True + # Validate the configuration and raise warning if needed + self._validate_dependencies_conversion() else: raise ValueError("Invalid source") @@ -46,11 +51,15 @@ def get(self, filename: str) -> pd.DataFrame: df = self._cache.get(filename) if df is None: df = self._read(filename) - df = self._filter(filename, df) - df = self._prune(filename, df) - self._convert_types(filename, df) - df = df.reset_index(drop=True) - df = self._transform(filename, df) + if self._proxy_feed: + # files feed responsible for file access + df = self._filter(filename, df) + df = self._prune(filename, df) + df = df.reset_index(drop=True) + else: + # proxy feed responsible for data conversion + self._convert_types(filename, df) + df = self._transform(filename, df) self.set(filename, df) return self._cache[filename] @@ -95,7 +104,7 @@ def _read_csv(self, filename: str) -> pd.DataFrame: # DataFrame containing any required columns. return empty_df(columns) - # If the file isn't in the zip, return an empty DataFrame. + # Read file encoding with open(path, "rb") as f: encoding = detect_encoding(f) @@ -121,7 +130,6 @@ def _filter(self, filename: str, df: pd.DataFrame) -> pd.DataFrame: # If applicable, filter this dataframe by the given set of values if col in df.columns: df = df[df[col].isin(setwrap(values))] - return df def _prune(self, filename: str, df: pd.DataFrame) -> pd.DataFrame: @@ -147,10 +155,44 @@ def _prune(self, filename: str, df: pd.DataFrame) -> pd.DataFrame: depcol = deps[depfile] # If applicable, prune this dataframe by the other if col in df.columns and depcol in depdf.columns: - df = df[df[col].isin(depdf[depcol])] + converter = self._get_convert_function(filename, col) + # Convert the column before pruning since depdf is already converted + col_series = converter(df[col]) if converter else df[col] + df = df[col_series.isin(depdf[depcol])] return df + def _get_convert_function(self, filename, colname): + """return the convert function from the config + for a specific file and column""" + return self._config.nodes.get(filename, {}).get("converters", {}).get(colname) + + def _validate_dependencies_conversion(self): + """Validate that dependent columns in different files + has the same convert function if one exist. + """ + + def check_column_pair(column_pair: dict) -> bool: + assert len(column_pair) == 2 + convert_funcs = [ + self._get_convert_function(filename, colname) + for filename, colname in column_pair.items() + ] + if convert_funcs[0] != convert_funcs[1]: + return False + return True + + for file_a, file_b, data in self._config.edges(data=True): + dependencies = data.get("dependencies", []) + for column_pair in dependencies: + if check_column_pair(column_pair): + continue + warn( + f"Converters Mismatch: column `{column_pair[file_a]}` in {file_a} " + f"is dependant on column `{column_pair[file_b]}` in {file_b} " + f"but converted with different functions, which might cause merging problems." + ) + def _convert_types(self, filename: str, df: pd.DataFrame) -> None: """ Apply type conversions diff --git a/partridge/readers.py b/partridge/readers.py index 1fdcb4e..54bf924 100644 --- a/partridge/readers.py +++ b/partridge/readers.py @@ -13,8 +13,6 @@ from .gtfs import Feed from .parsers import vparse_date from .types import View -from .utilities import remove_node_attributes - DAY_NAMES = ( "monday", @@ -105,10 +103,9 @@ def finalize() -> None: def _load_feed(path: str, view: View, config: nx.DiGraph) -> Feed: """Multi-file feed filtering""" - config_ = remove_node_attributes(config, ["converters", "transformations"]) - feed_ = Feed(path, view={}, config=config_) + feed_ = Feed(path, view={}, config=config) for filename, column_filters in view.items(): - config_ = reroot_graph(config_, filename) + config_ = reroot_graph(config, filename) view_ = {filename: column_filters} feed_ = Feed(feed_, view=view_, config=config_) return Feed(feed_, config=config) diff --git a/tests/test_feed.py b/tests/test_feed.py index a69a01f..a0b06fd 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -1,4 +1,6 @@ import datetime + +import pandas as pd import pytest import partridge as ptg @@ -225,3 +227,18 @@ def test_filtered_columns(path): assert set(feed_full.trips.columns) == set(feed_view.trips.columns) assert set(feed_full.trips.columns) == set(feed_null.trips.columns) + + +@pytest.mark.parametrize("path", [fixture("amazon-2017-08-06")]) +def test_converted_id_column(path): + conf = default_config() + conf.nodes["routes.txt"]["converters"]["route_id"] = pd.to_numeric + with pytest.warns(UserWarning, match="Converters Mismatch:"): + ptg.load_feed(path, config=conf) + conf.nodes["trips.txt"]["converters"]["route_id"] = pd.to_numeric + # Just to prevent another warning + conf.nodes["fare_rules.txt"]["converters"] = {} + conf.nodes["fare_rules.txt"]["converters"]["route_id"] = pd.to_numeric + feed = ptg.load_feed(path, config=conf) + assert len(feed.trips) > 0 + assert len(feed.routes) > 0