Skip to content

Fix misleading data_time / batch_time / data_wait metrics on CUDA (async bleed-through)#749

Open
DLemming wants to merge 6 commits into
lightly-ai:mainfrom
DLemming:dlemming-fix-cuda-async-data-time
Open

Fix misleading data_time / batch_time / data_wait metrics on CUDA (async bleed-through)#749
DLemming wants to merge 6 commits into
lightly-ai:mainfrom
DLemming:dlemming-fix-cuda-async-data-time

Conversation

@DLemming

@DLemming DLemming commented May 26, 2026

Copy link
Copy Markdown

What has changed and why?

On CUDA, on_train_batch_end fires before the GPU finishes the backward pass and optimizer step. The previous implementation recorded batch_end_time at that point, so the remaining async GPU work bled into the next step's data_time, rendering all data/batch-time realted metrics misleading / wildly inaccurate.

The result: data_wait reported ~60% when the true value was ~10%. This is actively misleading — users would chase data-loading bottlenecks that don't exist while the GPU was fully utilized the entire time.

_callbacks/tqdm_progress_bar.py:

  • Read profiling/data_time and profiling/batch_time from trainer.callback_metrics

_methods/method.py

  • Separate branches for CPU and CUDA
  • For CUDA, register non-blocking torch.cuda.Event at on_train_batch_start / on_train_batch_end, measuring real gpu time
  • data_time = wall_gap (start → start) − GPU duration (from events)
  • CPU-only training falls back to the previous time.perf_counter() approach, which remains accurate when compute is synchronous.

How has it been tested?

Changes have been tested on cuda running a SimCLR-Pretraining (bs=1232, res=224px) and a DINOv2-Pretraining (bs=128, global_crop_res=196px). data_wait dropped from ~60% to ~10% for SimCLR, and increased from ~0.8% to ~2.5% for DINOv2.

Did you update CHANGELOG.md?

  • Yes
  • Not needed (internal change)

Did you update the documentation?

  • Yes
  • Not needed (internal change without effects for user)

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant

CLAassistant commented May 26, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

@mrpositron

Copy link
Copy Markdown
Contributor

/review

@mrpositron mrpositron left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I left a few comments.

if trainer.strategy.root_device.type == "cuda":
# Record end event — queried at next step start
self._step_end_event = torch.cuda.Event(enable_timing=True) # type: ignore[no-untyped-call]
self._step_end_event.record() # type: ignore[no-untyped-call]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

record() only enqueues the marker; elapsed_time() needs the GPU to have actually reached _step_end_event. On a compute-bound run the CPU can reach the next on_train_batch_start while the GPU is still finishing the previous step, then elapsed_time raises RuntimeError and the run crashes.

self.batch_time: float | None = None
# CUDA-only: events bracketing the GPU step.
self._step_start_event: torch.cuda.Event | None = None
self._step_end_event: torch.cuda.Event | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Event-bracketing timing logic is now duplicated in both tqdm_progress_bar.py and method.py, each with its own _step_start_event / _step_end_event state. Could the progress bar instead read the metrics the Method already logs (profiling/data_time / profiling/batch_time) rather than re-timing the GPU independently?

@DLemming

DLemming commented Jun 3, 2026

Copy link
Copy Markdown
Author

Thanks for the feedback.

Implemented your suggestions. The progress bar now reads profiling/data_time / profiling/batch_time from trainer.callback_metrics instead of re-timing the GPU, so the event-bracketing logic lives only in Method. I also guarded the elapsed_time crash possibility you mentioned with a non-blocking _end_event.query().

One honest limitation worth noting: When the GPU hasn't reached the end event yet we skip that step's update rather than block, resulting in batch_time/data_time going stale, and metrics being slightly biased toward less compute-bound steps. I think that's an acceptable trade-off vs. the old behavior, which reported flat-out wrong numbers (~60% data_wait instad of ~10%).

In practice I haven't been able to trigger the stale path once even on clearly compute-bound runs (e.g. DINOv2), so something in the loop seems to synchronize often enough that the end event is usually ready by the next step start anyway.

Happy to look more into that edge case later, but for now this seems like a solid improvement over the currently available version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants