From 876bbf194a70bed67b014ba2a91866237859e3ec Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Mon, 14 Jun 2021 12:41:38 +0200 Subject: [PATCH] fix: fix sparse pipeline test --- .../sparse_pipeline/test_sparse_pipeline.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/integration/sparse_pipeline/test_sparse_pipeline.py b/tests/integration/sparse_pipeline/test_sparse_pipeline.py index 31323c121dbf7..b45f5bf0271ce 100644 --- a/tests/integration/sparse_pipeline/test_sparse_pipeline.py +++ b/tests/integration/sparse_pipeline/test_sparse_pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable +from typing import Any import os import pytest @@ -10,6 +10,7 @@ from tests import validate_callback cur_dir = os.path.dirname(os.path.abspath(__file__)) +TOP_K = 3 @pytest.fixture(scope='function') @@ -27,32 +28,28 @@ def docs_to_index(num_docs): class DummyCSRSparseIndexEncoder(Executor): - embedding_cls_type = 'scipy_csr' - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.docs = DocumentArray() - self.vectors = {} - @requests(on='index') + @requests(on='/index') def encode(self, docs: 'DocumentArray', *args, **kwargs) -> Any: + for i, doc in enumerate(docs): + doc.embedding = sparse.coo_matrix(doc.content) self.docs.extend(docs) - for i, doc in enumerate(self.docs): - doc.embedding = sparse.csr_matrix(doc.content) - self.vectors[doc.id] = doc.embedding.getrow(i) - @requests(on='search') - def query(self, parameters, *args, **kwargs): - top_k = parameters['top_k'] - doc = parameters['doc'] - distances = [item for item in range(0, min(top_k, len(self.docs)))] - return [self.docs[:top_k]], np.array([distances]) + @requests(on='/search') + def query(self, docs: 'DocumentArray', parameters, *args, **kwargs): + top_k = int(parameters['top_k']) + for doc in docs: + doc.matches = self.docs[:top_k] def test_sparse_pipeline(mocker, docs_to_index): def validate(response): - assert len(response.data.docs) == 10 - for doc in response.data.docs: + assert len(response.docs) == 1 + for doc in response.docs: + assert len(doc.matches) == TOP_K for i, match in enumerate(doc.matches): assert match.id == docs_to_index[i].id assert isinstance(match.embedding, sparse.coo_matrix) @@ -63,13 +60,15 @@ def validate(response): error_mock = mocker.Mock() with f: - f.index( + f.post( + on='/index', inputs=docs_to_index, - on_done=mock, + on_error=error_mock, ) - f.search( + f.post( + on='/search', inputs=docs_to_index[0], - parameters={'doc': docs_to_index[0], 'top_k': 1}, + parameters={'top_k': TOP_K}, on_done=mock, on_error=error_mock, )