diff --git a/eodag/api/product/_product.py b/eodag/api/product/_product.py index da7c63308..55b34216c 100644 --- a/eodag/api/product/_product.py +++ b/eodag/api/product/_product.py @@ -28,12 +28,16 @@ import geojson import orjson import requests +from boto3 import Session +from boto3.resources.base import ServiceResource from pystac import Item -from requests import RequestException +from requests import PreparedRequest, RequestException from requests.auth import AuthBase +from requests.structures import CaseInsensitiveDict from shapely import geometry from shapely.errors import ShapelyError +from eodag.plugins.authentication.aws_auth import AwsAuth from eodag.types.queryables import CommonStacMetadata from eodag.types.stac_metadata import create_stac_metadata_model @@ -73,7 +77,12 @@ _import_stac_item_from_known_provider, _import_stac_item_from_unknown_provider, ) -from eodag.utils.exceptions import DownloadError, MisconfiguredError, ValidationError +from eodag.utils.exceptions import ( + DatasetCreationError, + DownloadError, + MisconfiguredError, + ValidationError, +) from eodag.utils.repr import dict_to_html_table if TYPE_CHECKING: @@ -622,6 +631,65 @@ def stream_download( **kwargs, ) + def get_storage_options( + self, + asset_key: Optional[str] = None, + ) -> dict[str, Any]: + """ + Get fsspec storage_options keyword arguments + """ + auth = self.downloader_auth.authenticate() if self.downloader_auth else None + if self.downloader is None: + return {} + + # default url and headers + try: + url = self.assets[asset_key]["href"] if asset_key else self.location + except KeyError as e: + raise DatasetCreationError(f"{asset_key} not found in {self} assets") from e + headers = {**USER_AGENT} + + if isinstance(auth, ServiceResource) and isinstance( + self.downloader_auth, AwsAuth + ): + auth_kwargs: dict[str, Any] = dict() + # AwsAuth + if s3_endpoint := getattr(self.downloader_auth.config, "s3_endpoint", None): + auth_kwargs["client_kwargs"] = {"endpoint_url": s3_endpoint} + if creds := cast( + Session, self.downloader_auth.s3_session + ).get_credentials(): + auth_kwargs["key"] = creds.access_key + auth_kwargs["secret"] = creds.secret_key + if creds.token: + auth_kwargs["token"] = creds.token + if requester_pays := getattr( + self.downloader_auth.config, "requester_pays", False + ): + auth_kwargs["requester_pays"] = requester_pays + else: + auth_kwargs["anon"] = True + return {"path": url, **auth_kwargs} + + if isinstance(auth, AuthBase): + # update url and headers with auth + req = PreparedRequest() + req.url = url + req.headers = CaseInsensitiveDict(headers) + if auth: + auth(req) + return {"path": req.url, "headers": dict(req.headers)} + + return {"path": url} + + def request_asset( + self, + url: str, + ) -> requests.Response: + """Perform a GET request to the given URL using product's authentication headers.""" + headers = self.get_storage_options().get("headers", {}) + return requests.get(url, headers=headers, stream=True) + def _init_progress_bar( self, progress_callback: Optional[ProgressCallback], diff --git a/eodag/plugins/authentication/openid_connect.py b/eodag/plugins/authentication/openid_connect.py index ae97d64f4..711863c99 100644 --- a/eodag/plugins/authentication/openid_connect.py +++ b/eodag/plugins/authentication/openid_connect.py @@ -17,11 +17,13 @@ # limitations under the License. from __future__ import annotations +import base64 import logging import re import string from datetime import datetime, timedelta, timezone from random import SystemRandom +from threading import Lock from typing import TYPE_CHECKING, Any, Optional from urllib.parse import parse_qs, urlparse @@ -76,6 +78,7 @@ class OIDCRefreshTokenBase(Authentication): def __init__(self, provider: str, config: PluginConfig) -> None: super(OIDCRefreshTokenBase, self).__init__(provider, config) + self._auth_lock = Lock() self.access_token = "" self.access_token_expiration = datetime.min.replace(tzinfo=timezone.utc) @@ -252,8 +255,9 @@ class OIDCAuthorizationCodeFlowAuth(OIDCRefreshTokenBase): * :attr:`~eodag.config.PluginConfig.token_key` (``str``): The key pointing to the token in the json response to the POST request to the token server * :attr:`~eodag.config.PluginConfig.token_provision` (``str``) (**mandatory**): One of - ``qs`` or ``header``. This is how the token obtained will be used to authenticate the - user on protected requests. If ``qs`` is chosen, then ``token_qs_key`` is mandatory + ``qs``, ``header`` or ``basic``. This is how the token obtained will be used to authenticate the + user on protected requests. If ``qs`` is chosen, then ``token_qs_key`` is mandatory. If ``basic`` is chosen, + the token is used as password with username "anonymous". * :attr:`~eodag.config.PluginConfig.login_form_xpath` (``str``) (**mandatory**): The xpath to the HTML form element representing the user login form * :attr:`~eodag.config.PluginConfig.authentication_uri_source` (``str``) (**mandatory**): Where @@ -301,9 +305,13 @@ def __init__(self, provider: str, config: PluginConfig) -> None: def validate_config_credentials(self) -> None: """Validate configured credentials""" super(OIDCAuthorizationCodeFlowAuth, self).validate_config_credentials() - if getattr(self.config, "token_provision", None) not in ("qs", "header"): + if getattr(self.config, "token_provision", None) not in ( + "qs", + "header", + "basic", + ): raise MisconfiguredError( - 'Provider config parameter "token_provision" must be one of "qs" or "header"' + 'Provider config parameter "token_provision" must be one of "qs", "header", or "basic"' ) if self.config.token_provision == "qs" and not getattr( self.config, "token_qs_key", "" @@ -315,12 +323,14 @@ def validate_config_credentials(self) -> None: def authenticate(self) -> CodeAuthorizedAuth: """Authenticate""" - self._get_access_token() + with self._auth_lock: + self._get_access_token() return CodeAuthorizedAuth( self.access_token, self.config.token_provision, key=getattr(self.config, "token_qs_key", None), + refresh_token=self.refresh_token, ) def _request_new_token(self) -> dict[str, str]: @@ -583,10 +593,17 @@ def compute_state() -> str: class CodeAuthorizedAuth(AuthBase): """CodeAuthorizedAuth custom authentication class to be used with requests module""" - def __init__(self, token: str, where: str, key: Optional[str] = None) -> None: + def __init__( + self, + token: str, + where: str, + key: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> None: self.token = token self.where = where self.key = key + self.refresh_token = refresh_token def __call__(self, request: PreparedRequest) -> PreparedRequest: """Perform the actual authentication""" @@ -601,6 +618,13 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: elif self.where == "header": request.headers["Authorization"] = "Bearer {}".format(self.token) + + if self.where == "basic" and self.refresh_token is not None: + auth_str = base64.b64encode( + f"anonymous:{self.refresh_token}".encode() + ).decode() + request.headers["Authorization"] = f"Basic {auth_str}" + logger.debug( re.sub( r"'Bearer [^']+'", diff --git a/eodag/utils/exceptions.py b/eodag/utils/exceptions.py index 6fec02b18..bc2080a46 100644 --- a/eodag/utils/exceptions.py +++ b/eodag/utils/exceptions.py @@ -140,3 +140,7 @@ def __init__( f"Request timeout {timeout_msg} for URL {url}" if url else str(exception) ) super().__init__(message) + + +class DatasetCreationError(EodagError): + """An error indicating that :class:`xarray.Dataset` or :class:`eodag_cube.types.XarrayDict` could not be created""" diff --git a/tests/context.py b/tests/context.py index ea7a2345a..9c2a28bde 100644 --- a/tests/context.py +++ b/tests/context.py @@ -58,6 +58,8 @@ from eodag.plugins.authentication.aws_auth import AwsAuth from eodag.plugins.authentication.header import HeaderAuth from eodag.plugins.authentication.openid_connect import CodeAuthorizedAuth +from eodag.plugins.authentication.header import HTTPHeaderAuth +from eodag.plugins.authentication.qsauth import HttpQueryStringAuth from eodag.plugins.base import PluginTopic from eodag.plugins.crunch.filter_date import FilterDate from eodag.plugins.crunch.filter_latest_tpl_name import FilterLatestByName @@ -136,6 +138,7 @@ UnsupportedProvider, ValidationError, InvalidDataError, + DatasetCreationError, ) from eodag.utils.stac_reader import fetch_stac_items, _TextOpener from tests import TEST_RESOURCES_PATH diff --git a/tests/units/test_auth_plugins.py b/tests/units/test_auth_plugins.py index 55429483b..2ef729be6 100644 --- a/tests/units/test_auth_plugins.py +++ b/tests/units/test_auth_plugins.py @@ -18,7 +18,7 @@ import pickle import unittest -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest import mock import boto3 @@ -2112,12 +2112,18 @@ def get_auth_plugin(self, provider): "jwks_uri": "http://foo.bar/auth/realms/myrealm/protocol/openid-connect/certs", "id_token_signing_alg_values_supported": ["RS256", "HS512"], } - mock_request.return_value.json.side_effect = [oidc_config, oidc_config] + mock_request.return_value.json.return_value = oidc_config auth_plugin = super( TestAuthPluginOIDCAuthorizationCodeFlowAuth, self ).get_auth_plugin(provider) - # reset token info - auth_plugin.token_info = {} + auth_plugin.access_token = "" + auth_plugin.refresh_token = "" + auth_plugin.access_token_expiration = datetime.min.replace( + tzinfo=timezone.utc + ) + auth_plugin.refresh_token_expiration = datetime.min.replace( + tzinfo=timezone.utc + ) return auth_plugin def test_plugins_auth_codeflowauth_validate_credentials(self): @@ -2128,7 +2134,7 @@ def test_plugins_auth_codeflowauth_validate_credentials(self): with self.assertRaises(MisconfiguredError) as context: auth_plugin.validate_config_credentials() self.assertTrue( - '"token_provision" must be one of "qs" or "header"' + '"token_provision" must be one of "qs", "header", or "basic"' in str(context.exception) ) # `token_provision=="qs"` but `token_qs_key` is missing @@ -2364,6 +2370,40 @@ def test_plugins_auth_codeflowauth_authenticate_token_qs_key_ok( self.assertEqual(auth.where, "qs") self.assertEqual(auth.key, auth_plugin.config.token_qs_key) + @mock.patch( + "eodag.plugins.authentication.openid_connect.OIDCRefreshTokenBase.decode_jwt_token", + autospec=True, + ) + @mock.patch( + "eodag.plugins.authentication.openid_connect.OIDCAuthorizationCodeFlowAuth._request_new_token", + autospec=True, + ) + def test_plugins_auth_codeflowauth_authenticate_basic_ok( + self, + mock_request_new_token, + mock_decode, + ): + """OIDCAuthorizationCodeFlowAuth.authenticate must return a basic auth object with the refresh token.""" + auth_plugin = self.get_auth_plugin("provider_ok") + auth_plugin.config.token_provision = "basic" + json_response = { + "access_token": "obtained-access-token", + "expires_in": "3600", + "refresh_expires_in": "7200", + "refresh_token": "obtained-refresh-token", + } + mock_request_new_token.return_value = json_response + mock_decode.return_value = { + "exp": (now_in_utc() + timedelta(seconds=3600)).timestamp() + } + + auth = auth_plugin.authenticate() + + self.assertIsInstance(auth, CodeAuthorizedAuth) + self.assertEqual(auth.token, json_response["access_token"]) + self.assertEqual(auth.where, "basic") + self.assertEqual(auth.refresh_token, json_response["refresh_token"]) + @mock.patch( "eodag.plugins.authentication.openid_connect.OIDCAuthorizationCodeFlowAuth.authenticate_user", autospec=True, diff --git a/tests/units/test_eoproduct.py b/tests/units/test_eoproduct.py index ffa8bbd5c..6d02eeaa2 100644 --- a/tests/units/test_eoproduct.py +++ b/tests/units/test_eoproduct.py @@ -37,9 +37,14 @@ from tests.context import ( DEFAULT_SHAPELY_GEOMETRY, NOT_AVAILABLE, + USER_AGENT, + AwsAuth, + DatasetCreationError, DatasetDriver, Download, EOProduct, + HTTPHeaderAuth, + HttpQueryStringAuth, ProgressCallback, mock, ) @@ -548,6 +553,44 @@ def test_eoproduct_download_http_dynamic_options(self): product_zip_file = "{}.zip".format(product_dir_path) self.assertTrue(os.path.isfile(product_zip_file)) + @mock.patch("eodag.api.product._product.requests.get") + def test_eoproduct_request_asset(self, mock_get): + """EOProduct.request_asset must perform a GET request with storage options headers.""" + product = self._dummy_product() + + product.request_asset("https://example.com/zarr/.zmetadata") + + mock_get.assert_called_once_with( + "https://example.com/zarr/.zmetadata", + headers={}, + stream=True, + ) + + @mock.patch("eodag.api.product._product.requests.get") + def test_eoproduct_request_asset_with_auth_headers(self, mock_get): + """EOProduct.request_asset must forward authentication headers from get_storage_options.""" + product = self._dummy_product() + # Mock downloader and auth + mock_downloader = mock.MagicMock() + mock_auth = mock.MagicMock() + product.register_downloader(mock_downloader, mock_auth) + + # Mock get_storage_options to return auth headers + product.get_storage_options = mock.MagicMock( + return_value={ + "path": "https://example.com/zarr/.zmetadata", + "headers": {"Authorization": "Bearer token123"}, + } + ) + + product.request_asset("https://example.com/zarr/.zmetadata") + + mock_get.assert_called_once_with( + "https://example.com/zarr/.zmetadata", + headers={"Authorization": "Bearer token123"}, + stream=True, + ) + @responses.activate def test_eoproduct_download_progress_bar(self): """eoproduct.download must show a progress bar""" @@ -803,3 +846,186 @@ def test_eoproduct_from_pystac(self): pystac_item = Item.from_dict(product.as_dict()) product_from_pystac = EOProduct.from_pystac(pystac_item) self.assertIsInstance(product_from_pystac, EOProduct) + + def test_get_storage_options_http_headers(self): + """get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + # http headers auth + product.register_downloader( + Download("foo", PluginConfig()), + HTTPHeaderAuth( + "foo", + PluginConfig.from_mapping( + { + "type": "Download", + "credentials": {"apikey": "foo"}, + "headers": {"X-API-Key": "{apikey}"}, + } + ), + ), + ) + self.assertDictEqual( + product.get_storage_options(), + { + "path": self.download_url, + "headers": {"X-API-Key": "foo", **USER_AGENT}, + }, + ) + + def test_get_storage_options_http_no_auth(self): + """get_storage_options should return path when no auth""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + # http headers auth + product.register_downloader( + Download("foo", PluginConfig()), + None, + ) + self.assertDictEqual( + product.get_storage_options(), + { + "path": self.download_url, + }, + ) + + def test_get_storage_options_http_qs(self): + """get_storage_options should be adapted to the provider config""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + # http qs auth + product.register_downloader( + Download("foo", PluginConfig()), + HttpQueryStringAuth( + "foo", + PluginConfig.from_mapping( + { + "type": "Download", + "credentials": {"apikey": "foo"}, + } + ), + ), + ) + self.assertDictEqual( + product.get_storage_options(), + { + "path": f"{self.download_url}?apikey=foo", + "headers": USER_AGENT, + }, + ) + + @mock.patch("eodag.api.product._product.ServiceResource", new=object) + def test_get_storage_options_s3_credentials_endpoint(self): + """get_storage_options should be adapted to the provider config using s3 credentials and endpoint""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + auth_plugin = AwsAuth( + "foo", + PluginConfig.from_mapping( + { + "type": "Authentication", + "s3_endpoint": "http://foo.bar", + "credentials": { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_session_token": "baz", + }, + "requester_pays": True, + } + ), + ) + auth_plugin.s3_session = mock.MagicMock() + auth_plugin.s3_session.get_credentials.return_value = mock.Mock( + access_key="foo", + secret_key="bar", + token="baz", + ) + auth_plugin.authenticate = mock.MagicMock(return_value=object()) + product.register_downloader(Download("foo", PluginConfig()), auth_plugin) + self.assertDictEqual( + product.get_storage_options(), + { + "path": self.download_url, + "key": "foo", + "secret": "bar", + "token": "baz", + "client_kwargs": {"endpoint_url": "http://foo.bar"}, + "requester_pays": True, + }, + ) + + @mock.patch("eodag.api.product._product.ServiceResource", new=object) + def test_get_storage_options_s3_credentials(self): + """get_storage_options should be adapted to the provider config using s3 credentials""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + auth_plugin = AwsAuth( + "foo", + PluginConfig.from_mapping( + { + "type": "Authentication", + "credentials": { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_session_token": "baz", + }, + } + ), + ) + auth_plugin.s3_session = mock.MagicMock() + auth_plugin.s3_session.get_credentials.return_value = mock.Mock( + access_key="foo", + secret_key="bar", + token="baz", + ) + auth_plugin.authenticate = mock.MagicMock(return_value=object()) + product.register_downloader(Download("foo", PluginConfig()), auth_plugin) + self.assertDictEqual( + product.get_storage_options(), + { + "path": self.download_url, + "key": "foo", + "secret": "bar", + "token": "baz", + }, + ) + + @mock.patch("eodag.api.product._product.ServiceResource", new=object) + def test_get_storage_options_s3_anon(self): + """get_storage_options should be adapted to the provider config using anonymous s3 access""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + auth_plugin = AwsAuth( + "foo", + PluginConfig.from_mapping( + {"type": "Authentication", "requester_pays": True} + ), + ) + auth_plugin.s3_session = mock.MagicMock() + auth_plugin.s3_session.get_credentials.return_value = None + auth_plugin.authenticate = mock.MagicMock(return_value=object()) + product.register_downloader(Download("foo", PluginConfig()), auth_plugin) + self.assertDictEqual( + product.get_storage_options(), + { + "path": self.download_url, + "anon": True, + }, + ) + + def test_get_storage_options_error(self): + """get_storage_options should raise when the asset key is missing""" + product = EOProduct( + self.provider, self.eoproduct_props, collection=self.collection + ) + product.downloader = mock.MagicMock() + with self.assertRaises( + DatasetCreationError, msg=f"foo not found in {product} assets" + ): + product.get_storage_options(asset_key="foo")