8000 Feature: modification and creation date in the aggregate messages by 1yam · Pull Request #473 · aleph-im/pyaleph · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Feature: modification and creation date in the aggregate messages #473

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions src/aleph/db/accessors/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import datetime as dt
from typing import Optional, Iterable, Any, Dict, Tuple, Sequence

from typing import (
Optional,
Iterable,
Any,
Dict,
Tuple,
Sequence,
overload,
Literal,
Union,
)
from sqlalchemy import (
select,
delete,
Expand All @@ -22,20 +31,64 @@ def aggregate_exists(session: DbSession, key: str, owner: str) -> bool:
)


AggregateContent = Iterable[Tuple[str, Dict[str, Any]]]
AggregateContentWithInfo = Iterable[Tuple[str, dt.datetime, dt.datetime, str, str]]


@overload
def get_aggregates_by_owner(
session: Any,
owner: str,
with_info: Literal[False],
keys: Optional[Sequence[str]] = None,
) -> AggregateContent:
...


@overload
def get_aggregates_by_owner(
session: Any,
owner: str,
with_info: Literal[True],
keys: Optional[Sequence[str]] = None,
) -> AggregateContentWithInfo:
...


@overload
def get_aggregates_by_owner(
session: DbSession, owner: str, keys: Optional[Sequence[str]] = None
) -> Iterable[Tuple[str, Dict[str, Any]]]:
session, owner: str, with_info: bool, keys: Optional[Sequence[str]] = None
) -> Union[AggregateContent, AggregateContentWithInfo]:
...


def get_aggregates_by_owner(session, owner, with_info, keys=None):
where_clause = AggregateDb.owner == owner
if keys:
where_clause = where_clause & AggregateDb.key.in_(keys)

select_stmt = (
select(AggregateDb.key, AggregateDb.content)
.where(where_clause)
.order_by(AggregateDb.key)
)
return session.execute(select_stmt).all() # type: ignore
if with_info:
query = (
session.query(
AggregateDb.key,
AggregateDb.content,
AggregateDb.creation_datetime.label("created"),
AggregateElementDb.creation_datetime.label("last_updated"),
AggregateDb.last_revision_hash.label("last_update_item_hash"),
AggregateElementDb.item_hash.label("original_item_hash"),
)
.join(
AggregateElementDb,
AggregateDb.last_revision_hash == AggregateElementDb.item_hash,
)
.filter(AggregateDb.owner == owner)
)
else:
query = (
session.query(AggregateDb.key, AggregateDb.content)
.filter(where_clause)
.order_by(AggregateDb.key)
)
return query.all()


def get_aggregate_by_key(
Expand All @@ -44,7 +97,6 @@ def get_aggregate_by_key(
key: str,
with_content: bool = True,
) -> Optional[AggregateDb]:

options = []

if not with_content:
Expand Down Expand Up @@ -170,7 +222,6 @@ def mark_aggregate_as_dirty(session: DbSession, owner: str, key: str) -> None:


def refresh_aggregate(session: DbSession, owner: str, key: str) -> None:

# Step 1: use a group by to retrieve the aggregate content. This uses a custom
# aggregate function (see 78dd67881db4_jsonb_merge_aggregate.py).
select_merged_aggregate_subquery = (
Expand Down
48 changes: 40 additions & 8 deletions src/aleph/web/controllers/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from typing import List, Optional, Any, Dict, Tuple
import datetime as dt
from typing import List, Optional, Any, Dict, Tuple, Literal, Sequence

from aiohttp import web
from pydantic import BaseModel, validator, ValidationError
from sqlalchemy import select

from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate
from aleph.db.accessors.aggregates import (
get_aggregates_by_owner,
refresh_aggregate,
)
from aleph.db.models import AggregateDb
from .utils import LIST_FIELD_SEPARATOR

Expand All @@ -17,6 +21,7 @@
class AggregatesQueryParams(BaseModel):
keys: Optional[List[str]] = None
limit: int = DEFAULT_LIMIT
with_info: bool = False

@validator(
"keys",
Expand All @@ -33,17 +38,15 @@ async def address_aggregate(request: web.Request) -> web.Response:
TODO: handle filter on a single key, or even subkey.
"""

address = request.match_info["address"]
address: str = request.match_info["address"]

try:
query_params = AggregatesQueryParams.parse_obj(request.query)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(
text=e.json(), content_type="application/json"
)

session_factory = request.app["session_factory"]

with session_factory() as session:
dirty_aggregates = session.execute(
select(AggregateDb.key).where(
Expand All @@ -57,9 +60,9 @@ async def address_aggregate(request: web.Request) -> web.Response:
refresh_aggregate(session=session, owner=address, key=key)
session.commit()

aggregates: List[Tuple[str, Dict[str, Any]]] = list(
aggregates = list(
get_aggregates_by_owner(
session=session, owner=address, keys=query_params.keys
session=session, owner=address, with_info=query_params.with_info
)
)

Expand All @@ -68,6 +71,35 @@ async def address_aggregate(request: web.Request) -> web.Response:

output = {
"address": address,
"data": {result[0]: result[1] for result in aggregates},
"data": {},
}
info: Dict = {}
data: Dict = {}

for result in aggregates:
data[result[0]] = result[1]
if query_params.with_info:
(
aggregate_key,
content,
created,
last_updated,
original_item_hash,
last_update_item_hash,
) = result

if isinstance(created, dt.datetime):
created = created.isoformat()
if isinstance(last_updated, dt.datetime):
last_updated = last_updated.isoformat()
info[aggregate_key] = {
"created": str(created),
"last_updated": str(last_updated),
"original_item_hash": str(original_item_hash),
"last_update_item_hash": str(last_update_item_hash),
}

output["data"] = data
output["info"] = info

return web.json_response(output)
32 changes: 22 additions & 10 deletions tests/api/test_aggregates.py
9E7A
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@ def make_uri(address: str) -> str:
return AGGREGATES_URI.format(address=address)


async def get_aggregates(api_client, address: str, **params) -> aiohttp.ClientResponse:
async def get_aggregates(
api_client, address: str, with_info: bool, **params
) -> aiohttp.ClientResponse:
params["with_info"] = str(with_info)
return await api_client.get(make_uri(address), params=params)


async def get_aggregates_expect_success(api_client, address: str, **params):
response = await get_aggregates(api_client, address, **params)
async def get_aggregates_expect_success(
api_client, address: str, with_info: bool, **params
):
response = await get_aggregates(api_client, address, with_info, **params)
assert response.status == 200, await response.text()
return await response.json()

Expand All @@ -47,7 +52,7 @@ async def test_get_aggregates_no_update(
assert fixture_aggregate_messages # To avoid unused parameter warnings

address = ADDRESS_2
aggregates = await get_aggregates_expect_success(ccn_api_client, address)
aggregates = await get_aggregates_expect_success(ccn_api_client, address, False)

assert aggregates["address"] == address
assert aggregates["data"] == EXPECTED_AGGREGATES[address]
Expand All @@ -63,12 +68,17 @@ async def test_get_aggregates(
assert fixture_aggregate_messages # To avoid unused parameter warnings

address = ADDRESS_1
aggregates = await get_aggregates_expect_success(ccn_api_client, address)
aggregates = await get_aggregates_expect_success(ccn_api_client, address, True)

assert address == aggregates["address"]
assert aggregates["data"]["test_key"] == {"a": 1, "b": 2}
assert aggregates["data"]["test_target"] == {"a": 1, "b": 2}
assert aggregates["data"]["test_reference"] == {"a": 1, "b": 2, "c": 3, "d": 4}
assert aggregates["info"]["test_reference"]
assert (
aggregates["info"]["test_reference"]["original_item_hash"]
== fixture_aggregate_messages[1].item_hash
)


@pytest.mark.asyncio
Expand All @@ -83,15 +93,15 @@ async def test_get_aggregates_filter_by_key(

address, key = ADDRESS_1, "test_target"
aggregates = await get_aggregates_expect_success(
ccn_api_client, address=address, keys=key
ccn_api_client, address=address, keys=key, with_info=False
)
assert aggregates["address"] == address
assert aggregates["data"][key] == EXPECTED_AGGREGATES[address][key]

# Multiple keys
address, keys = ADDRESS_1, ["test_target", "test_reference"]
aggregates = await get_aggregates_expect_success(
ccn_api_client, address=address, keys=",".join(keys)
ccn_api_client, address=address, keys=",".join(keys), with_info=False
)
assert aggregates["address"] == address
for key in keys:
Expand All @@ -114,7 +124,7 @@ async def test_get_aggregates_limit(

address, key = ADDRESS_1, "test_reference"
aggregates = await get_aggregates_expect_success(
ccn_api_client, address=address, keys=key, limit=1
ccn_api_client, address=address, keys=key, limit=1, with_info=False
)
assert aggregates["address"] == address
assert aggregates["data"][key] == {"c": 3, "d": 4}
Expand All @@ -131,7 +141,7 @@ async def test_get_aggregates_invalid_address(

invalid_address = "unknown"

response = await get_aggregates(ccn_api_client, invalid_address)
response = await get_aggregates(ccn_api_client, invalid_address, False)
assert response.status == 404


Expand All @@ -145,7 +155,9 @@ async def test_get_aggregates_invalid_params(
assert fixture_aggregate_messages # To avoid unused parameter warnings

# A string as limit
response = await get_aggregates(ccn_api_client, ADDRESS_1, limit="abc")
response = await get_aggregates(
ccn_api_client, ADDRESS_1, limit="abc", with_info=False
)
assert response.status == 422
assert response.content_type == "application/json"

Expand Down
1 change: 1 addition & 0 deletions tests/message_processing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ def message_processor(mocker, mock_config: Config, session_factory: DbSessionFac
mq_conn=mocker.AsyncMock(),
)
return message_processor

0