Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
changed:
- Handle dataframe sum axis properly.
50 changes: 34 additions & 16 deletions microdf/microdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,42 @@ def override_df_functions(self) -> None:
setattr(self, name, self._create_agnostic_function(name))

def _create_scalar_function(self, name: str) -> Callable:
"""Create a scalar function that returns a Series of results.
def fn(*args, **kwargs) -> Union[pd.Series, float]:
axis = kwargs.get("axis", 0)
# Remove axis from kwargs since MicroSeries doesn't use it
ms_kwargs = {k: v for k, v in kwargs.items() if k != "axis"}

:param name: Name of the function to create
:return: Function that applies the operation to all columns
"""
if axis == 0 or axis == "index":
# Column-wise: use MicroSeries methods
results = {}
for col in self.columns:
if pd.api.types.is_numeric_dtype(self[col]):
try:
results[col] = getattr(self[col], name)(
*args, **ms_kwargs
)
except Exception:
pass
return pd.Series(results)

def fn(*args, **kwargs) -> pd.Series:
results = {}
for col in self.columns:
if pd.api.types.is_numeric_dtype(self[col]):
try:
results[col] = getattr(self[col], name)(
*args, **kwargs
)
except Exception:
# Skip columns that can't be aggregated
pass
return pd.Series(results)
elif axis == 1 or axis == "columns":
# Row-wise: use pandas DataFrame methods
numeric_cols = [
col
for col in self.columns
if pd.api.types.is_numeric_dtype(self[col])
]
if numeric_cols:
# Create regular DataFrame and call its method
df = pd.DataFrame(
{col: self[col].values for col in numeric_cols},
index=self.index,
)
return getattr(df, name)(axis=1, *args, **ms_kwargs)
return pd.Series(dtype="float64", index=self.index)

else:
raise ValueError(f"Invalid axis: {axis}")

return fn

Expand Down
61 changes: 61 additions & 0 deletions microdf/tests/test_microseries_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,64 @@ def test_reset_index_inplace() -> None:
assert "second" in mdf_multi.columns
assert list(mdf_multi.index) == [0, 1, 2, 3]
np.testing.assert_array_equal(mdf_multi.weights.values, weights)


def test_sum_axis_1() -> None:
# Test basic row-wise sum
df = mdf.MicroDataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]},
weights=[0.5, 1.0, 2.0],
)

# Row-wise sum (axis=1) should not use weights
row_sums = df.sum(axis=1)
expected = pd.Series([12, 15, 18], index=df.index) # 1+4+7, 2+5+8, 3+6+9
pd.testing.assert_series_equal(row_sums, expected)

# Column-wise sum (axis=0) should use weights
col_sums = df.sum(axis=0)
expected_weighted = pd.Series(
{
"A": 1 * 0.5 + 2 * 1.0 + 3 * 2.0, # 8.5
"B": 4 * 0.5 + 5 * 1.0 + 6 * 2.0, # 19.0
"C": 7 * 0.5 + 8 * 1.0 + 9 * 2.0, # 29.5
}
)
pd.testing.assert_series_equal(col_sums, expected_weighted)

# Test with mixed types (non-numeric columns should be ignored)
df_mixed = mdf.MicroDataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6], "text": ["a", "b", "c"]},
weights=[1, 1, 1],
)

row_sums_mixed = df_mixed.sum(axis=1)
expected_mixed = pd.Series([5, 7, 9], index=df_mixed.index) # Only A+B
pd.testing.assert_series_equal(row_sums_mixed, expected_mixed)

# Test with axis='columns' (string form)
row_sums_str = df.sum(axis="columns")
pd.testing.assert_series_equal(row_sums_str, expected)

# Test with additional parameters
df_with_nan = mdf.MicroDataFrame(
{"A": [1, np.nan, 3], "B": [4, 5, 6], "C": [7, 8, np.nan]},
weights=[1, 1, 1],
)

# skipna=True (default)
row_sums_skipna = df_with_nan.sum(axis=1)
expected_skipna = pd.Series([12.0, 13.0, 9.0]) # NaN values skipped
pd.testing.assert_series_equal(row_sums_skipna, expected_skipna)

# skipna=False
row_sums_no_skipna = df_with_nan.sum(axis=1, skipna=False)
expected_no_skipna = pd.Series([12.0, np.nan, np.nan]) # NaN propagates
pd.testing.assert_series_equal(row_sums_no_skipna, expected_no_skipna)

# Test min_count parameter
row_sums_min_count = df_with_nan.sum(axis=1, min_count=3)
expected_min_count = pd.Series(
[12.0, np.nan, np.nan]
) # Row 1 and 2 have < 3 non-NA values
pd.testing.assert_series_equal(row_sums_min_count, expected_min_count)