diff --git a/torchreid/data/datamanager.py b/torchreid/data/datamanager.py index 7ae28cbaf..ebd13c938 100644 --- a/torchreid/data/datamanager.py +++ b/torchreid/data/datamanager.py @@ -175,7 +175,8 @@ def __init__( train_sampler_t='RandomSampler', cuhk03_labeled=False, cuhk03_classic_split=False, - market1501_500k=False + market1501_500k=False, + **kwargs ): super(ImageDataManager, self).__init__( @@ -202,7 +203,8 @@ def __init__( split_id=split_id, cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split, - market1501_500k=market1501_500k + market1501_500k=market1501_500k, + **kwargs ) trainset.append(trainset_) trainset = sum(trainset) diff --git a/torchreid/utils/feature_extractor.py b/torchreid/utils/feature_extractor.py index 83635fd3f..c01af91d8 100644 --- a/torchreid/utils/feature_extractor.py +++ b/torchreid/utils/feature_extractor.py @@ -104,7 +104,7 @@ def __init__( self.model = model self.preprocess = preprocess self.to_pil = to_pil - self.device = device + self.device = next(self.model.parameters()).device def __call__(self, input): if isinstance(input, list):