Skip to content

MeanEnsemble to accumulate results in-place  #6366

@wyli

Description

@wyli

Describe the bug

img_ = self.get_stacked_torch(img)
if self.weights is not None:
self.weights = self.weights.to(img_.device)
shape = tuple(self.weights.shape)
for _ in range(img_.ndimension() - self.weights.ndimension()):
shape += (1,)
weights = self.weights.reshape(*shape)
img_ = img_ * weights / weights.mean(dim=0, keepdim=True)

the current implementation keeps all the predictions which take up unnecessary space compared with maintaining an output and updating it in-place.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions