Skip to content
35 changes: 29 additions & 6 deletions packages/bigframes/bigframes/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
import google.cloud.bigquery._job_helpers
import google.cloud.bigquery.job.query
import google.cloud.bigquery.table
from google.cloud.bigquery.job.query import QueryPlanEntry

import bigframes.session.executor

_FALLBACK_TO_GLOBAL = "fallback_to_global"


class Subscriber:
def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher):
def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher): # noqa: E501
self._publisher = publisher
self._callback = callback
self._subscriber_id = uuid.uuid4()
Expand Down Expand Up @@ -125,15 +128,21 @@ class BigQuerySentEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QuerySentEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -146,15 +155,21 @@ class BigQueryRetryEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QueryRetryEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -167,14 +182,17 @@ class BigQueryReceivedEvent(ExecutionRunning):
job_id: Optional[str] = None
statement_type: Optional[str] = None
state: Optional[str] = None
query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]] = None
query_plan: Optional[list[QueryPlanEntry]] = None
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryReceivedEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
billing_project=event.billing_project,
Expand All @@ -186,6 +204,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand All @@ -204,10 +223,13 @@ class BigQueryFinishedEvent(ExecutionRunning):
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryFinishedEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
billing_project=event.billing_project,
Expand All @@ -221,6 +243,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand Down
7 changes: 6 additions & 1 deletion packages/bigframes/bigframes/formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def progress_callback(
# This will allow cleanup to continue.
return

progress_bar = bigframes._config.options.display.progress_bar
# Prioritize progress_bar set on the event, falling back to thread-local option.
progress_bar = getattr(
event, "progress_bar", bigframes.core.events._FALLBACK_TO_GLOBAL
)
if progress_bar == bigframes.core.events._FALLBACK_TO_GLOBAL:
progress_bar = bigframes._config.options.display.progress_bar

if progress_bar == "auto":
progress_bar = "notebook" if in_ipython() else "terminal"
Expand Down
27 changes: 16 additions & 11 deletions packages/bigframes/bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,23 @@ def add_and_trim_labels(job_config, session=None):


def create_bq_event_callback(publisher):
def publish_bq_event(event):
if isinstance(event, google.cloud.bigquery._job_helpers.QueryFinishedEvent):
bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryReceivedEvent):
bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryRetryEvent):
bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QuerySentEvent):
bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(event)
else:
bf_event = bigframes.core.events.BigQueryUnknownEvent(event)
import bigframes._config

progress_bar = bigframes._config.options.display.progress_bar

event_map = {
google.cloud.bigquery._job_helpers.QueryFinishedEvent: bigframes.core.events.BigQueryFinishedEvent,
google.cloud.bigquery._job_helpers.QueryReceivedEvent: bigframes.core.events.BigQueryReceivedEvent,
google.cloud.bigquery._job_helpers.QueryRetryEvent: bigframes.core.events.BigQueryRetryEvent,
google.cloud.bigquery._job_helpers.QuerySentEvent: bigframes.core.events.BigQuerySentEvent,
}

def publish_bq_event(event):
bf_event = bigframes.core.events.BigQueryUnknownEvent(event)
for bq_type, bf_type in event_map.items():
if isinstance(event, bq_type):
bf_event = bf_type.from_bqclient(event, progress_bar=progress_bar) # type: ignore
break
publisher.publish(bf_event)

return publish_bq_event
Expand Down
17 changes: 17 additions & 0 deletions packages/bigframes/tests/system/small/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def test_progress_bar_load_jobs(
assert_loading_msg_exist(capsys.readouterr().out, pattern="Load")


def test_progress_bar_uniqueness_check(session: bf.Session, capsys):
# Ensure strictly_ordered is True (default) to trigger uniqueness check
assert session._strictly_ordered

capsys.readouterr() # clear output

with bf.option_context("display.progress_bar", "terminal"):
# Read a table and specify a non-unique index_col to trigger the check.
# We use a public table to make it a "real" test.
session.read_gbq_table(
"bigquery-public-data.ml_datasets.penguins",
index_col="island",
)

assert_loading_msg_exist(capsys.readouterr().out)


def assert_loading_msg_exist(capstdout: str, pattern=job_load_message_regex):
num_loading_msg = 0
lines = capstdout.split("\n")
Expand Down
25 changes: 25 additions & 0 deletions packages/bigframes/tests/unit/test_formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,28 @@ def test_get_job_url():
job_id=job_id, location=location, project_id=project_id
)
assert actual_url == expected_url


def test_progress_callback_respects_event_progress_bar():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
progress_bar=None,
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_not_called()


def test_progress_callback_falls_back_to_global():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_called_once()
Loading