Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions imagededup/methods/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,17 @@ class CNN:
def __init__(
self,
verbose: bool = True,
model_config: Optional[CustomModel] = None
model_config: Optional[CustomModel] = None,
batch_size: int = 64
) -> None:
"""
Initialize a pytorch MobileNet model v3 that is sliced at the last convolutional layer.
Set the batch size for pytorch dataloader to be 64 samples.
Set the batch size for pytorch dataloader to be 64 samples by default.

Args:
verbose: Display progress bar if True else disable it. Default value is True.
model_config: A CustomModel that can be used to initialize a custom PyTorch model along with the corresponding transform.
batch_size: Batch size for the dataloader during encoding generation. Lower values use less GPU memory. Default value is 64.
"""
self.model_config = model_config if model_config is not None else CustomModel(
model=MobilenetV3(), transform=MobilenetV3.transform, name=MobilenetV3.name
Expand All @@ -64,7 +66,9 @@ def __init__(
) # The logger needs to be bound to the class, otherwise stderr also gets
# directed to stdout (Don't know why that is the case)

self.batch_size = 64
if not isinstance(batch_size, int) or batch_size < 1:
raise ValueError('batch_size must be a positive integer')
self.batch_size = batch_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.logger.info(f"Device set to {self.device} ..")

Expand Down Expand Up @@ -109,12 +113,7 @@ def _get_cnn_features_single(self, image_array: np.ndarray) -> np.ndarray:
image_pp = image_pp.unsqueeze(0)
img_features_tensor = self.model(image_pp.to(self.device))

if self.device.type == "cuda":
unpacked_img_features_tensor = img_features_tensor.cpu().detach().numpy()
else:
unpacked_img_features_tensor = img_features_tensor.detach().numpy()

return unpacked_img_features_tensor
return img_features_tensor.cpu().detach().numpy()

def _get_cnn_features_batch(
self,
Expand Down Expand Up @@ -146,31 +145,36 @@ def _get_cnn_features_batch(

with torch.no_grad():
for ims, filenames, bad_images in self.dataloader:
if ims is None or len(ims) == 0:
bad_im_count += len(bad_images)
continue
arr = self.model(ims.to(self.device))
feat_arr.extend(arr)
feat_arr.append(arr.cpu().detach().numpy())
del arr
if self.device.type == 'cuda':
torch.cuda.empty_cache()
all_filenames.extend(filenames)
if bad_images:
bad_im_count += 1
bad_im_count += len(bad_images)

if bad_im_count:
self.logger.info(
f"Found {bad_im_count} bad images, ignoring for encoding generation .."
)

feat_vec = torch.stack(feat_arr).squeeze()
feat_vec = (
feat_vec.detach().numpy()
if self.device.type == "cpu"
else feat_vec.detach().cpu().numpy()
)
if not feat_arr:
self.logger.info('No valid images found for encoding generation ..')
self.encoding_map = {}
return self.encoding_map

feat_vec = np.vstack(feat_arr)
valid_image_files = [filename for filename in all_filenames if filename]
self.logger.info("End: Image encoding generation")

filenames = generate_relative_names(image_dir, valid_image_files)
if (
len(feat_vec.shape) == 1
feat_vec.shape[0] == 1
): # can happen when encode_images is called on a directory containing a single image
self.encoding_map = {filenames[0]: feat_vec}
self.encoding_map = {filenames[0]: feat_vec[0]}
else:
self.encoding_map = {j: feat_vec[i, :] for i, j in enumerate(filenames)}
return self.encoding_map
Expand Down
4 changes: 4 additions & 0 deletions imagededup/utils/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def _collate_fn(batch: List[Dict]) -> Tuple[torch.tensor, str, str]:
filenames.append(b['filename'])
else:
bad_images.append(b['filename'])

if not ims:
return None, filenames, bad_images

return torch.stack(ims), filenames, bad_images


Expand Down
10 changes: 4 additions & 6 deletions imagededup/utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def save_json(results: Dict, filename: str, float_scores: bool = False) -> None:

def parallelise(function: Callable, data: List, verbose: bool, num_workers: int) -> List:
num_workers = 1 if num_workers < 1 else num_workers # Pool needs to have at least 1 worker.
pool = Pool(processes=num_workers)
results = list(
tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose)
)
pool.close()
pool.join()
with Pool(processes=num_workers) as pool:
results = list(
tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose)
)
return results


Expand Down
71 changes: 71 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ def test__init_defaults(cnn):
assert cnn.model_config.name == MobilenetV3.name


def test__init_custom_batch_size():
cnn = CNN(batch_size=32)
assert cnn.batch_size == 32

cnn = CNN(batch_size=1)
assert cnn.batch_size == 1


def test__init_invalid_batch_size():
with pytest.raises(ValueError):
CNN(batch_size=0)

with pytest.raises(ValueError):
CNN(batch_size=-1)

with pytest.raises(ValueError):
CNN(batch_size='abc')


def test__init_custom():
cnn = CNN(model_config=CustomModel(model=EfficientNet(),
transform=EfficientNet.transform,
Expand Down Expand Up @@ -983,6 +1002,58 @@ def test_find_duplicates_to_remove_encoding_integration(cnn):
)


# batch_size


def test_small_batch_size_produces_same_results(cnn):
cnn_small = CNN(batch_size=2)
encodings_default = cnn.encode_images(TEST_IMAGE_DIR)
encodings_small = cnn_small.encode_images(TEST_IMAGE_DIR)

assert set(encodings_default.keys()) == set(encodings_small.keys())
for k in encodings_default:
np.testing.assert_allclose(encodings_default[k], encodings_small[k], atol=1e-5)


def test_batch_size_one_produces_same_results(cnn):
cnn_one = CNN(batch_size=1)
encodings_default = cnn.encode_images(TEST_IMAGE_DIR)
encodings_one = cnn_one.encode_images(TEST_IMAGE_DIR)

assert set(encodings_default.keys()) == set(encodings_one.keys())
for k in encodings_default:
np.testing.assert_allclose(encodings_default[k], encodings_one[k], atol=1e-5)


def test_batch_size_larger_than_dataset():
cnn_large = CNN(batch_size=128)
encodings = cnn_large.encode_images(TEST_IMAGE_DIR)
assert len(encodings) == 10


def test_small_batch_size_find_duplicates_integration():
cnn_small = CNN(batch_size=2)
duplicates = cnn_small.find_duplicates(
image_dir=TEST_IMAGE_DIR_MIXED,
min_similarity_threshold=0.9,
scores=False,
)
assert 'ukbench00120.jpg' in duplicates
assert len(duplicates['ukbench00120.jpg']) > 0
assert len(duplicates['ukbench09268.jpg']) == 0


def test_all_bad_images_returns_empty_encoding(tmp_path):
bad_file = tmp_path / 'corrupt.jpg'
bad_file.write_bytes(b'not an image')
bad_file2 = tmp_path / 'corrupt2.jpg'
bad_file2.write_bytes(b'also not an image')

cnn_inst = CNN()
result = cnn_inst.encode_images(tmp_path)
assert result == {}


def test_scores_saving(cnn):
save_file = 'myduplicates.json'
cnn.find_duplicates(
Expand Down