Skip to content

Commit 55fe934

Browse files
Add support for int8 values to Categorify inference (#1818)
1 parent feaa418 commit 55fe934

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

cpp/nvtabular/inference/categorify.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ namespace nvtabular
9393
case 'u':
9494
switch (dtype.itemsize())
9595
{
96+
case 1:
97+
insert_int_mapping<uint8_t>(values);
98+
return;
9699
case 2:
97100
insert_int_mapping<uint16_t>(values);
98101
return;
@@ -107,6 +110,9 @@ namespace nvtabular
107110
case 'i':
108111
switch (dtype.itemsize())
109112
{
113+
case 1:
114+
insert_int_mapping<int8_t>(values);
115+
return;
110116
case 2:
111117
insert_int_mapping<int16_t>(values);
112118
return;
@@ -204,6 +210,8 @@ namespace nvtabular
204210
case 'u':
205211
switch (itemsize)
206212
{
213+
case 1:
214+
return transform_int<uint8_t>(input);
207215
case 2:
208216
return transform_int<uint16_t>(input);
209217
case 4:
@@ -215,6 +223,8 @@ namespace nvtabular
215223
case 'i':
216224
switch (itemsize)
217225
{
226+
case 1:
227+
return transform_int<int8_t>(input);
218228
case 2:
219229
return transform_int<int16_t>(input);
220230
case 4:

tests/unit/ops/test_categorify.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,11 @@ def test_categorify_inference():
704704
"unicode_string": np.random.randint(
705705
low=a_char, high=z_char, size=num_rows * 10, dtype="int32"
706706
).view("U10"),
707+
"int8_feature": np.random.randint(0, 10, dtype="int8", size=num_rows),
707708
"int16_feature": np.random.randint(0, 10, dtype="int16", size=num_rows),
708709
"int32_feature": np.random.randint(0, 10, dtype="int32", size=num_rows),
709710
"int64_feature": np.random.randint(0, 10, dtype="int64", size=num_rows),
711+
"uint8_feature": np.random.randint(0, 10, dtype="uint8", size=num_rows),
710712
"uint16_feature": np.random.randint(0, 10, dtype="uint16", size=num_rows),
711713
"uint32_feature": np.random.randint(0, 10, dtype="uint32", size=num_rows),
712714
"uint64_feature": np.random.randint(0, 10, dtype="uint64", size=num_rows),

0 commit comments

Comments
 (0)