diff --git a/bin/load_and_evaluate_patchcore.py b/bin/load_and_evaluate_patchcore.py index 27f8ca0..c2f39dd 100644 --- a/bin/load_and_evaluate_patchcore.py +++ b/bin/load_and_evaluate_patchcore.py @@ -210,7 +210,7 @@ def get_patchcore_iter(device): [x for x in os.listdir(patch_core_path) if ".faiss" in x] ) if n_patchcores == 1: - nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers) + nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers, device) patchcore_instance = patchcore.patchcore.PatchCore(device) patchcore_instance.load_from_path( load_path=patch_core_path, device=device, nn_method=nn_method @@ -219,7 +219,7 @@ def get_patchcore_iter(device): else: for i in range(n_patchcores): nn_method = patchcore.common.FaissNN( - faiss_on_gpu, faiss_num_workers + faiss_on_gpu, faiss_num_workers, device ) patchcore_instance = patchcore.patchcore.PatchCore(device) patchcore_instance.load_from_path( diff --git a/bin/run_patchcore.py b/bin/run_patchcore.py index 8666b2b..f9bf9ae 100644 --- a/bin/run_patchcore.py +++ b/bin/run_patchcore.py @@ -294,7 +294,7 @@ def get_patchcore(input_shape, sampler, device): backbone = patchcore.backbones.load(backbone_name) backbone.name, backbone.seed = backbone_name, backbone_seed - nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers) + nn_method = patchcore.common.FaissNN(faiss_on_gpu, faiss_num_workers, device) patchcore_instance = patchcore.patchcore.PatchCore(device) patchcore_instance.load( diff --git a/src/patchcore/common.py b/src/patchcore/common.py index eeb3c64..642d19d 100644 --- a/src/patchcore/common.py +++ b/src/patchcore/common.py @@ -12,16 +12,21 @@ class FaissNN(object): - def __init__(self, on_gpu: bool = False, num_workers: int = 4) -> None: + def __init__(self, on_gpu: bool = False, num_workers: int = 4, device: Union[int,torch.device]=0) -> None: """FAISS Nearest neighbourhood search. Args: on_gpu: If set true, nearest neighbour searches are done on GPU. num_workers: Number of workers to use with FAISS for similarity search. + device: a gpu id or gpu device for FAISS NN search. """ faiss.omp_set_num_threads(num_workers) self.on_gpu = on_gpu self.search_index = None + + if isinstance(device, torch.device): + device = int(torch.cuda.current_device()) + self.device = device def _gpu_cloner_options(self): return faiss.GpuClonerOptions() @@ -42,8 +47,10 @@ def _index_to_cpu(self, index): def _create_index(self, dimension): if self.on_gpu: + gpu_config = faiss.GpuIndexFlatConfig() + gpu_config.device = self.device return faiss.GpuIndexFlatL2( - faiss.StandardGpuResources(), dimension, faiss.GpuIndexFlatConfig() + faiss.StandardGpuResources(), dimension, gpu_config ) return faiss.IndexFlatL2(dimension)