Skip to content

Commit 217bb59

Browse files
authored
Merge pull request #63 from histogrammar/decimal-support
Decimal support
2 parents 0ee95a4 + b5462d0 commit 217bb59

File tree

8 files changed

+47
-34
lines changed

8 files changed

+47
-34
lines changed

histogrammar/dfinterface/filling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def check_column(col, sep=":"):
3838
return col
3939

4040

41-
def check_dtype(dtype):
41+
def normalize_dtype(dtype):
4242
"""Convert datatype to consistent numpy datatype
4343
4444
:param dtype: input datatype

histogrammar/dfinterface/histogram_filler_base.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..primitives.stack import Stack
2828
from ..primitives.sum import Sum
2929

30-
from .filling_utils import check_column, check_dtype
30+
from .filling_utils import check_column, normalize_dtype
3131

3232

3333
class HistogramFillerBase(object):
@@ -111,7 +111,7 @@ def __init__(
111111
self.bin_specs = bin_specs or {}
112112
self.time_axis = time_axis
113113
var_dtype = var_dtype or {}
114-
self.var_dtype = {k: check_dtype(v) for k, v in var_dtype.items()}
114+
self.var_dtype = {k: normalize_dtype(v) for k, v in var_dtype.items()}
115115
self.read_key = read_key
116116
self.store_key = store_key
117117

@@ -404,32 +404,31 @@ def categorize_features(self, df):
404404

405405
for col_list in features:
406406
for col in col_list:
407+
# data type with metadata
408+
dt_col = self.get_data_type(df, col)
407409

408-
dt = self.var_dtype.get(col, check_dtype(self.get_data_type(df, col)))
410+
# normalized data type
411+
dt = self.var_dtype.get(col, normalize_dtype(dt_col))
409412

410413
if col not in self.var_dtype:
411414
self.var_dtype[col] = dt
412415

416+
# metadata indicates decimal
417+
if hasattr(dt_col, 'metadata') and dt_col.metadata is not None and dt_col.metadata["decimal"]:
418+
cols_by_type["decimal"].add(col)
419+
413420
if np.issubdtype(dt, np.integer):
414-
colset = cols_by_type["int"]
415-
if col not in colset:
416-
colset.add(col)
421+
cols_by_type["int"].add(col)
422+
417423
if np.issubdtype(dt, np.number):
418424
colset = cols_by_type["num"]
419-
if col not in colset:
420-
colset.add(col)
421425
elif np.issubdtype(dt, np.datetime64):
422426
colset = cols_by_type["dt"]
423-
if col not in colset:
424-
colset.add(col)
425427
elif np.issubdtype(dt, np.bool_):
426428
colset = cols_by_type["bool"]
427-
if col not in colset:
428-
colset.add(col)
429429
else:
430430
colset = cols_by_type["str"]
431-
if col not in colset:
432-
colset.add(col)
431+
colset.add(col)
433432

434433
self.logger.debug(
435434
'Data type of column "{col}" is "{type}".'.format(

histogrammar/dfinterface/make_histograms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from .pandas_histogrammar import PandasHistogrammar
4444
from .spark_histogrammar import SparkHistogrammar
45-
from .filling_utils import check_dtype
45+
from .filling_utils import normalize_dtype
4646
from ..util import _get_sub_hist
4747

4848
logger = logging.getLogger()
@@ -232,7 +232,7 @@ def get_time_axes(df):
232232
return [
233233
c
234234
for c in df.columns
235-
if np.issubdtype(check_dtype(get_data_type(df, c)), np.datetime64)
235+
if np.issubdtype(normalize_dtype(get_data_type(df, c)), np.datetime64)
236236
]
237237

238238

histogrammar/dfinterface/pandas_histogrammar.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def get_data_type(self, df, col):
136136
elif inferred == 'boolean':
137137
data_type = 'bool'
138138
elif inferred in {'decimal', 'floating', 'mixed-integer-float'}:
139-
data_type = 'float'
139+
# decimal needs preprocessing (cast), signal this in metadata
140+
if inferred == "decimal":
141+
data_type = np.dtype('float', metadata={"decimal": True})
142+
else:
143+
data_type = "float"
140144
elif inferred in {'date', 'datetime', 'datetime64'}:
141145
data_type = 'datetime64'
142146
else: # categorical, mixed, etc -> object uses to_string()
@@ -187,6 +191,12 @@ def process_features(self, df, cols_by_type):
187191
)
188192
)
189193
idf[col] = df[col].apply(to_ns)
194+
195+
# treat decimal as float, as decimal is not supported by .quantile
196+
# (https://github.com/pandas-dev/pandas/issues/13157)
197+
for col in cols_by_type["decimal"]:
198+
idf[col] = df[col].apply(float)
199+
190200
return idf
191201

192202
def fill_histograms(self, idf):

histogrammar/dfinterface/spark_histogrammar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def get_data_type(self, df, col):
169169
dt = bool
170170
elif dt == "bigint":
171171
dt = np.int64
172+
elif dt.startswith("decimal("):
173+
return np.dtype(float, metadata={"decimal": True})
172174

173175
return np.dtype(dt)
174176

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from decimal import Decimal
12
from json import load
23
from os.path import dirname
34

@@ -88,4 +89,8 @@ def pytest_configure():
8889

8990
df = pd.read_csv(resources.data(CSV_FILE))
9091
df["date"] = pd.to_datetime(df["date"])
92+
93+
# Decimal type
94+
df["amount"] = df["balance"].str.replace("$", "", regex=False).str.replace(",", "", regex=False).apply(Decimal)
95+
9196
pytest.test_df = df

tests/test_pandas_histogrammar.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414

1515
def test_get_histograms():
16-
1716
pandas_filler = PandasHistogrammar(
1817
features=[
1918
"date",
@@ -47,7 +46,6 @@ def test_get_histograms():
4746

4847

4948
def test_make_histograms():
50-
5149
features = [
5250
"date",
5351
"isActive",
@@ -85,15 +83,14 @@ def test_make_histograms():
8583

8684

8785
def test_make_histograms_no_time_axis():
88-
8986
hists, features, bin_specs, time_axis, var_dtype = make_histograms(
9087
pytest.test_df, time_axis="", ret_specs=True,
9188
)
9289

93-
assert len(hists) == 21
94-
assert len(features) == 21
95-
assert len(bin_specs) == 6
96-
assert len(var_dtype) == 21
90+
assert len(hists) == 22
91+
assert len(features) == 22
92+
assert len(bin_specs) == 7
93+
assert len(var_dtype) == 22
9794
assert time_axis == ""
9895
assert "date" in hists
9996
h = hists["date"]
@@ -110,15 +107,14 @@ def test_make_histograms_no_time_axis():
110107

111108

112109
def test_make_histograms_with_time_axis():
113-
114110
hists, features, bin_specs, time_axis, var_dtype = make_histograms(
115111
pytest.test_df, time_axis=True, ret_specs=True, time_width=None, time_offset=None
116112
)
117113

118-
assert len(hists) == 20
119-
assert len(features) == 20
120-
assert len(bin_specs) == 20
121-
assert len(var_dtype) == 21
114+
assert len(hists) == 21
115+
assert len(features) == 21
116+
assert len(bin_specs) == 21
117+
assert len(var_dtype) == 22
122118
assert time_axis == "date"
123119
assert "date:age" in hists
124120
h = hists["date:age"]
@@ -167,10 +163,10 @@ def test_make_histograms_unit_binning():
167163
pytest.test_df, binning="unit", time_axis="", ret_specs=True
168164
)
169165

170-
assert len(hists) == 21
171-
assert len(features) == 21
166+
assert len(hists) == 22
167+
assert len(features) == 22
172168
assert len(bin_specs) == 0
173-
assert len(var_dtype) == 21
169+
assert len(var_dtype) == 22
174170
assert time_axis == ""
175171
assert "date" in hists
176172
h = hists["date"]

tests/test_spark_histogrammar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def get_spark():
3030
SparkSession.builder.master("local")
3131
.appName("histogrammar-pytest")
3232
.config("spark.jars", f"{hist_spark_jar},{hist_jar}")
33-
.config("spark.sql.execution.arrow.enabled", "false")
3433
.config("spark.sql.session.timeZone", "GMT")
3534
.getOrCreate()
3635
)
@@ -81,6 +80,7 @@ def test_get_histograms(spark_co):
8180
["isActive", "age"],
8281
["latitude", "longitude"],
8382
"transaction",
83+
"amount",
8484
],
8585
bin_specs={
8686
"transaction": {"num": 100, "low": -2000, "high": 2000},
@@ -140,6 +140,7 @@ def test_get_histograms_module(spark_co):
140140
"longitude",
141141
["isActive", "age"],
142142
["latitude", "longitude"],
143+
"amount",
143144
],
144145
bin_specs={
145146
"longitude": {"bin_width": 5.0, "bin_offset": 0.0},

0 commit comments

Comments
 (0)