8000 Detect if mysql and sqlite support row_number by emontnemery · Pull Request #57475 · home-assistant/core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Detect if mysql and sqlite support row_number #57475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions homeassistant/components/recorder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def __init__(
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
self._queue_watcher = None
self._db_supports_row_number = True

self.enabled = True

Expand Down Expand Up @@ -972,6 +973,7 @@ def _setup_connection(self):
def setup_recorder_connection(dbapi_connection, connection_record):
"""Dbapi specific connection settings."""
setup_connection_for_dialect(
self,
self.engine.dialect.name,
dbapi_connection,
not self._completed_first_database_setup,
Expand Down
103 changes: 77 additions & 26 deletions homeassistant/components/recorder/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@
.label("rownum"),
]

QUERY_STATISTICS_SUMMARY_SUM_LEGACY = [
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.last_reset,
StatisticsShortTerm.state,
StatisticsShortTerm.sum,
]

QUERY_STATISTIC_META = [
StatisticsMeta.id,
StatisticsMeta.statistic_id,
Expand Down Expand Up @@ -274,37 +281,81 @@ def compile_hourly_statistics(
}

# Get last hour's last sum
subquery = (
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
.filter(StatisticsShortTerm.start < bindparam("end_time"))
.subquery()
)
query = (
session.query(subquery)
.filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id)
)
stats = execute(query.params(start_time=start_time, end_time=end_time))
if instance._db_supports_row_number: # pylint: disable=[protected-access]
subquery = (
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
.filter(StatisticsShortTerm.start < bindparam("end_time"))
.subquery()
)
query = (
session.query(subquery)
.filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id)
)
stats = execute(query.params(start_time=start_time, end_time=end_time))

if stats:
for stat in stats:
metadata_id, start, last_reset, state, _sum, _ = stat
if metadata_id in summary:
summary[metadata_id].update(
{
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
else:
baked_query = instance.hass.data[STATISTICS_SHORT_TERM_BAKERY](
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
)

if stats:
for stat in stats:
metadata_id, start, last_reset, state, _sum, _ = stat
if metadata_id in summary:
summary[metadata_id].update(
{
baked_query += lambda q: q.filter(
StatisticsShortTerm.start >= bindparam("start_time")
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start < bindparam("end_time")
)
baked_query += lambda q: q.order_by(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
)

stats = execute(
baked_query(session).params(start_time=start_time, end_time=end_time)
)

if stats:
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore
(
metadata_id,
last_reset,
state,
_sum,
) = next(group)
if metadata_id in summary:
summary[metadata_id].update(
{
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}

# Insert compiled hourly statistics in the database
for metadata_id, stat in summary.items():
Expand Down
28 changes: 27 additions & 1 deletion homeassistant/components/recorder/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,18 @@ def execute_on_connection(dbapi_connection, statement):
cursor.close()


def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection):
def query_on_connection(dbapi_connection, statement):
"""Execute a single statement with a dbapi connection and return the result."""
cursor = dbapi_connection.cursor()
cursor.execute(statement)
result = cursor.fetchall()
cursor.close()
return result


def setup_connection_for_dialect(
instance, dialect_name, dbapi_connection, first_connection
):
"""Execute statements needed for dialect connection."""
# Returns False if the the connection needs to be setup
# on the next connection, returns True if the connection
Expand All @@ -280,6 +291,13 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
# WAL mode only needs to be setup once
# instead of every time we open the sqlite connection
# as its persistent and isn't free to call every time.
result = query_on_connection(dbapi_connection, "SELECT sqlite_version()")
version = result[0][0]
major, minor, _patch = version.split(".", 2)
if int(major) == 3 and int(minor) < 25:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)

# approximately 8MiB of memory
execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192")
Expand All @@ -289,6 +307,14 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio

if dialect_name == "mysql":
execute_on_connection(dbapi_connection, "SET session wait_timeout=28800")
if first_connection:
result = query_on_connection(dbapi_connection, "SELECT VERSION()")
version = result[0][0]
major, minor, _patch = version.split(".", 2)
if int(major) == 5 and int(minor) < 8:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)


def end_incomplete_runs(session, start_time):
Expand Down
80 changes: 62 additions & 18 deletions tests/components/recorder/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,44 +122,88 @@ async def test_last_run_was_recently_clean(hass):
)


def test_setup_connection_for_dialect_mysql():
@pytest.mark.parametrize(
"mysql_version, db_supports_row_number",
[
("10.0.0", True),
("5.8.0", True),
("5.7.0", False),
],
)
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number):
"""Test setting up the connection for a mysql dialect."""
execute_mock = MagicMock()
instance_mock = MagicMock(_db_supports_row_number=True)
execute_args = []
close_mock = MagicMock()

def execute_mock(statement):
nonlocal execute_args
execute_args.append(statement)

def fetchall_mock():
nonlocal execute_args
if execute_args[-1] == "SELECT VERSION()":
return [[mysql_version]]
return None

def _make_cursor_mock(*_):
return MagicMock(execute=execute_mock, close=close_mock)
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)

dbapi_connection = MagicMock(cursor=_make_cursor_mock)

util.setup_connection_for_dialect("mysql", dbapi_connection, True)
util.setup_connection_for_dialect(instance_mock, "mysql", dbapi_connection, True)

assert len(execute_args) == 2
assert execute_args[0] == "SET session wait_timeout=28800"
assert execute_args[1] == "SELECT VERSION()"

assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800"
assert instance_mock._db_supports_row_number == db_supports_row_number

9E81
def test_setup_connection_for_dialect_sqlite():
@pytest.mark.parametrize(
"sqlite_version, db_supports_row_number",
[
("3.25.0", True),
("3.24.0", False),
],
)
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number):
"""Test setting up the connection for a sqlite dialect."""
execute_mock = MagicMock()
instance_mock = MagicMock(_db_supports_row_number=True)
execute_args = []
close_mock = MagicMock()

def execute_mock(statement):
nonlocal execute_args
execute_args.append(statement)

def fetchall_mock():
nonlocal execute_args
if execute_args[-1] == "SELECT sqlite_version()":
return [[sqlite_version]]
return None

def _make_cursor_mock(*_):
return MagicMock(execute=execute_mock, close=close_mock)
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)

dbapi_connection = MagicMock(cursor=_make_cursor_mock)

util.setup_connection_for_dialect("sqlite", dbapi_connection, True)
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, True)

assert len(execute_args) == 4
assert execute_args[0] == "PRAGMA journal_mode=WAL"
assert execute_args[1] == "SELECT sqlite_version()"
assert execute_args[2] == "PRAGMA cache_size = -8192"
assert execute_args[3] == "PRAGMA foreign_keys=ON"

assert len(execute_mock.call_args_list) == 3
assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL"
assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192"
assert execute_mock.call_args_list[2][0][0] == "PRAGMA foreign_keys=ON"
execute_args = []
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, False)

execute_mock.reset_mock()
util.setup_connection_for_dialect("sqlite", dbapi_connection, False)
assert len(execute_args) == 2
assert execute_args[0] == "PRAGMA cache_size = -8192"
assert execute_args[1] == "PRAGMA foreign_keys=ON"

assert len(execute_mock.call_args_list) == 2
assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192"
assert execute_mock.call_args_list[1][0][0] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number


def test_basic_sanity_check(hass_recorder):
Expand Down
13 changes: 12 additions & 1 deletion tests/components/sensor/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,13 @@ def test_compile_hourly_statistics_changing_statistics(
assert "Error while processing event StatisticsTask" not in caplog.text


def test_compile_statistics_hourly_summary(hass_recorder, caplog):
@pytest.mark.parametrize(
"db_supports_row_number,in_log,not_in_log",
[(True, "row_number", None), (False, None, "row_number")],
)
def test_compile_statistics_hourly_summary(
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
):
"""Test compiling hourly statistics."""
zero = dt_util.utcnow()
zero = zero.replace(minute=0, second=0, microsecond=0)
Expand All @@ -1815,6 +1821,7 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
zero += timedelta(hours=1)
hass = hass_recorder()
recorder = hass.data[DATA_INSTANCE]
recorder._db_supports_row_number = db_supports_row_number
setup_component(hass, "sensor", {})
attributes = {
"device_class": None,
Expand Down Expand Up @@ -2052,6 +2059,10 @@ def _sum(seq, last_state, last_sum):
end += timedelta(hours=1)
assert stats == expected_stats
assert "Error while processing event StatisticsTask" not in caplog.text
if in_log:
assert in_log in caplog.text
if not_in_log:
assert not_in_log not in caplog.text


def record_states(hass, zero, entity_id, attributes, seq=None):
Expand Down
0