diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..d4ac666 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed: + - Handle dataframe sum axis properly. diff --git a/microdf/microdataframe.py b/microdf/microdataframe.py index 1fbeab5..7ad18c8 100644 --- a/microdf/microdataframe.py +++ b/microdf/microdataframe.py @@ -37,24 +37,42 @@ def override_df_functions(self) -> None: setattr(self, name, self._create_agnostic_function(name)) def _create_scalar_function(self, name: str) -> Callable: - """Create a scalar function that returns a Series of results. + def fn(*args, **kwargs) -> Union[pd.Series, float]: + axis = kwargs.get("axis", 0) + # Remove axis from kwargs since MicroSeries doesn't use it + ms_kwargs = {k: v for k, v in kwargs.items() if k != "axis"} - :param name: Name of the function to create - :return: Function that applies the operation to all columns - """ + if axis == 0 or axis == "index": + # Column-wise: use MicroSeries methods + results = {} + for col in self.columns: + if pd.api.types.is_numeric_dtype(self[col]): + try: + results[col] = getattr(self[col], name)( + *args, **ms_kwargs + ) + except Exception: + pass + return pd.Series(results) - def fn(*args, **kwargs) -> pd.Series: - results = {} - for col in self.columns: - if pd.api.types.is_numeric_dtype(self[col]): - try: - results[col] = getattr(self[col], name)( - *args, **kwargs - ) - except Exception: - # Skip columns that can't be aggregated - pass - return pd.Series(results) + elif axis == 1 or axis == "columns": + # Row-wise: use pandas DataFrame methods + numeric_cols = [ + col + for col in self.columns + if pd.api.types.is_numeric_dtype(self[col]) + ] + if numeric_cols: + # Create regular DataFrame and call its method + df = pd.DataFrame( + {col: self[col].values for col in numeric_cols}, + index=self.index, + ) + return getattr(df, name)(axis=1, *args, **ms_kwargs) + return pd.Series(dtype="float64", index=self.index) + + else: + raise ValueError(f"Invalid axis: {axis}") return fn diff --git a/microdf/tests/test_microseries_dataframe.py b/microdf/tests/test_microseries_dataframe.py index 6987280..ad9ad25 100644 --- a/microdf/tests/test_microseries_dataframe.py +++ b/microdf/tests/test_microseries_dataframe.py @@ -287,3 +287,64 @@ def test_reset_index_inplace() -> None: assert "second" in mdf_multi.columns assert list(mdf_multi.index) == [0, 1, 2, 3] np.testing.assert_array_equal(mdf_multi.weights.values, weights) + + +def test_sum_axis_1() -> None: + # Test basic row-wise sum + df = mdf.MicroDataFrame( + {"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}, + weights=[0.5, 1.0, 2.0], + ) + + # Row-wise sum (axis=1) should not use weights + row_sums = df.sum(axis=1) + expected = pd.Series([12, 15, 18], index=df.index) # 1+4+7, 2+5+8, 3+6+9 + pd.testing.assert_series_equal(row_sums, expected) + + # Column-wise sum (axis=0) should use weights + col_sums = df.sum(axis=0) + expected_weighted = pd.Series( + { + "A": 1 * 0.5 + 2 * 1.0 + 3 * 2.0, # 8.5 + "B": 4 * 0.5 + 5 * 1.0 + 6 * 2.0, # 19.0 + "C": 7 * 0.5 + 8 * 1.0 + 9 * 2.0, # 29.5 + } + ) + pd.testing.assert_series_equal(col_sums, expected_weighted) + + # Test with mixed types (non-numeric columns should be ignored) + df_mixed = mdf.MicroDataFrame( + {"A": [1, 2, 3], "B": [4, 5, 6], "text": ["a", "b", "c"]}, + weights=[1, 1, 1], + ) + + row_sums_mixed = df_mixed.sum(axis=1) + expected_mixed = pd.Series([5, 7, 9], index=df_mixed.index) # Only A+B + pd.testing.assert_series_equal(row_sums_mixed, expected_mixed) + + # Test with axis='columns' (string form) + row_sums_str = df.sum(axis="columns") + pd.testing.assert_series_equal(row_sums_str, expected) + + # Test with additional parameters + df_with_nan = mdf.MicroDataFrame( + {"A": [1, np.nan, 3], "B": [4, 5, 6], "C": [7, 8, np.nan]}, + weights=[1, 1, 1], + ) + + # skipna=True (default) + row_sums_skipna = df_with_nan.sum(axis=1) + expected_skipna = pd.Series([12.0, 13.0, 9.0]) # NaN values skipped + pd.testing.assert_series_equal(row_sums_skipna, expected_skipna) + + # skipna=False + row_sums_no_skipna = df_with_nan.sum(axis=1, skipna=False) + expected_no_skipna = pd.Series([12.0, np.nan, np.nan]) # NaN propagates + pd.testing.assert_series_equal(row_sums_no_skipna, expected_no_skipna) + + # Test min_count parameter + row_sums_min_count = df_with_nan.sum(axis=1, min_count=3) + expected_min_count = pd.Series( + [12.0, np.nan, np.nan] + ) # Row 1 and 2 have < 3 non-NA values + pd.testing.assert_series_equal(row_sums_min_count, expected_min_count)