Skip to content

Commit 858f0dd

Browse files
committed
clear init to ensure loading of framework specific ops only (#48)
remove embedding ops from init to ensure loading of only available framework ops
1 parent 680f444 commit 858f0dd

File tree

3 files changed

+4
-16
lines changed

3 files changed

+4
-16
lines changed

merlin/loader/ops/embeddings/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,3 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
17-
# flake8: noqa
18-
from merlin.loader.ops.embeddings.tf_embedding_op import (
19-
TF_MmapNumpyTorchEmbedding,
20-
TF_NumpyEmbeddingOperator,
21-
TFEmbeddingOperator,
22-
)
23-
from merlin.loader.ops.embeddings.torch_embedding_op import (
24-
Torch_MmapNumpyTorchEmbedding,
25-
Torch_NumpyEmbeddingOperator,
26-
TorchEmbeddingOperator,
27-
)

tests/unit/loader/test_tf_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020

2121
from merlin.core.dispatch import HAS_GPU
2222
from merlin.io import Dataset
23-
from merlin.loader.tensorflow import Loader
2423
from merlin.schema import Tags
2524

2625
tf = pytest.importorskip("tensorflow")
2726

28-
from merlin.loader.ops.embeddings import ( # noqa
27+
from merlin.loader.ops.embeddings.tf_embedding_op import ( # noqa
2928
TF_MmapNumpyTorchEmbedding,
3029
TF_NumpyEmbeddingOperator,
3130
TFEmbeddingOperator,
3231
)
32+
from merlin.loader.tensorflow import Loader # noqa
3333

3434

3535
@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"])

tests/unit/loader/test_torch_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020

2121
from merlin.core.dispatch import HAS_GPU
2222
from merlin.io import Dataset
23-
from merlin.loader.torch import Loader
2423
from merlin.schema import Tags
2524

2625
torch = pytest.importorskip("torch")
2726

28-
from merlin.loader.ops.embeddings import ( # noqa
27+
from merlin.loader.ops.embeddings.torch_embedding_op import ( # noqa
2928
Torch_MmapNumpyTorchEmbedding,
3029
Torch_NumpyEmbeddingOperator,
3130
TorchEmbeddingOperator,
3231
)
32+
from merlin.loader.torch import Loader # noqa
3333

3434

3535
@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"])

0 commit comments

Comments
 (0)