Skip to content

Commit 0931633

Browse files
marcromeynedknv
andauthored
Introducing Block (#1087)
* Introducing Block * Adding improved doc-strings * Adding torch github-action + add copyright * Trying to fix failing tests * Fixing bug in pytorch gh-action * Increase test-coverage * Expose Sequence in __init__ * Give n in repeat a default value --------- Co-authored-by: edknv <[email protected]>
1 parent 8bb655b commit 0931633

File tree

16 files changed

+1647
-42
lines changed

16 files changed

+1647
-42
lines changed

.github/workflows/pytorch.yml

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

2-
# name: pytorch
2+
name: pytorch
33

4-
# on:
5-
# push:
6-
# branches: [main]
7-
# pull_request:
8-
# branches: [main]
4+
on:
5+
push:
6+
branches: [main]
7+
pull_request:
8+
branches: [main]
99

1010
concurrency:
1111
group: ${{ github.workflow }}-${{ github.ref }}
@@ -18,38 +18,43 @@ jobs:
1818
matrix:
1919
python-version: [3.8]
2020
os: [ubuntu-latest]
21+
torch-version: ["<2.0", "~=2.0"]
2122

22-
# steps:
23-
# - uses: actions/checkout@v3
24-
# - name: Set up Python ${{ matrix.python-version }}
25-
# uses: actions/setup-python@v4
26-
# with:
27-
# python-version: ${{ matrix.python-version }}
28-
# cache: 'pip'
29-
# cache-dependency-path: 'requirements/dev.txt'
30-
# - name: Install Ubuntu packages
31-
# run: |
32-
# sudo apt-get update -y
33-
# sudo apt-get install -y protobuf-compiler
34-
# - name: Install Merlin dependencies
35-
# run: |
36-
# ref_type=${{ github.ref_type }}
37-
# branch=main
38-
# if [[ $ref_type == "tag"* ]]
39-
# then
40-
# git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
41-
# branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
42-
# fi
43-
# pip install "pandas>=1.2.0,<1.4.0dev0"
44-
# pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
45-
# pip install "merlin-core@git+https://github.com/NVIDIA-Merlin/core.git@$branch"
46-
# - name: Install dependencies
47-
# run: |
48-
# python -m pip install --upgrade pip
49-
# python -m pip install .[pytorch-dev]
50-
# - name: Build
51-
# run: |
52-
# python setup.py develop
53-
# - name: Run unittests
54-
# run: |
55-
# make tests-torch
23+
steps:
24+
- uses: actions/checkout@v3
25+
- name: Set up Python ${{ matrix.python-version }}
26+
uses: actions/setup-python@v4
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
cache: 'pip'
30+
cache-dependency-path: |
31+
**/setup.cfg
32+
requirements/*.txt
33+
- name: Install Ubuntu packages
34+
run: |
35+
sudo apt-get update -y
36+
sudo apt-get install -y protobuf-compiler
37+
- name: Install Merlin dependencies
38+
run: |
39+
ref_type=${{ github.ref_type }}
40+
branch=main
41+
if [[ $ref_type == "tag"* ]]
42+
then
43+
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
44+
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
45+
fi
46+
pip install "pandas>=1.2.0,<1.4.0dev0"
47+
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
48+
pip install "merlin-dataloader@git+https://github.com/NVIDIA-Merlin/dataloader.git@$branch"
49+
pip install "merlin-core@git+https://github.com/NVIDIA-Merlin/core.git@$branch"
50+
pip install "merlin-systems@git+https://github.com/NVIDIA-Merlin/systems.git@$branch"
51+
- name: Install dependencies
52+
run: |
53+
python -m pip install "torch${{ matrix.torch-version }}"
54+
python -m pip install .[pytorch-dev]
55+
- name: Build
56+
run: |
57+
python setup.py develop
58+
- name: Run unittests
59+
run: |
60+
make tests-torch

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ tests-tf-integration:
3333

3434
tests-torch:
3535
coverage run -m pytest -rsx tests -m "torch" || exit 1
36-
coverage report --include 'merlin/models/*'
37-
coverage html --include 'merlin/models/*'
36+
coverage report --include 'merlin/models/torch/*'
37+
coverage html --include 'merlin/models/torch/*'
3838

3939
tests-implicit:
4040
coverage run -m pytest -rsx tests -m "implicit" || exit 1

merlin/models/torch/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from merlin.models.torch.batch import Batch, Sequence
18+
from merlin.models.torch.block import Block
19+
20+
__all__ = ["Batch", "Block", "Sequence"]

merlin/models/torch/batch.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from typing import Dict, Optional, Union
18+
19+
import torch
20+
21+
22+
@torch.jit.script
23+
class Sequence:
24+
"""
25+
A PyTorch scriptable class representing a sequence of tabular data.
26+
27+
Attributes:
28+
lengths (Dict[str, torch.Tensor]): A dictionary mapping the feature names to their
29+
corresponding sequence lengths.
30+
masks (Dict[str, torch.Tensor]): A dictionary mapping the feature names to their
31+
corresponding masks. Default is an empty dictionary.
32+
33+
Examples:
34+
>>> lengths = {'feature1': torch.tensor([4, 5]), 'feature2': torch.tensor([3, 7])}
35+
>>> masks = {'feature1': torch.tensor([[1, 0], [1, 1]]), 'feature2': torch.tensor([[1, 1], [1, 0]])} # noqa: E501
36+
>>> seq = Sequence(lengths, masks)
37+
"""
38+
39+
def __init__(
40+
self,
41+
lengths: Union[torch.Tensor, Dict[str, torch.Tensor]],
42+
masks: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
43+
):
44+
if isinstance(lengths, torch.Tensor):
45+
_lengths = {"default": lengths}
46+
elif torch.jit.isinstance(lengths, Dict[str, torch.Tensor]):
47+
_lengths = lengths
48+
else:
49+
raise ValueError("Lengths must be a tensor or a dictionary of tensors")
50+
self.lengths: Dict[str, torch.Tensor] = _lengths
51+
52+
if masks is None:
53+
_masks = {}
54+
elif isinstance(masks, torch.Tensor):
55+
_masks = {"default": masks}
56+
elif torch.jit.isinstance(masks, Dict[str, torch.Tensor]):
57+
_masks = masks
58+
else:
59+
raise ValueError("Masks must be a tensor or a dictionary of tensors")
60+
self.masks: Dict[str, torch.Tensor] = _masks
61+
62+
def __contains__(self, name: str) -> bool:
63+
return name in self.lengths
64+
65+
def length(self, name: str = "default") -> torch.Tensor:
66+
if name in self.lengths:
67+
return self.lengths[name]
68+
69+
raise ValueError("Batch has multiple lengths, please specify a feature name")
70+
71+
def mask(self, name: str = "default") -> torch.Tensor:
72+
if name in self.masks:
73+
return self.masks[name]
74+
75+
raise ValueError("Batch has multiple masks, please specify a feature name")
76+
77+
78+
@torch.jit.script
79+
class Batch:
80+
"""
81+
A PyTorch scriptable class representing a batch of data.
82+
83+
Attributes:
84+
features (Dict[str, torch.Tensor]): A dictionary mapping feature names to their
85+
corresponding feature values.
86+
targets (Dict[str, torch.Tensor]): A dictionary mapping target names to their
87+
corresponding target values. Default is an empty dictionary.
88+
sequences (Optional[Sequence]): An optional instance of the Sequence class
89+
representing sequence lengths and masks for the batch.
90+
91+
Examples:
92+
>>> features = {'feature1': torch.tensor([1, 2]), 'feature2': torch.tensor([3, 4])}
93+
>>> targets = {'target1': torch.tensor([0, 1])}
94+
>>> batch = Batch(features, targets)
95+
"""
96+
97+
def __init__(
98+
self,
99+
features: Union[torch.Tensor, Dict[str, torch.Tensor]],
100+
targets: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
101+
sequences: Optional[Sequence] = None,
102+
):
103+
default_key = "default"
104+
105+
if isinstance(features, torch.Tensor):
106+
_features = {default_key: features}
107+
elif torch.jit.isinstance(features, Dict[str, torch.Tensor]):
108+
_features = features
109+
else:
110+
raise ValueError("Features must be a tensor or a dictionary of tensors")
111+
112+
self.features: Dict[str, torch.Tensor] = _features
113+
114+
if isinstance(targets, torch.Tensor):
115+
targets = {default_key: targets}
116+
117+
if targets is None:
118+
_targets = {}
119+
elif torch.jit.isinstance(targets, Dict[str, torch.Tensor]):
120+
_targets = targets
121+
else:
122+
raise ValueError("Targets must be a tensor or a dictionary of tensors")
123+
self.targets: Dict[str, torch.Tensor] = _targets
124+
self.sequences: Optional[Sequence] = sequences
125+
126+
def replace(
127+
self,
128+
features: Optional[Dict[str, torch.Tensor]] = None,
129+
targets: Optional[Dict[str, torch.Tensor]] = None,
130+
sequences: Optional[Sequence] = None,
131+
) -> "Batch":
132+
"""
133+
Create a new `Batch` instance, replacing specified attributes with new values.
134+
135+
Parameters
136+
----------
137+
features : Optional[Dict[str, torch.Tensor]]
138+
A dictionary of tensors representing the features of the batch. Default is None.
139+
targets : Optional[Dict[str, torch.Tensor]]
140+
A dictionary of tensors representing the targets of the batch. Default is None.
141+
sequences : Optional[Sequence]
142+
An instance of the Sequence class representing sequence lengths and masks for the
143+
batch. Default is None.
144+
145+
Returns
146+
-------
147+
Batch
148+
A new Batch object with replaced attributes.
149+
"""
150+
151+
return Batch(
152+
features=features if features is not None else self.features,
153+
targets=targets if targets is not None else self.targets,
154+
sequences=sequences if sequences is not None else self.sequences,
155+
)
156+
157+
def feature(self, name: str = "default") -> torch.Tensor:
158+
"""Retrieve a feature tensor from the batch by its name.
159+
160+
Parameters
161+
----------
162+
name : str
163+
The name of the feature tensor to return. Default is "default".
164+
165+
Returns
166+
-------
167+
torch.Tensor
168+
The feature tensor of the specified name.
169+
170+
Raises
171+
------
172+
ValueError
173+
If the specified name does not exist in the features attribute.
174+
"""
175+
176+
if name in self.features:
177+
return self.features[name]
178+
179+
raise ValueError("Batch has multiple features, please specify a feature name")
180+
181+
def target(self, name: str = "default") -> torch.Tensor:
182+
"""Retrieve a target tensor from the batch by its name.
183+
184+
Parameters
185+
----------
186+
name : str
187+
The name of the target tensor to return. Default is "default".
188+
189+
Returns
190+
-------
191+
torch.Tensor
192+
The target tensor of the specified name.
193+
194+
Raises
195+
------
196+
ValueError
197+
If the specified name does not exist in the targets attribute.
198+
"""
199+
200+
if name in self.targets:
201+
return self.targets[name]
202+
203+
raise ValueError("Batch has multiple target, please specify a target name")
204+
205+
def __bool__(self) -> bool:
206+
return bool(self.features)

0 commit comments

Comments
 (0)