diff --git a/packages/bigframes/bigframes/core/events.py b/packages/bigframes/bigframes/core/events.py index 0724cc5414bb..6471cce58733 100644 --- a/packages/bigframes/bigframes/core/events.py +++ b/packages/bigframes/bigframes/core/events.py @@ -23,12 +23,15 @@ import google.cloud.bigquery._job_helpers import google.cloud.bigquery.job.query import google.cloud.bigquery.table +from google.cloud.bigquery.job.query import QueryPlanEntry import bigframes.session.executor +_FALLBACK_TO_GLOBAL = "fallback_to_global" + class Subscriber: - def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher): + def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher): # noqa: E501 self._publisher = publisher self._callback = callback self._subscriber_id = uuid.uuid4() @@ -125,15 +128,21 @@ class BigQuerySentEvent(ExecutionRunning): location: Optional[str] = None job_id: Optional[str] = None request_id: Optional[str] = None + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL @classmethod - def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent): + def from_bqclient( + cls, + event: google.cloud.bigquery._job_helpers.QuerySentEvent, + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL, + ): return cls( query=event.query, billing_project=event.billing_project, location=event.location, job_id=event.job_id, request_id=event.request_id, + progress_bar=progress_bar, ) @@ -146,15 +155,21 @@ class BigQueryRetryEvent(ExecutionRunning): location: Optional[str] = None job_id: Optional[str] = None request_id: Optional[str] = None + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL @classmethod - def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent): + def from_bqclient( + cls, + event: google.cloud.bigquery._job_helpers.QueryRetryEvent, + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL, + ): return cls( query=event.query, billing_project=event.billing_project, location=event.location, job_id=event.job_id, request_id=event.request_id, + progress_bar=progress_bar, ) @@ -167,14 +182,17 @@ class BigQueryReceivedEvent(ExecutionRunning): job_id: Optional[str] = None statement_type: Optional[str] = None state: Optional[str] = None - query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]] = None + query_plan: Optional[list[QueryPlanEntry]] = None created: Optional[datetime.datetime] = None started: Optional[datetime.datetime] = None ended: Optional[datetime.datetime] = None + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL @classmethod def from_bqclient( - cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent + cls, + event: google.cloud.bigquery._job_helpers.QueryReceivedEvent, + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL, ): return cls( billing_project=event.billing_project, @@ -186,6 +204,7 @@ def from_bqclient( created=event.created, started=event.started, ended=event.ended, + progress_bar=progress_bar, ) @@ -204,10 +223,13 @@ class BigQueryFinishedEvent(ExecutionRunning): created: Optional[datetime.datetime] = None started: Optional[datetime.datetime] = None ended: Optional[datetime.datetime] = None + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL @classmethod def from_bqclient( - cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent + cls, + event: google.cloud.bigquery._job_helpers.QueryFinishedEvent, + progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL, ): return cls( billing_project=event.billing_project, @@ -221,6 +243,7 @@ def from_bqclient( created=event.created, started=event.started, ended=event.ended, + progress_bar=progress_bar, ) diff --git a/packages/bigframes/bigframes/formatting_helpers.py b/packages/bigframes/bigframes/formatting_helpers.py index cef14d39a3f6..3d4082578f5a 100644 --- a/packages/bigframes/bigframes/formatting_helpers.py +++ b/packages/bigframes/bigframes/formatting_helpers.py @@ -152,7 +152,12 @@ def progress_callback( # This will allow cleanup to continue. return - progress_bar = bigframes._config.options.display.progress_bar + # Prioritize progress_bar set on the event, falling back to thread-local option. + progress_bar = getattr( + event, "progress_bar", bigframes.core.events._FALLBACK_TO_GLOBAL + ) + if progress_bar == bigframes.core.events._FALLBACK_TO_GLOBAL: + progress_bar = bigframes._config.options.display.progress_bar if progress_bar == "auto": progress_bar = "notebook" if in_ipython() else "terminal" diff --git a/packages/bigframes/bigframes/session/_io/bigquery/__init__.py b/packages/bigframes/bigframes/session/_io/bigquery/__init__.py index 780ba55c50db..703cb4704fec 100644 --- a/packages/bigframes/bigframes/session/_io/bigquery/__init__.py +++ b/packages/bigframes/bigframes/session/_io/bigquery/__init__.py @@ -245,18 +245,23 @@ def add_and_trim_labels(job_config, session=None): def create_bq_event_callback(publisher): - def publish_bq_event(event): - if isinstance(event, google.cloud.bigquery._job_helpers.QueryFinishedEvent): - bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QueryReceivedEvent): - bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QueryRetryEvent): - bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QuerySentEvent): - bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(event) - else: - bf_event = bigframes.core.events.BigQueryUnknownEvent(event) + import bigframes._config + + progress_bar = bigframes._config.options.display.progress_bar + event_map = { + google.cloud.bigquery._job_helpers.QueryFinishedEvent: bigframes.core.events.BigQueryFinishedEvent, + google.cloud.bigquery._job_helpers.QueryReceivedEvent: bigframes.core.events.BigQueryReceivedEvent, + google.cloud.bigquery._job_helpers.QueryRetryEvent: bigframes.core.events.BigQueryRetryEvent, + google.cloud.bigquery._job_helpers.QuerySentEvent: bigframes.core.events.BigQuerySentEvent, + } + + def publish_bq_event(event): + bf_event = bigframes.core.events.BigQueryUnknownEvent(event) + for bq_type, bf_type in event_map.items(): + if isinstance(event, bq_type): + bf_event = bf_type.from_bqclient(event, progress_bar=progress_bar) # type: ignore + break publisher.publish(bf_event) return publish_bq_event diff --git a/packages/bigframes/tests/system/small/test_progress_bar.py b/packages/bigframes/tests/system/small/test_progress_bar.py index bc247f6078ce..a179e18332af 100644 --- a/packages/bigframes/tests/system/small/test_progress_bar.py +++ b/packages/bigframes/tests/system/small/test_progress_bar.py @@ -104,6 +104,23 @@ def test_progress_bar_load_jobs( assert_loading_msg_exist(capsys.readouterr().out, pattern="Load") +def test_progress_bar_uniqueness_check(session: bf.Session, capsys): + # Ensure strictly_ordered is True (default) to trigger uniqueness check + assert session._strictly_ordered + + capsys.readouterr() # clear output + + with bf.option_context("display.progress_bar", "terminal"): + # Read a table and specify a non-unique index_col to trigger the check. + # We use a public table to make it a "real" test. + session.read_gbq_table( + "bigquery-public-data.ml_datasets.penguins", + index_col="island", + ) + + assert_loading_msg_exist(capsys.readouterr().out) + + def assert_loading_msg_exist(capstdout: str, pattern=job_load_message_regex): num_loading_msg = 0 lines = capstdout.split("\n") diff --git a/packages/bigframes/tests/unit/test_formatting_helpers.py b/packages/bigframes/tests/unit/test_formatting_helpers.py index ec681b36ab05..67be9398d241 100644 --- a/packages/bigframes/tests/unit/test_formatting_helpers.py +++ b/packages/bigframes/tests/unit/test_formatting_helpers.py @@ -212,3 +212,28 @@ def test_get_job_url(): job_id=job_id, location=location, project_id=project_id ) assert actual_url == expected_url + + +def test_progress_callback_respects_event_progress_bar(): + event = bfevents.BigQuerySentEvent( + query="SELECT * FROM my_table", + progress_bar=None, + ) + + with mock.patch("bigframes._config.options.display.progress_bar", "terminal"): + with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False): + with mock.patch("builtins.print") as mock_print: + formatting_helpers.progress_callback(event) + mock_print.assert_not_called() + + +def test_progress_callback_falls_back_to_global(): + event = bfevents.BigQuerySentEvent( + query="SELECT * FROM my_table", + ) + + with mock.patch("bigframes._config.options.display.progress_bar", "terminal"): + with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False): + with mock.patch("builtins.print") as mock_print: + formatting_helpers.progress_callback(event) + mock_print.assert_called_once()