Skip to content

Commit 6f18e25

Browse files
committed
param ordering with defaults
1 parent 0711d50 commit 6f18e25

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

hummingbird/ml/operator_converters/_one_hot_encoder_implementations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module):
2222
Because we are dealing with tensors, strings require additional length information for processing.
2323
"""
2424

25-
def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None, extra_config={}):
25+
def __init__(self, logical_operator, categories, device, extra_config={}, handle_unknown='error', infrequent=None):
2626
super(OneHotEncoderString, self).__init__(logical_operator, transformer=True)
2727

2828
self.num_columns = len(categories)
@@ -113,7 +113,7 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module):
113113
Class implementing OneHotEncoder operators for ints in PyTorch.
114114
"""
115115

116-
def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None):
116+
def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent=None):
117117
super(OneHotEncoder, self).__init__(logical_operator, transformer=True)
118118

119119
self.num_columns = len(categories)

hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config):
4949
]
5050
):
5151
categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
52-
return OneHotEncoderString(operator, categories, operator.raw_operator.handle_unknown,
53-
device, infrequent, extra_config)
52+
return OneHotEncoderString(operator, categories, device, extra_config=extra_config,
53+
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)
5454
else:
55-
return OneHotEncoder(operator, operator.raw_operator.categories_, operator.raw_operator.handle_unknown,
56-
device, infrequent)
55+
return OneHotEncoder(operator, operator.raw_operator.categories_, device,
56+
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)
5757

5858

5959
register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)

0 commit comments

Comments
 (0)