diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..7a731e4
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,19 @@
+[report]
+exclude_lines =
+ # Skip any pass lines such as may be used for @abstractmethod
+ pass
+
+ # Have to re-enable the standard pragma
+ pragma: no cover
+
+ # Don't complain about missing debug-only code:
+ def __repr__
+ if self\.debug
+
+ # Don't complain if tests don't hit defensive assertion code:
+ raise AssertionError
+ raise NotImplementedError
+
+ # Don't complain if non-runnable code isn't run:
+ if 0:
+ if __name__ == .__main__.:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5a3c927..e34e1fa 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,9 @@
# Change Log
All notable changes to this project will be documented in this file.
+## 2.2.1 - 2026-05-31
+ ### Runner
+ - added option to use day="last" for monthly scheduling
## 2.2.0 - 2026-03
### Runner
diff --git a/README.md b/README.md
index d78f887..0498fb3 100644
--- a/README.md
+++ b/README.md
@@ -50,11 +50,12 @@ Runner uses a schedule-based approach:
#### Data Flow
For each pipeline execution:
-1. **Dependency verification**: Check that all required input tables have data within specified time intervals
-2. **Transformation execution**: Run your transformation to produce a Spark DataFrame
-3. **Automatic enrichment**: Runner adds `INFORMATION_DATE` (the run date) and `VERSION` (package version) columns
-4. **Partitioned write**: Data is written to Databricks with partitioning configuration
-5. **Reporting**: Optional email notifications on failure and run information stored to tracking table
+1. **Previous completion check**: Runner checks if the target table partition for the run date already exists (unless `rerun` is enabled)
+2. **Dependency verification**: Check that all required input tables have data within specified time intervals
+3. **Transformation execution**: Run your transformation to produce a Spark DataFrame
+4. **Automatic enrichment**: Runner adds `INFORMATION_DATE` (the run date) and `VERSION` (package version) columns
+5. **Partitioned write**: Data is written to Databricks with partitioning configuration
+6. **Reporting**: Optional email notifications on failure and run information stored to tracking table
#### Dependency Tracking
@@ -66,13 +67,13 @@ Runner's dependency tracking ensures that all required input data is available b
* **Missing data handling**: If required data is missing, Runner raises an error for that specific pipeline/date but continues executing other pipelines and dates in the queue
**Example:** If you're running a pipeline on 2024-01-15 with a dependency that has a 7-day interval:
-* Runner checks if the dependency table has data for 2024-01-08 (15 days - 7 days)
+* Runner checks if the dependency table has data between 2024-01-08 and 2024-01-15 (15 days - 7 days)
* If the dependency has `filters: {VERSION: "v2"}`, it specifically checks for data where VERSION='v2'
* If data exists, the pipeline proceeds; otherwise, an error is raised for this specific execution, but other scheduled runs continue
-### Transformation
+### Job/Transformation
For the details on the interface see the [implementation](rialto/runner/transformation.py)
-Inside the transformation you have access to a [TableReader](#common), date of running, and if provided to Runner, a live spark session and [metadata manager](#metadata).
+Inside the job you have access to a [TableReader](#common), date of running, and if provided to Runner, a live spark session and [metadata manager](#metadata).
You can either implement your jobs directly via extending the Transformation class, or by using the [jobs](#jobs) abstraction.
### Runner
@@ -163,7 +164,7 @@ pipelines: # a list of pipelines to run
python_class: Pipeline2Class
schedule:
frequency: monthly
- day: 6
+ day: 6 # or 'latest' for the last day of the month, otherwise avoid using days higher than 28 to ensure all months are covered
info_date_shift: # can be combined as a list
- units: "days"
value: 5
@@ -434,7 +435,7 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo
from pyspark.sql import DataFrame
from rialto.common import TableReader
from rialto.jobs import config_parser, job, datasource
-from rialto.runner.config_loader import PipelineConfig
+from rialto.runner.services.config_loader import PipelineConfig
from pydantic import BaseModel
diff --git a/docs/source/conf.py b/docs/source/conf.py
index b7592da..e4273bf 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -26,9 +26,9 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "rialto"
-copyright = "2022, Marek Dobransky"
+copyright = "2022-2026, Marek Dobransky"
author = "Marek Dobransky"
-release = "2.2.0"
+release = "2.2.1"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
new file mode 100644
index 0000000..b5c7581
--- /dev/null
+++ b/docs/source/modules.rst
@@ -0,0 +1,7 @@
+rialto
+======
+
+.. toctree::
+ :maxdepth: 4
+
+ rialto
diff --git a/docs/source/rialto.common.rst b/docs/source/rialto.common.rst
index 57e7f9e..6ef2b4c 100644
--- a/docs/source/rialto.common.rst
+++ b/docs/source/rialto.common.rst
@@ -5,33 +5,33 @@ Submodules
----------
rialto.common.env\_yaml module
-----------------------------------
+------------------------------
.. automodule:: rialto.common.env_yaml
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.common.table\_reader module
----------------------------------
.. automodule:: rialto.common.table_reader
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.common.utils module
--------------------------
.. automodule:: rialto.common.utils
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.common
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.jobs.rst b/docs/source/rialto.jobs.rst
index 7af576a..924cb39 100644
--- a/docs/source/rialto.jobs.rst
+++ b/docs/source/rialto.jobs.rst
@@ -5,50 +5,49 @@ Submodules
----------
rialto.jobs.decorators module
-----------------------------------------
+-----------------------------
.. automodule:: rialto.jobs.decorators
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.jobs.job\_base module
----------------------------------------
+----------------------------
.. automodule:: rialto.jobs.job_base
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
-rialto.jobs.module\register module
----------------------------------------
+rialto.jobs.module\_register module
+-----------------------------------
.. automodule:: rialto.jobs.module_register
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.jobs.resolver module
---------------------------------------
+---------------------------
.. automodule:: rialto.jobs.resolver
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.jobs.test\_utils module
------------------------------------------
+------------------------------
.. automodule:: rialto.jobs.test_utils
:members:
- :undoc-members:
:show-inheritance:
-
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.jobs
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.loader.rst b/docs/source/rialto.loader.rst
index 9f258eb..b18b709 100644
--- a/docs/source/rialto.loader.rst
+++ b/docs/source/rialto.loader.rst
@@ -9,30 +9,29 @@ rialto.loader.config\_loader module
.. automodule:: rialto.loader.config_loader
:members:
- :undoc-members:
:show-inheritance:
-
+ :undoc-members:
rialto.loader.interfaces module
-------------------------------
.. automodule:: rialto.loader.interfaces
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.loader.pyspark\_feature\_loader module
---------------------------------------------
.. automodule:: rialto.loader.pyspark_feature_loader
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.loader
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.maker.rst b/docs/source/rialto.maker.rst
index 88a3de7..d3352c3 100644
--- a/docs/source/rialto.maker.rst
+++ b/docs/source/rialto.maker.rst
@@ -9,37 +9,37 @@ rialto.maker.containers module
.. automodule:: rialto.maker.containers
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.maker.feature\_maker module
----------------------------------
.. automodule:: rialto.maker.feature_maker
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.maker.utils module
-------------------------
.. automodule:: rialto.maker.utils
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.maker.wrappers module
----------------------------
.. automodule:: rialto.maker.wrappers
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.maker
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.metadata.data_classes.rst b/docs/source/rialto.metadata.data_classes.rst
index 44353ff..085ae2a 100644
--- a/docs/source/rialto.metadata.data_classes.rst
+++ b/docs/source/rialto.metadata.data_classes.rst
@@ -9,21 +9,21 @@ rialto.metadata.data\_classes.feature\_metadata module
.. automodule:: rialto.metadata.data_classes.feature_metadata
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.metadata.data\_classes.group\_metadata module
----------------------------------------------------
.. automodule:: rialto.metadata.data_classes.group_metadata
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.metadata.data_classes
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.metadata.rst b/docs/source/rialto.metadata.rst
index 08933bf..076e49d 100644
--- a/docs/source/rialto.metadata.rst
+++ b/docs/source/rialto.metadata.rst
@@ -17,29 +17,29 @@ rialto.metadata.enums module
.. automodule:: rialto.metadata.enums
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.metadata.metadata\_manager module
----------------------------------------
.. automodule:: rialto.metadata.metadata_manager
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.metadata.utils module
----------------------------
.. automodule:: rialto.metadata.utils
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.metadata
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.rst b/docs/source/rialto.rst
new file mode 100644
index 0000000..f55ad4c
--- /dev/null
+++ b/docs/source/rialto.rst
@@ -0,0 +1,23 @@
+rialto package
+==============
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ rialto.common
+ rialto.jobs
+ rialto.loader
+ rialto.maker
+ rialto.metadata
+ rialto.runner
+
+Module contents
+---------------
+
+.. automodule:: rialto
+ :members:
+ :show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.runner.reporting.rst b/docs/source/rialto.runner.reporting.rst
index 7ef8db5..543934f 100644
--- a/docs/source/rialto.runner.reporting.rst
+++ b/docs/source/rialto.runner.reporting.rst
@@ -1,5 +1,5 @@
rialto.runner.reporting package
-=====================================
+===============================
Submodules
----------
@@ -9,37 +9,37 @@ rialto.runner.reporting.bookkeeper module
.. automodule:: rialto.runner.reporting.bookkeeper
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.runner.reporting.mailer module
-------------------------------------
.. automodule:: rialto.runner.reporting.mailer
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.runner.reporting.record module
-------------------------------------
.. automodule:: rialto.runner.reporting.record
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.runner.reporting.tracker module
--------------------------------------
.. automodule:: rialto.runner.reporting.tracker
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.runner.reporting
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.runner.rst b/docs/source/rialto.runner.rst
index c4ec20a..2666ec6 100644
--- a/docs/source/rialto.runner.rst
+++ b/docs/source/rialto.runner.rst
@@ -8,70 +8,55 @@ Subpackages
:maxdepth: 4
rialto.runner.reporting
+ rialto.runner.services
Submodules
----------
-rialto.runner.config\_loader module
------------------------------------
+rialto.runner.engine module
+---------------------------
-.. automodule:: rialto.runner.config_loader
+.. automodule:: rialto.runner.engine
:members:
- :undoc-members:
:show-inheritance:
-
-rialto.runner.config\_overrides module
---------------------------------------
-
-.. automodule:: rialto.runner.config_overrides
- :members:
:undoc-members:
- :show-inheritance:
-
-rialto.runner.date\_manager module
-----------------------------------
-
-.. automodule:: rialto.runner.date_manager
- :members:
- :undoc-members:
- :show-inheritance:
rialto.runner.runner module
---------------------------
.. automodule:: rialto.runner.runner
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
-rialto.runner.table module
---------------------------
+rialto.runner.runner\_services module
+-------------------------------------
-.. automodule:: rialto.runner.table
+.. automodule:: rialto.runner.runner_services
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.runner.transformation module
-----------------------------------
.. automodule:: rialto.runner.transformation
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
rialto.runner.utils module
------------------------------------
+--------------------------
.. automodule:: rialto.runner.utils
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
Module contents
---------------
.. automodule:: rialto.runner
:members:
- :undoc-members:
:show-inheritance:
+ :undoc-members:
diff --git a/docs/source/rialto.runner.services.rst b/docs/source/rialto.runner.services.rst
new file mode 100644
index 0000000..158db5f
--- /dev/null
+++ b/docs/source/rialto.runner.services.rst
@@ -0,0 +1,93 @@
+rialto.runner.services package
+==============================
+
+Submodules
+----------
+
+rialto.runner.services.config\_loader module
+--------------------------------------------
+
+.. automodule:: rialto.runner.services.config_loader
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.config\_overrides module
+-----------------------------------------------
+
+.. automodule:: rialto.runner.services.config_overrides
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.data\_checker module
+-------------------------------------------
+
+.. automodule:: rialto.runner.services.data_checker
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.date\_manager module
+-------------------------------------------
+
+.. automodule:: rialto.runner.services.date_manager
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.executor module
+--------------------------------------
+
+.. automodule:: rialto.runner.services.executor
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.result\_mapper module
+--------------------------------------------
+
+.. automodule:: rialto.runner.services.result_mapper
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.table module
+-----------------------------------
+
+.. automodule:: rialto.runner.services.table
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.task\_registry module
+--------------------------------------------
+
+.. automodule:: rialto.runner.services.task_registry
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.task\_status\_checker module
+---------------------------------------------------
+
+.. automodule:: rialto.runner.services.task_status_checker
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+rialto.runner.services.writer module
+------------------------------------
+
+.. automodule:: rialto.runner.services.writer
+ :members:
+ :show-inheritance:
+ :undoc-members:
+
+Module contents
+---------------
+
+.. automodule:: rialto.runner.services
+ :members:
+ :show-inheritance:
+ :undoc-members:
diff --git a/pyproject.toml b/pyproject.toml
index 71e9217..9d4404a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "rialto"
-version = "2.2.0"
+version = "2.2.1"
description = "Rialto is a framework for building and deploying machine learning features in a scalable and reusable way. It provides a set of tools that make it easy to define and deploy features and models, and it provides a way to orchestrate the execution of these features and models."
authors = [
{ name = "Marek Dobransky", email = "marekdobr@gmail.com" },
diff --git a/rialto/__init__.py b/rialto/__init__.py
index 79c3773..94ab807 100644
--- a/rialto/__init__.py
+++ b/rialto/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/common/__init__.py b/rialto/common/__init__.py
index 1bd5055..cf821c6 100644
--- a/rialto/common/__init__.py
+++ b/rialto/common/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py
index 228d59b..6165e68 100644
--- a/rialto/common/table_reader.py
+++ b/rialto/common/table_reader.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -71,6 +71,16 @@ def get_table(
"""
raise NotImplementedError
+ @abc.abstractmethod
+ def table_exists(self, table: str) -> bool:
+ """
+ Check table exists in storage
+
+ :param table: full table path
+ :return: bool
+ """
+ raise NotImplementedError
+
class TableReader(DataReader):
"""An implementation of data reader for databricks tables"""
@@ -165,3 +175,12 @@ def get_table(
if uppercase_columns:
df = self._uppercase_column_names(df)
return df
+
+ def table_exists(self, table: str) -> bool:
+ """
+ Check table exists in spark catalog
+
+ :param table: full table path
+ :return: bool
+ """
+ return self.spark.catalog.tableExists(table)
diff --git a/rialto/common/utils.py b/rialto/common/utils.py
index 541e977..7bbd00f 100644
--- a/rialto/common/utils.py
+++ b/rialto/common/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py
index 46eb756..d04e954 100644
--- a/rialto/jobs/__init__.py
+++ b/rialto/jobs/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/jobs/decorators.py b/rialto/jobs/decorators.py
index 8affcbd..b0ff1cb 100644
--- a/rialto/jobs/decorators.py
+++ b/rialto/jobs/decorators.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/jobs/job_base.py b/rialto/jobs/job_base.py
index eae0854..ccd5b41 100644
--- a/rialto/jobs/job_base.py
+++ b/rialto/jobs/job_base.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,7 +28,7 @@
from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
from rialto.runner import Transformation
-from rialto.runner.config_loader import PipelineConfig
+from rialto.runner.services.config_loader import PipelineConfig
class JobMetadata(BaseModel):
diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py
index 49021ff..5d32471 100644
--- a/rialto/jobs/module_register.py
+++ b/rialto/jobs/module_register.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/jobs/resolver.py b/rialto/jobs/resolver.py
index 34b08e8..f9dfe21 100644
--- a/rialto/jobs/resolver.py
+++ b/rialto/jobs/resolver.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py
index cced2fe..754c1ee 100644
--- a/rialto/jobs/test_utils.py
+++ b/rialto/jobs/test_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/loader/__init__.py b/rialto/loader/__init__.py
index 7e1e936..543ec90 100644
--- a/rialto/loader/__init__.py
+++ b/rialto/loader/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/loader/config_loader.py b/rialto/loader/config_loader.py
index ead2705..06dffea 100644
--- a/rialto/loader/config_loader.py
+++ b/rialto/loader/config_loader.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/loader/interfaces.py b/rialto/loader/interfaces.py
index 9089f40..6b743a6 100644
--- a/rialto/loader/interfaces.py
+++ b/rialto/loader/interfaces.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/loader/pyspark_feature_loader.py b/rialto/loader/pyspark_feature_loader.py
index bd5f884..c8542a8 100644
--- a/rialto/loader/pyspark_feature_loader.py
+++ b/rialto/loader/pyspark_feature_loader.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/maker/__init__.py b/rialto/maker/__init__.py
index d31cd4a..86e6f34 100644
--- a/rialto/maker/__init__.py
+++ b/rialto/maker/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/maker/containers.py b/rialto/maker/containers.py
index 9a93417..851e627 100644
--- a/rialto/maker/containers.py
+++ b/rialto/maker/containers.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/maker/feature_maker.py b/rialto/maker/feature_maker.py
index 5fad175..f9201c1 100644
--- a/rialto/maker/feature_maker.py
+++ b/rialto/maker/feature_maker.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/maker/utils.py b/rialto/maker/utils.py
index 829dc60..89faee7 100644
--- a/rialto/maker/utils.py
+++ b/rialto/maker/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/maker/wrappers.py b/rialto/maker/wrappers.py
index a7a4103..5e0748e 100644
--- a/rialto/maker/wrappers.py
+++ b/rialto/maker/wrappers.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/__init__.py b/rialto/metadata/__init__.py
index 5e8893c..950221d 100644
--- a/rialto/metadata/__init__.py
+++ b/rialto/metadata/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/data_classes/__init__.py b/rialto/metadata/data_classes/__init__.py
index 79c3773..94ab807 100644
--- a/rialto/metadata/data_classes/__init__.py
+++ b/rialto/metadata/data_classes/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/data_classes/feature_metadata.py b/rialto/metadata/data_classes/feature_metadata.py
index cff0039..939bf1d 100644
--- a/rialto/metadata/data_classes/feature_metadata.py
+++ b/rialto/metadata/data_classes/feature_metadata.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/data_classes/group_metadata.py b/rialto/metadata/data_classes/group_metadata.py
index 5b1eb9c..0c8136c 100644
--- a/rialto/metadata/data_classes/group_metadata.py
+++ b/rialto/metadata/data_classes/group_metadata.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,5 +82,5 @@ def from_spark(cls, schema: Row) -> Self:
frequency=Schedule[schema.group_frequency],
description=schema.group_description,
key=schema.group_key,
- owner=schema.group_owner
+ owner=schema.group_owner,
)
diff --git a/rialto/metadata/enums.py b/rialto/metadata/enums.py
index be9ea5b..e910552 100644
--- a/rialto/metadata/enums.py
+++ b/rialto/metadata/enums.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/metadata_manager.py b/rialto/metadata/metadata_manager.py
index 90312f3..474bbda 100644
--- a/rialto/metadata/metadata_manager.py
+++ b/rialto/metadata/metadata_manager.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/metadata/utils.py b/rialto/metadata/utils.py
index 0cb591c..80d2dab 100644
--- a/rialto/metadata/utils.py
+++ b/rialto/metadata/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
def class_to_catalog_name(class_name) -> str:
"""
- Map python class name of feature group (CammelCase) to databricks compatible format (lowercase with underscores)
+ Map python class name of feature group (CamelCase) to databricks compatible format (lowercase with underscores)
:param class_name: Python class name
:return: feature storage name
diff --git a/rialto/runner/__init__.py b/rialto/runner/__init__.py
index 6ae343f..a75d1e1 100644
--- a/rialto/runner/__init__.py
+++ b/rialto/runner/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py
deleted file mode 100644
index 1bcef7b..0000000
--- a/rialto/runner/date_manager.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright 2022 ABSA Group Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-__all__ = ["DateManager"]
-
-from datetime import date, datetime
-from typing import List
-
-from dateutil.relativedelta import relativedelta
-
-from rialto.runner.config_loader import ScheduleConfig
-
-
-class DateManager:
- """Date generation and shifts based on configuration"""
-
- @staticmethod
- def str_to_date(str_date: str) -> date:
- """
- Convert YYYY-MM-DD string to date
-
- :param str_date: string date
- :return: date
- """
- return datetime.strptime(str_date, "%Y-%m-%d").date()
-
- @staticmethod
- def date_subtract(run_date: date, units: str, value: int) -> date:
- """
- Generate starting date from given date and config
-
- :param run_date: base date
- :param units: units: years, months, weeks, days
- :param value: number of units to subtract
- :return: Starting date
- """
- if units == "years":
- return run_date - relativedelta(years=value)
- if units == "months":
- return run_date - relativedelta(months=value)
- if units == "weeks":
- return run_date - relativedelta(weeks=value)
- if units == "days":
- return run_date - relativedelta(days=value)
- raise ValueError(f"Unknown time unit {units}")
-
- @staticmethod
- def all_dates(date_from: date, date_to: date) -> List[date]:
- """
- Get list of all dates between, inclusive
-
- :param date_from: starting date
- :param date_to: ending date
- :return: List[date]
- """
- if date_to < date_from:
- date_to, date_from = date_from, date_to
-
- return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)]
-
- @staticmethod
- def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[date]:
- """
- Select dates inside given interval depending on frequency and selected day
-
- :param date_from: interval start
- :param date_to: interval end
- :param schedule: schedule config
- :return: list of dates
- """
- options = DateManager.all_dates(date_from, date_to)
- if schedule.frequency == "daily":
- return options
- if schedule.frequency == "weekly":
- return [x for x in options if x.isoweekday() == schedule.day]
- if schedule.frequency == "monthly":
- return [x for x in options if x.day == schedule.day]
- raise ValueError(f"Unknown frequency {schedule.frequency}")
-
- @staticmethod
- def to_info_date(date: date, schedule: ScheduleConfig) -> date:
- """
- Shift given date according to config
-
- :param date: input date
- :param schedule: schedule config
- :return: date
- """
- if isinstance(schedule.info_date_shift, List):
- for shift in schedule.info_date_shift:
- date = DateManager.date_subtract(date, units=shift.units, value=shift.value)
- else:
- date = DateManager.date_subtract(
- date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value
- )
- return date
diff --git a/rialto/runner/engine.py b/rialto/runner/engine.py
new file mode 100644
index 0000000..43dcc39
--- /dev/null
+++ b/rialto/runner/engine.py
@@ -0,0 +1,159 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["RunnerEngine"]
+
+import traceback
+from datetime import datetime
+from typing import List
+
+from loguru import logger
+from pyspark.sql import DataFrame
+
+from rialto.runner.runner_services import RunnerServices
+from rialto.runner.services.config_loader import PipelineConfig
+from rialto.runner.services.result_mapper import TaskResultMapper
+from rialto.runner.services.task_registry import PipelineTask
+
+
+class RunnerEngine:
+ """Orchestrates pipeline execution lifecycle and task tracking"""
+
+ def __init__(self, services: RunnerServices, rerun: bool, skip_dependencies: bool):
+ self.services = services
+ self.rerun = rerun
+ self.skip_dependencies = skip_dependencies
+
+ def select_pipelines(self, op: str = None) -> List[PipelineConfig]:
+ """Select pipelines to run based on operation name"""
+ if op:
+ selected = [p for p in self.services.config.pipelines if p.name == op]
+ if not selected:
+ raise ValueError(f"Unknown operation selected: {op}")
+ return selected
+ return self.services.config.pipelines
+
+ def register_tasks(self, pipelines: List[PipelineConfig]) -> None:
+ """Register tasks for all pipelines and date combinations"""
+ for pipeline in pipelines:
+ for exec_date, partition_date in self.services.date_manager.get_execution_and_partition_dates(
+ pipeline.schedule
+ ):
+ self.services.registry.add_task(
+ name=pipeline.name,
+ execution_date=exec_date,
+ partition_date=partition_date,
+ config=pipeline,
+ )
+
+ def check_tasks(self) -> None:
+ """Check task completion and dependency status"""
+ for task in self.services.registry.tasks:
+ if not self.rerun:
+ try:
+ self.services.task_checker.check_completion(task)
+ except Exception as e:
+ logger.error(f"{task.name} completion check failed for {task.partition_date}:\n\t{e}")
+ task.precheck_failed = True
+ task.error = str(e)
+ task.error_trace = traceback.format_exc()
+ if not self.skip_dependencies:
+ try:
+ self.services.task_checker.check_pipeline_dependencies(task)
+ except Exception as e:
+ logger.error(f"{task.name} dependency check failed for {task.partition_date}:\n\t{e}")
+ task.precheck_failed = True
+ task.error = str(e)
+ task.error_trace = traceback.format_exc()
+
+ def log_task_status(self) -> None:
+ """Log summary of task statuses"""
+ self.services.registry.log_status()
+
+ def run_tasks(self) -> None:
+ """Execute runnable tasks with per-task error isolation"""
+ for task in self.services.registry.tasks:
+ logger.info(f"Executing task {task.name} for partition date {task.partition_date}")
+ self._execute_task_with_tracking(task)
+
+ def _execute_task_with_tracking(self, task: PipelineTask) -> None:
+ """Execute single task with record tracking"""
+ run_start = datetime.now()
+
+ if task.precheck_failed:
+ self.services.tracker.add(TaskResultMapper.exception(task, run_start, task.error, task.error_trace))
+ return
+
+ # Skip already-complete tasks
+ if task.completion and not self.rerun:
+ logger.info(f"Skipping task {task.name} for partition {task.partition_date} - already complete")
+ self.services.tracker.add(TaskResultMapper.already_complete(task, run_start))
+ return
+
+ # Skip if dependencies not met
+ incomplete_deps = [
+ f"{dep.table.get_table_path()} from {dep.date_from} until {dep.date_until}"
+ for dep in task.dependencies
+ if not dep.complete
+ ]
+ if incomplete_deps and not self.skip_dependencies:
+ logger.info(
+ f"Incomplete dependencies for task {task.name} for "
+ f"partition {task.partition_date} - {', '.join(incomplete_deps)}"
+ )
+ self.services.tracker.add(TaskResultMapper.dependencies_incomplete(task, run_start, incomplete_deps))
+ return
+
+ # Execute task
+ try:
+ df = self.services.executor.execute(task)
+ self.services.writer.write(df, task.partition_date, task.target)
+ records = self.services.data_checker.check_written(task.target, task.partition_date, df)
+ logger.info(
+ f"Task {task.name} for partition {task.partition_date} completed successfully with {records} records"
+ )
+ self.services.tracker.add(TaskResultMapper.success(task, run_start, records))
+ except KeyboardInterrupt:
+ self.services.tracker.add(TaskResultMapper.interrupted(task, run_start))
+ raise
+ except Exception as e:
+ logger.exception(f"Task {task.name} failed for partition {task.partition_date}")
+ self.services.tracker.add(TaskResultMapper.exception(task, run_start, str(e), traceback.format_exc()))
+
+ def finalize(self) -> None:
+ """Send final reports via mail/bookkeeping"""
+ self.services.tracker.report_by_mail()
+ self.log_task_status()
+
+ def run(self, op: str = None) -> None:
+ """Execute all tasks"""
+ pipelines = self.select_pipelines(op)
+ self.register_tasks(pipelines)
+ self.check_tasks()
+ self.log_task_status()
+ self.run_tasks()
+ self.finalize()
+
+ def dry_run_execution(self, op: str = None) -> None:
+ """Execute pre-run checks without task execution"""
+ pipelines = self.select_pipelines(op)
+ self.register_tasks(pipelines)
+ self.check_tasks()
+ self.log_task_status()
+
+ def debug_first_task(self, op: str = None) -> DataFrame:
+ """Debug mode: execute first task and return result"""
+ pipelines = self.select_pipelines(op)
+ self.register_tasks(pipelines)
+ return self.services.executor.execute(self.services.registry.tasks[0])
diff --git a/rialto/runner/reporting/bookkeeper.py b/rialto/runner/reporting/bookkeeper.py
index f00d0bc..e06cbff 100644
--- a/rialto/runner/reporting/bookkeeper.py
+++ b/rialto/runner/reporting/bookkeeper.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/runner/reporting/mailer.py b/rialto/runner/reporting/mailer.py
index 1485a57..46c8bf3 100644
--- a/rialto/runner/reporting/mailer.py
+++ b/rialto/runner/reporting/mailer.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ class HTMLMessage:
@staticmethod
def _get_status_color(status: str):
- if status == "Success":
+ if status == "Success" or status == "Skipped":
return "#398f00"
elif status == "Error":
return "#ff0000"
@@ -128,10 +128,6 @@ def _head():
-
"""
@staticmethod
@@ -164,14 +160,10 @@ def _make_exceptions(records: List[Record]):
- Expand
-
-
-
-
{record.exception}
-
-
-
+
+
+
{record.exception}
+
"""
html += r
return html
diff --git a/rialto/runner/reporting/record.py b/rialto/runner/reporting/record.py
index 7099b2e..bd6ab4d 100644
--- a/rialto/runner/reporting/record.py
+++ b/rialto/runner/reporting/record.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/runner/reporting/tracker.py b/rialto/runner/reporting/tracker.py
index 0f926cc..8adc796 100644
--- a/rialto/runner/reporting/tracker.py
+++ b/rialto/runner/reporting/tracker.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +18,10 @@
from pyspark.sql import SparkSession
-from rialto.runner.config_loader import MailConfig
from rialto.runner.reporting.bookkeeper import BookKeeper
from rialto.runner.reporting.mailer import HTMLMessage, Mailer
from rialto.runner.reporting.record import Record
+from rialto.runner.services.config_loader import MailConfig
class Tracker:
@@ -29,9 +29,7 @@ class Tracker:
def __init__(self, mail_cfg: MailConfig, bookkeeping: str = None, spark: SparkSession = None):
self.records = []
- self.last_error = None
self.pipeline_start = datetime.now()
- self.exceptions = []
self.mail_cfg = mail_cfg
self.bookkeeper = None
diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py
index 384998f..dec6045 100644
--- a/rialto/runner/runner.py
+++ b/rialto/runner/runner.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,26 +14,16 @@
__all__ = ["Runner"]
-import datetime
-from datetime import date
-from typing import Dict, List, Tuple
+from typing import Dict
-from loguru import logger
from pyspark.sql import DataFrame, SparkSession
-import rialto.runner.utils as utils
-from rialto.common import TableReader
-from rialto.runner.config_loader import PipelineConfig, get_pipelines_config
-from rialto.runner.date_manager import DateManager
-from rialto.runner.reporting.record import Record
-from rialto.runner.reporting.tracker import Tracker
-from rialto.runner.table import Table
-from rialto.runner.transformation import Transformation
-from rialto.runner.writer import Writer
+from rialto.runner.engine import RunnerEngine
+from rialto.runner.runner_services import DefaultRunnerServices, RunnerServices
class Runner:
- """A scheduler and dependency checker for feature runs"""
+ """Entry point for pipeline execution orchestration (beginner-friendly API)"""
def __init__(
self,
@@ -45,312 +35,43 @@ def __init__(
skip_dependencies: bool = False,
overrides: Dict = None,
merge_schema: bool = False,
+ services: RunnerServices = None,
):
- self.spark = spark
- self.config = get_pipelines_config(config_path, overrides)
- self.reader = TableReader(spark)
- self.rerun = rerun
- self.skip_dependencies = skip_dependencies
- self.op = op
- self.writer = Writer(spark, merge_schema=merge_schema)
- self.tracker = Tracker(
- mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark
- )
-
- if run_date:
- run_date = DateManager.str_to_date(run_date)
- else:
- run_date = date.today()
-
- self.date_from = DateManager.date_subtract(
- run_date=run_date,
- units=self.config.runner.watched_period_units,
- value=self.config.runner.watched_period_value,
- )
-
- self.date_until = run_date
-
- if self.date_from > self.date_until:
- raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}")
- logger.info(f"Running period set to: {self.date_from} - {self.date_until}")
-
- def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame:
- """
- Run the job
-
- :param instance: Instance of Transformation
- :param run_date: date to run for
- :param pipeline: pipeline configuration
- :return: Dataframe
"""
- metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline)
-
- df = instance.run(
- spark=self.spark,
+ Initialize Runner for pipeline orchestration.
+
+ :param spark: SparkSession instance
+ :param config_path: Path to pipeline configuration YAML
+ :param run_date: Override run date (optional)
+ :param rerun: Force re-execution of completed tasks
+ :param op: Target specific pipeline by name (optional)
+ :param skip_dependencies: Skip dependency validation
+ :param overrides: Configuration overrides
+ :param merge_schema: Enable schema merging in writer
+ :param services: Custom RunnerServices bundle (optional, for advanced users)
+ """
+ self._services = services or DefaultRunnerServices.build(
+ spark=spark,
+ config_path=config_path,
run_date=run_date,
- config=pipeline,
- reader=self.reader,
- metadata_manager=metadata_manager,
- feature_loader=feature_loader,
+ merge_schema=merge_schema,
+ overrides=overrides,
)
-
- return df
-
- def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int:
- """
- Check if there are records written for given date
-
- :param info_date: date to check
- :param table: target table object
- :return: number of records
- """
- filters = {}
- if pipeline.target.rerun_filters is not None:
- filters = pipeline.target.rerun_filters
- else:
- if table.secondary_partitions:
- row = df.select(*table.secondary_partitions).distinct().collect()[0]
- for c in table.secondary_partitions:
- val = row[0][c]
- filters[c] = val
-
- df = self.reader.get_table(
- table.get_table_path(), date_column=table.partition, date_from=info_date, date_to=info_date, filters=filters
+ self._engine = RunnerEngine(
+ services=self._services,
+ rerun=rerun,
+ skip_dependencies=skip_dependencies,
)
-
- return df.count()
-
- def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]:
- """
- For given list of dates, check if there is a matching partition for each
-
- :param table: Table object
- :param dates: list of dates to check
- :return: list of bool
- """
- if utils.table_exists(self.spark, table.get_table_path()):
- checks = []
- for check_date in dates:
- df = self.reader.get_table(
- table.get_table_path(),
- date_column=table.partition,
- date_from=check_date,
- date_to=check_date,
- filters=target_filters,
- )
- data_exists = df.count() > 0
- if data_exists and target_filters is None and table.secondary_partitions is not None:
- # ensure rerun if the write consideres secondary partitions but the filter doesn't
- data_exists = False
- checks.append(data_exists)
- return checks
- else:
- logger.info(f"Table {table.get_table_path()} doesn't exist!")
- return [False for _ in dates]
-
- def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool:
- """
- Check for all dependencies in config if they have available partitions
-
- :param pipeline: configuration
- :param run_date: run date
- :return: bool
- """
- logger.info(f"{pipeline.name} checking dependencies for {run_date}")
-
- error = ""
-
- for dependency in pipeline.dependencies:
- dep_from = DateManager.date_subtract(run_date, dependency.interval.units, dependency.interval.value)
- logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}")
-
- possible_dep_dates = DateManager.all_dates(dep_from, run_date)
-
- logger.debug(f"Date column for {dependency.table} is {dependency.date_col}")
-
- source = Table(table_path=dependency.table, partition=dependency.date_col)
- if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters):
- logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled")
- else:
- msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}"
- logger.info(msg)
- error = error + msg + "\n"
-
- if error != "":
- self.tracker.last_error = error
- return False
-
- return True
-
- def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]:
- """
- Check if model has run for given dates
-
- :param target_path: Table object
- :param info_dates: list of dates
- :return: bool list
- """
- if self.rerun:
- return [False for _ in info_dates]
- else:
- return self.check_dates_have_data(target, info_dates, filters)
-
- def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]:
- """
- Select run dates and info dates based on completion
-
- :param pipeline: pipeline config
- :param table: table path
- :return: list of run dates and list of info dates
- """
- possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule)
- possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates]
- current_state = self._get_completion(table, possible_info_dates, filters)
-
- selection = [
- (run, info) for run, info, state in zip(possible_run_dates, possible_info_dates, current_state) if not state
- ]
-
- if not len(selection):
- logger.info(f"{pipeline.name} has no dates to run")
- return [], []
-
- selected_run_dates, selected_info_dates = zip(*selection)
- logger.info(f"{pipeline.name} identified to run for {selected_run_dates}")
-
- return list(selected_run_dates), list(selected_info_dates)
-
- def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int:
- """
- Run one pipeline for one date
-
- :param pipeline: pipeline cfg
- :param run_date: run date
- :param info_date: information date
- :param target: target Table
- :return: success bool
- """
- if self.skip_dependencies or self.check_dependencies(pipeline, run_date):
- logger.info(f"Running {pipeline.name} for {run_date}")
-
- feature_group = utils.load_module(pipeline.module)
- df = self._execute(feature_group, run_date, pipeline)
- self.writer.write(df, info_date, target)
- records = self._check_written(info_date, target, df, pipeline)
- logger.info(f"Generated {records} records")
- if records == 0:
- raise RuntimeError("No records generated")
- else:
- return records
- return 0
-
- def _run_pipeline(self, pipeline: PipelineConfig):
- """
- Run single pipeline for all required dates
-
- :param pipeline: pipeline cfg
- :return: success bool
- """
- target = Table(
- schema_path=pipeline.target.target_schema,
- class_name=pipeline.module.python_class,
- partition=pipeline.target.target_partition_column,
- secondary_partitions=pipeline.target.secondary_partition_columns,
- table=pipeline.target.custom_name,
- )
- logger.info(f"Loaded pipeline {pipeline.name}")
-
- selected_run_dates, selected_info_dates = self._select_run_dates(
- pipeline, target, pipeline.target.rerun_filters
- )
-
- # ----------- Checking dependencies available ----------
- for run_date, info_date in zip(selected_run_dates, selected_info_dates):
- run_start = datetime.datetime.now()
- try:
- records = self._run_one_date(pipeline, run_date, info_date, target)
- if records > 0:
- status = "Success"
- message = ""
- else:
- status = "Failure"
- message = self.tracker.last_error
- self.tracker.add(
- Record(
- job=pipeline.name,
- target=target.get_table_path(),
- date=info_date,
- time=datetime.datetime.now() - run_start,
- records=records,
- status=status,
- reason=message,
- )
- )
- except Exception as error:
- logger.error(f"An exception occurred in pipeline {pipeline.name}")
- logger.exception(error)
- self.tracker.add(
- Record(
- job=pipeline.name,
- target=target.get_table_path(),
- date=info_date,
- time=datetime.datetime.now() - run_start,
- records=0,
- status="Error",
- reason="Exception",
- exception=str(error),
- )
- )
- except KeyboardInterrupt:
- logger.error(f"Pipeline {pipeline.name} interrupted")
- self.tracker.add(
- Record(
- job=pipeline.name,
- target=target.get_table_path(),
- date=info_date,
- time=datetime.datetime.now() - run_start,
- records=0,
- status="Error",
- reason="Interrupted by user",
- )
- )
- raise KeyboardInterrupt
+ self.op = op
def __call__(self):
"""Execute pipelines"""
- logger.info("Executing pipelines")
- try:
- if self.op:
- selected = [p for p in self.config.pipelines if p.name == self.op]
- if len(selected) < 1:
- raise ValueError(f"Unknown operation selected: {self.op}")
- self._run_pipeline(selected[0])
- else:
- for pipeline in self.config.pipelines:
- self._run_pipeline(pipeline)
- finally:
- print(self.tracker.records)
- self.tracker.report_by_mail()
- logger.info("Execution finished")
+ self._engine.run(self.op)
- def debug(self) -> DataFrame:
- """Debug mode - run only first op for one date and return the resulting dataframe"""
- logger.info("Running in debug mode")
- if self.op:
- pipeline = [p for p in self.config.pipelines if p.name == self.op][0]
- else:
- pipeline = self.config.pipelines[0]
+ def dry_run(self):
+ """Dry run - log status of pipelines without executing"""
+ self._engine.dry_run_execution(self.op)
- target = Table(
- schema_path=pipeline.target.target_schema,
- class_name=pipeline.module.python_class,
- partition=pipeline.target.target_partition_column,
- secondary_partitions=pipeline.target.secondary_partition_columns,
- table=pipeline.target.custom_name,
- )
- selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target)
- if len(selected_run_dates) > 0:
- df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline)
- return self.writer._process(df, selected_info_dates[0], target)
- else:
- logger.info("No dates to run in debug mode")
+ def _debug(self) -> DataFrame:
+ """Debug mode - run only first op for one date and return the resulting dataframe"""
+ return self._engine.debug_first_task(self.op)
diff --git a/rialto/runner/runner_services.py b/rialto/runner/runner_services.py
new file mode 100644
index 0000000..9eea552
--- /dev/null
+++ b/rialto/runner/runner_services.py
@@ -0,0 +1,91 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["RunnerServices", "DefaultRunnerServices"]
+
+from dataclasses import dataclass
+
+from pyspark.sql import SparkSession
+
+from rialto.common import TableReader
+from rialto.runner.reporting.tracker import Tracker
+from rialto.runner.services.config_loader import ConfigLoader, PipelinesConfig
+from rialto.runner.services.data_checker import DataChecker
+from rialto.runner.services.date_manager import DateManager
+from rialto.runner.services.executor import PipelineExecutor
+from rialto.runner.services.task_registry import TaskRegistry
+from rialto.runner.services.task_status_checker import TaskStatusChecker
+from rialto.runner.services.writer import DatabricksWriter
+
+
+@dataclass
+class RunnerServices:
+ """Bundle of collaborator services for Runner orchestration"""
+
+ config: PipelinesConfig
+ date_manager: DateManager
+ writer: DatabricksWriter
+ data_checker: DataChecker
+ task_checker: TaskStatusChecker
+ registry: TaskRegistry
+ executor: PipelineExecutor
+ tracker: Tracker
+
+
+class DefaultRunnerServices:
+ """Factory for default Runner services composition"""
+
+ @staticmethod
+ def build(
+ spark: SparkSession,
+ config_path: str,
+ run_date: str = None,
+ merge_schema: bool = False,
+ overrides: dict = None,
+ ) -> RunnerServices:
+ """
+ Build default services for Runner.
+
+ :param spark: SparkSession instance
+ :param config_path: Path to pipeline configuration YAML
+ :param run_date: Override run date (optional)
+ :param merge_schema: Enable schema merging in writer
+ :param overrides: Configuration overrides
+ :return: RunnerServices bundle
+ """
+ config = ConfigLoader.load_yaml(config_path, overrides)
+ date_manager = DateManager(config.runner, run_date)
+ writer = DatabricksWriter(spark, merge_schema=merge_schema)
+
+ reader = TableReader(spark)
+ data_checker = DataChecker(reader)
+ task_checker = TaskStatusChecker(data_checker)
+ registry = TaskRegistry(spark, date_manager=date_manager)
+ executor = PipelineExecutor(spark=spark, reader=reader, checker=data_checker)
+ tracker = Tracker(
+ mail_cfg=config.runner.mail,
+ bookkeeping=config.runner.bookkeeping,
+ spark=spark,
+ )
+
+ return RunnerServices(
+ config=config,
+ date_manager=date_manager,
+ writer=writer,
+ data_checker=data_checker,
+ task_checker=task_checker,
+ registry=registry,
+ executor=executor,
+ tracker=tracker,
+ )
diff --git a/rialto/runner/services/__init__.py b/rialto/runner/services/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/rialto/runner/config_loader.py b/rialto/runner/services/config_loader.py
similarity index 72%
rename from rialto/runner/config_loader.py
rename to rialto/runner/services/config_loader.py
index 7978ac5..64801fb 100644
--- a/rialto/runner/config_loader.py
+++ b/rialto/runner/services/config_loader.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,15 @@
# limitations under the License.
__all__ = [
- "get_pipelines_config",
+ "ConfigLoader",
]
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union
-from pydantic import BaseModel, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field
from rialto.common.utils import load_yaml
-from rialto.runner.config_overrides import override_config
+from rialto.runner.services.config_overrides import override_config
class BaseConfig(BaseModel):
@@ -35,8 +35,10 @@ class IntervalConfig(BaseConfig):
class ScheduleConfig(BaseConfig):
frequency: str
- day: Optional[int] = 0
- info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0)
+ day: Optional[Union[int, str]] = 0
+ info_date_shift: Optional[Union[IntervalConfig, List[IntervalConfig]]] = Field(
+ default_factory=lambda: IntervalConfig(units="days", value=0)
+ )
class DependencyConfig(BaseConfig):
@@ -88,16 +90,16 @@ class PipelineConfig(BaseConfig):
name: str
module: ModuleConfig
schedule: ScheduleConfig
- dependencies: Optional[List[DependencyConfig]] = []
- target: TargetConfig = None
+ dependencies: Optional[List[DependencyConfig]] = Field(default_factory=list)
+ target: Optional[TargetConfig] = None
metadata_manager: Optional[MetadataManagerConfig] = None
feature_loader: Optional[FeatureLoaderConfig] = None
- extras: Optional[Dict] = {}
+ extras: Optional[Dict] = Field(default_factory=dict)
class PipelinesConfig(BaseConfig):
runner: RunnerConfig
- pipelines: list[PipelineConfig]
+ pipelines: List[PipelineConfig]
def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig:
@@ -108,3 +110,12 @@ def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig:
return PipelinesConfig(**cfg)
else:
return PipelinesConfig(**raw_config)
+
+
+class ConfigLoader:
+ """Loader for pipelines config"""
+
+ @staticmethod
+ def load_yaml(path: str, overrides: Dict) -> PipelinesConfig:
+ """Load yaml config and apply overrides"""
+ return get_pipelines_config(path, overrides)
diff --git a/rialto/runner/config_overrides.py b/rialto/runner/services/config_overrides.py
similarity index 98%
rename from rialto/runner/config_overrides.py
rename to rialto/runner/services/config_overrides.py
index cd81232..9350bca 100644
--- a/rialto/runner/config_overrides.py
+++ b/rialto/runner/services/config_overrides.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/rialto/runner/services/data_checker.py b/rialto/runner/services/data_checker.py
new file mode 100644
index 0000000..6e55e94
--- /dev/null
+++ b/rialto/runner/services/data_checker.py
@@ -0,0 +1,103 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+__all__ = ["DataChecker"]
+
+from datetime import date
+from typing import Dict
+
+from loguru import logger
+from pyspark.sql import DataFrame
+
+from rialto.common import DataReader
+from rialto.runner.services.table import Table
+
+
+class DataChecker:
+ """Checks if data for given date or date range is present in storage"""
+
+ def __init__(self, reader: DataReader):
+ self.reader = reader
+
+ def check_date(self, target: Table, partition_date: date) -> bool:
+ """Check if data for given date is present in target
+
+ :param target: target Table to check
+ :param partition_date: Date to check
+ :return: True if data for given date is present, False otherwise
+ """
+ return self.check_range(target, partition_date, partition_date)
+
+ def check_range(self, target: Table, start_date: date, end_date: date) -> bool:
+ """Check if data for given date range is present in target
+
+ :param target: target Table to check
+ :param start_date: Starting date of the range to check
+ :param end_date: Ending date of the range to check
+ :return: True if data for given date range is present, False otherwise
+ """
+ if self.reader.table_exists(target.get_table_path()):
+ df = self.reader.get_table(
+ target.get_table_path(),
+ date_column=target.partition,
+ date_from=start_date,
+ date_to=end_date,
+ filters=target.filters,
+ )
+ data_exists = df.count() > 0
+ if (
+ data_exists
+ and (target.filters is None or target.filters == {})
+ and target.secondary_partitions is not None
+ ):
+ logger.warning(
+ f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of "
+ f"secondary partitions and no filters."
+ )
+ data_exists = False
+ return data_exists
+ else:
+ logger.warning(f"Target table {target.get_table_path()} doesn't exist yet.")
+ return False
+
+ def _get_filters(self, target: Table, df: DataFrame) -> Dict:
+ if target.filters is not None:
+ return target.filters
+ elif target.secondary_partitions:
+ filters = {}
+ logger.info("Inferring target sub-partition values from generated data.")
+ row = df.select(*target.secondary_partitions).distinct().collect()[0]
+ for c in target.secondary_partitions:
+ filters[c] = row[c]
+ return filters
+ else:
+ return {}
+
+ def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int:
+ """Check how many records were written
+
+ :param target: target Table to check
+ :param partition_date: Date to check
+ :param df: DataFrame that was written, used to determine filters if not provided in config
+ :return: Number of records for given date
+ """
+ filters = self._get_filters(target, df)
+ df = self.reader.get_table(
+ target.get_table_path(),
+ date_column=target.partition,
+ date_from=partition_date,
+ date_to=partition_date,
+ filters=filters,
+ )
+
+ return df.count()
diff --git a/rialto/runner/services/date_manager.py b/rialto/runner/services/date_manager.py
new file mode 100644
index 0000000..518129c
--- /dev/null
+++ b/rialto/runner/services/date_manager.py
@@ -0,0 +1,144 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["DateManager"]
+
+from datetime import date, datetime
+from typing import List
+
+from dateutil.relativedelta import relativedelta
+from loguru import logger
+
+from rialto.runner.services.config_loader import RunnerConfig, ScheduleConfig
+
+
+class DateManager:
+ """Date generation and shifts based on configuration"""
+
+ def __init__(self, config: RunnerConfig, run_date: str = None):
+ if run_date:
+ run_date = self.str_to_date(run_date)
+ else:
+ run_date = date.today()
+
+ self.date_from = self.date_subtract(
+ input_date=run_date,
+ units=config.watched_period_units,
+ value=config.watched_period_value,
+ )
+
+ self.date_until = run_date
+
+ if self.date_from > self.date_until:
+ raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}")
+ logger.info(f"Running period set to: {self.date_from} - {self.date_until}")
+
+ def get_date_from(self) -> date:
+ """Get starting date of the execution window"""
+ return self.date_from
+
+ def get_date_until(self) -> date:
+ """Get ending date of the execution window"""
+ return self.date_until
+
+ @staticmethod
+ def str_to_date(str_date: str) -> date:
+ """
+ Convert YYYY-MM-DD string to date
+
+ :param str_date: string date
+ :return: date
+ """
+ try:
+ return datetime.strptime(str_date, "%Y-%m-%d").date()
+ except ValueError:
+ raise ValueError(f"Invalid date format: {str_date}. Expected YYYY-MM-DD.")
+
+ @staticmethod
+ def date_subtract(input_date: date, units: str, value: int) -> date:
+ """
+ Subtract given number of units from input date
+
+ :param input_date: base date
+ :param units: units: years, months, weeks, days
+ :param value: number of units to subtract
+ :return: Starting date
+ """
+ if units == "years":
+ return input_date - relativedelta(years=value)
+ if units == "months":
+ return input_date - relativedelta(months=value)
+ if units == "weeks":
+ return input_date - relativedelta(weeks=value)
+ if units == "days":
+ return input_date - relativedelta(days=value)
+ raise ValueError(f"Unknown time unit {units}")
+
+ @staticmethod
+ def all_dates(date_from: date, date_until: date) -> List[date]:
+ """
+ Get list of all dates between, inclusive
+
+ :param date_from: starting date
+ :param date_until: ending date
+ :return: List[date]
+ """
+ return [date_from + relativedelta(days=n) for n in range((date_until - date_from).days + 1)]
+
+ def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]:
+ """
+ Get list of execution and partition dates for given configuration
+
+ :return: List of tuples with execution and partition dates
+ """
+ execution_dates = self._execution_dates(schedule)
+ return [(ex_date, self._to_partition_date(ex_date, schedule)) for ex_date in execution_dates]
+
+ def _execution_dates(self, schedule: ScheduleConfig) -> List[date]:
+ """
+ Select dates inside given interval depending on frequency and selected day
+
+ :param schedule: schedule config
+ :return: List of execution dates
+ """
+ options = self.all_dates(self.date_from, self.date_until)
+ frequency = schedule.frequency.lower()
+ if frequency == "daily":
+ return options
+ if frequency == "weekly":
+ if not (1 <= schedule.day <= 7):
+ raise ValueError(f"Invalid day for weekly frequency: {schedule.day}. Must be 1-7.")
+ return [x for x in options if x.isoweekday() == schedule.day]
+ if frequency == "monthly":
+ if schedule.day == "last":
+ return [x for x in options if (x + relativedelta(days=1)).month != x.month]
+ if not (1 <= schedule.day <= 31):
+ raise ValueError(f"Invalid day for monthly frequency: {schedule.day}. Must be 1-31 or last.")
+ return [x for x in options if x.day == schedule.day]
+ raise ValueError(f"Unknown frequency: {schedule.frequency}")
+
+ def _to_partition_date(self, date: date, schedule: ScheduleConfig) -> date:
+ """
+ Shift given date according to config
+
+ :param date: input date
+ :param schedule: schedule config
+ :return: date
+ """
+ if isinstance(schedule.info_date_shift, list):
+ for shift in schedule.info_date_shift:
+ date = self.date_subtract(date, units=shift.units, value=shift.value)
+ else:
+ date = self.date_subtract(date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value)
+ return date
diff --git a/rialto/runner/services/executor.py b/rialto/runner/services/executor.py
new file mode 100644
index 0000000..be84ffe
--- /dev/null
+++ b/rialto/runner/services/executor.py
@@ -0,0 +1,96 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["PipelineExecutor"]
+
+from importlib import import_module
+from typing import Tuple
+
+from loguru import logger
+from pyspark.sql import DataFrame, SparkSession
+
+from rialto.common import DataReader
+from rialto.loader import PysparkFeatureLoader
+from rialto.metadata import MetadataManager
+from rialto.runner.services.config_loader import ModuleConfig, PipelineConfig
+from rialto.runner.services.data_checker import DataChecker
+from rialto.runner.services.task_registry import PipelineTask
+from rialto.runner.transformation import Transformation
+
+
+class PipelineExecutor:
+ """Executes a single pipeline task."""
+
+ def __init__(self, spark: SparkSession, reader: DataReader, checker: DataChecker):
+ self.spark = spark
+ self.reader = reader
+ self.checker = checker
+
+ def _init_tools(
+ self, spark: SparkSession, pipeline: PipelineConfig
+ ) -> Tuple[MetadataManager, PysparkFeatureLoader]:
+ """
+ Initialize metadata manager and feature loader
+
+ :param spark: Spark session
+ :param pipeline: Pipeline configuration
+ :return: MetadataManager and PysparkFeatureLoader
+ """
+ if pipeline.metadata_manager is not None:
+ metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema)
+ else:
+ metadata_manager = None
+
+ if pipeline.feature_loader is not None:
+ feature_loader = PysparkFeatureLoader(
+ spark,
+ feature_schema=pipeline.feature_loader.feature_schema,
+ metadata_schema=pipeline.feature_loader.metadata_schema,
+ )
+ else:
+ feature_loader = None
+ return metadata_manager, feature_loader
+
+ def _load_module(self, cfg: ModuleConfig) -> Transformation:
+ """
+ Load feature group
+
+ :param cfg: Feature configuration
+ :return: Transformation object
+ """
+ module = import_module(cfg.python_module)
+ class_obj = getattr(module, cfg.python_class)
+ return class_obj()
+
+ def execute(self, pipeline: PipelineTask) -> DataFrame:
+ """
+ Execute the pipeline task.
+
+ :param pipeline: Pipeline object to execute.
+ :return: DataFrame resulting from pipeline execution.
+ """
+ logger.info(f"Executing pipeline {pipeline.name} for partition date {pipeline.partition_date}")
+
+ # Load and run the job
+ job = self._load_module(pipeline.config.module)
+ metadata_manager, feature_loader = self._init_tools(self.spark, pipeline.config)
+ df = job.run(
+ spark=self.spark,
+ run_date=pipeline.execution_date,
+ config=pipeline.config,
+ reader=self.reader,
+ metadata_manager=metadata_manager,
+ feature_loader=feature_loader,
+ )
+ return df
diff --git a/rialto/runner/services/result_mapper.py b/rialto/runner/services/result_mapper.py
new file mode 100644
index 0000000..a749324
--- /dev/null
+++ b/rialto/runner/services/result_mapper.py
@@ -0,0 +1,113 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["TaskResultMapper"]
+
+from datetime import datetime
+
+from rialto.runner.reporting.record import Record
+from rialto.runner.services.task_registry import PipelineTask
+
+
+class TaskResultMapper:
+ """Maps task execution outcomes to Record objects with consistent schema"""
+
+ @staticmethod
+ def success(
+ task: PipelineTask,
+ run_start: datetime,
+ records_count: int,
+ ) -> Record:
+ """Map successful task execution to Record"""
+ task.result = str(records_count)
+ return Record(
+ job=task.name,
+ target=task.target.get_table_path(),
+ date=task.partition_date,
+ time=datetime.now() - run_start,
+ records=records_count,
+ status="Success",
+ reason="OK",
+ exception=None,
+ )
+
+ @staticmethod
+ def already_complete(task: PipelineTask, run_start: datetime) -> Record:
+ """Map skipped (already complete) task to Record"""
+ task.result = "Skipped"
+ return Record(
+ job=task.name,
+ target=task.target.get_table_path(),
+ date=task.partition_date,
+ time=datetime.now() - run_start,
+ records=0,
+ status="Skipped",
+ reason="AlreadyComplete",
+ exception=None,
+ )
+
+ @staticmethod
+ def dependencies_incomplete(
+ task: PipelineTask,
+ run_start: datetime,
+ failed_deps: list,
+ ) -> Record:
+ """Map dependency failure to Record"""
+ task.result = "Failed"
+ details = "Dependencies Incomplete: " + ",\n".join(failed_deps) if failed_deps else "Unknown"
+ return Record(
+ job=task.name,
+ target=task.target.get_table_path(),
+ date=task.partition_date,
+ time=datetime.now() - run_start,
+ records=0,
+ status="Failed",
+ reason="Dependencies Incomplete",
+ exception=details,
+ )
+
+ @staticmethod
+ def exception(
+ task: PipelineTask,
+ run_start: datetime,
+ exception_message: str,
+ traceback_str: str,
+ ) -> Record:
+ """Map exception during execution to Record"""
+ task.result = "Error"
+ return Record(
+ job=task.name,
+ target=task.target.get_table_path(),
+ date=task.partition_date,
+ time=datetime.now() - run_start,
+ records=0,
+ status="Error",
+ reason=exception_message,
+ exception=traceback_str,
+ )
+
+ @staticmethod
+ def interrupted(task: PipelineTask, run_start: datetime) -> Record:
+ """Map keyboard interrupt to Record"""
+ task.result = "Interrupt"
+ return Record(
+ job=task.name,
+ target=task.target.get_table_path(),
+ date=task.partition_date,
+ time=datetime.now() - run_start,
+ records=0,
+ status="Error",
+ reason="Keyboard Interrupt",
+ exception=None,
+ )
diff --git a/rialto/runner/table.py b/rialto/runner/services/table.py
similarity index 61%
rename from rialto/runner/table.py
rename to rialto/runner/services/table.py
index 2d44498..dedd072 100644
--- a/rialto/runner/table.py
+++ b/rialto/runner/services/table.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +14,48 @@
__all__ = ["Table"]
-from typing import List
+from typing import Dict, List
from rialto.metadata import class_to_catalog_name
+from rialto.runner.services.config_loader import DependencyConfig, PipelineConfig
class Table:
"""Handler for databricks catalog paths"""
+ @classmethod
+ def from_target_config(cls, config: PipelineConfig) -> "Table":
+ """
+ Create table object from pipeline config target section
+
+ :param config: Pipeline configuration
+
+ :return: Table object
+ """
+ return cls(
+ schema_path=config.target.target_schema,
+ class_name=config.module.python_class,
+ partition=config.target.target_partition_column,
+ secondary_partitions=config.target.secondary_partition_columns,
+ table=config.target.custom_name,
+ filters=config.target.rerun_filters,
+ )
+
+ @classmethod
+ def from_dependency_config(cls, config: DependencyConfig) -> "Table":
+ """
+ Create table object from pipeline config dependency section
+
+ :param config: Dependency configuration
+
+ :return: Table object
+ """
+ return cls(
+ table_path=config.table,
+ partition=config.date_col,
+ filters=config.filters,
+ )
+
def __init__(
self,
catalog: str = None,
@@ -32,12 +66,14 @@ def __init__(
class_name: str = None,
partition: str = None,
secondary_partitions: List[str] = None,
+ filters: Dict = None,
):
self.catalog = catalog
self.schema = schema
self.table = table
self.partition = partition
self.secondary_partitions = secondary_partitions
+ self.filters = filters
if schema_path:
schema_path = schema_path.split(".")
self.catalog = schema_path[0]
@@ -58,7 +94,7 @@ def get_table_path(self) -> str:
"""Get full table path"""
return f"{self.catalog}.{self.schema}.{self.table}"
- def get_all_partitions(self) -> List[str]:
+ def get_all_partition_columns(self) -> List[str]:
"""Get list of all partitions"""
if self.secondary_partitions:
return [self.partition] + self.secondary_partitions
diff --git a/rialto/runner/services/task_registry.py b/rialto/runner/services/task_registry.py
new file mode 100644
index 0000000..2bb7d13
--- /dev/null
+++ b/rialto/runner/services/task_registry.py
@@ -0,0 +1,109 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+__all__ = ["TaskRegistry", "PipelineTask", "PipelineDependency"]
+
+from dataclasses import dataclass, field
+from datetime import date
+from typing import Iterator, List
+
+from loguru import logger
+from pyspark.sql import SparkSession
+
+from rialto.runner.services.config_loader import PipelineConfig
+from rialto.runner.services.date_manager import DateManager
+from rialto.runner.services.table import Table
+
+
+@dataclass
+class PipelineDependency:
+ """Class representing a pipeline dependency, with associated table and date range for checking completion"""
+
+ table: Table
+ date_from: date
+ date_until: date
+ complete: bool = False
+
+
+@dataclass
+class PipelineTask:
+ """Class representing a pipeline to be executed."""
+
+ name: str
+ execution_date: date
+ partition_date: date
+ config: PipelineConfig
+ target: Table
+ dependencies: List[PipelineDependency] = field(default_factory=list)
+ completion: bool = False
+ dependencies_complete: bool = False
+ precheck_failed: bool = False
+ error: str | None = None
+ error_trace: str | None = None
+ result: str = ""
+
+
+class TaskRegistry:
+ """Registry for pipeline tasks to be executed"""
+
+ def __init__(self, spark: SparkSession, date_manager: DateManager):
+ self.spark = spark
+ self.date_manager = date_manager
+ self.tasks = []
+
+ def add_task(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None:
+ """
+ Add task to registry
+
+ :param name: Name of the pipeline
+ :param execution_date: Date when the pipeline is scheduled to run
+ :param partition_date: Date for which the pipeline is processing data
+ :param config: PipelineConfig object with pipeline configuration
+
+ :return: None, adds a Pipeline object to self.tasks
+ """
+ target = Table.from_target_config(config)
+ new_pipe = PipelineTask(
+ name=name, execution_date=execution_date, partition_date=partition_date, config=config, target=target
+ )
+
+ for dependency_config in config.dependencies:
+ dependency_table = Table.from_dependency_config(dependency_config)
+ dependency_from = self.date_manager.date_subtract(
+ execution_date, dependency_config.interval.units, dependency_config.interval.value
+ )
+ dependency = PipelineDependency(
+ table=dependency_table, date_from=dependency_from, date_until=execution_date
+ )
+ new_pipe.dependencies.append(dependency)
+
+ self.tasks.append(new_pipe)
+
+ def __iter__(self) -> Iterator[PipelineTask]:
+ """Allow iteration over tasks in execution plan"""
+ return iter(self.tasks)
+
+ def log_status(self) -> None:
+ """Log status of all tasks in registry, showing completion and dependency status"""
+ check = "\u2714" # ✔
+ cross = "\u2718" # ✘
+ status = f"\n{'Job Name':<50} {'Partition Date':<15} {'Complete':<8} {'Dependencies':<12} {'Result':<9}\n"
+ status = status + ("-" * 70 + "\n")
+ for task in self.tasks:
+ complete_icon = check if task.completion else cross
+ deps_icon = check if task.dependencies_complete else cross
+ status = (
+ status + f"{task.name:<50} {str(task.partition_date):<15} {complete_icon:^8} "
+ f"{deps_icon:^12} {task.result:<9}\n"
+ )
+ logger.info(status)
diff --git a/rialto/runner/services/task_status_checker.py b/rialto/runner/services/task_status_checker.py
new file mode 100644
index 0000000..10f5c31
--- /dev/null
+++ b/rialto/runner/services/task_status_checker.py
@@ -0,0 +1,58 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+__all__ = ["TaskStatusChecker"]
+
+from loguru import logger
+
+from rialto.runner.services.data_checker import DataChecker
+from rialto.runner.services.task_registry import PipelineTask
+
+
+class TaskStatusChecker:
+ """Handles completion and dependency checks for pipeline tasks."""
+
+ def __init__(self, checker: DataChecker):
+ self.checker = checker
+
+ def check_completion(self, pipeline: PipelineTask) -> None:
+ """
+ Check if pipeline is complete by checking if target data exists for partition date
+
+ :param pipeline: Pipeline object for which to check completion
+
+ :return: None, updates self.completion attribute
+ """
+ pipeline.completion = self.checker.check_date(pipeline.target, pipeline.partition_date)
+ logger.info(
+ f"Job {pipeline.name} completion status for partition date "
+ f"{pipeline.partition_date}: {pipeline.completion}"
+ )
+
+ def check_pipeline_dependencies(self, pipeline: PipelineTask) -> None:
+ """
+ Check if dependencies are complete by checking if data exists for each dependency in date range
+
+ :param pipeline: Pipeline object for which to check dependencies
+
+ :return: None, updates self.dependencies_complete attribute
+ """
+ for dependency in pipeline.dependencies:
+ dependency.complete = self.checker.check_range(
+ dependency.table, dependency.date_from, dependency.date_until
+ )
+ logger.info(
+ f"Dependency {dependency.table.get_table_path()} completion status for date range "
+ f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}"
+ )
+ pipeline.dependencies_complete = all([dependency.complete for dependency in pipeline.dependencies])
diff --git a/rialto/runner/writer.py b/rialto/runner/services/writer.py
similarity index 51%
rename from rialto/runner/writer.py
rename to rialto/runner/services/writer.py
index bc147fd..b4c8c91 100644
--- a/rialto/runner/writer.py
+++ b/rialto/runner/services/writer.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,21 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ["Writer"]
+__all__ = ["DatabricksWriter", "Writer"]
+from abc import ABC, abstractmethod
from datetime import date
-from typing import List
+from typing import Any, List
import pyspark.sql.functions as F
from loguru import logger
from pyspark.sql import DataFrame, SparkSession
-from rialto.runner.table import Table
+from rialto.runner.services.table import Table
-class Writer:
+class Writer(ABC):
"""Supporting class for runner"""
+ @abstractmethod
+ def write(self, df: DataFrame, info_date: date, table: Table) -> None:
+ """
+ Write dataframe to storage
+
+ :param df: dataframe to write
+ :param info_date: date to partition
+ :param table: path to write to
+ :return: None
+ """
+ pass
+
+
+class DatabricksWriter(Writer):
+ """Supporting class for runner, Databricks write operations"""
+
def __init__(self, spark: SparkSession, merge_schema=False):
self.spark = spark
self.merge_schema = merge_schema
@@ -61,50 +78,71 @@ def _align_schema(self, df: DataFrame, existing_columns: List) -> DataFrame:
:return: dataframe with aligned schema
"""
if existing_columns is not None:
+ missing = [c for c in existing_columns if c not in df.columns]
+ if missing:
+ raise ValueError(f"DataFrame is missing columns present in existing table: {missing}")
return df.select(
- *[F.col(c) for c in existing_columns if c in df.columns],
+ *[F.col(c) for c in existing_columns],
*[F.col(c) for c in df.columns if c not in existing_columns],
)
return df
- def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame:
- df = df.withColumn(table.partition, F.lit(info_date))
+ def _process(self, df: DataFrame, partition_date: date, table: Table) -> DataFrame:
+ df = df.withColumn(table.partition, F.lit(partition_date))
df = self._align_schema(df, self._get_existing_columns(table))
return df
- def _get_replace_condition(self, df: DataFrame, partition_cols: List[str]) -> str:
- row = df.select(*partition_cols).distinct().collect()
- if len(row) > 1:
- raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}")
-
- parts = []
- for c in partition_cols:
- val = row[0][c]
- if val is None:
- parts.append(f"{c} IS NULL")
- elif isinstance(val, (int, float)):
- parts.append(f"{c} = {val}")
- else:
- parts.append(f"{c} = '{val}'")
- condition = " AND ".join(parts)
- return condition
-
- def write(self, df: DataFrame, info_date: date, table: Table) -> None:
+ def _get_replace_expression(self, key: str, value: Any) -> str:
+ if value is None:
+ return f"{key} IS NULL"
+ elif isinstance(value, (int, float)):
+ return f"{key} = {value}"
+ else:
+ return f"{key} = '{value}'"
+
+ def _get_replace_condition(self, df: DataFrame, target: Table, partition_date: date) -> str:
+ partition_cols = target.get_all_partition_columns()
+
+ # only date column
+ if len(partition_cols) == 1:
+ return f"{partition_cols[0]} = '{partition_date.strftime('%Y-%m-%d')}'"
+
+ # if target filters present for all partitions
+ elif target.filters and len(partition_cols) - 1 == len(target.filters):
+ target.filters[target.partition] = partition_date.strftime("%Y-%m-%d")
+ parts = []
+ for c in partition_cols:
+ parts.append(self._get_replace_expression(c, target.filters[c]))
+ condition = " AND ".join(parts)
+ return condition
+ # grab from dataframe
+ else:
+ row = df.select(*partition_cols).distinct().collect()
+ if len(row) > 1:
+ raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}")
+
+ parts = []
+ for c in partition_cols:
+ parts.append(self._get_replace_expression(c, row[0][c]))
+ condition = " AND ".join(parts)
+ return condition
+
+ def write(self, df: DataFrame, partition_date: date, table: Table) -> None:
"""
Write dataframe to storage
:param df: dataframe to write
- :param info_date: date to partition
+ :param partition_date: date to partition
:param table: path to write to
:return: None
"""
self._create_schema(table)
- df = self._process(df, info_date, table)
+ df = self._process(df, partition_date, table)
- replace_where = self._get_replace_condition(df, table.get_all_partitions())
+ replace_where = self._get_replace_condition(df, table, partition_date)
df.write.format("delta").partitionBy(table.partition).mode("overwrite").option(
"mergeSchema", "true" if self.merge_schema else "false"
diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py
index 5b6f2eb..be5b903 100644
--- a/rialto/runner/transformation.py
+++ b/rialto/runner/transformation.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
from rialto.common import DataReader
from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
-from rialto.runner.config_loader import PipelineConfig
+from rialto.runner.services.config_loader import PipelineConfig
class Transformation(metaclass=abc.ABCMeta):
diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py
index 5af1723..7e42542 100644
--- a/rialto/runner/utils.py
+++ b/rialto/runner/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,82 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ["load_module", "table_exists", "get_partitions", "init_tools", "find_dependency"]
+__all__ = ["find_dependency"]
-from datetime import date
-from importlib import import_module
-from typing import List, Tuple
-
-from pyspark.sql import SparkSession
-
-from rialto.common import DataReader
-from rialto.loader import PysparkFeatureLoader
-from rialto.metadata import MetadataManager
-from rialto.runner.config_loader import ModuleConfig, PipelineConfig
-from rialto.runner.table import Table
-from rialto.runner.transformation import Transformation
-
-
-def load_module(cfg: ModuleConfig) -> Transformation:
- """
- Load feature group
-
- :param cfg: Feature configuration
- :return: Transformation object
- """
- module = import_module(cfg.python_module)
- class_obj = getattr(module, cfg.python_class)
- return class_obj()
-
-
-def table_exists(spark: SparkSession, table: str) -> bool:
- """
- Check table exists in spark catalog
-
- :param table: full table path
- :return: bool
- """
- return spark.catalog.tableExists(table)
-
-
-def get_partitions(reader: DataReader, table: Table) -> List[date]:
- """
- Get partition values
-
- :param table: Table object
- :return: List of partition values
- """
- rows = (
- reader.get_table(table.get_table_path(), date_column=table.partition)
- .select(table.partition)
- .distinct()
- .collect()
- )
- return [r[table.partition] for r in rows]
-
-
-def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]:
- """
- Initialize metadata manager and feature loader
-
- :param spark: Spark session
- :param pipeline: Pipeline configuration
- :return: MetadataManager and PysparkFeatureLoader
- """
- if pipeline.metadata_manager is not None:
- metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema)
- else:
- metadata_manager = None
-
- if pipeline.feature_loader is not None:
- feature_loader = PysparkFeatureLoader(
- spark,
- feature_schema=pipeline.feature_loader.feature_schema,
- metadata_schema=pipeline.feature_loader.metadata_schema,
- )
- else:
- feature_loader = None
- return metadata_manager, feature_loader
+from rialto.runner.services.config_loader import PipelineConfig
def find_dependency(config: PipelineConfig, name: str):
diff --git a/tests/__init__.py b/tests/__init__.py
index 79c3773..94ab807 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/common/conftest.py b/tests/common/conftest.py
index 79455ff..198d297 100644
--- a/tests/common/conftest.py
+++ b/tests/common/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/common/test_reader.py b/tests/common/test_reader.py
index 452ad13..8127154 100644
--- a/tests/common/test_reader.py
+++ b/tests/common/test_reader.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py
index f11344d..fb2a7ba 100644
--- a/tests/common/test_utils.py
+++ b/tests/common/test_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/__init__.py b/tests/jobs/__init__.py
index 79c3773..94ab807 100644
--- a/tests/jobs/__init__.py
+++ b/tests/jobs/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py
index dda863d..c00517d 100644
--- a/tests/jobs/conftest.py
+++ b/tests/jobs/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py
index 3d785c3..ad36d7d 100644
--- a/tests/jobs/resources.py
+++ b/tests/jobs/resources.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py
index b0fb898..c93b90e 100644
--- a/tests/jobs/test_decorators.py
+++ b/tests/jobs/test_decorators.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py
index 7069b12..86314e3 100644
--- a/tests/jobs/test_job/test_job.py
+++ b/tests/jobs/test_job/test_job.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py
index 7d9b409..55ea96e 100644
--- a/tests/jobs/test_job_base.py
+++ b/tests/jobs/test_job_base.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/test_resolver.py b/tests/jobs/test_resolver.py
index 443e27b..eca150d 100644
--- a/tests/jobs/test_resolver.py
+++ b/tests/jobs/test_resolver.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py
index 3143210..a3df677 100644
--- a/tests/jobs/test_test_utils.py
+++ b/tests/jobs/test_test_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/__init__.py b/tests/loader/__init__.py
index 79c3773..94ab807 100644
--- a/tests/loader/__init__.py
+++ b/tests/loader/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/metadata_config/full_example.yaml b/tests/loader/metadata_config/full_example.yaml
index 9ad780c..a8e7d1c 100644
--- a/tests/loader/metadata_config/full_example.yaml
+++ b/tests/loader/metadata_config/full_example.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/metadata_config/missing_field_example.yaml b/tests/loader/metadata_config/missing_field_example.yaml
index 1caf3b0..7956c08 100644
--- a/tests/loader/metadata_config/missing_field_example.yaml
+++ b/tests/loader/metadata_config/missing_field_example.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/metadata_config/missing_value_example.yaml b/tests/loader/metadata_config/missing_value_example.yaml
index 844e25f..e7da0c2 100644
--- a/tests/loader/metadata_config/missing_value_example.yaml
+++ b/tests/loader/metadata_config/missing_value_example.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/metadata_config/no_map_example.yaml b/tests/loader/metadata_config/no_map_example.yaml
index d3679fa..6234e76 100644
--- a/tests/loader/metadata_config/no_map_example.yaml
+++ b/tests/loader/metadata_config/no_map_example.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/metadata_config/test_main_config.py b/tests/loader/metadata_config/test_main_config.py
index b09f155..c8033eb 100644
--- a/tests/loader/metadata_config/test_main_config.py
+++ b/tests/loader/metadata_config/test_main_config.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/pyspark/dataframe_builder.py b/tests/loader/pyspark/dataframe_builder.py
index 94a755e..0fa4967 100644
--- a/tests/loader/pyspark/dataframe_builder.py
+++ b/tests/loader/pyspark/dataframe_builder.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/pyspark/example_cfg.yaml b/tests/loader/pyspark/example_cfg.yaml
index 6b19277..b5fa6c3 100644
--- a/tests/loader/pyspark/example_cfg.yaml
+++ b/tests/loader/pyspark/example_cfg.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/pyspark/resources.py b/tests/loader/pyspark/resources.py
index 64a8363..1f69226 100644
--- a/tests/loader/pyspark/resources.py
+++ b/tests/loader/pyspark/resources.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py
index dd2049f..5fc7338 100644
--- a/tests/loader/pyspark/test_from_cfg.py
+++ b/tests/loader/pyspark/test_from_cfg.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/__init__.py b/tests/maker/__init__.py
index 79c3773..94ab807 100644
--- a/tests/maker/__init__.py
+++ b/tests/maker/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/conftest.py b/tests/maker/conftest.py
index 79455ff..198d297 100644
--- a/tests/maker/conftest.py
+++ b/tests/maker/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_FeatureFunction.py b/tests/maker/test_FeatureFunction.py
index 43590ae..d478d76 100644
--- a/tests/maker/test_FeatureFunction.py
+++ b/tests/maker/test_FeatureFunction.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_FeatureHolder.py b/tests/maker/test_FeatureHolder.py
index 5c00cdb..c02f389 100644
--- a/tests/maker/test_FeatureHolder.py
+++ b/tests/maker/test_FeatureHolder.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_FeatureMaker.py b/tests/maker/test_FeatureMaker.py
index f3da26d..ed99155 100644
--- a/tests/maker/test_FeatureMaker.py
+++ b/tests/maker/test_FeatureMaker.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/__init__.py b/tests/maker/test_features/__init__.py
index 79c3773..94ab807 100644
--- a/tests/maker/test_features/__init__.py
+++ b/tests/maker/test_features/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/aggregated_num_sum_outbound.py b/tests/maker/test_features/aggregated_num_sum_outbound.py
index ce3937b..8d3bb4c 100644
--- a/tests/maker/test_features/aggregated_num_sum_outbound.py
+++ b/tests/maker/test_features/aggregated_num_sum_outbound.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/aggregated_num_sum_txn.py b/tests/maker/test_features/aggregated_num_sum_txn.py
index 6c807af..a005ef8 100644
--- a/tests/maker/test_features/aggregated_num_sum_txn.py
+++ b/tests/maker/test_features/aggregated_num_sum_txn.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/dependent_features_fail.py b/tests/maker/test_features/dependent_features_fail.py
index d9a8c7f..a702028 100644
--- a/tests/maker/test_features/dependent_features_fail.py
+++ b/tests/maker/test_features/dependent_features_fail.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/dependent_features_fail2.py b/tests/maker/test_features/dependent_features_fail2.py
index 4964c8a..96db490 100644
--- a/tests/maker/test_features/dependent_features_fail2.py
+++ b/tests/maker/test_features/dependent_features_fail2.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/dependent_features_ok.py b/tests/maker/test_features/dependent_features_ok.py
index 232f08b..f96852a 100644
--- a/tests/maker/test_features/dependent_features_ok.py
+++ b/tests/maker/test_features/dependent_features_ok.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/sequential_avg_outbound.py b/tests/maker/test_features/sequential_avg_outbound.py
index cedad5f..aefe46e 100644
--- a/tests/maker/test_features/sequential_avg_outbound.py
+++ b/tests/maker/test_features/sequential_avg_outbound.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/sequential_avg_txn.py b/tests/maker/test_features/sequential_avg_txn.py
index 65d1f7f..4e3bad3 100644
--- a/tests/maker/test_features/sequential_avg_txn.py
+++ b/tests/maker/test_features/sequential_avg_txn.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/sequential_for_testing.py b/tests/maker/test_features/sequential_for_testing.py
index 5a8de84..b9e0dd5 100644
--- a/tests/maker/test_features/sequential_for_testing.py
+++ b/tests/maker/test_features/sequential_for_testing.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/sequential_outbound.py b/tests/maker/test_features/sequential_outbound.py
index 6b0764e..ffe5961 100644
--- a/tests/maker/test_features/sequential_outbound.py
+++ b/tests/maker/test_features/sequential_outbound.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_features/sequential_outbound_with_param.py b/tests/maker/test_features/sequential_outbound_with_param.py
index eb50d80..3a003c4 100644
--- a/tests/maker/test_features/sequential_outbound_with_param.py
+++ b/tests/maker/test_features/sequential_outbound_with_param.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/maker/test_wrappers.py b/tests/maker/test_wrappers.py
index 135b4ad..fc80051 100644
--- a/tests/maker/test_wrappers.py
+++ b/tests/maker/test_wrappers.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/metadata/__init__.py b/tests/metadata/__init__.py
index 79c3773..94ab807 100644
--- a/tests/metadata/__init__.py
+++ b/tests/metadata/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/metadata/conftest.py b/tests/metadata/conftest.py
index d500b57..01cfcde 100644
--- a/tests/metadata/conftest.py
+++ b/tests/metadata/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/metadata/resources.py b/tests/metadata/resources.py
index fb9df8c..5ab4e85 100644
--- a/tests/metadata/resources.py
+++ b/tests/metadata/resources.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/metadata/test_metadata_connector.py b/tests/metadata/test_metadata_connector.py
index 6594e6c..01e5ade 100644
--- a/tests/metadata/test_metadata_connector.py
+++ b/tests/metadata/test_metadata_connector.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/__init__.py b/tests/runner/__init__.py
index 79c3773..94ab807 100644
--- a/tests/runner/__init__.py
+++ b/tests/runner/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py
index 4e527be..095aa0f 100644
--- a/tests/runner/conftest.py
+++ b/tests/runner/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,6 +37,6 @@ def spark(request):
return spark
-@pytest.fixture(scope="function")
+@pytest.fixture(scope="session")
def basic_runner(spark):
return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
diff --git a/tests/runner/reporting/__init__.py b/tests/runner/reporting/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/runner/reporting/test_record.py b/tests/runner/reporting/test_record.py
new file mode 100644
index 0000000..37489f5
--- /dev/null
+++ b/tests/runner/reporting/test_record.py
@@ -0,0 +1,54 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import datetime, timedelta
+
+import pytest
+
+from rialto.runner.reporting.record import Record
+from rialto.runner.services.config_loader import RunnerConfig
+from rialto.runner.services.date_manager import DateManager
+
+
+@pytest.fixture(scope="module")
+def basic_date_manager() -> DateManager:
+ cfg = RunnerConfig(watched_period_units="months", watched_period_value=4)
+ return DateManager(cfg)
+
+
+@pytest.fixture(scope="function")
+def record(basic_date_manager):
+ return Record(
+ "job",
+ "target",
+ basic_date_manager.str_to_date("2024-01-01"),
+ timedelta(days=0, hours=1, minutes=2, seconds=3),
+ 1,
+ "status",
+ "reason",
+ None,
+ datetime(2024, 1, 1, 1, 2, 3),
+ )
+
+
+def test_record_to_spark(spark, basic_date_manager, record):
+ row = record.to_spark_row()
+ assert row.job == "job"
+ assert row.target == "target"
+ assert row.date == basic_date_manager.str_to_date("2024-01-01")
+ assert row.time == "1:02:03"
+ assert row.records == 1
+ assert row.status == "status"
+ assert row.reason == "reason"
+ assert row.exception is None
+ assert row.run_timestamp == datetime(2024, 1, 1, 1, 2, 3)
diff --git a/tests/runner/transformations/__init__.py b/tests/runner/resources/__init__.py
similarity index 83%
rename from tests/runner/transformations/__init__.py
rename to tests/runner/resources/__init__.py
index eaa15cd..08a0182 100644
--- a/tests/runner/transformations/__init__.py
+++ b/tests/runner/resources/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.runner.transformations.simple_group import SimpleGroup # noqa
+from tests.runner.resources.simple_group import SimpleGroup # noqa
diff --git a/tests/runner/transformations/config.yaml b/tests/runner/resources/config.yaml
similarity index 98%
rename from tests/runner/transformations/config.yaml
rename to tests/runner/resources/config.yaml
index 11f9fc9..659a893 100644
--- a/tests/runner/transformations/config.yaml
+++ b/tests/runner/resources/config.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/resources/config2.yaml
similarity index 96%
rename from tests/runner/transformations/config2.yaml
rename to tests/runner/resources/config2.yaml
index f7b9604..31067c1 100644
--- a/tests/runner/transformations/config2.yaml
+++ b/tests/runner/resources/config2.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/transformations/config3.yaml b/tests/runner/resources/config3.yaml
similarity index 96%
rename from tests/runner/transformations/config3.yaml
rename to tests/runner/resources/config3.yaml
index 72af1da..70e34d3 100644
--- a/tests/runner/transformations/config3.yaml
+++ b/tests/runner/resources/config3.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/overrider.yaml b/tests/runner/resources/overrider.yaml
similarity index 98%
rename from tests/runner/overrider.yaml
rename to tests/runner/resources/overrider.yaml
index 75257a7..add1fb4 100644
--- a/tests/runner/overrider.yaml
+++ b/tests/runner/resources/overrider.yaml
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/resources/simple_group.py
similarity index 96%
rename from tests/runner/transformations/simple_group.py
rename to tests/runner/resources/simple_group.py
index ec2311c..044c3bf 100644
--- a/tests/runner/transformations/simple_group.py
+++ b/tests/runner/resources/simple_group.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/runner/runner_resources.py b/tests/runner/runner_resources.py
index 17b5447..a657c0f 100644
--- a/tests/runner/runner_resources.py
+++ b/tests/runner/runner_resources.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# limitations under the License.
from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType
-from rialto.runner.date_manager import DateManager
+from rialto.runner.services.date_manager import DateManager
simple_group_data = [
("A", DateManager.str_to_date("2023-03-05")),
diff --git a/tests/runner/services/__init__.py b/tests/runner/services/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/runner/test_overrides.py b/tests/runner/services/test_config_overrides.py
similarity index 66%
rename from tests/runner/test_overrides.py
rename to tests/runner/services/test_config_overrides.py
index 5c738fb..6be3603 100644
--- a/tests/runner/test_overrides.py
+++ b/tests/runner/services/test_config_overrides.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,51 +16,53 @@
from rialto.runner import Runner
+CONFIG_PATH = "tests/runner/resources/overrider.yaml"
+
def test_overrides_simple(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]},
)
- assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
+ assert runner._services.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
def test_overrides_array_index(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.to[1]": "a@b.c"},
)
- assert runner.config.runner.mail.to == ["developer@testing.org", "a@b.c"]
+ assert runner._services.config.runner.mail.to == ["developer@testing.org", "a@b.c"]
def test_overrides_array_append(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.to[-1]": "test"},
)
- assert runner.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"]
+ assert runner._services.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"]
def test_overrides_array_lookup(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"},
)
- assert runner.config.pipelines[0].target.target_schema == "new_schema"
+ assert runner._services.config.pipelines[0].target.target_schema == "new_schema"
def test_overrides_combined(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={
"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"],
@@ -68,16 +70,16 @@ def test_overrides_combined(spark):
"pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1,
},
)
- assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
- assert runner.config.pipelines[0].target.target_schema == "new_schema"
- assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1
+ assert runner._services.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
+ assert runner._services.config.pipelines[0].target.target_schema == "new_schema"
+ assert runner._services.config.pipelines[0].schedule.info_date_shift[0].value == 1
def test_index_out_of_range(spark):
with pytest.raises(IndexError) as error:
Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.to[8]": "test"},
)
@@ -88,7 +90,7 @@ def test_invalid_index_key(spark):
with pytest.raises(ValueError) as error:
Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.test[8]": "test"},
)
@@ -99,7 +101,7 @@ def test_invalid_key(spark):
with pytest.raises(ValueError) as error:
Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.mail.test.param": "test"},
)
@@ -110,7 +112,7 @@ def test_new_key(spark):
with pytest.raises(ValidationError):
Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={"runner.some_value": 5},
)
@@ -119,7 +121,7 @@ def test_new_key(spark):
def test_replace_section(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={
"pipelines[name=SimpleGroup].feature_loader": {
@@ -128,13 +130,13 @@ def test_replace_section(spark):
}
},
)
- assert runner.config.pipelines[0].feature_loader.feature_schema == "catalog.features"
+ assert runner._services.config.pipelines[0].feature_loader.feature_schema == "catalog.features"
def test_add_section(spark):
runner = Runner(
spark,
- config_path="tests/runner/overrider.yaml",
+ config_path=CONFIG_PATH,
run_date="2023-03-31",
overrides={
"pipelines[name=OtherGroup].feature_loader": {
@@ -143,4 +145,15 @@ def test_add_section(spark):
}
},
)
- assert runner.config.pipelines[1].feature_loader.feature_schema == "catalog.features"
+ assert runner._services.config.pipelines[1].feature_loader.feature_schema == "catalog.features"
+
+
+def test_invalid_append_index_for_nested_path(spark):
+ with pytest.raises(ValueError) as error:
+ Runner(
+ spark,
+ config_path=CONFIG_PATH,
+ run_date="2023-03-31",
+ overrides={"runner.mail.to[-1].domain": "example.com"},
+ )
+ assert error.value.args[0] == "Invalid index -1 for key to in path ['to[-1]', 'domain']"
diff --git a/tests/runner/services/test_data_checker.py b/tests/runner/services/test_data_checker.py
new file mode 100644
index 0000000..a2bf90c
--- /dev/null
+++ b/tests/runner/services/test_data_checker.py
@@ -0,0 +1,230 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date
+from unittest.mock import MagicMock
+
+import pytest
+from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType
+
+from rialto.common import TableReader
+from rialto.runner.services.data_checker import DataChecker
+from rialto.runner.services.table import Table
+
+
+@pytest.fixture(scope="module")
+def simple_dataframe(spark):
+ df = [
+ ("A", date(2023, 3, 5)),
+ ("B", date(2023, 3, 12)),
+ ("C", date(2023, 3, 19)),
+ ]
+ schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)])
+ return spark.createDataFrame(df, schema=schema)
+
+
+@pytest.fixture(scope="module")
+def partitioned_dataframe(spark):
+ df = [
+ ("W", 1, "A", date(2023, 3, 5)),
+ ("E", 1, "B", date(2023, 3, 5)),
+ ("R", 2, "B", date(2023, 3, 5)),
+ ("T", 1, "B", date(2023, 3, 12)),
+ ("Y", 2, "A", date(2023, 3, 19)),
+ ]
+ schema = StructType(
+ [
+ StructField("VALUE", StringType(), True),
+ StructField("VERSION", IntegerType(), True),
+ StructField("TYPE", StringType(), True),
+ StructField("DATE", DateType(), True),
+ ]
+ )
+ return spark.createDataFrame(df, schema=schema)
+
+
+@pytest.fixture(scope="module")
+def new_insert_partitioned_dataframe(spark):
+ df = [
+ ("E", 1, "B", date(2023, 3, 5)),
+ ("T", 1, "B", date(2023, 3, 5)),
+ ]
+ schema = StructType(
+ [
+ StructField("VALUE", StringType(), True),
+ StructField("VERSION", IntegerType(), True),
+ StructField("TYPE", StringType(), True),
+ StructField("DATE", DateType(), True),
+ ]
+ )
+ return spark.createDataFrame(df, schema=schema)
+
+
+@pytest.mark.parametrize(
+ "partition_date, expected",
+ [
+ (date(2023, 3, 12), True),
+ (date(2023, 3, 10), False),
+ (date(2023, 3, 19), True),
+ (date(2023, 3, 26), False),
+ ],
+)
+def test_check_date(mocker, spark, simple_dataframe, partition_date, expected):
+ mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True)
+ mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe)
+
+ data_checker = DataChecker(TableReader(spark))
+ table = Table(table_path="catalog.schema.simple_group", partition="DATE")
+ result = data_checker.check_date(table, partition_date)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "start_date, end_date, expected",
+ [
+ (date(2023, 3, 12), date(2023, 4, 12), True),
+ (date(2023, 3, 10), date(2023, 3, 11), False),
+ (date(2023, 3, 19), date(2023, 3, 19), True),
+ (date(2023, 3, 26), date(2023, 3, 29), False),
+ ],
+)
+def test_check_range(mocker, spark, simple_dataframe, start_date, end_date, expected):
+ mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True)
+ mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe)
+
+ data_checker = DataChecker(TableReader(spark))
+ table = Table(table_path="catalog.schema.simple_group", partition="DATE")
+ result = data_checker.check_range(table, start_date, end_date)
+ assert result == expected
+
+
+def test_check_range_no_table(
+ mocker,
+ spark,
+):
+ mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=False)
+
+ data_checker = DataChecker(TableReader(spark))
+ table = Table(table_path="catalog.schema.simple_group", partition="DATE")
+ result = data_checker.check_date(table, date(2023, 3, 12))
+ assert result is False
+
+
+@pytest.mark.parametrize(
+ "partition_date, expected",
+ [
+ (date(2023, 2, 26), False),
+ (date(2023, 3, 5), True),
+ (date(2023, 3, 12), False),
+ (date(2023, 3, 19), False),
+ (date(2023, 3, 26), False),
+ ],
+)
+def test_check_date_secondary_partitions_and_filters(mocker, spark, partitioned_dataframe, partition_date, expected):
+ mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True)
+ mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe)
+
+ data_checker = DataChecker(TableReader(spark))
+ table = Table(
+ table_path="catalog.schema.simple_group",
+ partition="DATE",
+ secondary_partitions=["VERSION", "TYPE"],
+ filters={"version": 1, "type": "A"},
+ )
+ result = data_checker.check_date(table, partition_date)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "partition_date, expected",
+ [
+ (date(2023, 2, 26), False),
+ (date(2023, 3, 5), False),
+ (date(2023, 3, 12), False),
+ (date(2023, 3, 19), False),
+ (date(2023, 3, 26), False),
+ ],
+)
+def test_check_date_secondary_partitions_no_filters(mocker, spark, partitioned_dataframe, partition_date, expected):
+ mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True)
+ mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe)
+
+ data_checker = DataChecker(TableReader(spark))
+ table = Table(
+ table_path="catalog.schema.simple_group",
+ partition="DATE",
+ secondary_partitions=["VERSION", "TYPE"],
+ filters=None,
+ )
+ result = data_checker.check_date(table, partition_date)
+ assert result == expected
+
+
+def test_check_written_with_no_filters_or_secondary_partitions():
+ mock_reader = MagicMock()
+ mock_df = MagicMock()
+ mock_reader.get_table.return_value = mock_df
+ mock_df.count.return_value = 42
+
+ checker = DataChecker(mock_reader)
+ table = Table(table_path="dummy.table.path", partition="DATE")
+ result = checker.check_written(table, date(2023, 3, 5), MagicMock())
+ assert result == 42
+ mock_reader.get_table.assert_called_once_with(
+ "dummy.table.path",
+ date_column="DATE",
+ date_from=date(2023, 3, 5),
+ date_to=date(2023, 3, 5),
+ filters={},
+ )
+
+
+def test_check_written_with_filters():
+ mock_reader = MagicMock()
+ mock_df = MagicMock()
+ mock_reader.get_table.return_value = mock_df
+ mock_df.count.return_value = 42
+
+ checker = DataChecker(mock_reader)
+ table = Table(table_path="dummy.table.path", partition="DATE", filters={"foo": "bar"})
+ result = checker.check_written(table, date(2023, 3, 5), MagicMock())
+ assert result == 42
+ mock_reader.get_table.assert_called_once_with(
+ "dummy.table.path",
+ date_column="DATE",
+ date_from=date(2023, 3, 5),
+ date_to=date(2023, 3, 5),
+ filters={"foo": "bar"},
+ )
+
+
+def test_check_written_with_secondary_partitions(mocker, new_insert_partitioned_dataframe):
+ # Setup
+ mock_reader = MagicMock()
+ mock_df = MagicMock()
+ mock_df.count.return_value = 7
+ mock_reader.get_table.return_value = mock_df
+
+ checker = DataChecker(mock_reader)
+ table = Table(
+ table_path="dummy.table.path", partition="DATE", filters=None, secondary_partitions=["VERSION", "TYPE"]
+ )
+ result = checker.check_written(table, date(2023, 3, 5), new_insert_partitioned_dataframe)
+ assert result == 7
+ mock_reader.get_table.assert_called_once_with(
+ "dummy.table.path",
+ date_column="DATE",
+ date_from=date(2023, 3, 5),
+ date_to=date(2023, 3, 5),
+ filters={"VERSION": 1, "TYPE": "B"},
+ )
diff --git a/tests/runner/services/test_date_manager.py b/tests/runner/services/test_date_manager.py
new file mode 100644
index 0000000..61c9615
--- /dev/null
+++ b/tests/runner/services/test_date_manager.py
@@ -0,0 +1,227 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date, datetime
+
+import pytest
+
+from rialto.runner.services.config_loader import (
+ IntervalConfig,
+ RunnerConfig,
+ ScheduleConfig,
+)
+from rialto.runner.services.date_manager import DateManager
+
+
+def test_str_to_date():
+ assert DateManager.str_to_date("2023-03-05") == datetime.strptime("2023-03-05", "%Y-%m-%d").date()
+
+
+def test_str_to_date_bad():
+ with pytest.raises(ValueError):
+ DateManager.str_to_date("2023/03/05")
+
+
+def test_invalid_range():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=-3)
+ with pytest.raises(ValueError):
+ DateManager(config=runner_cfg, run_date="2023-03-05")
+
+
+@pytest.mark.parametrize(
+ "units , value, res",
+ [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")],
+)
+def test_date_subtract(units, value, res):
+ rundate = DateManager.str_to_date("2023-03-05")
+ date_from = DateManager.date_subtract(input_date=rundate, units=units, value=value)
+ assert date_from == DateManager.str_to_date(res)
+
+
+def test_date_subtract_bad():
+ rundate = DateManager.str_to_date("2023-03-05")
+ with pytest.raises(ValueError) as exception:
+ DateManager.date_subtract(input_date=rundate, units="random", value=1)
+ assert str(exception.value) == "Unknown time unit random"
+
+
+def test_all_dates():
+ all_dates = DateManager.all_dates(
+ date_from=DateManager.str_to_date("2023-02-05"),
+ date_until=DateManager.str_to_date("2023-04-12"),
+ )
+ assert len(all_dates) == 67
+ assert all_dates[1] == DateManager.str_to_date("2023-02-06")
+
+
+def test_date_from():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ date_manager = DateManager(config=runner_cfg, run_date="2023-03-05")
+ assert date_manager.get_date_from() == DateManager.str_to_date("2022-12-05")
+
+
+def test_date_until():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ date_manager = DateManager(config=runner_cfg, run_date="2023-03-05")
+ assert date_manager.get_date_until() == DateManager.str_to_date("2023-03-05")
+
+
+def test_run_dates_daily_no_shift():
+ runner_cfg = RunnerConfig(watched_period_units="weeks", watched_period_value=1)
+ cfg = ScheduleConfig(frequency="daily")
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg))
+
+ expected_execution_dates = [
+ date(2026, 5, 13),
+ date(2026, 5, 14),
+ date(2026, 5, 15),
+ date(2026, 5, 16),
+ date(2026, 5, 17),
+ date(2026, 5, 18),
+ date(2026, 5, 19),
+ date(2026, 5, 20),
+ ]
+
+ expected_partition_dates = [
+ date(2026, 5, 13),
+ date(2026, 5, 14),
+ date(2026, 5, 15),
+ date(2026, 5, 16),
+ date(2026, 5, 17),
+ date(2026, 5, 18),
+ date(2026, 5, 19),
+ date(2026, 5, 20),
+ ]
+ assert expected_execution_dates == list(exec)
+ assert expected_partition_dates == list(part)
+
+
+def test_run_dates_weekly_backwards_shift():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=1)
+ cfg = ScheduleConfig(frequency="weekly", day=5, info_date_shift=IntervalConfig(units="days", value=2))
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg))
+
+ expected_execution_dates = [
+ date(2026, 4, 24),
+ date(2026, 5, 1),
+ date(2026, 5, 8),
+ date(2026, 5, 15),
+ ]
+
+ expected_partition_dates = [
+ date(2026, 4, 22),
+ date(2026, 4, 29),
+ date(2026, 5, 6),
+ date(2026, 5, 13),
+ ]
+ assert expected_execution_dates == list(exec)
+ assert expected_partition_dates == list(part)
+
+
+def test_run_dates_monthly_with_forward_shift():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ cfg = ScheduleConfig(frequency="monthly", day=5, info_date_shift=IntervalConfig(units="days", value=-2))
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg))
+
+ expected_execution_dates = [
+ date(2026, 3, 5),
+ date(2026, 4, 5),
+ date(2026, 5, 5),
+ ]
+
+ expected_partition_dates = [
+ date(2026, 3, 7),
+ date(2026, 4, 7),
+ date(2026, 5, 7),
+ ]
+ assert expected_execution_dates == list(exec)
+ assert expected_partition_dates == list(part)
+
+
+def test_run_dates_monthly_with_double_shift():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ cfg = ScheduleConfig(
+ frequency="monthly",
+ day=5,
+ info_date_shift=[IntervalConfig(units="days", value=-2), IntervalConfig(units="months", value=2)],
+ )
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg))
+
+ expected_execution_dates = [
+ date(2026, 3, 5),
+ date(2026, 4, 5),
+ date(2026, 5, 5),
+ ]
+
+ expected_partition_dates = [
+ date(2026, 1, 7),
+ date(2026, 2, 7),
+ date(2026, 3, 7),
+ ]
+ assert expected_execution_dates == list(exec)
+ assert expected_partition_dates == list(part)
+
+
+def test_run_dates_monthly_last():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ cfg = ScheduleConfig(frequency="monthly", day="last")
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg))
+
+ expected_execution_dates = [
+ date(2026, 2, 28),
+ date(2026, 3, 31),
+ date(2026, 4, 30),
+ ]
+
+ expected_partition_dates = [
+ date(2026, 2, 28),
+ date(2026, 3, 31),
+ date(2026, 4, 30),
+ ]
+ assert expected_execution_dates == list(exec)
+ assert expected_partition_dates == list(part)
+
+
+def test_invalid_days():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ weekly_cfg = ScheduleConfig(frequency="weekly", day=12)
+ monthly_cfg = ScheduleConfig(frequency="monthly", day=42)
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ with pytest.raises(ValueError):
+ manager.get_execution_and_partition_dates(schedule=weekly_cfg)
+
+ with pytest.raises(ValueError):
+ manager.get_execution_and_partition_dates(schedule=monthly_cfg)
+
+
+def test_invalid_frequency():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ bad_cfg = ScheduleConfig(frequency="abc", day=42)
+ manager = DateManager(config=runner_cfg, run_date="2026-05-20")
+
+ with pytest.raises(ValueError):
+ manager.get_execution_and_partition_dates(schedule=bad_cfg)
+
+ with pytest.raises(ValueError):
+ manager.get_execution_and_partition_dates(schedule=bad_cfg)
diff --git a/tests/runner/services/test_executor.py b/tests/runner/services/test_executor.py
new file mode 100644
index 0000000..b7582c1
--- /dev/null
+++ b/tests/runner/services/test_executor.py
@@ -0,0 +1,160 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import date
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from rialto.runner.services.executor import PipelineExecutor
+from rialto.runner.services.task_registry import PipelineTask
+
+
+def _make_executor():
+ return PipelineExecutor(
+ spark=MagicMock(),
+ reader=MagicMock(),
+ checker=MagicMock(),
+ )
+
+
+def _make_task():
+ config = MagicMock()
+ config.module = MagicMock()
+ config.module.python_module = "tests.runner.transformations"
+ config.module.python_class = "SimpleGroup"
+ config.metadata_manager = None
+ config.feature_loader = None
+
+ task = PipelineTask(
+ name="SimpleGroup",
+ execution_date=date(2026, 1, 1),
+ partition_date=date(2026, 1, 1),
+ config=config,
+ target=MagicMock(),
+ )
+ return task
+
+
+def test_load_module_imports_and_instantiates_class():
+ executor = _make_executor()
+ cfg = MagicMock()
+ cfg.python_module = "fake.module"
+ cfg.python_class = "FakeClass"
+
+ fake_module = MagicMock()
+ fake_class = MagicMock()
+ fake_instance = MagicMock()
+ fake_class.return_value = fake_instance
+ setattr(fake_module, "FakeClass", fake_class)
+
+ with patch("rialto.runner.services.executor.import_module", return_value=fake_module) as import_module_mock:
+ result = executor._load_module(cfg)
+
+ import_module_mock.assert_called_once_with("fake.module")
+ fake_class.assert_called_once_with()
+ assert result is fake_instance
+
+
+def test_init_tools_returns_none_when_not_configured():
+ executor = _make_executor()
+
+ pipeline_cfg = MagicMock()
+ pipeline_cfg.metadata_manager = None
+ pipeline_cfg.feature_loader = None
+
+ metadata_manager, feature_loader = executor._init_tools(executor.spark, pipeline_cfg)
+
+ assert metadata_manager is None
+ assert feature_loader is None
+
+
+def test_init_tools_creates_both_tools_when_configured():
+ executor = _make_executor()
+
+ pipeline_cfg = MagicMock()
+ pipeline_cfg.metadata_manager = MagicMock()
+ pipeline_cfg.metadata_manager.metadata_schema = "meta_schema"
+ pipeline_cfg.feature_loader = MagicMock()
+ pipeline_cfg.feature_loader.feature_schema = "feature_schema"
+ pipeline_cfg.feature_loader.metadata_schema = "feature_meta_schema"
+
+ with patch("rialto.runner.services.executor.MetadataManager") as metadata_manager_cls, patch(
+ "rialto.runner.services.executor.PysparkFeatureLoader"
+ ) as feature_loader_cls:
+ metadata_manager, feature_loader = executor._init_tools(executor.spark, pipeline_cfg)
+
+ metadata_manager_cls.assert_called_once_with(executor.spark, "meta_schema")
+ feature_loader_cls.assert_called_once_with(
+ executor.spark,
+ feature_schema="feature_schema",
+ metadata_schema="feature_meta_schema",
+ )
+ assert metadata_manager is metadata_manager_cls.return_value
+ assert feature_loader is feature_loader_cls.return_value
+
+
+def test_execute_calls_job_run_and_returns_df():
+ executor = _make_executor()
+ task = _make_task()
+
+ mock_job = MagicMock()
+ mock_df = MagicMock()
+ mock_job.run.return_value = mock_df
+
+ with patch.object(executor, "_load_module", return_value=mock_job) as load_module_mock, patch.object(
+ executor, "_init_tools", return_value=(None, None)
+ ) as init_tools_mock:
+ result = executor.execute(task)
+
+ load_module_mock.assert_called_once_with(task.config.module)
+ init_tools_mock.assert_called_once_with(executor.spark, task.config)
+ mock_job.run.assert_called_once_with(
+ spark=executor.spark,
+ run_date=task.execution_date,
+ config=task.config,
+ reader=executor.reader,
+ metadata_manager=None,
+ feature_loader=None,
+ )
+ assert result is mock_df
+
+
+def test_execute_logs_start_of_execution():
+ executor = _make_executor()
+ task = _make_task()
+
+ mock_job = MagicMock()
+ mock_job.run.return_value = MagicMock()
+
+ with patch("rialto.runner.services.executor.logger.info") as logger_info_mock, patch.object(
+ executor, "_load_module", return_value=mock_job
+ ), patch.object(executor, "_init_tools", return_value=(None, None)):
+ executor.execute(task)
+
+ logger_info_mock.assert_called_once_with(f"Executing pipeline {task.name} for partition date {task.partition_date}")
+
+
+def test_execute_raises_when_job_run_fails():
+ executor = _make_executor()
+ task = _make_task()
+
+ mock_job = MagicMock()
+ mock_job.run.side_effect = RuntimeError("job failed")
+
+ with patch.object(executor, "_load_module", return_value=mock_job), patch.object(
+ executor, "_init_tools", return_value=(None, None)
+ ):
+ with pytest.raises(RuntimeError, match="job failed"):
+ executor.execute(task)
diff --git a/tests/runner/services/test_result_mapper.py b/tests/runner/services/test_result_mapper.py
new file mode 100644
index 0000000..6a6e8d0
--- /dev/null
+++ b/tests/runner/services/test_result_mapper.py
@@ -0,0 +1,120 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date, datetime, timedelta
+from unittest.mock import Mock
+
+from rialto.runner.services.result_mapper import TaskResultMapper
+
+
+def _make_task(name="my_pipeline", table_path="catalog.schema.table", partition_date=date(2026, 1, 1)):
+ task = Mock()
+ task.name = name
+ task.partition_date = partition_date
+ task.target.get_table_path.return_value = table_path
+ return task
+
+
+def _run_start():
+ return datetime.now() - timedelta(seconds=5)
+
+
+# ── success ──────────────────────────────────────────────────────────────────
+
+
+def test_success():
+ task = _make_task(name="pipe_a", table_path="cat.sch.tbl", partition_date=date(2026, 5, 1))
+ record = TaskResultMapper.success(task, _run_start(), records_count=42)
+ assert record.status == "Success"
+ assert record.reason == "OK"
+ assert record.exception is None
+ assert record.records == 42
+ assert record.job == "pipe_a"
+ assert record.target == "cat.sch.tbl"
+ assert record.date == date(2026, 5, 1)
+ assert isinstance(record.time, timedelta)
+ assert record.time.total_seconds() > 0
+
+
+# ── already_complete ──────────────────────────────────────────────────────────
+
+
+def test_already_complete():
+ task = _make_task(name="pipe_b", table_path="cat.sch.tbl2", partition_date=date(2026, 3, 15))
+ record = TaskResultMapper.already_complete(task, _run_start())
+ assert record.status == "Skipped"
+ assert record.reason == "AlreadyComplete"
+ assert record.exception is None
+ assert record.records == 0
+ assert record.job == "pipe_b"
+ assert record.target == "cat.sch.tbl2"
+ assert record.date == date(2026, 3, 15)
+
+
+# ── dependencies_incomplete ───────────────────────────────────────────────────
+
+
+def test_dependencies_incomplete_status_and_reason():
+ task = _make_task(name="pipe_b", table_path="cat.sch.tbl2", partition_date=date(2026, 3, 15))
+ failed = ["cat.s.dep1 from 2026-01-01 until 2026-01-07", "cat.s.dep2 from 2026-01-01 until 2026-01-07"]
+ record = TaskResultMapper.dependencies_incomplete(task, _run_start(), failed)
+ assert record.status == "Failed"
+ assert record.reason == "Dependencies Incomplete"
+ assert record.records == 0
+ assert record.job == "pipe_b"
+ assert record.target == "cat.sch.tbl2"
+ assert record.date == date(2026, 3, 15)
+ for dep in failed:
+ assert dep in record.exception
+
+
+def test_dependencies_incomplete_lists_failed_deps_in_exception():
+ failed = ["cat.s.dep1 from 2026-01-01 until 2026-01-07", "cat.s.dep2 from 2026-01-01 until 2026-01-07"]
+ record = TaskResultMapper.dependencies_incomplete(_make_task(), _run_start(), failed)
+ for dep in failed:
+ assert dep in record.exception
+
+
+def test_dependencies_incomplete_empty_list_falls_back():
+ record = TaskResultMapper.dependencies_incomplete(_make_task(), _run_start(), [])
+ assert record.exception == "Unknown"
+
+
+# ── exception ─────────────────────────────────────────────────────────────────
+
+
+def test_exception():
+ task = _make_task(name="pipe_c", table_path="cat.sch.tbl3", partition_date=date(2026, 4, 10))
+ record = TaskResultMapper.exception(task, _run_start(), "ValueError", "Traceback...")
+ assert record.status == "Error"
+ assert record.reason == "ValueError"
+ assert record.exception == "Traceback..."
+ assert record.records == 0
+ assert record.job == "pipe_c"
+ assert record.target == "cat.sch.tbl3"
+ assert record.date == date(2026, 4, 10)
+
+
+# ── interrupted ───────────────────────────────────────────────────────────────
+
+
+def test_interrupted():
+ task = _make_task(name="pipe_c", table_path="cat.sch.tbl3", partition_date=date(2026, 4, 10))
+ record = TaskResultMapper.interrupted(task, _run_start())
+ assert record.status == "Error"
+ assert record.reason == "Keyboard Interrupt"
+ assert record.exception is None
+ assert record.records == 0
+ assert record.job == "pipe_c"
+ assert record.target == "cat.sch.tbl3"
+ assert record.date == date(2026, 4, 10)
diff --git a/tests/runner/services/test_table.py b/tests/runner/services/test_table.py
new file mode 100644
index 0000000..d34d5c9
--- /dev/null
+++ b/tests/runner/services/test_table.py
@@ -0,0 +1,107 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from rialto.runner.services.config_loader import TargetConfig
+from rialto.runner.services.table import Table
+
+
+def test_table_basic_init():
+ t = Table(catalog="cat", schema="sch", table="tab", schema_path=None, table_path=None, class_name=None)
+
+ assert t.get_table_path() == "cat.sch.tab"
+ assert t.get_schema_path() == "cat.sch"
+
+
+def test_table_classname_init():
+ t = Table(catalog=None, schema=None, table=None, schema_path="cat.sch", table_path=None, class_name="ClaSs")
+
+ assert t.get_table_path() == "cat.sch.cla_ss"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "cla_ss"
+
+
+def test_table_path_init():
+ t = Table(catalog=None, schema=None, table=None, schema_path=None, table_path="cat.sch.tab", class_name=None)
+
+ assert t.get_table_path() == "cat.sch.tab"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "tab"
+
+
+def test_table_secondary_partitions():
+ t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"])
+
+ assert t.get_all_partition_columns() == ["part", "sec1", "sec2"]
+
+
+def test_table_get_partitions_only_main():
+ t = Table(catalog="cat", schema="sch", table="tab", partition="part")
+
+ assert t.get_all_partition_columns() == ["part"]
+
+
+def test_table_prioritize_table_name():
+ t = Table(catalog=None, schema=None, table="custom", schema_path="cat.sch", table_path=None, class_name="ClaSs")
+
+ assert t.get_table_path() == "cat.sch.custom"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "custom"
+
+
+def test_from_target_config():
+ tconfig = TargetConfig(
+ target_schema="cat.sch",
+ target_partition_column="part",
+ secondary_partition_columns=["sec1", "sec2"],
+ custom_name=None,
+ rerun_filters={"col": "value"},
+ )
+
+ pipeline_cfg = Mock()
+ pipeline_cfg.module.python_class = "TestClass"
+ pipeline_cfg.target = tconfig
+
+ t = Table.from_target_config(pipeline_cfg)
+
+ assert t.get_table_path() == "cat.sch.test_class"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "test_class"
+ assert t.get_all_partition_columns() == ["part", "sec1", "sec2"]
+ assert t.filters == {"col": "value"}
+
+
+def test_from_dependency_config():
+ dconfig = Mock()
+ dconfig.table = "cat.sch.tab"
+ dconfig.date_col = "date"
+ dconfig.filters = {"col": "value"}
+
+ t = Table.from_dependency_config(dconfig)
+
+ assert t.get_table_path() == "cat.sch.tab"
+ assert t.get_schema_path() == "cat.sch"
+ assert t.catalog == "cat"
+ assert t.schema == "sch"
+ assert t.table == "tab"
+ assert t.partition == "date"
+ assert t.filters == {"col": "value"}
diff --git a/tests/runner/services/test_task_registry.py b/tests/runner/services/test_task_registry.py
new file mode 100644
index 0000000..b13a579
--- /dev/null
+++ b/tests/runner/services/test_task_registry.py
@@ -0,0 +1,177 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date
+from unittest.mock import patch
+
+import pytest
+
+from rialto.runner.services.config_loader import (
+ DependencyConfig,
+ IntervalConfig,
+ ModuleConfig,
+ PipelineConfig,
+ RunnerConfig,
+ ScheduleConfig,
+ TargetConfig,
+)
+from rialto.runner.services.date_manager import DateManager
+from rialto.runner.services.task_registry import (
+ PipelineDependency,
+ PipelineTask,
+ TaskRegistry,
+)
+
+
+@pytest.fixture(scope="module")
+def date_manager():
+ runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3)
+ return DateManager(runner_cfg, "2020-01-01")
+
+
+@pytest.fixture
+def pipeline_config_no_deps():
+ return PipelineConfig(
+ name="test_pipeline",
+ module=ModuleConfig(python_module="some.module", python_class="TestClass"),
+ schedule=ScheduleConfig(frequency="monthly", day=1),
+ target=TargetConfig(
+ target_schema="cat.sch",
+ target_partition_column="part",
+ secondary_partition_columns=["sec1", "sec2"],
+ custom_name=None,
+ rerun_filters={"col": "value"},
+ ),
+ )
+
+
+@pytest.fixture
+def pipeline_config_with_deps():
+ return PipelineConfig(
+ name="test_pipeline_with_deps",
+ module=ModuleConfig(python_module="some.module", python_class="TestClass"),
+ schedule=ScheduleConfig(frequency="monthly", day=1),
+ target=TargetConfig(
+ target_schema="cat.sch",
+ target_partition_column="part",
+ ),
+ dependencies=[
+ DependencyConfig(
+ table="cat.sch.dep_table",
+ date_col="part",
+ interval=IntervalConfig(units="months", value=1),
+ )
+ ],
+ )
+
+
+def test_registry_initializes_empty(spark, date_manager):
+ registry = TaskRegistry(spark, date_manager)
+ assert list(registry) == []
+
+
+def test_add_task_no_dependencies(spark, date_manager, pipeline_config_no_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task(
+ name="test_pipeline",
+ execution_date=date(2020, 1, 1),
+ partition_date=date(2019, 12, 31),
+ config=pipeline_config_no_deps,
+ )
+
+ tasks = list(registry)
+ assert len(tasks) == 1
+
+ task = tasks[0]
+ assert isinstance(task, PipelineTask)
+ assert task.name == "test_pipeline"
+ assert task.execution_date == date(2020, 1, 1)
+ assert task.partition_date == date(2019, 12, 31)
+ assert task.config is pipeline_config_no_deps
+ assert task.dependencies == []
+ assert task.completion is False
+ assert task.dependencies_complete is False
+
+
+def test_add_task_target_table(spark, date_manager, pipeline_config_no_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task(
+ name="test_pipeline",
+ execution_date=date(2020, 1, 1),
+ partition_date=date(2019, 12, 31),
+ config=pipeline_config_no_deps,
+ )
+
+ task = list(registry)[0]
+ assert task.target.catalog == "cat"
+ assert task.target.schema == "sch"
+ assert task.target.partition == "part"
+ assert task.target.secondary_partitions == ["sec1", "sec2"]
+ assert task.target.filters == {"col": "value"}
+
+
+def test_add_task_with_dependencies(spark, date_manager, pipeline_config_with_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task(
+ name="test_pipeline_with_deps",
+ execution_date=date(2020, 1, 1),
+ partition_date=date(2019, 12, 31),
+ config=pipeline_config_with_deps,
+ )
+
+ task = list(registry)[0]
+ assert len(task.dependencies) == 1
+
+ dep = task.dependencies[0]
+ assert isinstance(dep, PipelineDependency)
+ assert dep.table.get_table_path() == "cat.sch.dep_table"
+ assert dep.date_until == date(2020, 1, 1)
+ assert dep.date_from == date(2019, 12, 1) # 1 month subtracted
+ assert dep.complete is False
+
+
+def test_add_multiple_tasks(spark, date_manager, pipeline_config_no_deps, pipeline_config_with_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task("pipeline_a", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps)
+ registry.add_task("pipeline_b", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_with_deps)
+
+ tasks = list(registry)
+ assert len(tasks) == 2
+ assert tasks[0].name == "pipeline_a"
+ assert tasks[1].name == "pipeline_b"
+
+
+def test_iteration(spark, date_manager, pipeline_config_no_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task("p1", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps)
+ registry.add_task("p2", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps)
+
+ names = [task.name for task in registry]
+ assert names == ["p1", "p2"]
+
+
+def test_log_status_contains_task_info(spark, date_manager, pipeline_config_no_deps):
+ registry = TaskRegistry(spark, date_manager)
+ registry.add_task("test_pipeline", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps)
+ registry.tasks[0].completion = True
+
+ with patch("rialto.runner.services.task_registry.logger") as mock_logger:
+ registry.log_status()
+
+ mock_logger.info.assert_called_once()
+ logged_output = mock_logger.info.call_args[0][0]
+
+ assert "test_pipeline" in logged_output
+ assert "2019-12-31" in logged_output
+ assert "✔" in logged_output
+ assert "✘" in logged_output
diff --git a/tests/runner/services/test_task_status_checker.py b/tests/runner/services/test_task_status_checker.py
new file mode 100644
index 0000000..55ad406
--- /dev/null
+++ b/tests/runner/services/test_task_status_checker.py
@@ -0,0 +1,132 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import date
+from unittest.mock import Mock, patch
+
+import pytest
+
+from rialto.runner.services.table import Table
+from rialto.runner.services.task_registry import PipelineDependency
+from rialto.runner.services.task_status_checker import TaskStatusChecker
+
+
+def make_task(name="my_pipeline", partition_date=date(2020, 1, 1), dependencies=None):
+ task = Mock()
+ task.name = name
+ task.partition_date = partition_date
+ task.target = Table(schema_path="cat.sch", class_name="TestClass", partition="part")
+ task.dependencies = dependencies or []
+ task.completion = False
+ task.dependencies_complete = False
+ return task
+
+
+def make_dependency(table_path="cat.sch.dep_table", date_from=date(2019, 12, 1), date_until=date(2020, 1, 1)):
+ table = Table(table_path=table_path, partition="part")
+ return PipelineDependency(table=table, date_from=date_from, date_until=date_until)
+
+
+@pytest.fixture
+def mock_data_checker():
+ return Mock() # Mock of DataChecker
+
+
+@pytest.fixture
+def status_checker(mock_data_checker):
+ return TaskStatusChecker(checker=mock_data_checker)
+
+
+def test_check_completion_sets_true_when_data_exists(status_checker, mock_data_checker):
+ mock_data_checker.check_date.return_value = True
+ task = make_task()
+
+ status_checker.check_completion(task)
+
+ assert task.completion is True
+ mock_data_checker.check_date.assert_called_once_with(task.target, task.partition_date)
+
+
+def test_check_completion_sets_false_when_no_data(status_checker, mock_data_checker):
+ mock_data_checker.check_date.return_value = False
+ task = make_task()
+
+ status_checker.check_completion(task)
+
+ assert task.completion is False
+
+
+# --- check_pipeline_dependencies ---
+
+
+def test_check_pipeline_dependencies_no_deps(status_checker, mock_data_checker):
+ task = make_task(dependencies=[])
+
+ status_checker.check_pipeline_dependencies(task)
+
+ assert task.dependencies_complete is True # all([]) == True
+ mock_data_checker.check_range.assert_not_called()
+
+
+def test_check_pipeline_dependencies_all_complete(status_checker, mock_data_checker):
+ mock_data_checker.check_range.return_value = True
+ dep1 = make_dependency("cat.sch.table_a")
+ dep2 = make_dependency("cat.sch.table_b")
+ task = make_task(dependencies=[dep1, dep2])
+
+ status_checker.check_pipeline_dependencies(task)
+
+ assert dep1.complete is True
+ assert dep2.complete is True
+ assert task.dependencies_complete is True
+ assert mock_data_checker.check_range.call_count == 2
+
+
+def test_check_pipeline_dependencies_one_incomplete(status_checker, mock_data_checker):
+ mock_data_checker.check_range.side_effect = [True, False]
+ dep1 = make_dependency("cat.sch.table_a")
+ dep2 = make_dependency("cat.sch.table_b")
+ task = make_task(dependencies=[dep1, dep2])
+
+ status_checker.check_pipeline_dependencies(task)
+
+ assert dep1.complete is True
+ assert dep2.complete is False
+ assert task.dependencies_complete is False
+
+
+def test_check_pipeline_dependencies_passes_correct_dates(status_checker, mock_data_checker):
+ mock_data_checker.check_range.return_value = True
+ dep = make_dependency(date_from=date(2019, 10, 1), date_until=date(2020, 1, 1))
+ task = make_task(dependencies=[dep])
+
+ status_checker.check_pipeline_dependencies(task)
+
+ mock_data_checker.check_range.assert_called_once_with(dep.table, date(2019, 10, 1), date(2020, 1, 1))
+
+
+def test_check_completion_logs_status():
+ checker = Mock()
+ checker.check_date.return_value = True
+ task = make_task(name="logged_pipeline", partition_date=date(2020, 3, 1))
+ status_checker = TaskStatusChecker(checker=checker)
+
+ with patch("rialto.runner.services.task_status_checker.logger") as mock_logger:
+ status_checker.check_completion(task)
+
+ mock_logger.info.assert_called_once()
+ log_msg = mock_logger.info.call_args[0][0]
+ assert "logged_pipeline" in log_msg
+ assert "2020-03-01" in log_msg
+ assert "True" in log_msg
diff --git a/tests/runner/services/test_writer.py b/tests/runner/services/test_writer.py
new file mode 100644
index 0000000..818d9fc
--- /dev/null
+++ b/tests/runner/services/test_writer.py
@@ -0,0 +1,244 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import date, datetime
+from unittest.mock import MagicMock, Mock, patch
+
+import pytest
+from pyspark.sql import Row
+from pyspark.sql.types import StringType, StructField, StructType
+
+from rialto.runner.services.table import Table
+from rialto.runner.services.writer import DatabricksWriter
+
+
+@pytest.fixture
+def writer(spark):
+ return DatabricksWriter(spark, merge_schema=False)
+
+
+@pytest.fixture
+def writer_merge(spark):
+ return DatabricksWriter(spark, merge_schema=True)
+
+
+@pytest.fixture
+def simple_table():
+ return Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date")
+
+
+@pytest.fixture
+def table_with_secondary(simple_table):
+ simple_table.secondary_partitions = ["region"]
+ return simple_table
+
+
+# --- _align_schema ---
+
+
+def test_align_schema_no_existing_columns_returns_df_unchanged(spark, writer):
+ df = spark.createDataFrame([Row(a=1, b=2)])
+ result = writer._align_schema(df, None)
+ assert result.columns == df.columns
+
+
+def test_align_schema_reorders_to_existing_columns(spark, writer):
+ df = spark.createDataFrame([Row(a=1, b=2, c=3)])
+ result = writer._align_schema(df, ["c", "a", "b"])
+ assert result.columns == ["c", "a", "b"]
+
+
+def test_align_schema_new_columns_appended_after_existing(spark, writer):
+ df = spark.createDataFrame([Row(a=1, b=2, new_col=3)])
+ result = writer._align_schema(df, ["a", "b"])
+ assert result.columns == ["a", "b", "new_col"]
+
+
+def test_align_schema_missing_existing_columns_raises_value_error(spark, writer):
+ # column "gone" is in existing but not in df — should raise ValueError
+ df = spark.createDataFrame([Row(a=1, b=2)])
+ with pytest.raises(ValueError):
+ writer._align_schema(df, ["gone", "a", "b"])
+
+
+# --- _get_replace_condition ---
+
+
+def test_get_replace_condition_string_value(spark, writer):
+ df = spark.createDataFrame([Row(info_date=date(2020, 1, 1))])
+ target = Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date")
+ condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1))
+ assert condition == "info_date = '2020-01-01'"
+
+
+def test_get_replace_condition_second_value_null(spark, writer):
+ df = spark.createDataFrame(
+ [Row(information_date="2020-01-01", region=None)],
+ schema=StructType(
+ [
+ StructField("information_date", StringType(), nullable=False),
+ StructField("region", StringType(), nullable=True),
+ ]
+ ),
+ )
+ target = Table(
+ schema_path="default.test_schema",
+ class_name="MyTable",
+ partition="information_date",
+ secondary_partitions=["region"],
+ )
+ condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1))
+ assert condition == "information_date = '2020-01-01' AND region IS NULL"
+
+
+def test_get_replace_condition_second_value_no_filters(spark, writer):
+ df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)])
+ target = Table(
+ schema_path="default.test_schema",
+ class_name="MyTable",
+ partition="information_date",
+ secondary_partitions=["region"],
+ )
+ condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1))
+ assert condition == "information_date = '2020-01-01' AND region = 1"
+
+
+def test_get_replace_condition_second_value_with_filters(spark, writer):
+ df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)])
+ target = Table(
+ schema_path="default.test_schema",
+ class_name="MyTable",
+ partition="information_date",
+ secondary_partitions=["region"],
+ filters={"region": 1},
+ )
+ condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1))
+ assert condition == "information_date = '2020-01-01' AND region = 1"
+
+
+def test_get_replace_condition_raises_on_multiple_distinct_values(spark, writer):
+ df = spark.createDataFrame(
+ [Row(information_date="2020-01-01", region=1), Row(information_date="2020-01-01", region=2)]
+ )
+ target = Table(
+ schema_path="default.test_schema",
+ class_name="MyTable",
+ partition="information_date",
+ secondary_partitions=["region"],
+ )
+ with pytest.raises(ValueError, match="more than 1 distinct value"):
+ writer._get_replace_condition(df, target, datetime(2020, 1, 1))
+
+
+# --- _process ---
+
+
+def test_process_adds_partition_column(spark, writer, simple_table):
+ df = spark.createDataFrame([Row(a=1)])
+ with patch.object(writer, "_get_existing_columns", return_value=None):
+ result = writer._process(df, date(2020, 1, 1), simple_table)
+ assert "info_date" in result.columns
+ assert result.collect()[0]["info_date"] == date(2020, 1, 1)
+
+
+def test_get_existing_columns_returns_columns():
+ # Arrange
+ spark = Mock()
+ table = Mock()
+ table.get_table_path.return_value = "catalog.schema.tbl"
+
+ spark_df = Mock()
+ spark_df.columns = ["id", "name"]
+ spark.table.return_value = spark_df
+
+ writer = DatabricksWriter(spark=spark)
+ result = writer._get_existing_columns(table)
+
+ assert result == ["id", "name"]
+ table.get_table_path.assert_called_once_with()
+ spark.table.assert_called_once_with("catalog.schema.tbl")
+
+
+def test_get_existing_columns_returns_none_on_exception():
+ # Arrange
+ spark = Mock()
+ table = Mock()
+ table.get_table_path.return_value = "catalog.schema.tbl"
+ spark.table.side_effect = Exception("table not found")
+
+ writer = DatabricksWriter(spark=spark)
+
+ # Patch the module-level logger used in writer.py
+ with patch("rialto.runner.services.writer.logger.warning") as warning_mock:
+ # Act
+ result = writer._get_existing_columns(table)
+
+ # Assert
+ assert result is None
+ table.get_table_path.assert_called() # called at least once (try + except log message)
+ spark.table.assert_called_once_with("catalog.schema.tbl")
+ warning_mock.assert_called_once()
+
+
+# --- write (integration of internal steps) ---
+
+
+def test_write_calls_create_schema(spark, writer, simple_table):
+ with patch.object(writer, "_create_schema") as mock_create, patch.object(writer, "_process"), patch.object(
+ writer, "_get_replace_condition"
+ ):
+ df = Mock()
+ df.write = Mock()
+ writer.write(df, Mock(), simple_table)
+
+ mock_create.assert_called_once_with(simple_table)
+
+
+def test_create_schema_uses_table_schema_path():
+ # Arrange
+ spark = Mock()
+ table = Mock()
+ table.get_schema_path.return_value = "my_catalog.my_schema"
+
+ writer = DatabricksWriter(spark=spark)
+ writer._create_schema(table)
+
+ spark.sql.assert_called_once_with("CREATE SCHEMA IF NOT EXISTS my_catalog.my_schema")
+ table.get_schema_path.assert_called_once_with()
+
+
+def test_write_merge_schema_option(spark, writer_merge, simple_table):
+ df = MagicMock()
+ df.write = MagicMock()
+
+ with patch.object(writer_merge, "_create_schema"), patch.object(
+ writer_merge, "_process", return_value=df
+ ), patch.object(writer_merge, "_get_replace_condition"):
+ writer_merge.write(df, date(2020, 1, 1), simple_table)
+
+ option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list
+ assert any(call.args == ("mergeSchema", "true") for call in option_calls)
+
+
+def test_write_not_merge_schema_option(spark, writer, simple_table):
+ df = MagicMock()
+ df.write = MagicMock()
+
+ with patch.object(writer, "_create_schema"), patch.object(writer, "_process", return_value=df), patch.object(
+ writer, "_get_replace_condition"
+ ):
+ writer.write(df, date(2020, 1, 1), simple_table)
+
+ option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list
+ assert any(call.args == ("mergeSchema", "false") for call in option_calls)
diff --git a/tests/runner/test_bookkeeping.py b/tests/runner/test_bookkeeping.py
deleted file mode 100644
index 06507d5..0000000
--- a/tests/runner/test_bookkeeping.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from datetime import datetime, timedelta
-
-from rialto.runner.date_manager import DateManager
-from rialto.runner.reporting.record import Record
-
-record = Record(
- "job",
- "target",
- DateManager.str_to_date("2024-01-01"),
- timedelta(days=0, hours=1, minutes=2, seconds=3),
- 1,
- "status",
- "reason",
- None,
- datetime(2024, 1, 1, 1, 2, 3),
-)
-
-
-def test_record_to_spark(spark):
- row = record.to_spark_row()
- assert row.job == "job"
- assert row.target == "target"
- assert row.date == DateManager.str_to_date("2024-01-01")
- assert row.time == "1:02:03"
- assert row.records == 1
- assert row.status == "status"
- assert row.reason == "reason"
- assert row.exception is None
- assert row.run_timestamp == datetime(2024, 1, 1, 1, 2, 3)
diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py
deleted file mode 100644
index 73b61b8..0000000
--- a/tests/runner/test_date_manager.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2022 ABSA Group Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from datetime import datetime
-
-import pytest
-
-from rialto.runner.config_loader import IntervalConfig, ScheduleConfig
-from rialto.runner.date_manager import DateManager
-
-
-def test_str_to_date():
- assert DateManager.str_to_date("2023-03-05") == datetime.strptime("2023-03-05", "%Y-%m-%d").date()
-
-
-@pytest.mark.parametrize(
- "units , value, res",
- [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")],
-)
-def test_date_from(units, value, res):
- rundate = DateManager.str_to_date("2023-03-05")
- date_from = DateManager.date_subtract(run_date=rundate, units=units, value=value)
- assert date_from == DateManager.str_to_date(res)
-
-
-def test_date_from_bad():
- rundate = DateManager.str_to_date("2023-03-05")
- with pytest.raises(ValueError) as exception:
- DateManager.date_subtract(run_date=rundate, units="random", value=1)
- assert str(exception.value) == "Unknown time unit random"
-
-
-def test_all_dates():
- all_dates = DateManager.all_dates(
- date_from=DateManager.str_to_date("2023-02-05"),
- date_to=DateManager.str_to_date("2023-04-12"),
- )
- assert len(all_dates) == 67
- assert all_dates[1] == DateManager.str_to_date("2023-02-06")
-
-
-def test_all_dates_reversed():
- all_dates = DateManager.all_dates(
- date_from=DateManager.str_to_date("2023-04-12"),
- date_to=DateManager.str_to_date("2023-02-05"),
- )
- assert len(all_dates) == 67
- assert all_dates[1] == DateManager.str_to_date("2023-02-06")
-
-
-def test_run_dates_weekly():
- cfg = ScheduleConfig(frequency="weekly", day=5)
-
- run_dates = DateManager.run_dates(
- date_from=DateManager.str_to_date("2023-02-05"),
- date_to=DateManager.str_to_date("2023-04-07"),
- schedule=cfg,
- )
-
- expected = [
- "2023-02-10",
- "2023-02-17",
- "2023-02-24",
- "2023-03-03",
- "2023-03-10",
- "2023-03-17",
- "2023-03-24",
- "2023-03-31",
- "2023-04-07",
- ]
- expected = [DateManager.str_to_date(d) for d in expected]
- assert run_dates == expected
-
-
-def test_run_dates_monthly():
- cfg = ScheduleConfig(frequency="monthly", day=5)
-
- run_dates = DateManager.run_dates(
- date_from=DateManager.str_to_date("2022-08-05"),
- date_to=DateManager.str_to_date("2023-04-07"),
- schedule=cfg,
- )
-
- expected = [
- "2022-08-05",
- "2022-09-05",
- "2022-10-05",
- "2022-11-05",
- "2022-12-05",
- "2023-01-05",
- "2023-02-05",
- "2023-03-05",
- "2023-04-05",
- ]
- expected = [DateManager.str_to_date(d) for d in expected]
- assert run_dates == expected
-
-
-def test_run_dates_daily():
- cfg = ScheduleConfig(frequency="daily")
-
- run_dates = DateManager.run_dates(
- date_from=DateManager.str_to_date("2023-03-28"),
- date_to=DateManager.str_to_date("2023-04-03"),
- schedule=cfg,
- )
-
- expected = [
- "2023-03-28",
- "2023-03-29",
- "2023-03-30",
- "2023-03-31",
- "2023-04-01",
- "2023-04-02",
- "2023-04-03",
- ]
- expected = [DateManager.str_to_date(d) for d in expected]
- assert run_dates == expected
-
-
-def test_run_dates_invalid():
- cfg = ScheduleConfig(frequency="random")
- with pytest.raises(ValueError) as exception:
- DateManager.run_dates(
- date_from=DateManager.str_to_date("2023-03-28"),
- date_to=DateManager.str_to_date("2023-04-03"),
- schedule=cfg,
- )
- assert str(exception.value) == "Unknown frequency random"
-
-
-@pytest.mark.parametrize(
- "shift, res",
- [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")],
-)
-def test_to_info_date(shift, res):
- cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)])
- base = DateManager.str_to_date("2023-03-05")
- info = DateManager.to_info_date(base, cfg)
- assert DateManager.str_to_date(res) == info
-
-
-@pytest.mark.parametrize(
- "unit, result",
- [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")],
-)
-def test_info_date_shift_units(unit, result):
- cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)])
- base = DateManager.str_to_date("2023-03-05")
- info = DateManager.to_info_date(base, cfg)
- assert DateManager.str_to_date(result) == info
-
-
-def test_info_date_shift_combined():
- cfg = ScheduleConfig(
- frequency="daily",
- info_date_shift=[IntervalConfig(units="months", value=3), IntervalConfig(units="days", value=4)],
- )
- base = DateManager.str_to_date("2023-03-05")
- info = DateManager.to_info_date(base, cfg)
- assert DateManager.str_to_date("2022-12-01") == info
diff --git a/tests/runner/test_engine.py b/tests/runner/test_engine.py
new file mode 100644
index 0000000..db74f8e
--- /dev/null
+++ b/tests/runner/test_engine.py
@@ -0,0 +1,450 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import date
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+
+from rialto.runner.engine import RunnerEngine
+from rialto.runner.services.config_loader import (
+ ModuleConfig,
+ PipelineConfig,
+ ScheduleConfig,
+)
+
+
+def _pipeline(name="p1", schedule="weekly"):
+ sch_cfg = ScheduleConfig(frequency=schedule, day=1)
+ return PipelineConfig(name=name, schedule=sch_cfg, module=ModuleConfig(python_module="mod", python_class="Class"))
+
+
+def _dependency(complete=True, table_path="src.sch.dep", date_from=date(2026, 1, 1), date_until=date(2026, 1, 7)):
+ dep = MagicMock()
+ dep.complete = complete
+ dep.date_from = date_from
+ dep.date_until = date_until
+ dep.table.get_table_path.return_value = table_path
+ return dep
+
+
+def _task(
+ name="p1",
+ completion=False,
+ dependencies_complete=True,
+ precheck_failed=False,
+ partition_date=date(2026, 1, 8),
+ execution_date=date(2026, 1, 8),
+ deps=None,
+):
+ t = MagicMock()
+ t.name = name
+ t.completion = completion
+ t.dependencies_complete = dependencies_complete
+ t.precheck_failed = precheck_failed
+ t.partition_date = partition_date
+ t.execution_date = execution_date
+ t.dependencies = deps if deps is not None else []
+ t.target = MagicMock()
+ return t
+
+
+def _services():
+ s = MagicMock()
+ s.config = MagicMock()
+ s.config.pipelines = [_pipeline("p1"), _pipeline("p2")]
+ s.date_manager = MagicMock()
+ s.registry = MagicMock()
+ s.registry.tasks = []
+ s.task_checker = MagicMock()
+ s.executor = MagicMock()
+ s.writer = MagicMock()
+ s.data_checker = MagicMock()
+ s.tracker = MagicMock()
+ return s
+
+
+# ---- select_pipelines ------------------------------------------------------
+
+
+def test_select_pipelines_returns_all_when_op_none():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ result = engine.select_pipelines(None)
+
+ assert result == services.config.pipelines
+
+
+def test_select_pipelines_returns_matching_op():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ result = engine.select_pipelines("p2")
+
+ assert len(result) == 1
+ assert result[0].name == "p2"
+
+
+def test_select_pipelines_raises_for_unknown_op():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with pytest.raises(ValueError, match="Unknown operation selected: nope"):
+ engine.select_pipelines("nope")
+
+
+# ---- register_tasks --------------------------------------------------------
+
+
+def test_register_tasks_adds_task_for_each_execution_partition_pair():
+ services = _services()
+ services.date_manager.get_execution_and_partition_dates.return_value = [
+ (date(2026, 1, 10), date(2026, 1, 8)),
+ (date(2026, 1, 17), date(2026, 1, 15)),
+ ]
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+ pipelines = [_pipeline("p1", "weekly")]
+
+ engine.register_tasks(pipelines)
+
+ assert services.registry.add_task.call_count == 2
+ services.registry.add_task.assert_has_calls(
+ [
+ call(
+ name="p1",
+ execution_date=date(2026, 1, 10),
+ partition_date=date(2026, 1, 8),
+ config=pipelines[0],
+ ),
+ call(
+ name="p1",
+ execution_date=date(2026, 1, 17),
+ partition_date=date(2026, 1, 15),
+ config=pipelines[0],
+ ),
+ ]
+ )
+
+
+# ---- check_tasks -----------------------------------------------------------
+
+
+def test_check_tasks_calls_both_checks_by_default():
+ services = _services()
+ t1 = _task(name="p1")
+ t2 = _task(name="p2")
+ services.registry.tasks = [t1, t2]
+
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+ engine.check_tasks()
+
+ services.task_checker.check_completion.assert_has_calls([call(t1), call(t2)])
+ services.task_checker.check_pipeline_dependencies.assert_has_calls([call(t1), call(t2)])
+
+
+def test_check_tasks_skips_completion_when_rerun_true():
+ services = _services()
+ t = _task()
+ services.registry.tasks = [t]
+
+ engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False)
+ engine.check_tasks()
+
+ services.task_checker.check_completion.assert_not_called()
+ services.task_checker.check_pipeline_dependencies.assert_called_once_with(t)
+
+
+def test_check_tasks_skips_dependencies_when_skip_dependencies_true():
+ services = _services()
+ t = _task()
+ services.registry.tasks = [t]
+
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=True)
+ engine.check_tasks()
+
+ services.task_checker.check_completion.assert_called_once_with(t)
+ services.task_checker.check_pipeline_dependencies.assert_not_called()
+
+
+def test_check_completion_records_exception_and_sets_precheck_failed():
+ services = _services()
+ t = _task()
+ services.registry.tasks = [t]
+ services.task_checker.check_completion.side_effect = RuntimeError("boom")
+
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+ engine.check_tasks()
+
+ services.task_checker.check_completion.assert_called_once_with(t)
+ assert t.precheck_failed is True
+ assert t.error == "boom"
+ assert "RuntimeError: boom" in t.error_trace
+
+
+def test_check_dependencies_records_exception_and_sets_precheck_failed():
+ services = _services()
+ t = _task()
+ services.registry.tasks = [t]
+ services.task_checker.check_pipeline_dependencies.side_effect = RuntimeError("boom")
+ engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False)
+ engine.check_tasks()
+
+ services.task_checker.check_pipeline_dependencies.assert_called_once_with(t)
+ assert t.precheck_failed is True
+ assert t.error == "boom"
+ assert "RuntimeError: boom" in t.error_trace
+
+
+# ---- run_tasks -------------------------------------------------------------
+
+
+def test_run_tasks_calls_execute_with_tracking_for_each_task():
+ services = _services()
+ t1 = _task(name="p1")
+ t2 = _task(name="p2")
+ services.registry.tasks = [t1, t2]
+
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch.object(engine, "_execute_task_with_tracking") as exec_track, patch(
+ "rialto.runner.engine.logger.info"
+ ) as log_info:
+ engine.run_tasks()
+
+ exec_track.assert_has_calls([call(t1), call(t2)])
+ assert exec_track.call_count == 2
+ assert log_info.call_count == 2
+
+
+# ---- _execute_task_with_tracking branches ----------------------------------
+
+
+def test_execute_task_with_tracking_records_precheck_failure_and_skips_execution():
+ services = _services()
+ task = _task(precheck_failed=True)
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.exception", return_value="rec") as mapper:
+ engine._execute_task_with_tracking(task)
+
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+ services.executor.execute.assert_not_called()
+
+
+def test_execute_task_with_tracking_skips_already_complete():
+ services = _services()
+ task = _task(completion=True, dependencies_complete=True, precheck_failed=False)
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.already_complete", return_value="rec") as mapper, patch(
+ "rialto.runner.engine.logger.info"
+ ):
+ engine._execute_task_with_tracking(task)
+
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+ services.executor.execute.assert_not_called()
+
+
+def test_execute_task_with_tracking_skips_incomplete_dependencies():
+ services = _services()
+ deps = [_dependency(complete=False, table_path="src.sch.dep1")]
+ task = _task(completion=False, dependencies_complete=False, deps=deps)
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.dependencies_incomplete", return_value="rec") as mapper, patch(
+ "rialto.runner.engine.logger.info"
+ ):
+ engine._execute_task_with_tracking(task)
+
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+ services.executor.execute.assert_not_called()
+
+
+def test_execute_task_with_tracking_success_path():
+ services = _services()
+ task = _task(completion=False, dependencies_complete=True)
+ services.executor.execute.return_value = "df"
+ services.data_checker.check_written.return_value = 123
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as mapper, patch(
+ "rialto.runner.engine.logger.info"
+ ):
+ engine._execute_task_with_tracking(task)
+
+ services.executor.execute.assert_called_once_with(task)
+ services.writer.write.assert_called_once_with("df", task.partition_date, task.target)
+ services.data_checker.check_written.assert_called_once_with(task.target, task.partition_date, "df")
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+
+
+def test_execute_task_with_tracking_exception_path():
+ services = _services()
+ task = _task(completion=False, dependencies_complete=True)
+ services.executor.execute.side_effect = RuntimeError("boom")
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with (
+ patch("rialto.runner.engine.TaskResultMapper.exception", return_value="rec") as mapper,
+ patch("rialto.runner.engine.logger.exception") as log_exc,
+ ):
+ engine._execute_task_with_tracking(task)
+
+ log_exc.assert_called_once()
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+ services.writer.write.assert_not_called()
+
+
+def test_execute_task_with_tracking_keyboard_interrupt_records_and_reraises():
+ services = _services()
+ task = _task(completion=False, dependencies_complete=True)
+ services.executor.execute.side_effect = KeyboardInterrupt()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.interrupted", return_value="rec") as mapper:
+ with pytest.raises(KeyboardInterrupt):
+ engine._execute_task_with_tracking(task)
+
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+
+
+def test_execute_task_with_tracking_runs_when_skip_dependencies_true_even_if_incomplete():
+ services = _services()
+ deps = [_dependency(complete=False)]
+ task = _task(completion=False, dependencies_complete=False, deps=deps)
+ services.executor.execute.return_value = "df"
+ services.data_checker.check_written.return_value = 1
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=True)
+
+ with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as success_mapper, patch(
+ "rialto.runner.engine.TaskResultMapper.dependencies_incomplete"
+ ) as dep_mapper, patch("rialto.runner.engine.logger.info"):
+ engine._execute_task_with_tracking(task)
+
+ dep_mapper.assert_not_called()
+ success_mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+
+
+def test_execute_task_with_tracking_runs_when_rerun_true_even_if_completion_true():
+ services = _services()
+ task = _task(completion=True, dependencies_complete=True)
+ services.executor.execute.return_value = "df"
+ services.data_checker.check_written.return_value = 10
+ engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False)
+
+ with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as mapper, patch(
+ "rialto.runner.engine.logger.info"
+ ):
+ engine._execute_task_with_tracking(task)
+
+ mapper.assert_called_once()
+ services.tracker.add.assert_called_once_with("rec")
+
+
+# ---- wrappers --------------------------------------------------------------
+
+
+def test_finalize_calls_tracker_report_by_mail_and_log():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch.object(engine, "log_task_status") as log_task_status:
+ engine.finalize()
+
+ services.tracker.report_by_mail.assert_called_once_with()
+ log_task_status.assert_called_once_with()
+
+
+def test_run_calls_full_flow():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes, patch.object(
+ engine, "register_tasks"
+ ) as register_tasks, patch.object(engine, "check_tasks") as check_tasks, patch.object(
+ engine, "log_task_status"
+ ) as log_status, patch.object(
+ engine, "run_tasks"
+ ) as run_tasks, patch.object(
+ engine, "finalize"
+ ) as finalize:
+ engine.run("my_op")
+
+ select_pipes.assert_called_once_with("my_op")
+ register_tasks.assert_called_once_with(["pipes"])
+ check_tasks.assert_called_once_with()
+ log_status.assert_called_once_with()
+ run_tasks.assert_called_once_with()
+ finalize.assert_called_once_with()
+
+
+def test_dry_run_execution_calls_expected_flow_without_run_or_finalize():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes, patch.object(
+ engine, "register_tasks"
+ ) as register_tasks, patch.object(engine, "check_tasks") as check_tasks, patch.object(
+ engine, "log_task_status"
+ ) as log_status, patch.object(
+ engine, "run_tasks"
+ ) as run_tasks, patch.object(
+ engine, "finalize"
+ ) as finalize:
+ engine.dry_run_execution("my_op")
+
+ select_pipes.assert_called_once_with("my_op")
+ register_tasks.assert_called_once_with(["pipes"])
+ check_tasks.assert_called_once_with()
+ log_status.assert_called_once_with()
+ run_tasks.assert_not_called()
+ finalize.assert_not_called()
+
+
+def test_debug_first_task_registers_and_executes_first_task():
+ services = _services()
+ t1 = _task(name="first")
+ t2 = _task(name="second")
+ services.registry.tasks = [t1, t2]
+ services.executor.execute.return_value = "df_debug"
+
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ with (
+ patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes,
+ patch.object(engine, "register_tasks") as register_tasks,
+ ):
+ result = engine.debug_first_task("my_op")
+
+ select_pipes.assert_called_once_with("my_op")
+ register_tasks.assert_called_once_with(["pipes"])
+ services.executor.execute.assert_called_once_with(t1)
+ assert result == "df_debug"
+
+
+# ---- logging ---------------------------------------------------------------
+def test_log_task_status_calls_registry_log_status():
+ services = _services()
+ engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False)
+
+ engine.log_task_status()
+
+ services.registry.log_status.assert_called_once_with()
diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py
index cbb4b7b..b07d9d9 100644
--- a/tests/runner/test_runner.py
+++ b/tests/runner/test_runner.py
@@ -1,4 +1,4 @@
-# Copyright 2022 ABSA Group Limited
+# Copyright 2022-2026 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,320 +11,117 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-from pyspark.sql import DataFrame
-import rialto.runner.utils as utils
-from rialto.common.table_reader import TableReader
-from rialto.runner.runner import DateManager, Runner
-from rialto.runner.table import Table
-from tests.runner.runner_resources import (
- dep1_data,
- dep2_data,
- general_schema,
- multi_part_data,
- multi_schema,
- simple_group_data,
-)
-from tests.runner.transformations.simple_group import SimpleGroup
+from unittest.mock import MagicMock, patch
+from rialto.runner.runner import Runner
-class MockReader(TableReader):
- def __init__(self, spark):
- self.spark = spark
- def _get_raw_data(self, table: str) -> DataFrame:
- if table == "catalog.schema.simple_group":
- return self.spark.createDataFrame(simple_group_data, general_schema)
- if table == "source.schema.dep1":
- return self.spark.createDataFrame(dep1_data, general_schema)
- if table == "source.schema.dep2":
- return self.spark.createDataFrame(dep2_data, general_schema)
- if table == "source.schema.multi_part_data":
- return self.spark.createDataFrame(multi_part_data, multi_schema)
+def test_runner_init_builds_default_services_when_not_provided():
+ spark = MagicMock()
+ built_services = MagicMock()
+ with (
+ patch("rialto.runner.runner.DefaultRunnerServices.build", return_value=built_services) as build_mock,
+ patch("rialto.runner.runner.RunnerEngine") as engine_cls,
+ ):
+ runner = Runner(
+ spark=spark,
+ config_path="tests/runner/transformations/config.yaml",
+ run_date="2026-01-01",
+ rerun=True,
+ op="SimpleGroup",
+ skip_dependencies=True,
+ overrides={"runner": {"watched_period_units": "days", "watched_period_value": 7}},
+ merge_schema=True,
+ )
-def test_table_exists(spark, mocker):
- mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True)
- utils.table_exists(spark, "abc")
- mock.assert_called_once_with("abc")
-
-
-def test_load_module(spark, basic_runner):
- module = utils.load_module(basic_runner.config.pipelines[0].module)
- assert isinstance(module, SimpleGroup)
-
-
-def test_generate(spark, mocker, basic_runner):
- run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
- group = SimpleGroup()
- config = basic_runner.config.pipelines[0]
- basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config)
-
- run.assert_called_once_with(
- reader=basic_runner.reader,
- run_date=DateManager.str_to_date("2023-01-31"),
+ build_mock.assert_called_once_with(
spark=spark,
- config=config,
- metadata_manager=None,
- feature_loader=None,
- )
-
-
-def test_generate_w_dep(spark, mocker, basic_runner):
- run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
- group = SimpleGroup()
- basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2])
- run.assert_called_once_with(
- reader=basic_runner.reader,
- run_date=DateManager.str_to_date("2023-01-31"),
- spark=spark,
- config=basic_runner.config.pipelines[2],
- metadata_manager=None,
- feature_loader=None,
- )
-
-
-def test_init_dates(spark):
- runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
- assert runner.date_from == DateManager.str_to_date("2023-01-31")
- assert runner.date_until == DateManager.str_to_date("2023-03-31")
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
- overrides={"runner.watched_period_units": "weeks", "runner.watched_period_value": 2},
- )
- assert runner.date_from == DateManager.str_to_date("2023-03-17")
- assert runner.date_until == DateManager.str_to_date("2023-03-31")
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config2.yaml",
- run_date="2023-03-31",
- )
- assert runner.date_from == DateManager.str_to_date("2023-02-24")
- assert runner.date_until == DateManager.str_to_date("2023-03-31")
-
-
-def test_completion(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.utils.table_exists", return_value=True)
-
- basic_runner.reader = MockReader(spark)
-
- dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- dates = [DateManager.str_to_date(d) for d in dates]
-
- comp = basic_runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates)
- expected = [False, True, True, True, False]
- assert comp == expected
-
-
-def test_completion_rerun(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
- runner.reader = MockReader(spark)
-
- dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- dates = [DateManager.str_to_date(d) for d in dates]
-
- comp = runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates)
- expected = [False, True, True, True, False]
- assert comp == expected
-
-
-def test_completion_secondary_partitions(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.utils.table_exists", return_value=True)
-
- basic_runner.reader = MockReader(spark)
-
- dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- dates = [DateManager.str_to_date(d) for d in dates]
- filters = {"version": 1, "type": "A"}
-
- comp = basic_runner._get_completion(
- Table(table_path="source.schema.multi_part_data", partition="DATE"), dates, filters
- )
- expected = [False, True, False, False, False]
- assert comp == expected
-
-
-def test_completion_secondary_partitions_no_filter(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.utils.table_exists", return_value=True)
-
- basic_runner.reader = MockReader(spark)
-
- dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- dates = [DateManager.str_to_date(d) for d in dates]
-
- comp = basic_runner._get_completion(
- Table(table_path="source.schema.multi_part_data", partition="DATE", secondary_partitions=["VERSION"]), dates
- )
- expected = [False, False, False, False, False]
- assert comp == expected
-
-
-def test_check_dates_have_partition(spark, mocker):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
- )
- runner.reader = MockReader(spark)
- dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
- dates = [DateManager.str_to_date(d) for d in dates]
- res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep1", partition="DATE"), dates)
- expected = [False, True, False]
- assert res == expected
-
-
-def test_check_dates_have_partition_no_table(spark, mocker):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=False)
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
- )
- dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
- dates = [DateManager.str_to_date(d) for d in dates]
- res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep66", partition="DATE"), dates)
- expected = [False, False, False]
- assert res == expected
-
-
-@pytest.mark.parametrize(
- "r_date, expected",
- [("2023-02-26", False), ("2023-03-05", True)],
-)
-def test_check_dependencies(spark, mocker, r_date, expected):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
- )
- runner.reader = MockReader(spark)
- res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date))
- assert res == expected
-
-
-@pytest.mark.parametrize(
- "r_date, expected",
- [("2023-03-19", True), ("2023-03-18", False)],
-)
-def test_check_dependencies_filter(spark, mocker, r_date, expected):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config3.yaml",
- run_date="2023-03-19",
- )
- runner.reader = MockReader(spark)
- res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date))
- assert res == expected
-
-
-def test_check_no_dependencies(spark, mocker):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(
- spark,
config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
+ run_date="2026-01-01",
+ merge_schema=True,
+ overrides={"runner": {"watched_period_units": "days", "watched_period_value": 7}},
)
- runner.reader = MockReader(spark)
- res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05"))
- assert res is True
-
-
-def test_select_dates(spark, mocker):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
-
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-31",
- overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 1},
+ engine_cls.assert_called_once_with(
+ services=built_services,
+ rerun=True,
+ skip_dependencies=True,
)
- runner.reader = MockReader(spark)
-
- r, i = runner._select_run_dates(
- runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE")
+ assert runner.op == "SimpleGroup"
+
+
+def test_runner_init_uses_injected_services_without_build():
+ spark = MagicMock()
+ injected_services = MagicMock()
+
+ with patch("rialto.runner.runner.DefaultRunnerServices.build") as build_mock, patch(
+ "rialto.runner.runner.RunnerEngine"
+ ) as engine_cls:
+ Runner(
+ spark=spark,
+ config_path="tests/runner/transformations/config.yaml",
+ services=injected_services,
+ )
+
+ build_mock.assert_not_called()
+ engine_cls.assert_called_once_with(
+ services=injected_services,
+ rerun=False,
+ skip_dependencies=False,
)
- expected_run = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- expected_run = [DateManager.str_to_date(d) for d in expected_run]
- expected_info = ["2023-03-02", "2023-03-09", "2023-03-16", "2023-03-23"]
- expected_info = [DateManager.str_to_date(d) for d in expected_info]
- assert r == expected_run
- assert i == expected_info
-def test_select_dates_all_done(spark, mocker):
- mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
+def test_runner_call_delegates_to_engine_run():
+ spark = MagicMock()
+ injected_services = MagicMock()
+ engine = MagicMock()
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- run_date="2023-03-02",
- overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 0},
- )
- runner.reader = MockReader(spark)
-
- r, i = runner._select_run_dates(
- runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE")
- )
- expected_run = []
- expected_run = [DateManager.str_to_date(d) for d in expected_run]
- expected_info = []
- expected_info = [DateManager.str_to_date(d) for d in expected_info]
- assert r == expected_run
- assert i == expected_info
-
-
-def test_op_selected(spark, mocker):
- mocker.patch("rialto.runner.reporting.tracker.Tracker.report_by_mail")
- run = mocker.patch("rialto.runner.runner.Runner._run_pipeline")
-
- runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="SimpleGroup")
+ with patch("rialto.runner.runner.RunnerEngine", return_value=engine):
+ runner = Runner(
+ spark=spark,
+ config_path="tests/runner/transformations/config.yaml",
+ op="SimpleGroup",
+ services=injected_services,
+ )
runner()
- run.assert_called_once()
-
-
-def test_op_bad(spark, mocker):
- mocker.patch("rialto.runner.reporting.tracker.Tracker.report_by_mail")
- mocker.patch("rialto.runner.runner.Runner._run_pipeline")
-
- runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="BadOp")
+ engine.run.assert_called_once_with("SimpleGroup")
- with pytest.raises(ValueError) as exception:
- runner()
- assert str(exception.value) == "Unknown operation selected: BadOp"
+def test_runner_dry_run_delegates_to_engine_dry_run_execution():
+ spark = MagicMock()
+ injected_services = MagicMock()
+ engine = MagicMock()
-def test_bookkeeping_active(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._run_pipeline")
+ with patch("rialto.runner.runner.RunnerEngine", return_value=engine):
+ runner = Runner(
+ spark=spark,
+ config_path="tests/runner/transformations/config.yaml",
+ op="SimpleGroup",
+ services=injected_services,
+ )
- runner = Runner(spark, config_path="tests/runner/transformations/config.yaml")
- assert runner.config.runner.bookkeeping == "some.test.location"
+ runner.dry_run()
+ engine.dry_run_execution.assert_called_once_with("SimpleGroup")
-def test_bookkeeping_inactive(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._run_pipeline")
+def test_runner_debug_delegates_to_engine_and_returns_dataframe():
+ spark = MagicMock()
+ injected_services = MagicMock()
+ engine = MagicMock()
+ debug_df = MagicMock()
+ engine.debug_first_task.return_value = debug_df
- runner = Runner(spark, config_path="tests/runner/transformations/config2.yaml")
- assert runner.config.runner.bookkeeping is None
+ with patch("rialto.runner.runner.RunnerEngine", return_value=engine):
+ runner = Runner(
+ spark=spark,
+ config_path="tests/runner/transformations/config.yaml",
+ op="SimpleGroup",
+ services=injected_services,
+ )
+ result = runner._debug()
-def test_config_multi_partition(spark, mocker):
- runner = Runner(spark, config_path="tests/runner/transformations/config3.yaml")
- assert runner.config.pipelines[0].target.secondary_partition_columns == ["VERSION", "ENV"]
- assert runner.config.pipelines[0].dependencies[0].filters == {"VERSION": 2, "TYPE": "A"}
- assert runner.config.pipelines[0].target.rerun_filters == {"version": 3, "env": "dev"}
+ engine.debug_first_task.assert_called_once_with("SimpleGroup")
+ assert result is debug_df
diff --git a/tests/runner/test_services.py b/tests/runner/test_services.py
new file mode 100644
index 0000000..5ab63ac
--- /dev/null
+++ b/tests/runner/test_services.py
@@ -0,0 +1,129 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# tests/runner/test_services.py
+from datetime import date
+from unittest.mock import Mock, patch
+
+from rialto.runner.reporting.tracker import Tracker
+from rialto.runner.runner_services import DefaultRunnerServices, RunnerServices
+from rialto.runner.services.data_checker import DataChecker
+from rialto.runner.services.date_manager import DateManager
+from rialto.runner.services.executor import PipelineExecutor
+from rialto.runner.services.task_registry import TaskRegistry
+from rialto.runner.services.task_status_checker import TaskStatusChecker
+from rialto.runner.services.writer import DatabricksWriter
+
+CONFIG_PATH = "tests/runner/resources/config.yaml"
+
+
+# ── RunnerServices dataclass ──────────────────────────────────────────────────
+
+
+def test_runner_services_stores_all_fields():
+ """RunnerServices holds all injected collaborators as attributes"""
+ config = Mock()
+ date_manager = Mock()
+ writer = Mock()
+ data_checker = Mock()
+ task_checker = Mock()
+ registry = Mock()
+ executor = Mock()
+ tracker = Mock()
+
+ services = RunnerServices(
+ config=config,
+ date_manager=date_manager,
+ writer=writer,
+ data_checker=data_checker,
+ task_checker=task_checker,
+ registry=registry,
+ executor=executor,
+ tracker=tracker,
+ )
+
+ assert services.config is config
+ assert services.date_manager is date_manager
+ assert services.writer is writer
+ assert services.data_checker is data_checker
+ assert services.task_checker is task_checker
+ assert services.registry is registry
+ assert services.executor is executor
+ assert services.tracker is tracker
+
+
+# ── DefaultRunnerServices.build ───────────────────────────────────────────────
+
+
+def test_build_returns_runner_services_instance(spark):
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ assert isinstance(services, RunnerServices)
+
+
+def test_build_creates_correct_types(spark):
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ assert isinstance(services.date_manager, DateManager)
+ assert isinstance(services.writer, DatabricksWriter)
+ assert isinstance(services.data_checker, DataChecker)
+ assert isinstance(services.task_checker, TaskStatusChecker)
+ assert isinstance(services.registry, TaskRegistry)
+ assert isinstance(services.executor, PipelineExecutor)
+ assert isinstance(services.tracker, Tracker)
+
+
+def test_build_passes_merge_schema_to_writer(spark):
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH, merge_schema=True)
+ assert services.writer.merge_schema is True
+
+
+def test_build_merge_schema_default_is_false(spark):
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ assert services.writer.merge_schema is False
+
+
+def test_build_shares_data_checker_between_task_checker_and_executor(spark):
+ """data_checker should be the same instance in task_checker and executor"""
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ assert services.executor.checker is services.data_checker
+
+
+def test_build_config_loads_pipelines(spark):
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ pipeline_names = [p.name for p in services.config.pipelines]
+ assert "SimpleGroup" in pipeline_names
+
+
+def test_build_with_run_date_passed_to_date_manager(spark):
+ """run_date is passed to DateManager — verify it doesn't raise and config is loaded"""
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH, run_date="2023-03-31")
+ assert isinstance(services.date_manager, DateManager)
+ assert services.date_manager.date_until == date(2023, 3, 31)
+
+
+def test_build_tracker_has_bookkeeper_when_config_has_bookkeeping(spark):
+ """config.yaml sets bookkeeping: some.test.location — tracker should have a bookkeeper"""
+ services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH)
+ assert services.tracker.bookkeeper is not None
+
+
+def test_build_calls_config_loader_with_overrides(spark):
+ overrides = {"runner": {"watched_period_units": "days", "watched_period_value": 7}}
+ with patch("rialto.runner.runner_services.ConfigLoader") as config_loader_cls, patch(
+ "rialto.runner.runner_services.DateManager", Mock()
+ ):
+ DefaultRunnerServices.build(
+ spark=spark,
+ config_path=CONFIG_PATH,
+ overrides=overrides,
+ )
+ config_loader_cls.load_yaml.assert_called_once_with(CONFIG_PATH, overrides)
diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py
deleted file mode 100644
index f1e4ead..0000000
--- a/tests/runner/test_table.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from rialto.runner.table import Table
-
-
-def test_table_basic_init():
- t = Table(catalog="cat", schema="sch", table="tab", schema_path=None, table_path=None, class_name=None)
-
- assert t.get_table_path() == "cat.sch.tab"
- assert t.get_schema_path() == "cat.sch"
-
-
-def test_table_classname_init():
- t = Table(catalog=None, schema=None, table=None, schema_path="cat.sch", table_path=None, class_name="ClaSs")
-
- assert t.get_table_path() == "cat.sch.cla_ss"
- assert t.get_schema_path() == "cat.sch"
- assert t.catalog == "cat"
- assert t.schema == "sch"
- assert t.table == "cla_ss"
-
-
-def test_table_path_init():
- t = Table(catalog=None, schema=None, table=None, schema_path=None, table_path="cat.sch.tab", class_name=None)
-
- assert t.get_table_path() == "cat.sch.tab"
- assert t.get_schema_path() == "cat.sch"
- assert t.catalog == "cat"
- assert t.schema == "sch"
- assert t.table == "tab"
-
-
-def test_table_secondary_partitions():
- t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"])
-
- assert t.get_all_partitions() == ["part", "sec1", "sec2"]
-
-
-def test_table_get_partitions_only_main():
- t = Table(catalog="cat", schema="sch", table="tab", partition="part")
-
- assert t.get_all_partitions() == ["part"]
-
-
-def test_table_prioritize_table_name():
- t = Table(catalog=None, schema=None, table="custom", schema_path="cat.sch", table_path=None, class_name="ClaSs")
-
- assert t.get_table_path() == "cat.sch.custom"
- assert t.get_schema_path() == "cat.sch"
- assert t.catalog == "cat"
- assert t.schema == "sch"
- assert t.table == "custom"
diff --git a/tests/runner/test_utils.py b/tests/runner/test_utils.py
new file mode 100644
index 0000000..4cfd626
--- /dev/null
+++ b/tests/runner/test_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2022-2026 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock
+
+from rialto.runner.utils import find_dependency
+
+
+def _config_with_dependencies(dep_names):
+ config = MagicMock()
+ deps = []
+ for n in dep_names:
+ d = MagicMock()
+ d.name = n
+ deps.append(d)
+ config.dependencies = deps
+ return config
+
+
+def test_find_dependency_returns_matching_dependency():
+ config = _config_with_dependencies(["dep_a", "dep_b", "dep_c"])
+
+ result = find_dependency(config, "dep_b")
+
+ assert result is config.dependencies[1]
+
+
+def test_find_dependency_returns_none_when_not_found():
+ config = _config_with_dependencies(["dep_a", "dep_b"])
+
+ result = find_dependency(config, "missing_dep")
+
+ assert result is None
diff --git a/tests/runner/test_writer.py b/tests/runner/test_writer.py
deleted file mode 100644
index 5ec9ec6..0000000
--- a/tests/runner/test_writer.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2022 ABSA Group Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from datetime import date
-
-import pytest
-
-from rialto.runner.writer import Writer
-
-
-@pytest.fixture
-def sample_multi_partition(spark):
- df = spark.createDataFrame(
- [
- ("REGION_A", 3, date(2023, 1, 1), 100),
- ("REGION_A", 3, date(2023, 1, 1), 300),
- ],
- schema="region string, version int, info_date date, value int",
- )
- return df
-
-
-@pytest.fixture
-def sample_multi_partition_non_unique(spark):
- df = spark.createDataFrame(
- [
- ("REGION_A", 1, date(2023, 1, 1), 100),
- ("REGION_A", 2, date(2023, 1, 1), 300),
- ],
- schema="region string, version int, info_date date, value int",
- )
- return df
-
-
-def test_replace_condition(sample_multi_partition):
- writer = Writer(spark=None)
- condition = writer._get_replace_condition(sample_multi_partition, partition_cols=["region", "version", "info_date"])
- expected_condition = "region = 'REGION_A' AND version = 3 AND info_date = '2023-01-01'"
- assert condition == expected_condition
-
-
-def test_replace_condition_non_unique(sample_multi_partition_non_unique):
- writer = Writer(spark=None)
- with pytest.raises(ValueError):
- writer._get_replace_condition(
- sample_multi_partition_non_unique, partition_cols=["region", "version", "info_date"]
- )
diff --git a/uv.lock b/uv.lock
index d693ef0..d2512f2 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1078,7 +1078,7 @@ wheels = [
[[package]]
name = "rialto"
-version = "2.2.0"
+version = "2.2.1"
source = { editable = "." }
dependencies = [
{ name = "delta-spark" },