-
Notifications
You must be signed in to change notification settings - Fork 53
Add Model class #1126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Model class #1126
Conversation
Documentation preview |
| ... BinaryOutput(schema.select_by_tag(Tags.TARGET).first), | ||
| ... ) | ||
| ... trainer = Trainer(max_epochs=1) | ||
| ... with Loader(dataset, batch_size=16) as loader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it required to use it as a context-manager?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's probably not necessary in the torch case, but I want to promote this idiom using the context manager everywhere, because the tensorflow equivalent has a memory leak in some cases without a context manager.
| super().__init__() | ||
| self.schema = schema | ||
|
|
||
| self.pre = BlockContainer(name="pre") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we break this part (until line 77) as a function an use it inside Block as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would be cleaner, but I'm having trouble with torchscript shenanigans.
| """Finds all instances of `ModelOutput` in the model.""" | ||
| return module_utils.find_all_instances(self, ModelOutput) | ||
|
|
||
| def first(self) -> nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this part of the Block class as well?
| and len(model_outputs) > 1 | ||
| ): | ||
| raise RuntimeError("Multiple outputs but only one target was provided.") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check if all model_outputs have a target property set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the logic to check if model_outputs has a target when no targets are provided in lines 205-206.
| self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None | ||
| ): | ||
| """Performs a forward pass through the model.""" | ||
| outputs = inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here as for the init, should we break this out in a function so it can be shared between here and the Block class.
merlin/models/torch/models/base.py
Outdated
| def training_step(self, batch, batch_idx): | ||
| """Performs a training step with a single batch.""" | ||
| del batch_idx | ||
| inputs, targets = batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need some logic here to construct the Batch class and pass it to the forward-pass.
| if self.schema: | ||
| return self.schema | ||
| return Schema([]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need to add a method here for output_schema that combines all the output-schema's of the various model-outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an output_schema() method.
This PR introduces the
Modelclass for constructing models from individual blocks in the pytorch backend.Some follow-up will be necessary, such as testing with a toy dataset using dataloader,
batch_predict, etc.