diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32c6adf..ed6f375 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.10.1 hooks: - id: mypy additional_dependencies: @@ -12,7 +12,7 @@ repos: - types-tabulate - types-tqdm - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.4.9" + rev: "v0.5.2" hooks: - id: ruff - id: ruff-format diff --git a/CHANGELOG.md b/CHANGELOG.md index af6a77b..c954390 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +## [0.4.1] - 2024-07-17 + +### Added + +- OAuth2 support ([#180](https://github.com/stac-utils/stac-asset/pull/180)) +- Retry configuration for HTTP ([#192](https://github.com/stac-utils/stac-asset/pull/192)) + ## [0.4.0] - 2024-06-19 ### Added @@ -165,7 +172,8 @@ Used to be . Initial release. -[unreleased]: https://github.com/stac-utils/stac-asset/compare/v0.4.0...HEAD +[unreleased]: https://github.com/stac-utils/stac-asset/compare/v0.4.1...HEAD +[0.4.1]: [0.4.0]: [0.3.3]: [0.3.2]: diff --git a/pyproject.toml b/pyproject.toml index b414654..78db00b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "stac-asset" -version = "0.4.0" +version = "0.4.1" description = "Read and download STAC assets across platforms and providers" authors = [{ name = "Pete Gadomski", email = "pete.gadomski@gmail.com" }] readme = "README.md" @@ -23,6 +23,8 @@ dependencies = [ "pystac>=1.8.4", "python-dateutil>=2.7.0", "yarl>=1.9.2", + "aiohttp-oauth2-client>=1.0.2", + "aiohttp-retry>=2.8.3", ] [project.optional-dependencies] @@ -40,8 +42,8 @@ dev = [ "pytest-asyncio~=0.21", "pytest-cov>=5.0", "pytest-recording~=0.13.1", - "ruff==0.4.9", - "types-aiofiles~=23.1", + "ruff==0.5.2", + "types-aiofiles~=24.1", "types-python-dateutil~=2.9", "types-tqdm~=4.66.0", "types-tabulate~=0.9.0", diff --git a/src/stac_asset/_cli.py b/src/stac_asset/_cli.py index 5e4df16..ff0e161 100644 --- a/src/stac_asset/_cli.py +++ b/src/stac_asset/_cli.py @@ -18,7 +18,11 @@ from . import Config, ErrorStrategy, _functions from .client import Clients -from .config import DEFAULT_S3_MAX_ATTEMPTS, DEFAULT_S3_RETRY_MODE +from .config import ( + DEFAULT_HTTP_MAX_ATTEMPTS, + DEFAULT_S3_MAX_ATTEMPTS, + DEFAULT_S3_RETRY_MODE, +) from .errors import DownloadError from .messages import ( ErrorAssetDownload, @@ -100,6 +104,11 @@ def cli() -> None: help="If downloading via the s3 client, the max number of retries", default=DEFAULT_S3_MAX_ATTEMPTS, ) +@click.option( + "--http-max-attempts", + help="If downloading via the http client, the max number of retries", + default=DEFAULT_HTTP_MAX_ATTEMPTS, +) @click.option( "-k", "--keep", @@ -139,6 +148,7 @@ def download( s3_requester_pays: bool, s3_retry_mode: str, s3_max_attempts: int, + http_max_attempts: int, keep: bool, fail_fast: bool, overwrite: bool, @@ -175,6 +185,7 @@ def download( s3_requester_pays, s3_retry_mode, s3_max_attempts, + http_max_attempts, keep=keep, fail_fast=fail_fast, overwrite=overwrite, @@ -194,6 +205,7 @@ async def download_async( s3_requester_pays: bool, s3_retry_mode: str, s3_max_attempts: int, + http_max_attempts: int, keep: bool, fail_fast: bool, overwrite: bool, @@ -205,6 +217,7 @@ async def download_async( s3_requester_pays=s3_requester_pays, s3_retry_mode=s3_retry_mode, s3_max_attempts=s3_max_attempts, + http_max_attempts=http_max_attempts, error_strategy=ErrorStrategy.KEEP if keep else ErrorStrategy.DELETE, warn=not fail_fast, fail_fast=fail_fast, @@ -392,12 +405,18 @@ class Download: help="If checking via the s3 client, the max number of retries", default=DEFAULT_S3_MAX_ATTEMPTS, ) +@click.option( + "--http-max-attempts", + help="If checking via the http client, the max number of retries", + default=DEFAULT_HTTP_MAX_ATTEMPTS, +) def info( href: Optional[str], alternate_assets: List[str], s3_requester_pays: bool, s3_retry_mode: str, s3_max_attempts: int, + http_max_attempts: int, ) -> None: asyncio.run( info_async( @@ -406,6 +425,7 @@ def info( s3_requester_pays=s3_requester_pays, s3_max_attempts=s3_max_attempts, s3_retry_mode=s3_retry_mode, + http_max_attempts=http_max_attempts, ) ) @@ -416,6 +436,7 @@ async def info_async( s3_requester_pays: bool, s3_retry_mode: str, s3_max_attempts: int, + http_max_attempts: int, ) -> None: """Prints information about an item or item collection. @@ -426,6 +447,7 @@ async def info_async( s3_requester_pays=s3_requester_pays, s3_retry_mode=s3_retry_mode, s3_max_attempts=s3_max_attempts, + http_max_attempts=http_max_attempts, ) input_dict = await read_as_dict(href, config) type_ = input_dict.get("type") diff --git a/src/stac_asset/config.py b/src/stac_asset/config.py index 33802db..9a85fae 100644 --- a/src/stac_asset/config.py +++ b/src/stac_asset/config.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import os from dataclasses import dataclass, field from typing import Dict, List, Optional @@ -11,6 +12,7 @@ DEFAULT_S3_RETRY_MODE = "adaptive" DEFAULT_S3_MAX_ATTEMPTS = 10 DEFAULT_HTTP_CLIENT_TIMEOUT = 300 +DEFAULT_HTTP_MAX_ATTEMPTS = 10 @dataclass @@ -63,6 +65,9 @@ class Config: http_client_timeout: Optional[float] = DEFAULT_HTTP_CLIENT_TIMEOUT """Total number of seconds for the whole request.""" + http_max_attempts: int = DEFAULT_HTTP_MAX_ATTEMPTS + """The maximum number of attempts when downloading assets via http.""" + http_check_content_type: bool = True """If true, check the asset's content type against the response from the server.""" @@ -87,6 +92,76 @@ class Config: s3_endpoint_url: Optional[str] = None """Set an optional custom endpoint url for s3.""" + oauth2_grant: Optional[str] = field(default=os.getenv("OAUTH2_GRANT")) + """OAuth2 grant type. + + If a value is provided for this field, + the :py:class:`~stac_asset.http_client.HttpClient` will be configured with + support for OAuth2 access tokens. + Can be configured with the ``OAUTH2_GRANT`` environment variable. + """ + + oauth2_token_url: Optional[str] = field(default=os.getenv("OAUTH2_TOKEN_URL")) + """OAuth2 token URL. + + Can be configured with the ``OAUTH2_TOKEN_URL`` environment variable. + """ + + oauth2_authorization_url: Optional[str] = field( + default=os.getenv("OAUTH2_AUTHORIZATION_URL") + ) + """OAuth2 authorization URL. + + Can be configured with the ``OAUTH2_AUTHORIZATION_URL`` environment variable. + """ + + oauth2_device_authorization_url: Optional[str] = field( + default=os.getenv("OAUTH2_DEVICE_AUTHORIZATION_URL") + ) + """OAuth2 device authorization URL. + + Can be configured with the ``OAUTH2_DEVICE_AUTHORIZATION_URL`` environment variable. + """ + + oauth2_client_id: Optional[str] = field(default=os.getenv("OAUTH2_CLIENT_ID")) + """OAuth2 client identifier. + + Can be configured with the ``OAUTH2_CLIENT_ID`` environment variable. + """ + + oauth2_client_secret: Optional[str] = field( + default=os.getenv("OAUTH2_CLIENT_SECRET") + ) + """OAuth2 client secret. + + Can be configured with the ``OAUTH2_CLIENT_SECRET`` environment variable. + """ + + oauth2_pkce: bool = field( + default=os.getenv("OAUTH2_PKCE", "true").lower() not in ("false", "0") + ) + """OAuth2 Proof Key for Code Exchange. + + Can be configured with the ``OAUTH2_PKCE`` environment variable. + By default, PKCE is enabled. + """ + + oauth2_username: Optional[str] = field(default=os.getenv("OAUTH2_USERNAME")) + """OAuth2 username for resource owner password credentials grant. + + Can be configured with the ``OAUTH2_USERNAME`` environment variable. + """ + + oauth2_password: Optional[str] = field(default=os.getenv("OAUTH2_PASSWORD")) + """OAuth2 password for resource owner password credentials grant. + + Can be configured with the ``OAUTH2_PASSWORD`` environment variable. + """ + + oauth2_extra: Dict[str, str] = field(default_factory=dict) + """Extra configuration options for the OAuth2 grant. + """ + def validate(self) -> None: """Validates this configuration. diff --git a/src/stac_asset/http_client.py b/src/stac_asset/http_client.py index 608c0e4..3a2a1c6 100644 --- a/src/stac_asset/http_client.py +++ b/src/stac_asset/http_client.py @@ -3,7 +3,11 @@ from types import TracebackType from typing import AsyncIterator, Optional, Type, TypeVar -from aiohttp import ClientSession, ClientTimeout +from aiohttp import ClientResponseError, ClientSession, ClientTimeout +from aiohttp_oauth2_client.client import OAuth2Client +from aiohttp_oauth2_client.models.grant import GrantType +from aiohttp_retry import JitterRetry, RetryClient +from aiohttp_retry.types import ClientType from yarl import URL from . import validate @@ -24,15 +28,99 @@ class HttpClient(Client): @classmethod async def from_config(cls: Type[T], config: Config) -> T: - """Creates the default http client with a vanilla session object.""" + """Creates an HTTP client with an aiohttp session object. + + To use OAuth2 access tokens, configure the + :py:attr:`~stac_asset.Config.oauth2_grant` and the necessary parameters for the + chosen grant type. + + OAuth2 device code grant (``urn:ietf:params:oauth:grant-type:device_code`` or ``device_code``): + - :py:attr:`~stac_asset.Config.oauth2_token_url` + - :py:attr:`~stac_asset.Config.oauth2_device_authorization_url` + - :py:attr:`~stac_asset.Config.oauth2_client_id` + - :py:attr:`~stac_asset.Config.oauth2_pkce` + + OAuth2 authorization code grant (``authorization_code``): + - :py:attr:`~stac_asset.Config.oauth2_token_url` + - :py:attr:`~stac_asset.Config.oauth2_authorization_url` + - :py:attr:`~stac_asset.Config.oauth2_client_id` + - :py:attr:`~stac_asset.Config.oauth2_pkce` + + OAuth2 resource owner password credentials grant (``password``): + - :py:attr:`~stac_asset.Config.oauth2_token_url` + - :py:attr:`~stac_asset.Config.oauth2_username` + - :py:attr:`~stac_asset.Config.oauth2_password` + - :py:attr:`~stac_asset.Config.oauth2_client_id` + + OAuth2 client credentials grant (``client_credentials``): + - :py:attr:`~stac_asset.Config.oauth2_token_url` + - :py:attr:`~stac_asset.Config.oauth2_client_id` + - :py:attr:`~stac_asset.Config.oauth2_client_secret` + """ # noqa: E501 # TODO add basic auth timeout = ClientTimeout(total=config.http_client_timeout) - session = ClientSession(timeout=timeout, headers=config.http_headers) + if config.oauth2_grant is not None: + if GrantType.DEVICE_CODE.endswith(config.oauth2_grant): + from aiohttp_oauth2_client.grant.device_code import DeviceCodeGrant + + grant = DeviceCodeGrant( + token_url=config.oauth2_token_url, + device_authorization_url=config.oauth2_device_authorization_url, + client_id=config.oauth2_client_id, + pkce=config.oauth2_pkce, + **config.oauth2_extra, + ) + elif config.oauth2_grant == GrantType.AUTHORIZATION_CODE: + from aiohttp_oauth2_client.grant.authorization_code import ( + AuthorizationCodeGrant, + ) + + grant = AuthorizationCodeGrant( + token_url=config.oauth2_token_url, + authorization_url=config.oauth2_authorization_url, + client_id=config.oauth2_client_id, + pkce=config.oauth2_pkce, + **config.oauth2_extra, + ) + elif config.oauth2_grant == GrantType.RESOURCE_OWNER_PASSWORD_CREDENTIALS: + from aiohttp_oauth2_client.grant.resource_owner_password_credentials import ( # noqa: E501 + ResourceOwnerPasswordCredentialsGrant, + ) + + grant = ResourceOwnerPasswordCredentialsGrant( + token_url=config.oauth2_token_url, + username=config.oauth2_username, + password=config.oauth2_password, + client_id=config.oauth2_client_id, + **config.oauth2_extra, + ) + elif config.oauth2_grant == GrantType.CLIENT_CREDENTIALS: + from aiohttp_oauth2_client.grant.client_credentials import ( + ClientCredentialsGrant, + ) + + grant = ClientCredentialsGrant( + token_url=config.oauth2_token_url, + client_id=config.oauth2_client_id, + client_secret=config.oauth2_client_secret, + **config.oauth2_extra, + ) + else: + raise ValueError("Unknown grant type") + session = OAuth2Client(grant, timeout=timeout, headers=config.http_headers) + else: + session = ClientSession(timeout=timeout, headers=config.http_headers) + session = RetryClient( + client_session=session, + retry_options=JitterRetry( + attempts=config.http_max_attempts, exceptions={ClientResponseError} + ), + ) return cls(session, config.http_check_content_type) def __init__( self, - session: ClientSession, + session: ClientType, check_content_type: bool, ) -> None: super().__init__() diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 47f2062..09ca244 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -1,4 +1,11 @@ import pytest +from aiohttp_oauth2_client.client import OAuth2Client +from aiohttp_oauth2_client.grant.authorization_code import AuthorizationCodeGrant +from aiohttp_oauth2_client.grant.client_credentials import ClientCredentialsGrant +from aiohttp_oauth2_client.grant.device_code import DeviceCodeGrant +from aiohttp_oauth2_client.grant.resource_owner_password_credentials import ( + ResourceOwnerPasswordCredentialsGrant, +) from stac_asset import Config, HttpClient pytestmark = [ @@ -14,4 +21,54 @@ async def test_href_exists() -> None: async def test_default_http_timeout() -> None: async with await HttpClient.from_config(Config(http_client_timeout=42)) as client: - assert client.session.timeout.total == 42 + assert client.session._client.timeout.total == 42 + + +async def test_oauth2_device_code_config() -> None: + config = Config( + oauth2_grant="device_code", + oauth2_token_url="https://example.com/token", + oauth2_device_authorization_url="https://example.com/auth/device", + oauth2_client_id="public", + ) + async with await HttpClient.from_config(config) as client: + assert isinstance(client.session, OAuth2Client) + assert isinstance(client.session.grant, DeviceCodeGrant) + + +async def test_oauth2_authorization_code_config() -> None: + config = Config( + oauth2_grant="authorization_code", + oauth2_token_url="https://example.com/token", + oauth2_authorization_url="https://example.com/auth", + oauth2_client_id="public", + oauth2_pkce=False, + ) + async with await HttpClient.from_config(config) as client: + assert isinstance(client.session, OAuth2Client) + assert isinstance(client.session.grant, AuthorizationCodeGrant) + + +async def test_oauth2_password_config() -> None: + config = Config( + oauth2_grant="password", + oauth2_token_url="https://example.com/token", + oauth2_client_id="public", + oauth2_username="user", + oauth2_password="secret", + ) + async with await HttpClient.from_config(config) as client: + assert isinstance(client.session, OAuth2Client) + assert isinstance(client.session.grant, ResourceOwnerPasswordCredentialsGrant) + + +async def test_oauth2_client_credentials_config() -> None: + config = Config( + oauth2_grant="client_credentials", + oauth2_token_url="https://example.com/token", + oauth2_client_id="my-client", + oauth2_client_secret="secret", + ) + async with await HttpClient.from_config(config) as client: + assert isinstance(client.session, OAuth2Client) + assert isinstance(client.session.grant, ClientCredentialsGrant)