Skip to content

Conversation

@edknv
Copy link
Contributor

@edknv edknv commented May 30, 2023

This PR introduces the Model class for constructing models from individual blocks in the pytorch backend.

Some follow-up will be necessary, such as testing with a toy dataset using dataloader, batch_predict, etc.

@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1126

@edknv edknv self-assigned this Jun 1, 2023
@edknv edknv added enhancement New feature or request area/pytorch labels Jun 1, 2023
@edknv edknv added this to the Merlin 23.06 milestone Jun 1, 2023
@edknv edknv requested review from marcromeyn and oliverholworthy and removed request for marcromeyn June 1, 2023 03:07
... BinaryOutput(schema.select_by_tag(Tags.TARGET).first),
... )
... trainer = Trainer(max_epochs=1)
... with Loader(dataset, batch_size=16) as loader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it required to use it as a context-manager?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's probably not necessary in the torch case, but I want to promote this idiom using the context manager everywhere, because the tensorflow equivalent has a memory leak in some cases without a context manager.

super().__init__()
self.schema = schema

self.pre = BlockContainer(name="pre")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we break this part (until line 77) as a function an use it inside Block as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be cleaner, but I'm having trouble with torchscript shenanigans.

"""Finds all instances of `ModelOutput` in the model."""
return module_utils.find_all_instances(self, ModelOutput)

def first(self) -> nn.Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this part of the Block class as well?

and len(model_outputs) > 1
):
raise RuntimeError("Multiple outputs but only one target was provided.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check if all model_outputs have a target property set?

Copy link
Contributor Author

@edknv edknv Jun 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the logic to check if model_outputs has a target when no targets are provided in lines 205-206.

self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
"""Performs a forward pass through the model."""
outputs = inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here as for the init, should we break this out in a function so it can be shared between here and the Block class.

def training_step(self, batch, batch_idx):
"""Performs a training step with a single batch."""
del batch_idx
inputs, targets = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need some logic here to construct the Batch class and pass it to the forward-pass.

if self.schema:
return self.schema
return Schema([])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to add a method here for output_schema that combines all the output-schema's of the various model-outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an output_schema() method.

@edknv edknv marked this pull request as ready for review June 4, 2023 23:34
@edknv edknv merged commit 92833fa into NVIDIA-Merlin:main Jun 13, 2023
@edknv edknv deleted the torch/model_cls branch June 13, 2023 16:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/pytorch enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants