diff --git a/homeassistant/components/sensor/sql.py b/homeassistant/components/sensor/sql.py index 395c082f9d2fe9..4edb13e0416f24 100644 --- a/homeassistant/components/sensor/sql.py +++ b/homeassistant/components/sensor/sql.py @@ -24,9 +24,17 @@ CONF_QUERY = 'query' CONF_COLUMN_NAME = 'column' + +def validate_sql_select(value): + """Validate that value is a SQL SELECT query.""" + if not value.lstrip().lower().startswith('select'): + raise vol.Invalid('Only SELECT queries allowed') + return value + + _QUERY_SCHEME = vol.Schema({ vol.Required(CONF_NAME): cv.string, - vol.Required(CONF_QUERY): cv.string, + vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select), vol.Required(CONF_COLUMN_NAME): cv.string, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, @@ -129,14 +137,16 @@ def update(self): finally: sess.close() + if not result.returns_rows or result.rowcount == 0: + _LOGGER.warning("%s returned no results", self._query) + self._state = None + self._attributes = {} + return + for res in result: - _LOGGER.debug(res.items()) + _LOGGER.debug("result = %s", res.items()) data = res[self._column_name] - self._attributes = {k: str(v) for k, v in res.items()} - - if data is None: - _LOGGER.error("%s returned no results", self._query) - return + self._attributes = {k: v for k, v in res.items()} if self._template is not None: self._state = self._template.async_render_with_possible_json_value( diff --git a/tests/components/sensor/test_sql.py b/tests/components/sensor/test_sql.py index ebf2d749e67b53..5e639b9f3386f3 100644 --- a/tests/components/sensor/test_sql.py +++ b/tests/components/sensor/test_sql.py @@ -1,7 +1,11 @@ """The test for the sql sensor platform.""" import unittest +import pytest +import voluptuous as vol +from homeassistant.components.sensor.sql import validate_sql_select from homeassistant.setup import setup_component +from homeassistant.const import STATE_UNKNOWN from tests.common import get_test_home_assistant @@ -35,3 +39,25 @@ def test_query(self): state = self.hass.states.get('sensor.count_tables') self.assertEqual(state.state, '0') + + def test_invalid_query(self): + """Test the SQL sensor for invalid queries.""" + with pytest.raises(vol.Invalid): + validate_sql_select("DROP TABLE *") + + config = { + 'sensor': { + 'platform': 'sql', + 'db_url': 'sqlite://', + 'queries': [{ + 'name': 'count_tables', + 'query': 'SELECT * value FROM sqlite_master;', + 'column': 'value', + }] + } + } + + assert setup_component(self.hass, 'sensor', config) + + state = self.hass.states.get('sensor.count_tables') + self.assertEqual(state.state, STATE_UNKNOWN)