Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bin/load_and_evaluate_patchcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_patchcore_iter(device):
[x for x in os.listdir(patch_core_path) if ".faiss" in x]
)
if n_patchcores == 1:
nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers)
nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers, device)
patchcore_instance = patchcore.patchcore.PatchCore(device)
patchcore_instance.load_from_path(
load_path=patch_core_path, device=device, nn_method=nn_method
Expand All @@ -219,7 +219,7 @@ def get_patchcore_iter(device):
else:
for i in range(n_patchcores):
nn_method = patchcore.common.FaissNN(
faiss_on_gpu, faiss_num_workers
faiss_on_gpu, faiss_num_workers, device
)
patchcore_instance = patchcore.patchcore.PatchCore(device)
patchcore_instance.load_from_path(
Expand Down
2 changes: 1 addition & 1 deletion bin/run_patchcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_patchcore(input_shape, sampler, device):
backbone = patchcore.backbones.load(backbone_name)
backbone.name, backbone.seed = backbone_name, backbone_seed

nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers)
nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers, device)

patchcore_instance = patchcore.patchcore.PatchCore(device)
patchcore_instance.load(
Expand Down
11 changes: 9 additions & 2 deletions src/patchcore/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@


class FaissNN(object):
def __init__(self, on_gpu: bool = False, num_workers: int = 4) -> None:
def __init__(self, on_gpu: bool = False, num_workers: int = 4, device: Union[int,torch.device]=0) -> None:
"""FAISS Nearest neighbourhood search.

Args:
on_gpu: If set true, nearest neighbour searches are done on GPU.
num_workers: Number of workers to use with FAISS for similarity search.
device: a gpu id or gpu device for FAISS NN search.
"""
faiss.omp_set_num_threads(num_workers)
self.on_gpu = on_gpu
self.search_index = None

if isinstance(device, torch.device):
device = int(torch.cuda.current_device())
self.device = device

def _gpu_cloner_options(self):
return faiss.GpuClonerOptions()
Expand All @@ -42,8 +47,10 @@ def _index_to_cpu(self, index):

def _create_index(self, dimension):
if self.on_gpu:
gpu_config = faiss.GpuIndexFlatConfig()
gpu_config.device = self.device
return faiss.GpuIndexFlatL2(
faiss.StandardGpuResources(), dimension, faiss.GpuIndexFlatConfig()
faiss.StandardGpuResources(), dimension, gpu_config
)
return faiss.IndexFlatL2(dimension)

Expand Down