diff --git a/README.md b/README.md index c614cee..c1ebcb1 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ conda install -c conda-forge dask-bigquery ## Example -`dask-bigquery` assumes that you are already authenticated. +`dask-bigquery` assumes that you are already authenticated. ```python import dask_bigquery diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py index 3147121..5c96f9e 100644 --- a/dask_bigquery/core.py +++ b/dask_bigquery/core.py @@ -3,6 +3,8 @@ from contextlib import contextmanager from functools import partial +import google.auth.transport.requests +import google.oauth2.credentials import pandas as pd import pyarrow from dask.base import tokenize @@ -17,7 +19,7 @@ @contextmanager -def bigquery_clients(project_id): +def bigquery_clients(project_id, credentials=None): """This context manager is a temporary solution until there is an upstream solution to handle this. See googleapis/google-cloud-python#9457 @@ -30,7 +32,9 @@ def bigquery_clients(project_id): user_agent=f"dask-bigquery/{dask_bigquery.__version__}" ) - with bigquery.Client(project_id, client_info=bq_client_info) as bq_client: + with bigquery.Client( + project_id, credentials=credentials, client_info=bq_client_info + ) as bq_client: bq_storage_client = bigquery_storage.BigQueryReadClient( credentials=bq_client._credentials, client_info=bqstorage_client_info, @@ -54,6 +58,7 @@ def bigquery_read( make_create_read_session_request: callable, project_id: str, read_kwargs: dict, + cred_token: str, stream_name: str, ) -> pd.DataFrame: """Read a single batch of rows via BQ Storage API, in Arrow binary format. @@ -70,8 +75,16 @@ def bigquery_read( BigQuery Storage API Stream "name" NOTE: Please set if reading from Storage API without any `row_restriction`. https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream + cred_token: str + google_auth bearer token """ - with bigquery_clients(project_id) as (_, bqs_client): + + if cred_token: + credentials = google.oauth2.credentials.Credentials(cred_token) + else: + credentials = None + + with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client): session = bqs_client.create_read_session(make_create_read_session_request()) schema = pyarrow.ipc.read_schema( pyarrow.py_buffer(session.arrow_schema.serialized_schema) @@ -91,6 +104,7 @@ def read_gbq( row_filter: str = "", columns: list[str] = None, read_kwargs: dict = None, + fwd_creds: bool = False, ): """Read table as dask dataframe using BigQuery Storage API via Arrow format. Partitions will be approximately balanced according to BigQuery stream allocation logic. @@ -109,13 +123,31 @@ def read_gbq( list of columns to load from the table read_kwargs: dict kwargs to pass to read_rows() + fwd_creds: bool + Set to True if user desires to forward credentials to the workers. Default to False. Returns ------- Dask DataFrame """ read_kwargs = read_kwargs or {} - with bigquery_clients(project_id) as (bq_client, bqs_client): + + if fwd_creds: + credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/bigquery.readonly"] + ) + + auth_req = google.auth.transport.requests.Request() + credentials.refresh(auth_req) + cred_token = credentials.token + else: + credentials = None + cred_token = None + + with bigquery_clients(project_id, credentials=credentials) as ( + bq_client, + bqs_client, + ): table_ref = bq_client.get_table(f"{dataset_id}.{table_id}") if table_ref.table_type == "VIEW": raise TypeError("Table type VIEW not supported") @@ -161,6 +193,7 @@ def make_create_read_session_request(row_filter=""): make_create_read_session_request, project_id, read_kwargs, + cred_token, ), label=label, ) diff --git a/dask_bigquery/tests/test_core.py b/dask_bigquery/tests/test_core.py index 82b922b..16dcfbe 100644 --- a/dask_bigquery/tests/test_core.py +++ b/dask_bigquery/tests/test_core.py @@ -51,22 +51,30 @@ def dataset(df): ) -def test_read_gbq(df, dataset, client): +@pytest.mark.parametrize("fwd_creds", [False, True]) +def test_read_gbq(df, dataset, fwd_creds, client): project_id, dataset_id, table_id = dataset - ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id) + ddf = read_gbq( + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + fwd_creds=fwd_creds, + ) assert list(ddf.columns) == ["name", "number", "idx"] assert ddf.npartitions == 2 assert assert_eq(ddf.set_index("idx"), df.set_index("idx")) -def test_read_row_filter(df, dataset, client): +@pytest.mark.parametrize("fwd_creds", [False, True]) +def test_read_row_filter(df, dataset, fwd_creds, client): project_id, dataset_id, table_id = dataset ddf = read_gbq( project_id=project_id, dataset_id=dataset_id, table_id=table_id, row_filter="idx < 5", + fwd_creds=fwd_creds, ) assert list(ddf.columns) == ["name", "number", "idx"] @@ -74,20 +82,23 @@ def test_read_row_filter(df, dataset, client): assert assert_eq(ddf.set_index("idx").loc[:4], df.set_index("idx").loc[:4]) -def test_read_kwargs(dataset, client): +@pytest.mark.parametrize("fwd_creds", [False, True]) +def test_read_kwargs(dataset, fwd_creds, client): project_id, dataset_id, table_id = dataset ddf = read_gbq( project_id=project_id, dataset_id=dataset_id, table_id=table_id, read_kwargs={"timeout": 1e-12}, + fwd_creds=fwd_creds, ) with pytest.raises(Exception, match="Deadline Exceeded"): ddf.compute() -def test_read_columns(df, dataset, client): +@pytest.mark.parametrize("fwd_creds", [False, True]) +def test_read_columns(df, dataset, fwd_creds, client): project_id, dataset_id, table_id = dataset assert df.shape[1] > 1, "Test data should have multiple columns" @@ -97,5 +108,27 @@ def test_read_columns(df, dataset, client): dataset_id=dataset_id, table_id=table_id, columns=columns, + fwd_creds=fwd_creds, ) assert list(ddf.columns) == columns + + +@pytest.mark.parametrize("fwd_creds", [False, True]) +def test_read_gbq_no_creds_fail(dataset, fwd_creds, monkeypatch, client): + """This test is to check that if we do not have credentials + we can not authenticate. + """ + project_id, dataset_id, table_id = dataset + + def mock_auth(scopes=["https://www.googleapis.com/auth/bigquery.readonly"]): + raise google.auth.exceptions.DefaultCredentialsError() + + monkeypatch.setattr(google.auth, "default", mock_auth) + + with pytest.raises(google.auth.exceptions.DefaultCredentialsError): + read_gbq( + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + fwd_creds=fwd_creds, + )