Skip to content

Commit d26e776

Browse files
authored
working subgraphs in workflow (#1842)
* working subgraphs in workflow * add in subgraph transform to compare * change to subworkflows in workflow instead of subgraph4 * remove reset_index calls in asserts of test
1 parent 25151f7 commit d26e776

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

nvtabular/workflow/workflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def fit_schema(self, input_schema: Schema):
142142
self.graph.construct_schema(input_schema)
143143
return self
144144

145+
@property
146+
def subworkflows(self):
147+
return list(self.graph.subgraphs.keys())
148+
145149
@property
146150
def input_dtypes(self):
147151
return self.graph.input_dtypes
@@ -165,6 +169,10 @@ def output_node(self):
165169
def _input_columns(self):
166170
return self.graph._input_columns()
167171

172+
def get_subworkflow(self, subgraph_name):
173+
subgraph = self.graph.subgraph(subgraph_name)
174+
return Workflow(subgraph.output_node)
175+
168176
def remove_inputs(self, input_cols) -> "Workflow":
169177
"""Removes input columns from the workflow.
170178
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import os
18+
19+
import numpy as np
20+
import pytest
21+
from pandas.api.types import is_integer_dtype
22+
23+
from merlin.core.utils import set_dask_client
24+
from merlin.dag.ops.subgraph import Subgraph
25+
from nvtabular import Workflow, ops
26+
from tests.conftest import assert_eq
27+
28+
29+
@pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.1])
30+
@pytest.mark.parametrize("engine", ["parquet", "csv", "csv-no-header"])
31+
@pytest.mark.parametrize("dump", [True, False])
32+
@pytest.mark.parametrize("replace", [True, False])
33+
def test_workflow_subgraphs(tmpdir, client, df, dataset, gpu_memory_frac, engine, dump, replace):
34+
cat_names = ["name-cat", "name-string"] if engine == "parquet" else ["name-string"]
35+
cont_names = ["x", "y", "id"]
36+
label_name = ["label"]
37+
38+
norms = ops.Normalize()
39+
cat_features = cat_names >> ops.Categorify()
40+
if replace:
41+
cont_features = cont_names >> ops.FillMissing() >> ops.LogOp >> norms
42+
else:
43+
fillmissing_logop = (
44+
cont_names
45+
>> ops.FillMissing()
46+
>> ops.LogOp
47+
>> ops.Rename(postfix="_FillMissing_1_LogOp_1")
48+
)
49+
cont_features = cont_names + fillmissing_logop >> norms
50+
51+
set_dask_client(client=client)
52+
wkflow_ops = Subgraph("cat_graph", cat_features) + Subgraph("cont_graph", cont_features)
53+
workflow = Workflow(wkflow_ops + label_name)
54+
55+
workflow.fit(dataset)
56+
57+
if dump:
58+
workflow_dir = os.path.join(tmpdir, "workflow")
59+
workflow.save(workflow_dir)
60+
workflow = None
61+
62+
workflow = Workflow.load(workflow_dir)
63+
64+
def get_norms(tar):
65+
ser_median = tar.dropna().quantile(0.5, interpolation="linear")
66+
gdf = tar.fillna(ser_median)
67+
gdf = np.log(gdf + 1)
68+
return gdf
69+
70+
concat_ops = "_FillMissing_1_LogOp_1"
71+
if replace:
72+
concat_ops = ""
73+
74+
df_pp = workflow.transform(dataset).to_ddf().compute()
75+
76+
if engine == "parquet":
77+
assert is_integer_dtype(df_pp["name-cat"].dtype)
78+
assert is_integer_dtype(df_pp["name-string"].dtype)
79+
80+
subgraph_cat = workflow.get_subworkflow("cat_graph")
81+
subgraph_cont = workflow.get_subworkflow("cont_graph")
82+
assert isinstance(subgraph_cat, Workflow)
83+
assert isinstance(subgraph_cont, Workflow)
84+
# will not be the same nodes of saved out and loaded back
85+
if not dump:
86+
assert subgraph_cat.output_node == cat_features
87+
assert subgraph_cont.output_node == cont_features
88+
# check failure path works as expected
89+
with pytest.raises(ValueError) as exc:
90+
workflow.get_subworkflow("not_exist")
91+
assert "No subgraph named" in str(exc.value)
92+
93+
# test transform results from subgraph
94+
sub_cat_df = subgraph_cat.transform(dataset).to_ddf().compute()
95+
assert_eq(sub_cat_df, df_pp[cat_names])
96+
97+
cont_names = [name + concat_ops for name in cont_names]
98+
sub_cont_df = subgraph_cont.transform(dataset).to_ddf().compute()
99+
assert_eq(sub_cont_df[cont_names], df_pp[cont_names])

0 commit comments

Comments
 (0)