diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index a4bf886472..2f2090f876 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -419,7 +419,9 @@ def get_cached_mlflow_url(): raise TimeoutExceededError(resource_type="TrainingJob", status=status) else: - print(f"\nTrainingJob Name: {training_job.training_job_name}") + print(f"\nTraining job started: {training_job.training_job_name}", flush=True) + print(f"Log group: /aws/sagemaker/TrainingJobs", flush=True) + print(f"Log stream prefix: {training_job.training_job_name}", flush=True) iteration = 0 while True: iteration += 1 @@ -432,8 +434,8 @@ def get_cached_mlflow_url(): # Show transitions with checkmarks if training_job.secondary_status_transitions: - print("\n--------------------------------------\n") - print("Status Transitions:") + print("\n--------------------------------------\n", flush=True) + print("Status Transitions:", flush=True) for trans in training_job.secondary_status_transitions: duration, check = _calculate_transition_duration(trans) @@ -452,28 +454,39 @@ def get_cached_mlflow_url(): if progress_pct is not None: step_msg += f" - {progress_pct:.1f}%{progress_text.replace(chr(10), ', ')}" - print(step_msg) - print(f"\nStatus: {status} - {secondary_status} (Elapsed: {elapsed:.1f}s)") + print(step_msg, flush=True) + print(f"\nStatus: {status} - {secondary_status} (Elapsed: {elapsed:.1f}s)", flush=True) if status in ["Completed", "Failed", "Stopped"]: if status == "Completed": if mlflow_url: - print(f"\n✓ Training completed! View metrics in MLflow: {mlflow_url}") + print(f"\n✓ Training completed! View metrics in MLflow: {mlflow_url}", flush=True) try: steps_per_epoch = training_job.progress_info.total_step_count_per_epoch loss_metrics_by_epoch = metrics_util._get_loss_metrics_by_epoch(run_name=mlflow_run_name, steps_per_epoch=steps_per_epoch) if loss_metrics_by_epoch: - print("\n------------ Loss Metrics by Epoch ------------") + print("\n------------ Loss Metrics by Epoch ------------", flush=True) for epoch, metrics in list(loss_metrics_by_epoch.items())[:-1]: - print(f"Epoch {epoch}: {metrics}") - print("----------------------------------------------") + print(f"Epoch {epoch}: {metrics}", flush=True) + print("----------------------------------------------", flush=True) except Exception: pass + if status == "Failed": + failure_reason = training_job.failure_reason + if failure_reason and not _is_unassigned_attribute(failure_reason): + print(f"\nFailure reason: {failure_reason}", flush=True) + print(f"\nLog group: /aws/sagemaker/TrainingJobs", flush=True) + print(f"Log stream prefix: {training_job.training_job_name}", flush=True) + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url + cw_url = get_cloudwatch_logs_url(training_job.training_job_arn) + if cw_url: + print(f"CloudWatch Logs: {cw_url}", flush=True) + raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) return failure_reason = training_job.failure_reason - if status == "Failed" or (failure_reason and not _is_unassigned_attribute(failure_reason)): + if failure_reason and not _is_unassigned_attribute(failure_reason): raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) if timeout and elapsed >= timeout: diff --git a/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait_observability.py b/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait_observability.py new file mode 100644 index 0000000000..789ac955d7 --- /dev/null +++ b/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait_observability.py @@ -0,0 +1,78 @@ +"""Tests for training job observability prints in script/terminal mode.""" +import time +from unittest.mock import MagicMock, patch + +import pytest + +from sagemaker.train.common_utils.trainer_wait import wait, _is_unassigned_attribute + + +class MockUnassigned: + pass + + +class MockTrainingJob: + def __init__(self, status="Completed", failure_reason=None): + self.training_job_name = "test-sft-job-2026" + self.training_job_arn = "arn:aws:sagemaker:us-west-2:123456789:training-job/test-sft-job-2026" + self.training_job_status = status + self.secondary_status = "Training" + self.secondary_status_transitions = [] + self.progress_info = MockUnassigned() + self.failure_reason = failure_reason + self.mlflow_config = MockUnassigned() + self._call_count = 0 + + def refresh(self): + self._call_count += 1 + + +class TestTrainingObservabilityAtStart: + """Test that job info is printed at start in terminal mode.""" + + @patch("sagemaker.train.common_utils.trainer_wait._is_jupyter_environment", return_value=False) + @patch("sagemaker.train.common_utils.trainer_wait._setup_mlflow_integration", return_value=(None, None, None)) + def test_prints_job_info_at_start(self, mock_mlflow, mock_jupyter, capsys): + job = MockTrainingJob(status="Completed") + wait(job, poll=0, timeout=1) + captured = capsys.readouterr() + assert "Training job started: test-sft-job-2026" in captured.out + assert "Log group: /aws/sagemaker/TrainingJobs" in captured.out + assert "Log stream prefix: test-sft-job-2026" in captured.out + + +class TestTrainingObservabilityOnFailure: + """Test that debug info is printed on failure.""" + + @patch("sagemaker.train.common_utils.trainer_wait._is_jupyter_environment", return_value=False) + @patch("sagemaker.train.common_utils.trainer_wait._setup_mlflow_integration", return_value=(None, None, None)) + def test_prints_debug_info_on_failure(self, mock_mlflow, mock_jupyter, capsys): + job = MockTrainingJob(status="Failed", failure_reason="OOM error") + with pytest.raises(Exception): + wait(job, poll=0, timeout=1) + captured = capsys.readouterr() + assert "Failure reason: OOM error" in captured.out + assert "Log group: /aws/sagemaker/TrainingJobs" in captured.out + assert "Log stream prefix: test-sft-job-2026" in captured.out + assert "CloudWatch Logs:" in captured.out + + @patch("sagemaker.train.common_utils.trainer_wait._is_jupyter_environment", return_value=False) + @patch("sagemaker.train.common_utils.trainer_wait._setup_mlflow_integration", return_value=(None, None, None)) + def test_prints_cloudwatch_url_on_failure(self, mock_mlflow, mock_jupyter, capsys): + job = MockTrainingJob(status="Failed", failure_reason="ClientError") + with pytest.raises(Exception): + wait(job, poll=0, timeout=1) + captured = capsys.readouterr() + assert "us-west-2.console.aws.amazon.com/cloudwatch" in captured.out + + +class TestTrainingObservabilityOnSuccess: + """Test that MLflow link is printed on success (existing behavior preserved).""" + + @patch("sagemaker.train.common_utils.trainer_wait._is_jupyter_environment", return_value=False) + @patch("sagemaker.train.common_utils.trainer_wait._setup_mlflow_integration", return_value=("https://mlflow.example.com", None, None)) + def test_prints_mlflow_on_success(self, mock_mlflow, mock_jupyter, capsys): + job = MockTrainingJob(status="Completed") + wait(job, poll=0, timeout=1) + captured = capsys.readouterr() + assert "mlflow.example.com" in captured.out