From 687ee3adb45784ba8661cf6f42e5db5e046c89bb Mon Sep 17 00:00:00 2001 From: Maximilian Werk Date: Mon, 8 Feb 2021 22:30:43 +0100 Subject: [PATCH] fix: size assertion for first add call --- jina/executors/indexers/vector.py | 3 ++- tests/unit/executors/indexers/test_numpyindexer.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jina/executors/indexers/vector.py b/jina/executors/indexers/vector.py index 40873dcba96d4..c7ce4d7dc25c7 100644 --- a/jina/executors/indexers/vector.py +++ b/jina/executors/indexers/vector.py @@ -116,7 +116,8 @@ def _validate_key_vector_shapes(self, keys, vectors): elif self.dtype != vectors.dtype.name: raise TypeError( f'vectors\' dtype {vectors.dtype.name} does not match with indexers\'s dtype: {self.dtype}') - elif keys.shape[0] != vectors.shape[0]: + + if keys.shape[0] != vectors.shape[0]: raise ValueError(f'number of key {keys.shape[0]} not equal to number of vectors {vectors.shape[0]}') def add(self, keys: Iterable[str], vectors: 'np.ndarray', *args, **kwargs) -> None: diff --git a/tests/unit/executors/indexers/test_numpyindexer.py b/tests/unit/executors/indexers/test_numpyindexer.py index 251719cffb345..a863d5ba18276 100644 --- a/tests/unit/executors/indexers/test_numpyindexer.py +++ b/tests/unit/executors/indexers/test_numpyindexer.py @@ -47,7 +47,6 @@ def test_numpy_indexer_long_ids(test_metas): indexer.save() assert os.path.exists(indexer.index_abspath) save_abspath = indexer.save_abspath - # assert False with BaseIndexer.load(save_abspath) as indexer: assert isinstance(indexer, NumpyIndexer) @@ -56,6 +55,16 @@ def test_numpy_indexer_long_ids(test_metas): assert idx.shape == (num_query, 4) +def test_numpy_indexer_assert_shape_mismatch(test_metas): + with NumpyIndexer(metric='euclidean', index_filename='np.test.gz', compress_level=0, + metas=test_metas) as indexer: + indexer.batch_size = 4 + vec_short = np.array([[1, 1, 1], [2, 2, 2]]) + vec_keys = np.array([1, 2, 3]) + with pytest.raises(ValueError): + indexer.add(vec_keys, vec_short) + + @pytest.mark.parametrize('batch_size, compress_level', [(None, 0), (None, 1), (16, 0), (16, 1)]) def test_numpy_indexer_known(batch_size, compress_level, test_metas): vectors = np.array([[1, 1, 1],