8000 Add AI task structured output by allenporter · Pull Request #148083 · home-assistant/core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add AI task structured output #148083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions homeassistant/components/ai_task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Integration to offer AI tasks to Home Assistant."""

import logging
from typing import Any

import voluptuous as vol

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
HassJobType,
HomeAssistant,
Expand All @@ -14,12 +15,14 @@
SupportsResponse,
callback,
)
from homeassistant.helpers import config_validation as cv, storage
from homeassistant.helpers import config_validation as cv, selector, storage
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType

from .const import (
ATTR_INSTRUCTIONS,
ATTR_REQUIRED,
ATTR_STRUCTURE,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_PREFERENCES,
Expand Down Expand Up @@ -47,6 +50,27 @@

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)

STRUCTURE_FIELD_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DESCRIPTION): str,
vol.Optional(ATTR_REQUIRED): bool,
vol.Required(CONF_SELECTOR): selector.validate_selector,
}
)


def _validate_structure_fields(value: dict[str, Any]) -> vol.Schema:
"""Validate the structure fields as a voluptuous Schema."""
if not isinstance(value, dict):
raise vol.Invalid("Structure must be a dictionary")

Check warning on line 65 in homeassistant/components/ai_task/__init__.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/ai_task/__init__.py#L65

Added line #L65 was not covered by tests
fields = {}
for k, v in value.items():
field_class = vol.Required if v.get(ATTR_REQUIRED, False) else vol.Optional
fields[field_class(k, description=v.get(CONF_DESCRIPTION))] = selector.selector(
v[CONF_SELECTOR]
)
return vol.Schema(fields, extra=vol.PREVENT_EXTRA)


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
Expand All @@ -64,6 +88,10 @@
vol.Required(ATTR_TASK_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_STRUCTURE): vol.All(
vol.Schema({str: STRUCTURE_FIELD_SCHEMA}),
_validate_structure_fields,
),
}
),
supports_response=SupportsResponse.ONLY,
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/ai_task/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
ATTR_STRUCTURE: Final = "structure"
ATTR_REQUIRED: Final = "required"

DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks."
Expand Down
6 changes: 6 additions & 0 deletions homeassistant/components/ai_task/services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ generate_data:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_DATA
structure:
advanced: true
required: false
example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }'
selector:
object:
4 changes: 4 additions & 0 deletions homeassistant/components/ai_task/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"entity_id": {
"name": "Entity ID",
"description": "Entity ID to run the task on. If not provided, the preferred entity will be used."
},
"structure": {
"name": "Structured output",
"description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field."
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions homeassistant/components/ai_task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from typing import Any

import voluptuous as vol

from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError

Expand All @@ -17,6 +19,7 @@ async def async_generate_data(
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
Expand All @@ -38,6 +41,7 @@ async def async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
)
)

Expand All @@ -52,6 +56,9 @@ class GenDataTask:
instructions: str
"""Instructions on what needs to be done."""

structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""

def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
def _base_components() -> dict[str, ModuleType]:
"""Return a cached lookup of base components."""
from homeassistant.components import ( # noqa: PLC0415
ai_task,
alarm_control_panel,
assist_satellite,
calendar,
Expand All @@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
)

return {
"ai_task": ai_task,
"alarm_control_panel": alarm_control_panel,
"assist_satellite": assist_satellite,
"calendar": calendar,
Expand Down
12 changes: 10 additions & 2 deletions tests/components/ai_task/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test helpers for AI Task integration."""

import json

import pytest

from homeassistant.components.ai_task import (
Expand Down Expand Up @@ -45,12 +47,18 @@ async def _async_generate_data(
) -> GenDataTaskResult:
"""Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task)
if task.structure is not None:
data = {"name": "Tracy Chen", "age": 30}
data_chat_log = json.dumps(data)
else:
data = "Mock result"
data_chat_log = data
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result")
AssistantContent(self.entity_id, data_chat_log)
)
return GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data="Mock result",
data=data,
)


Expand Down
39 changes: 39 additions & 0 deletions tests/components/ai_task/test_entity.py
4832
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for the AI Task entity model."""

from freezegun import freeze_time
import voluptuous as vol

from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import selector

from .conftest import TEST_ENTITY_ID, MockAITaskEntity

Expand Down Expand Up @@ -37,3 +39,40 @@ async def test_state_generate_data(
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt"


async def test_generate_structured_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the entity can generate structured data."""
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Please generate a profile for a new user",
structure=vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
)
),
}
),
)
# Arbitrary data returned by the mock entity (not determined by above schema in test)
assert result.data == {
"name": "Tracy Chen",
"age": 30,
}

assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Please generate a profile for a new user"
assert task.structure
assert isinstance(task.structure, vol.Schema)
Loading
0