Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 24 additions & 79 deletions tests/test_binary_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import numpy as np
import unittest
import faiss
import os
import tempfile


def make_binary_dataset(d, nb, nt, nq):
Expand All @@ -37,21 +35,12 @@ def test_flat(self):
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)
D, I = index.search(self.xq, 3)
index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)
D2, I2 = index2.search(self.xq, 3)

index2 = faiss.read_index_binary(tmpnam)

D2, I2 = index2.search(self.xq, 3)

assert (I2 == I).all()
assert (D2 == D).all()

finally:
os.remove(tmpnam)
assert (I2 == I).all()
assert (D2 == D).all()


class TestBinaryIVF(unittest.TestCase):
Expand All @@ -76,20 +65,12 @@ def test_ivf_flat(self):
index.add(self.xb)
D, I = index.search(self.xq, 3)

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)

index2 = faiss.read_index_binary(tmpnam)

D2, I2 = index2.search(self.xq, 3)
index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

assert (I2 == I).all()
assert (D2 == D).all()
D2, I2 = index2.search(self.xq, 3)

finally:
os.remove(tmpnam)
assert (I2 == I).all()
assert (D2 == D).all()


class TestObjectOwnership(unittest.TestCase):
Expand All @@ -109,16 +90,10 @@ def test_read_index_ownership(self):
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)
# this is the output of read_index_binary (==> checks ownership)
index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

index2 = faiss.read_index_binary(tmpnam)

assert index2.thisown
finally:
os.remove(tmpnam)
assert index2.thisown


class TestBinaryFromFloat(unittest.TestCase):
Expand All @@ -140,21 +115,11 @@ def test_binary_from_float(self):
index.add(self.xb)
D, I = index.search(self.xq, 3)

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)

index2 = faiss.read_index_binary(tmpnam)

D2, I2 = index2.search(self.xq, 3)

assert (I2 == I).all()
assert (D2 == D).all()

finally:
os.remove(tmpnam)
index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))
D2, I2 = index2.search(self.xq, 3)

assert (I2 == I).all()
assert (D2 == D).all()

class TestBinaryHNSW(unittest.TestCase):

Expand All @@ -174,20 +139,12 @@ def test_hnsw(self):
index.add(self.xb)
D, I = index.search(self.xq, 3)

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)

index2 = faiss.read_index_binary(tmpnam)
index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

D2, I2 = index2.search(self.xq, 3)
D2, I2 = index2.search(self.xq, 3)

assert (I2 == I).all()
assert (D2 == D).all()

finally:
os.remove(tmpnam)
assert (I2 == I).all()
assert (D2 == D).all()

def test_ivf_hnsw(self):
d = self.xq.shape[1] * 8
Expand All @@ -200,21 +157,9 @@ def test_ivf_hnsw(self):
index.add(self.xb)
D, I = index.search(self.xq, 3)

fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)

index2 = faiss.read_index_binary(tmpnam)

D2, I2 = index2.search(self.xq, 3)

assert (I2 == I).all()
assert (D2 == D).all()

finally:
os.remove(tmpnam)

index2 = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

D2, I2 = index2.search(self.xq, 3)

if __name__ == '__main__':
unittest.main()
assert (I2 == I).all()
assert (D2 == D).all()
28 changes: 6 additions & 22 deletions tests/test_fast_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import unittest
import time
import os
import tempfile

import numpy as np
import faiss
Expand Down Expand Up @@ -587,16 +585,9 @@ def subtest_io(self, factory_str):
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)

fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)

def test_io(self):
self.subtest_io('LSQ4x4fs_Nlsq2x4')
Expand Down Expand Up @@ -685,16 +676,9 @@ def subtest_io(self, factory_str):
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)

fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)

def test_io(self):
self.subtest_io('PLSQ2x3x4fs_Nlsq2x4')
Expand Down
30 changes: 7 additions & 23 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# LICENSE file in the root directory of this source tree.


import os
import unittest
import tempfile

import numpy as np
import faiss
Expand Down Expand Up @@ -821,18 +819,11 @@ def subtest_io(self, factory_str):
index = faiss.index_factory(d, factory_str)
index.train(ds.get_train())
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)
_, I1 = index.search(ds.get_queries(), 1)

fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
index2 = faiss.deserialize_index(faiss.serialize_index(index))
_, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)

def test_io(self):
self.subtest_io('IVF16,LSQ4x4fs_Nlsq2x4')
Expand Down Expand Up @@ -929,16 +920,9 @@ def subtest_io(self, factory_str):
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)

fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)

def test_io(self):
self.subtest_io('IVF16,PLSQ2x3x4fsr_Nlsq2x4')
Expand Down
1 change: 0 additions & 1 deletion tests/test_flat_l2_panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

import unittest
import tempfile
import os

import faiss
Expand Down
28 changes: 2 additions & 26 deletions tests/test_graph_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import numpy as np
import unittest
import faiss
import tempfile
import os

from common_faiss_tests import get_dataset_2

Expand Down Expand Up @@ -280,14 +278,7 @@ def make_knn_graph(self, metric):
return knn_graph

def subtest_io_and_clone(self, index, Dnsg, Insg):
fd, tmpfile = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, tmpfile)
index2 = faiss.read_index(tmpfile)
finally:
if os.path.exists(tmpfile):
os.unlink(tmpfile)
index2 = faiss.deserialize_index(faiss.serialize_index(index))

Dnsg2, Insg2 = index2.search(self.xq, 1)
np.testing.assert_array_equal(Dnsg2, Dnsg)
Expand All @@ -306,8 +297,6 @@ def subtest_connectivity(self, index, nb):

def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2):
d = self.xq.shape[1]
metrics = {faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}

flat_index = faiss.IndexFlat(d, metric)
flat_index.add(self.xb)
Expand Down Expand Up @@ -381,8 +370,6 @@ def test_build_invalid_knng(self):
def test_reset(self):
"""test IndexNSG.reset()"""
d = self.xq.shape[1]
metrics = {faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}

metric = faiss.METRIC_L2
flat_index = faiss.IndexFlat(d, metric)
Expand Down Expand Up @@ -546,14 +533,7 @@ def test_nndescentflat(self):
self.assertGreaterEqual(recalls, 450) # 462

# do some IO tests
fd, tmpfile = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, tmpfile)
index2 = faiss.read_index(tmpfile)
finally:
if os.path.exists(tmpfile):
os.unlink(tmpfile)
index2 = faiss.deserialize_index(faiss.serialize_index(index))

D2, I2 = index2.search(self.xq, 1)
np.testing.assert_array_equal(D2, D)
Expand Down Expand Up @@ -592,10 +572,6 @@ def test_knng_IP(self):
self.subtest(32, 10, faiss.METRIC_INNER_PRODUCT)

def subtest(self, d, K, metric):
metric_names = {faiss.METRIC_L1: 'L1',
faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}

nb = 1000
_, xb, _ = get_dataset_2(d, 0, nb, 0)

Expand Down
18 changes: 2 additions & 16 deletions tests/test_index_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

""" more elaborate that test_index.py """
from __future__ import absolute_import, division, print_function

import numpy as np
import unittest
Expand Down Expand Up @@ -193,13 +192,7 @@ def test_remove_id_map_binary(self):
assert False, 'should have raised an exception'

# while we are there, let's test I/O as well...
fd, tmpnam = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index_binary(index, tmpnam)
index = faiss.read_index_binary(tmpnam)
finally:
os.remove(tmpnam)
index = faiss.deserialize_index_binary(faiss.serialize_index_binary(index))

assert index.reconstruct(1004)[0] == 104
try:
Expand Down Expand Up @@ -469,14 +462,7 @@ def test_dedup(self):
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)

# test I/O
fd, tmpfile = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index_new, tmpfile)
index_st = faiss.read_index(tmpfile)
finally:
if os.path.exists(tmpfile):
os.unlink(tmpfile)
index_st = faiss.deserialize_index(faiss.serialize_index(index_new))
Dst, Ist = index_st.search(xq, 20)

check_ref_knn_with_draws(Dnew, Inew, Dst, Ist)
Expand Down
Loading
Loading