diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..7a731e4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,19 @@ +[report] +exclude_lines = + # Skip any pass lines such as may be used for @abstractmethod + pass + + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a3c927..e34e1fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.2.1 - 2026-05-31 + ### Runner + - added option to use day="last" for monthly scheduling ## 2.2.0 - 2026-03 ### Runner diff --git a/README.md b/README.md index d78f887..0498fb3 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,12 @@ Runner uses a schedule-based approach: #### Data Flow For each pipeline execution: -1. **Dependency verification**: Check that all required input tables have data within specified time intervals -2. **Transformation execution**: Run your transformation to produce a Spark DataFrame -3. **Automatic enrichment**: Runner adds `INFORMATION_DATE` (the run date) and `VERSION` (package version) columns -4. **Partitioned write**: Data is written to Databricks with partitioning configuration -5. **Reporting**: Optional email notifications on failure and run information stored to tracking table +1. **Previous completion check**: Runner checks if the target table partition for the run date already exists (unless `rerun` is enabled) +2. **Dependency verification**: Check that all required input tables have data within specified time intervals +3. **Transformation execution**: Run your transformation to produce a Spark DataFrame +4. **Automatic enrichment**: Runner adds `INFORMATION_DATE` (the run date) and `VERSION` (package version) columns +5. **Partitioned write**: Data is written to Databricks with partitioning configuration +6. **Reporting**: Optional email notifications on failure and run information stored to tracking table #### Dependency Tracking @@ -66,13 +67,13 @@ Runner's dependency tracking ensures that all required input data is available b * **Missing data handling**: If required data is missing, Runner raises an error for that specific pipeline/date but continues executing other pipelines and dates in the queue **Example:** If you're running a pipeline on 2024-01-15 with a dependency that has a 7-day interval: -* Runner checks if the dependency table has data for 2024-01-08 (15 days - 7 days) +* Runner checks if the dependency table has data between 2024-01-08 and 2024-01-15 (15 days - 7 days) * If the dependency has `filters: {VERSION: "v2"}`, it specifically checks for data where VERSION='v2' * If data exists, the pipeline proceeds; otherwise, an error is raised for this specific execution, but other scheduled runs continue -### Transformation +### Job/Transformation For the details on the interface see the [implementation](rialto/runner/transformation.py) -Inside the transformation you have access to a [TableReader](#common), date of running, and if provided to Runner, a live spark session and [metadata manager](#metadata). +Inside the job you have access to a [TableReader](#common), date of running, and if provided to Runner, a live spark session and [metadata manager](#metadata). You can either implement your jobs directly via extending the Transformation class, or by using the [jobs](#jobs) abstraction. ### Runner @@ -163,7 +164,7 @@ pipelines: # a list of pipelines to run python_class: Pipeline2Class schedule: frequency: monthly - day: 6 + day: 6 # or 'latest' for the last day of the month, otherwise avoid using days higher than 28 to ensure all months are covered info_date_shift: # can be combined as a list - units: "days" value: 5 @@ -434,7 +435,7 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo from pyspark.sql import DataFrame from rialto.common import TableReader from rialto.jobs import config_parser, job, datasource -from rialto.runner.config_loader import PipelineConfig +from rialto.runner.services.config_loader import PipelineConfig from pydantic import BaseModel diff --git a/docs/source/conf.py b/docs/source/conf.py index b7592da..e4273bf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,9 +26,9 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "rialto" -copyright = "2022, Marek Dobransky" +copyright = "2022-2026, Marek Dobransky" author = "Marek Dobransky" -release = "2.2.0" +release = "2.2.1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/modules.rst b/docs/source/modules.rst new file mode 100644 index 0000000..b5c7581 --- /dev/null +++ b/docs/source/modules.rst @@ -0,0 +1,7 @@ +rialto +====== + +.. toctree:: + :maxdepth: 4 + + rialto diff --git a/docs/source/rialto.common.rst b/docs/source/rialto.common.rst index 57e7f9e..6ef2b4c 100644 --- a/docs/source/rialto.common.rst +++ b/docs/source/rialto.common.rst @@ -5,33 +5,33 @@ Submodules ---------- rialto.common.env\_yaml module ----------------------------------- +------------------------------ .. automodule:: rialto.common.env_yaml :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.common.table\_reader module ---------------------------------- .. automodule:: rialto.common.table_reader :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.common.utils module -------------------------- .. automodule:: rialto.common.utils :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.common :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.jobs.rst b/docs/source/rialto.jobs.rst index 7af576a..924cb39 100644 --- a/docs/source/rialto.jobs.rst +++ b/docs/source/rialto.jobs.rst @@ -5,50 +5,49 @@ Submodules ---------- rialto.jobs.decorators module ----------------------------------------- +----------------------------- .. automodule:: rialto.jobs.decorators :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.jobs.job\_base module ---------------------------------------- +---------------------------- .. automodule:: rialto.jobs.job_base :members: - :undoc-members: :show-inheritance: + :undoc-members: -rialto.jobs.module\register module ---------------------------------------- +rialto.jobs.module\_register module +----------------------------------- .. automodule:: rialto.jobs.module_register :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.jobs.resolver module --------------------------------------- +--------------------------- .. automodule:: rialto.jobs.resolver :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.jobs.test\_utils module ------------------------------------------ +------------------------------ .. automodule:: rialto.jobs.test_utils :members: - :undoc-members: :show-inheritance: - + :undoc-members: Module contents --------------- .. automodule:: rialto.jobs :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.loader.rst b/docs/source/rialto.loader.rst index 9f258eb..b18b709 100644 --- a/docs/source/rialto.loader.rst +++ b/docs/source/rialto.loader.rst @@ -9,30 +9,29 @@ rialto.loader.config\_loader module .. automodule:: rialto.loader.config_loader :members: - :undoc-members: :show-inheritance: - + :undoc-members: rialto.loader.interfaces module ------------------------------- .. automodule:: rialto.loader.interfaces :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.loader.pyspark\_feature\_loader module --------------------------------------------- .. automodule:: rialto.loader.pyspark_feature_loader :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.loader :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.maker.rst b/docs/source/rialto.maker.rst index 88a3de7..d3352c3 100644 --- a/docs/source/rialto.maker.rst +++ b/docs/source/rialto.maker.rst @@ -9,37 +9,37 @@ rialto.maker.containers module .. automodule:: rialto.maker.containers :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.maker.feature\_maker module ---------------------------------- .. automodule:: rialto.maker.feature_maker :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.maker.utils module ------------------------- .. automodule:: rialto.maker.utils :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.maker.wrappers module ---------------------------- .. automodule:: rialto.maker.wrappers :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.maker :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.metadata.data_classes.rst b/docs/source/rialto.metadata.data_classes.rst index 44353ff..085ae2a 100644 --- a/docs/source/rialto.metadata.data_classes.rst +++ b/docs/source/rialto.metadata.data_classes.rst @@ -9,21 +9,21 @@ rialto.metadata.data\_classes.feature\_metadata module .. automodule:: rialto.metadata.data_classes.feature_metadata :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.metadata.data\_classes.group\_metadata module ---------------------------------------------------- .. automodule:: rialto.metadata.data_classes.group_metadata :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.metadata.data_classes :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.metadata.rst b/docs/source/rialto.metadata.rst index 08933bf..076e49d 100644 --- a/docs/source/rialto.metadata.rst +++ b/docs/source/rialto.metadata.rst @@ -17,29 +17,29 @@ rialto.metadata.enums module .. automodule:: rialto.metadata.enums :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.metadata.metadata\_manager module ---------------------------------------- .. automodule:: rialto.metadata.metadata_manager :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.metadata.utils module ---------------------------- .. automodule:: rialto.metadata.utils :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.metadata :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.rst b/docs/source/rialto.rst new file mode 100644 index 0000000..f55ad4c --- /dev/null +++ b/docs/source/rialto.rst @@ -0,0 +1,23 @@ +rialto package +============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + rialto.common + rialto.jobs + rialto.loader + rialto.maker + rialto.metadata + rialto.runner + +Module contents +--------------- + +.. automodule:: rialto + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.runner.reporting.rst b/docs/source/rialto.runner.reporting.rst index 7ef8db5..543934f 100644 --- a/docs/source/rialto.runner.reporting.rst +++ b/docs/source/rialto.runner.reporting.rst @@ -1,5 +1,5 @@ rialto.runner.reporting package -===================================== +=============================== Submodules ---------- @@ -9,37 +9,37 @@ rialto.runner.reporting.bookkeeper module .. automodule:: rialto.runner.reporting.bookkeeper :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.runner.reporting.mailer module ------------------------------------- .. automodule:: rialto.runner.reporting.mailer :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.runner.reporting.record module ------------------------------------- .. automodule:: rialto.runner.reporting.record :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.runner.reporting.tracker module -------------------------------------- .. automodule:: rialto.runner.reporting.tracker :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.runner.reporting :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.runner.rst b/docs/source/rialto.runner.rst index c4ec20a..2666ec6 100644 --- a/docs/source/rialto.runner.rst +++ b/docs/source/rialto.runner.rst @@ -8,70 +8,55 @@ Subpackages :maxdepth: 4 rialto.runner.reporting + rialto.runner.services Submodules ---------- -rialto.runner.config\_loader module ------------------------------------ +rialto.runner.engine module +--------------------------- -.. automodule:: rialto.runner.config_loader +.. automodule:: rialto.runner.engine :members: - :undoc-members: :show-inheritance: - -rialto.runner.config\_overrides module --------------------------------------- - -.. automodule:: rialto.runner.config_overrides - :members: :undoc-members: - :show-inheritance: - -rialto.runner.date\_manager module ----------------------------------- - -.. automodule:: rialto.runner.date_manager - :members: - :undoc-members: - :show-inheritance: rialto.runner.runner module --------------------------- .. automodule:: rialto.runner.runner :members: - :undoc-members: :show-inheritance: + :undoc-members: -rialto.runner.table module --------------------------- +rialto.runner.runner\_services module +------------------------------------- -.. automodule:: rialto.runner.table +.. automodule:: rialto.runner.runner_services :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.runner.transformation module ----------------------------------- .. automodule:: rialto.runner.transformation :members: - :undoc-members: :show-inheritance: + :undoc-members: rialto.runner.utils module ------------------------------------ +-------------------------- .. automodule:: rialto.runner.utils :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: rialto.runner :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/source/rialto.runner.services.rst b/docs/source/rialto.runner.services.rst new file mode 100644 index 0000000..158db5f --- /dev/null +++ b/docs/source/rialto.runner.services.rst @@ -0,0 +1,93 @@ +rialto.runner.services package +============================== + +Submodules +---------- + +rialto.runner.services.config\_loader module +-------------------------------------------- + +.. automodule:: rialto.runner.services.config_loader + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.config\_overrides module +----------------------------------------------- + +.. automodule:: rialto.runner.services.config_overrides + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.data\_checker module +------------------------------------------- + +.. automodule:: rialto.runner.services.data_checker + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.date\_manager module +------------------------------------------- + +.. automodule:: rialto.runner.services.date_manager + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.executor module +-------------------------------------- + +.. automodule:: rialto.runner.services.executor + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.result\_mapper module +-------------------------------------------- + +.. automodule:: rialto.runner.services.result_mapper + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.table module +----------------------------------- + +.. automodule:: rialto.runner.services.table + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.task\_registry module +-------------------------------------------- + +.. automodule:: rialto.runner.services.task_registry + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.task\_status\_checker module +--------------------------------------------------- + +.. automodule:: rialto.runner.services.task_status_checker + :members: + :show-inheritance: + :undoc-members: + +rialto.runner.services.writer module +------------------------------------ + +.. automodule:: rialto.runner.services.writer + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: rialto.runner.services + :members: + :show-inheritance: + :undoc-members: diff --git a/pyproject.toml b/pyproject.toml index 71e9217..9d4404a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rialto" -version = "2.2.0" +version = "2.2.1" description = "Rialto is a framework for building and deploying machine learning features in a scalable and reusable way. It provides a set of tools that make it easy to define and deploy features and models, and it provides a way to orchestrate the execution of these features and models." authors = [ { name = "Marek Dobransky", email = "marekdobr@gmail.com" }, diff --git a/rialto/__init__.py b/rialto/__init__.py index 79c3773..94ab807 100644 --- a/rialto/__init__.py +++ b/rialto/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/common/__init__.py b/rialto/common/__init__.py index 1bd5055..cf821c6 100644 --- a/rialto/common/__init__.py +++ b/rialto/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 228d59b..6165e68 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -71,6 +71,16 @@ def get_table( """ raise NotImplementedError + @abc.abstractmethod + def table_exists(self, table: str) -> bool: + """ + Check table exists in storage + + :param table: full table path + :return: bool + """ + raise NotImplementedError + class TableReader(DataReader): """An implementation of data reader for databricks tables""" @@ -165,3 +175,12 @@ def get_table( if uppercase_columns: df = self._uppercase_column_names(df) return df + + def table_exists(self, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return self.spark.catalog.tableExists(table) diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 541e977..7bbd00f 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index 46eb756..d04e954 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/jobs/decorators.py b/rialto/jobs/decorators.py index 8affcbd..b0ff1cb 100644 --- a/rialto/jobs/decorators.py +++ b/rialto/jobs/decorators.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/jobs/job_base.py b/rialto/jobs/job_base.py index eae0854..ccd5b41 100644 --- a/rialto/jobs/job_base.py +++ b/rialto/jobs/job_base.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner import Transformation -from rialto.runner.config_loader import PipelineConfig +from rialto.runner.services.config_loader import PipelineConfig class JobMetadata(BaseModel): diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py index 49021ff..5d32471 100644 --- a/rialto/jobs/module_register.py +++ b/rialto/jobs/module_register.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/jobs/resolver.py b/rialto/jobs/resolver.py index 34b08e8..f9dfe21 100644 --- a/rialto/jobs/resolver.py +++ b/rialto/jobs/resolver.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py index cced2fe..754c1ee 100644 --- a/rialto/jobs/test_utils.py +++ b/rialto/jobs/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/loader/__init__.py b/rialto/loader/__init__.py index 7e1e936..543ec90 100644 --- a/rialto/loader/__init__.py +++ b/rialto/loader/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/loader/config_loader.py b/rialto/loader/config_loader.py index ead2705..06dffea 100644 --- a/rialto/loader/config_loader.py +++ b/rialto/loader/config_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/loader/interfaces.py b/rialto/loader/interfaces.py index 9089f40..6b743a6 100644 --- a/rialto/loader/interfaces.py +++ b/rialto/loader/interfaces.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/loader/pyspark_feature_loader.py b/rialto/loader/pyspark_feature_loader.py index bd5f884..c8542a8 100644 --- a/rialto/loader/pyspark_feature_loader.py +++ b/rialto/loader/pyspark_feature_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/maker/__init__.py b/rialto/maker/__init__.py index d31cd4a..86e6f34 100644 --- a/rialto/maker/__init__.py +++ b/rialto/maker/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/maker/containers.py b/rialto/maker/containers.py index 9a93417..851e627 100644 --- a/rialto/maker/containers.py +++ b/rialto/maker/containers.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/maker/feature_maker.py b/rialto/maker/feature_maker.py index 5fad175..f9201c1 100644 --- a/rialto/maker/feature_maker.py +++ b/rialto/maker/feature_maker.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/maker/utils.py b/rialto/maker/utils.py index 829dc60..89faee7 100644 --- a/rialto/maker/utils.py +++ b/rialto/maker/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/maker/wrappers.py b/rialto/maker/wrappers.py index a7a4103..5e0748e 100644 --- a/rialto/maker/wrappers.py +++ b/rialto/maker/wrappers.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/__init__.py b/rialto/metadata/__init__.py index 5e8893c..950221d 100644 --- a/rialto/metadata/__init__.py +++ b/rialto/metadata/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/data_classes/__init__.py b/rialto/metadata/data_classes/__init__.py index 79c3773..94ab807 100644 --- a/rialto/metadata/data_classes/__init__.py +++ b/rialto/metadata/data_classes/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/data_classes/feature_metadata.py b/rialto/metadata/data_classes/feature_metadata.py index cff0039..939bf1d 100644 --- a/rialto/metadata/data_classes/feature_metadata.py +++ b/rialto/metadata/data_classes/feature_metadata.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/data_classes/group_metadata.py b/rialto/metadata/data_classes/group_metadata.py index 5b1eb9c..0c8136c 100644 --- a/rialto/metadata/data_classes/group_metadata.py +++ b/rialto/metadata/data_classes/group_metadata.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -82,5 +82,5 @@ def from_spark(cls, schema: Row) -> Self: frequency=Schedule[schema.group_frequency], description=schema.group_description, key=schema.group_key, - owner=schema.group_owner + owner=schema.group_owner, ) diff --git a/rialto/metadata/enums.py b/rialto/metadata/enums.py index be9ea5b..e910552 100644 --- a/rialto/metadata/enums.py +++ b/rialto/metadata/enums.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/metadata_manager.py b/rialto/metadata/metadata_manager.py index 90312f3..474bbda 100644 --- a/rialto/metadata/metadata_manager.py +++ b/rialto/metadata/metadata_manager.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/metadata/utils.py b/rialto/metadata/utils.py index 0cb591c..80d2dab 100644 --- a/rialto/metadata/utils.py +++ b/rialto/metadata/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ def class_to_catalog_name(class_name) -> str: """ - Map python class name of feature group (CammelCase) to databricks compatible format (lowercase with underscores) + Map python class name of feature group (CamelCase) to databricks compatible format (lowercase with underscores) :param class_name: Python class name :return: feature storage name diff --git a/rialto/runner/__init__.py b/rialto/runner/__init__.py index 6ae343f..a75d1e1 100644 --- a/rialto/runner/__init__.py +++ b/rialto/runner/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py deleted file mode 100644 index 1bcef7b..0000000 --- a/rialto/runner/date_manager.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["DateManager"] - -from datetime import date, datetime -from typing import List - -from dateutil.relativedelta import relativedelta - -from rialto.runner.config_loader import ScheduleConfig - - -class DateManager: - """Date generation and shifts based on configuration""" - - @staticmethod - def str_to_date(str_date: str) -> date: - """ - Convert YYYY-MM-DD string to date - - :param str_date: string date - :return: date - """ - return datetime.strptime(str_date, "%Y-%m-%d").date() - - @staticmethod - def date_subtract(run_date: date, units: str, value: int) -> date: - """ - Generate starting date from given date and config - - :param run_date: base date - :param units: units: years, months, weeks, days - :param value: number of units to subtract - :return: Starting date - """ - if units == "years": - return run_date - relativedelta(years=value) - if units == "months": - return run_date - relativedelta(months=value) - if units == "weeks": - return run_date - relativedelta(weeks=value) - if units == "days": - return run_date - relativedelta(days=value) - raise ValueError(f"Unknown time unit {units}") - - @staticmethod - def all_dates(date_from: date, date_to: date) -> List[date]: - """ - Get list of all dates between, inclusive - - :param date_from: starting date - :param date_to: ending date - :return: List[date] - """ - if date_to < date_from: - date_to, date_from = date_from, date_to - - return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)] - - @staticmethod - def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[date]: - """ - Select dates inside given interval depending on frequency and selected day - - :param date_from: interval start - :param date_to: interval end - :param schedule: schedule config - :return: list of dates - """ - options = DateManager.all_dates(date_from, date_to) - if schedule.frequency == "daily": - return options - if schedule.frequency == "weekly": - return [x for x in options if x.isoweekday() == schedule.day] - if schedule.frequency == "monthly": - return [x for x in options if x.day == schedule.day] - raise ValueError(f"Unknown frequency {schedule.frequency}") - - @staticmethod - def to_info_date(date: date, schedule: ScheduleConfig) -> date: - """ - Shift given date according to config - - :param date: input date - :param schedule: schedule config - :return: date - """ - if isinstance(schedule.info_date_shift, List): - for shift in schedule.info_date_shift: - date = DateManager.date_subtract(date, units=shift.units, value=shift.value) - else: - date = DateManager.date_subtract( - date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value - ) - return date diff --git a/rialto/runner/engine.py b/rialto/runner/engine.py new file mode 100644 index 0000000..43dcc39 --- /dev/null +++ b/rialto/runner/engine.py @@ -0,0 +1,159 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["RunnerEngine"] + +import traceback +from datetime import datetime +from typing import List + +from loguru import logger +from pyspark.sql import DataFrame + +from rialto.runner.runner_services import RunnerServices +from rialto.runner.services.config_loader import PipelineConfig +from rialto.runner.services.result_mapper import TaskResultMapper +from rialto.runner.services.task_registry import PipelineTask + + +class RunnerEngine: + """Orchestrates pipeline execution lifecycle and task tracking""" + + def __init__(self, services: RunnerServices, rerun: bool, skip_dependencies: bool): + self.services = services + self.rerun = rerun + self.skip_dependencies = skip_dependencies + + def select_pipelines(self, op: str = None) -> List[PipelineConfig]: + """Select pipelines to run based on operation name""" + if op: + selected = [p for p in self.services.config.pipelines if p.name == op] + if not selected: + raise ValueError(f"Unknown operation selected: {op}") + return selected + return self.services.config.pipelines + + def register_tasks(self, pipelines: List[PipelineConfig]) -> None: + """Register tasks for all pipelines and date combinations""" + for pipeline in pipelines: + for exec_date, partition_date in self.services.date_manager.get_execution_and_partition_dates( + pipeline.schedule + ): + self.services.registry.add_task( + name=pipeline.name, + execution_date=exec_date, + partition_date=partition_date, + config=pipeline, + ) + + def check_tasks(self) -> None: + """Check task completion and dependency status""" + for task in self.services.registry.tasks: + if not self.rerun: + try: + self.services.task_checker.check_completion(task) + except Exception as e: + logger.error(f"{task.name} completion check failed for {task.partition_date}:\n\t{e}") + task.precheck_failed = True + task.error = str(e) + task.error_trace = traceback.format_exc() + if not self.skip_dependencies: + try: + self.services.task_checker.check_pipeline_dependencies(task) + except Exception as e: + logger.error(f"{task.name} dependency check failed for {task.partition_date}:\n\t{e}") + task.precheck_failed = True + task.error = str(e) + task.error_trace = traceback.format_exc() + + def log_task_status(self) -> None: + """Log summary of task statuses""" + self.services.registry.log_status() + + def run_tasks(self) -> None: + """Execute runnable tasks with per-task error isolation""" + for task in self.services.registry.tasks: + logger.info(f"Executing task {task.name} for partition date {task.partition_date}") + self._execute_task_with_tracking(task) + + def _execute_task_with_tracking(self, task: PipelineTask) -> None: + """Execute single task with record tracking""" + run_start = datetime.now() + + if task.precheck_failed: + self.services.tracker.add(TaskResultMapper.exception(task, run_start, task.error, task.error_trace)) + return + + # Skip already-complete tasks + if task.completion and not self.rerun: + logger.info(f"Skipping task {task.name} for partition {task.partition_date} - already complete") + self.services.tracker.add(TaskResultMapper.already_complete(task, run_start)) + return + + # Skip if dependencies not met + incomplete_deps = [ + f"{dep.table.get_table_path()} from {dep.date_from} until {dep.date_until}" + for dep in task.dependencies + if not dep.complete + ] + if incomplete_deps and not self.skip_dependencies: + logger.info( + f"Incomplete dependencies for task {task.name} for " + f"partition {task.partition_date} - {', '.join(incomplete_deps)}" + ) + self.services.tracker.add(TaskResultMapper.dependencies_incomplete(task, run_start, incomplete_deps)) + return + + # Execute task + try: + df = self.services.executor.execute(task) + self.services.writer.write(df, task.partition_date, task.target) + records = self.services.data_checker.check_written(task.target, task.partition_date, df) + logger.info( + f"Task {task.name} for partition {task.partition_date} completed successfully with {records} records" + ) + self.services.tracker.add(TaskResultMapper.success(task, run_start, records)) + except KeyboardInterrupt: + self.services.tracker.add(TaskResultMapper.interrupted(task, run_start)) + raise + except Exception as e: + logger.exception(f"Task {task.name} failed for partition {task.partition_date}") + self.services.tracker.add(TaskResultMapper.exception(task, run_start, str(e), traceback.format_exc())) + + def finalize(self) -> None: + """Send final reports via mail/bookkeeping""" + self.services.tracker.report_by_mail() + self.log_task_status() + + def run(self, op: str = None) -> None: + """Execute all tasks""" + pipelines = self.select_pipelines(op) + self.register_tasks(pipelines) + self.check_tasks() + self.log_task_status() + self.run_tasks() + self.finalize() + + def dry_run_execution(self, op: str = None) -> None: + """Execute pre-run checks without task execution""" + pipelines = self.select_pipelines(op) + self.register_tasks(pipelines) + self.check_tasks() + self.log_task_status() + + def debug_first_task(self, op: str = None) -> DataFrame: + """Debug mode: execute first task and return result""" + pipelines = self.select_pipelines(op) + self.register_tasks(pipelines) + return self.services.executor.execute(self.services.registry.tasks[0]) diff --git a/rialto/runner/reporting/bookkeeper.py b/rialto/runner/reporting/bookkeeper.py index f00d0bc..e06cbff 100644 --- a/rialto/runner/reporting/bookkeeper.py +++ b/rialto/runner/reporting/bookkeeper.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/runner/reporting/mailer.py b/rialto/runner/reporting/mailer.py index 1485a57..46c8bf3 100644 --- a/rialto/runner/reporting/mailer.py +++ b/rialto/runner/reporting/mailer.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ class HTMLMessage: @staticmethod def _get_status_color(status: str): - if status == "Success": + if status == "Success" or status == "Skipped": return "#398f00" elif status == "Error": return "#ff0000" @@ -128,10 +128,6 @@ def _head(): - """ @staticmethod @@ -164,14 +160,10 @@ def _make_exceptions(records: List[Record]): - Expand -
- - - - -
{record.exception}
-
+ + + + """ html += r return html diff --git a/rialto/runner/reporting/record.py b/rialto/runner/reporting/record.py index 7099b2e..bd6ab4d 100644 --- a/rialto/runner/reporting/record.py +++ b/rialto/runner/reporting/record.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/runner/reporting/tracker.py b/rialto/runner/reporting/tracker.py index 0f926cc..8adc796 100644 --- a/rialto/runner/reporting/tracker.py +++ b/rialto/runner/reporting/tracker.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,10 +18,10 @@ from pyspark.sql import SparkSession -from rialto.runner.config_loader import MailConfig from rialto.runner.reporting.bookkeeper import BookKeeper from rialto.runner.reporting.mailer import HTMLMessage, Mailer from rialto.runner.reporting.record import Record +from rialto.runner.services.config_loader import MailConfig class Tracker: @@ -29,9 +29,7 @@ class Tracker: def __init__(self, mail_cfg: MailConfig, bookkeeping: str = None, spark: SparkSession = None): self.records = [] - self.last_error = None self.pipeline_start = datetime.now() - self.exceptions = [] self.mail_cfg = mail_cfg self.bookkeeper = None diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 384998f..dec6045 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,26 +14,16 @@ __all__ = ["Runner"] -import datetime -from datetime import date -from typing import Dict, List, Tuple +from typing import Dict -from loguru import logger from pyspark.sql import DataFrame, SparkSession -import rialto.runner.utils as utils -from rialto.common import TableReader -from rialto.runner.config_loader import PipelineConfig, get_pipelines_config -from rialto.runner.date_manager import DateManager -from rialto.runner.reporting.record import Record -from rialto.runner.reporting.tracker import Tracker -from rialto.runner.table import Table -from rialto.runner.transformation import Transformation -from rialto.runner.writer import Writer +from rialto.runner.engine import RunnerEngine +from rialto.runner.runner_services import DefaultRunnerServices, RunnerServices class Runner: - """A scheduler and dependency checker for feature runs""" + """Entry point for pipeline execution orchestration (beginner-friendly API)""" def __init__( self, @@ -45,312 +35,43 @@ def __init__( skip_dependencies: bool = False, overrides: Dict = None, merge_schema: bool = False, + services: RunnerServices = None, ): - self.spark = spark - self.config = get_pipelines_config(config_path, overrides) - self.reader = TableReader(spark) - self.rerun = rerun - self.skip_dependencies = skip_dependencies - self.op = op - self.writer = Writer(spark, merge_schema=merge_schema) - self.tracker = Tracker( - mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark - ) - - if run_date: - run_date = DateManager.str_to_date(run_date) - else: - run_date = date.today() - - self.date_from = DateManager.date_subtract( - run_date=run_date, - units=self.config.runner.watched_period_units, - value=self.config.runner.watched_period_value, - ) - - self.date_until = run_date - - if self.date_from > self.date_until: - raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") - logger.info(f"Running period set to: {self.date_from} - {self.date_until}") - - def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: - """ - Run the job - - :param instance: Instance of Transformation - :param run_date: date to run for - :param pipeline: pipeline configuration - :return: Dataframe """ - metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) - - df = instance.run( - spark=self.spark, + Initialize Runner for pipeline orchestration. + + :param spark: SparkSession instance + :param config_path: Path to pipeline configuration YAML + :param run_date: Override run date (optional) + :param rerun: Force re-execution of completed tasks + :param op: Target specific pipeline by name (optional) + :param skip_dependencies: Skip dependency validation + :param overrides: Configuration overrides + :param merge_schema: Enable schema merging in writer + :param services: Custom RunnerServices bundle (optional, for advanced users) + """ + self._services = services or DefaultRunnerServices.build( + spark=spark, + config_path=config_path, run_date=run_date, - config=pipeline, - reader=self.reader, - metadata_manager=metadata_manager, - feature_loader=feature_loader, + merge_schema=merge_schema, + overrides=overrides, ) - - return df - - def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: - """ - Check if there are records written for given date - - :param info_date: date to check - :param table: target table object - :return: number of records - """ - filters = {} - if pipeline.target.rerun_filters is not None: - filters = pipeline.target.rerun_filters - else: - if table.secondary_partitions: - row = df.select(*table.secondary_partitions).distinct().collect()[0] - for c in table.secondary_partitions: - val = row[0][c] - filters[c] = val - - df = self.reader.get_table( - table.get_table_path(), date_column=table.partition, date_from=info_date, date_to=info_date, filters=filters + self._engine = RunnerEngine( + services=self._services, + rerun=rerun, + skip_dependencies=skip_dependencies, ) - - return df.count() - - def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: - """ - For given list of dates, check if there is a matching partition for each - - :param table: Table object - :param dates: list of dates to check - :return: list of bool - """ - if utils.table_exists(self.spark, table.get_table_path()): - checks = [] - for check_date in dates: - df = self.reader.get_table( - table.get_table_path(), - date_column=table.partition, - date_from=check_date, - date_to=check_date, - filters=target_filters, - ) - data_exists = df.count() > 0 - if data_exists and target_filters is None and table.secondary_partitions is not None: - # ensure rerun if the write consideres secondary partitions but the filter doesn't - data_exists = False - checks.append(data_exists) - return checks - else: - logger.info(f"Table {table.get_table_path()} doesn't exist!") - return [False for _ in dates] - - def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: - """ - Check for all dependencies in config if they have available partitions - - :param pipeline: configuration - :param run_date: run date - :return: bool - """ - logger.info(f"{pipeline.name} checking dependencies for {run_date}") - - error = "" - - for dependency in pipeline.dependencies: - dep_from = DateManager.date_subtract(run_date, dependency.interval.units, dependency.interval.value) - logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}") - - possible_dep_dates = DateManager.all_dates(dep_from, run_date) - - logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") - - source = Table(table_path=dependency.table, partition=dependency.date_col) - if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): - logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") - else: - msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" - logger.info(msg) - error = error + msg + "\n" - - if error != "": - self.tracker.last_error = error - return False - - return True - - def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: - """ - Check if model has run for given dates - - :param target_path: Table object - :param info_dates: list of dates - :return: bool list - """ - if self.rerun: - return [False for _ in info_dates] - else: - return self.check_dates_have_data(target, info_dates, filters) - - def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: - """ - Select run dates and info dates based on completion - - :param pipeline: pipeline config - :param table: table path - :return: list of run dates and list of info dates - """ - possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) - possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] - current_state = self._get_completion(table, possible_info_dates, filters) - - selection = [ - (run, info) for run, info, state in zip(possible_run_dates, possible_info_dates, current_state) if not state - ] - - if not len(selection): - logger.info(f"{pipeline.name} has no dates to run") - return [], [] - - selected_run_dates, selected_info_dates = zip(*selection) - logger.info(f"{pipeline.name} identified to run for {selected_run_dates}") - - return list(selected_run_dates), list(selected_info_dates) - - def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int: - """ - Run one pipeline for one date - - :param pipeline: pipeline cfg - :param run_date: run date - :param info_date: information date - :param target: target Table - :return: success bool - """ - if self.skip_dependencies or self.check_dependencies(pipeline, run_date): - logger.info(f"Running {pipeline.name} for {run_date}") - - feature_group = utils.load_module(pipeline.module) - df = self._execute(feature_group, run_date, pipeline) - self.writer.write(df, info_date, target) - records = self._check_written(info_date, target, df, pipeline) - logger.info(f"Generated {records} records") - if records == 0: - raise RuntimeError("No records generated") - else: - return records - return 0 - - def _run_pipeline(self, pipeline: PipelineConfig): - """ - Run single pipeline for all required dates - - :param pipeline: pipeline cfg - :return: success bool - """ - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - logger.info(f"Loaded pipeline {pipeline.name}") - - selected_run_dates, selected_info_dates = self._select_run_dates( - pipeline, target, pipeline.target.rerun_filters - ) - - # ----------- Checking dependencies available ---------- - for run_date, info_date in zip(selected_run_dates, selected_info_dates): - run_start = datetime.datetime.now() - try: - records = self._run_one_date(pipeline, run_date, info_date, target) - if records > 0: - status = "Success" - message = "" - else: - status = "Failure" - message = self.tracker.last_error - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=records, - status=status, - reason=message, - ) - ) - except Exception as error: - logger.error(f"An exception occurred in pipeline {pipeline.name}") - logger.exception(error) - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Exception", - exception=str(error), - ) - ) - except KeyboardInterrupt: - logger.error(f"Pipeline {pipeline.name} interrupted") - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Interrupted by user", - ) - ) - raise KeyboardInterrupt + self.op = op def __call__(self): """Execute pipelines""" - logger.info("Executing pipelines") - try: - if self.op: - selected = [p for p in self.config.pipelines if p.name == self.op] - if len(selected) < 1: - raise ValueError(f"Unknown operation selected: {self.op}") - self._run_pipeline(selected[0]) - else: - for pipeline in self.config.pipelines: - self._run_pipeline(pipeline) - finally: - print(self.tracker.records) - self.tracker.report_by_mail() - logger.info("Execution finished") + self._engine.run(self.op) - def debug(self) -> DataFrame: - """Debug mode - run only first op for one date and return the resulting dataframe""" - logger.info("Running in debug mode") - if self.op: - pipeline = [p for p in self.config.pipelines if p.name == self.op][0] - else: - pipeline = self.config.pipelines[0] + def dry_run(self): + """Dry run - log status of pipelines without executing""" + self._engine.dry_run_execution(self.op) - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) - if len(selected_run_dates) > 0: - df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) - return self.writer._process(df, selected_info_dates[0], target) - else: - logger.info("No dates to run in debug mode") + def _debug(self) -> DataFrame: + """Debug mode - run only first op for one date and return the resulting dataframe""" + return self._engine.debug_first_task(self.op) diff --git a/rialto/runner/runner_services.py b/rialto/runner/runner_services.py new file mode 100644 index 0000000..9eea552 --- /dev/null +++ b/rialto/runner/runner_services.py @@ -0,0 +1,91 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["RunnerServices", "DefaultRunnerServices"] + +from dataclasses import dataclass + +from pyspark.sql import SparkSession + +from rialto.common import TableReader +from rialto.runner.reporting.tracker import Tracker +from rialto.runner.services.config_loader import ConfigLoader, PipelinesConfig +from rialto.runner.services.data_checker import DataChecker +from rialto.runner.services.date_manager import DateManager +from rialto.runner.services.executor import PipelineExecutor +from rialto.runner.services.task_registry import TaskRegistry +from rialto.runner.services.task_status_checker import TaskStatusChecker +from rialto.runner.services.writer import DatabricksWriter + + +@dataclass +class RunnerServices: + """Bundle of collaborator services for Runner orchestration""" + + config: PipelinesConfig + date_manager: DateManager + writer: DatabricksWriter + data_checker: DataChecker + task_checker: TaskStatusChecker + registry: TaskRegistry + executor: PipelineExecutor + tracker: Tracker + + +class DefaultRunnerServices: + """Factory for default Runner services composition""" + + @staticmethod + def build( + spark: SparkSession, + config_path: str, + run_date: str = None, + merge_schema: bool = False, + overrides: dict = None, + ) -> RunnerServices: + """ + Build default services for Runner. + + :param spark: SparkSession instance + :param config_path: Path to pipeline configuration YAML + :param run_date: Override run date (optional) + :param merge_schema: Enable schema merging in writer + :param overrides: Configuration overrides + :return: RunnerServices bundle + """ + config = ConfigLoader.load_yaml(config_path, overrides) + date_manager = DateManager(config.runner, run_date) + writer = DatabricksWriter(spark, merge_schema=merge_schema) + + reader = TableReader(spark) + data_checker = DataChecker(reader) + task_checker = TaskStatusChecker(data_checker) + registry = TaskRegistry(spark, date_manager=date_manager) + executor = PipelineExecutor(spark=spark, reader=reader, checker=data_checker) + tracker = Tracker( + mail_cfg=config.runner.mail, + bookkeeping=config.runner.bookkeeping, + spark=spark, + ) + + return RunnerServices( + config=config, + date_manager=date_manager, + writer=writer, + data_checker=data_checker, + task_checker=task_checker, + registry=registry, + executor=executor, + tracker=tracker, + ) diff --git a/rialto/runner/services/__init__.py b/rialto/runner/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rialto/runner/config_loader.py b/rialto/runner/services/config_loader.py similarity index 72% rename from rialto/runner/config_loader.py rename to rialto/runner/services/config_loader.py index 7978ac5..64801fb 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/services/config_loader.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +13,15 @@ # limitations under the License. __all__ = [ - "get_pipelines_config", + "ConfigLoader", ] -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from rialto.common.utils import load_yaml -from rialto.runner.config_overrides import override_config +from rialto.runner.services.config_overrides import override_config class BaseConfig(BaseModel): @@ -35,8 +35,10 @@ class IntervalConfig(BaseConfig): class ScheduleConfig(BaseConfig): frequency: str - day: Optional[int] = 0 - info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0) + day: Optional[Union[int, str]] = 0 + info_date_shift: Optional[Union[IntervalConfig, List[IntervalConfig]]] = Field( + default_factory=lambda: IntervalConfig(units="days", value=0) + ) class DependencyConfig(BaseConfig): @@ -88,16 +90,16 @@ class PipelineConfig(BaseConfig): name: str module: ModuleConfig schedule: ScheduleConfig - dependencies: Optional[List[DependencyConfig]] = [] - target: TargetConfig = None + dependencies: Optional[List[DependencyConfig]] = Field(default_factory=list) + target: Optional[TargetConfig] = None metadata_manager: Optional[MetadataManagerConfig] = None feature_loader: Optional[FeatureLoaderConfig] = None - extras: Optional[Dict] = {} + extras: Optional[Dict] = Field(default_factory=dict) class PipelinesConfig(BaseConfig): runner: RunnerConfig - pipelines: list[PipelineConfig] + pipelines: List[PipelineConfig] def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: @@ -108,3 +110,12 @@ def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: return PipelinesConfig(**cfg) else: return PipelinesConfig(**raw_config) + + +class ConfigLoader: + """Loader for pipelines config""" + + @staticmethod + def load_yaml(path: str, overrides: Dict) -> PipelinesConfig: + """Load yaml config and apply overrides""" + return get_pipelines_config(path, overrides) diff --git a/rialto/runner/config_overrides.py b/rialto/runner/services/config_overrides.py similarity index 98% rename from rialto/runner/config_overrides.py rename to rialto/runner/services/config_overrides.py index cd81232..9350bca 100644 --- a/rialto/runner/config_overrides.py +++ b/rialto/runner/services/config_overrides.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rialto/runner/services/data_checker.py b/rialto/runner/services/data_checker.py new file mode 100644 index 0000000..6e55e94 --- /dev/null +++ b/rialto/runner/services/data_checker.py @@ -0,0 +1,103 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["DataChecker"] + +from datetime import date +from typing import Dict + +from loguru import logger +from pyspark.sql import DataFrame + +from rialto.common import DataReader +from rialto.runner.services.table import Table + + +class DataChecker: + """Checks if data for given date or date range is present in storage""" + + def __init__(self, reader: DataReader): + self.reader = reader + + def check_date(self, target: Table, partition_date: date) -> bool: + """Check if data for given date is present in target + + :param target: target Table to check + :param partition_date: Date to check + :return: True if data for given date is present, False otherwise + """ + return self.check_range(target, partition_date, partition_date) + + def check_range(self, target: Table, start_date: date, end_date: date) -> bool: + """Check if data for given date range is present in target + + :param target: target Table to check + :param start_date: Starting date of the range to check + :param end_date: Ending date of the range to check + :return: True if data for given date range is present, False otherwise + """ + if self.reader.table_exists(target.get_table_path()): + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=start_date, + date_to=end_date, + filters=target.filters, + ) + data_exists = df.count() > 0 + if ( + data_exists + and (target.filters is None or target.filters == {}) + and target.secondary_partitions is not None + ): + logger.warning( + f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of " + f"secondary partitions and no filters." + ) + data_exists = False + return data_exists + else: + logger.warning(f"Target table {target.get_table_path()} doesn't exist yet.") + return False + + def _get_filters(self, target: Table, df: DataFrame) -> Dict: + if target.filters is not None: + return target.filters + elif target.secondary_partitions: + filters = {} + logger.info("Inferring target sub-partition values from generated data.") + row = df.select(*target.secondary_partitions).distinct().collect()[0] + for c in target.secondary_partitions: + filters[c] = row[c] + return filters + else: + return {} + + def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int: + """Check how many records were written + + :param target: target Table to check + :param partition_date: Date to check + :param df: DataFrame that was written, used to determine filters if not provided in config + :return: Number of records for given date + """ + filters = self._get_filters(target, df) + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=partition_date, + date_to=partition_date, + filters=filters, + ) + + return df.count() diff --git a/rialto/runner/services/date_manager.py b/rialto/runner/services/date_manager.py new file mode 100644 index 0000000..518129c --- /dev/null +++ b/rialto/runner/services/date_manager.py @@ -0,0 +1,144 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["DateManager"] + +from datetime import date, datetime +from typing import List + +from dateutil.relativedelta import relativedelta +from loguru import logger + +from rialto.runner.services.config_loader import RunnerConfig, ScheduleConfig + + +class DateManager: + """Date generation and shifts based on configuration""" + + def __init__(self, config: RunnerConfig, run_date: str = None): + if run_date: + run_date = self.str_to_date(run_date) + else: + run_date = date.today() + + self.date_from = self.date_subtract( + input_date=run_date, + units=config.watched_period_units, + value=config.watched_period_value, + ) + + self.date_until = run_date + + if self.date_from > self.date_until: + raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") + logger.info(f"Running period set to: {self.date_from} - {self.date_until}") + + def get_date_from(self) -> date: + """Get starting date of the execution window""" + return self.date_from + + def get_date_until(self) -> date: + """Get ending date of the execution window""" + return self.date_until + + @staticmethod + def str_to_date(str_date: str) -> date: + """ + Convert YYYY-MM-DD string to date + + :param str_date: string date + :return: date + """ + try: + return datetime.strptime(str_date, "%Y-%m-%d").date() + except ValueError: + raise ValueError(f"Invalid date format: {str_date}. Expected YYYY-MM-DD.") + + @staticmethod + def date_subtract(input_date: date, units: str, value: int) -> date: + """ + Subtract given number of units from input date + + :param input_date: base date + :param units: units: years, months, weeks, days + :param value: number of units to subtract + :return: Starting date + """ + if units == "years": + return input_date - relativedelta(years=value) + if units == "months": + return input_date - relativedelta(months=value) + if units == "weeks": + return input_date - relativedelta(weeks=value) + if units == "days": + return input_date - relativedelta(days=value) + raise ValueError(f"Unknown time unit {units}") + + @staticmethod + def all_dates(date_from: date, date_until: date) -> List[date]: + """ + Get list of all dates between, inclusive + + :param date_from: starting date + :param date_until: ending date + :return: List[date] + """ + return [date_from + relativedelta(days=n) for n in range((date_until - date_from).days + 1)] + + def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]: + """ + Get list of execution and partition dates for given configuration + + :return: List of tuples with execution and partition dates + """ + execution_dates = self._execution_dates(schedule) + return [(ex_date, self._to_partition_date(ex_date, schedule)) for ex_date in execution_dates] + + def _execution_dates(self, schedule: ScheduleConfig) -> List[date]: + """ + Select dates inside given interval depending on frequency and selected day + + :param schedule: schedule config + :return: List of execution dates + """ + options = self.all_dates(self.date_from, self.date_until) + frequency = schedule.frequency.lower() + if frequency == "daily": + return options + if frequency == "weekly": + if not (1 <= schedule.day <= 7): + raise ValueError(f"Invalid day for weekly frequency: {schedule.day}. Must be 1-7.") + return [x for x in options if x.isoweekday() == schedule.day] + if frequency == "monthly": + if schedule.day == "last": + return [x for x in options if (x + relativedelta(days=1)).month != x.month] + if not (1 <= schedule.day <= 31): + raise ValueError(f"Invalid day for monthly frequency: {schedule.day}. Must be 1-31 or last.") + return [x for x in options if x.day == schedule.day] + raise ValueError(f"Unknown frequency: {schedule.frequency}") + + def _to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: + """ + Shift given date according to config + + :param date: input date + :param schedule: schedule config + :return: date + """ + if isinstance(schedule.info_date_shift, list): + for shift in schedule.info_date_shift: + date = self.date_subtract(date, units=shift.units, value=shift.value) + else: + date = self.date_subtract(date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value) + return date diff --git a/rialto/runner/services/executor.py b/rialto/runner/services/executor.py new file mode 100644 index 0000000..be84ffe --- /dev/null +++ b/rialto/runner/services/executor.py @@ -0,0 +1,96 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["PipelineExecutor"] + +from importlib import import_module +from typing import Tuple + +from loguru import logger +from pyspark.sql import DataFrame, SparkSession + +from rialto.common import DataReader +from rialto.loader import PysparkFeatureLoader +from rialto.metadata import MetadataManager +from rialto.runner.services.config_loader import ModuleConfig, PipelineConfig +from rialto.runner.services.data_checker import DataChecker +from rialto.runner.services.task_registry import PipelineTask +from rialto.runner.transformation import Transformation + + +class PipelineExecutor: + """Executes a single pipeline task.""" + + def __init__(self, spark: SparkSession, reader: DataReader, checker: DataChecker): + self.spark = spark + self.reader = reader + self.checker = checker + + def _init_tools( + self, spark: SparkSession, pipeline: PipelineConfig + ) -> Tuple[MetadataManager, PysparkFeatureLoader]: + """ + Initialize metadata manager and feature loader + + :param spark: Spark session + :param pipeline: Pipeline configuration + :return: MetadataManager and PysparkFeatureLoader + """ + if pipeline.metadata_manager is not None: + metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema) + else: + metadata_manager = None + + if pipeline.feature_loader is not None: + feature_loader = PysparkFeatureLoader( + spark, + feature_schema=pipeline.feature_loader.feature_schema, + metadata_schema=pipeline.feature_loader.metadata_schema, + ) + else: + feature_loader = None + return metadata_manager, feature_loader + + def _load_module(self, cfg: ModuleConfig) -> Transformation: + """ + Load feature group + + :param cfg: Feature configuration + :return: Transformation object + """ + module = import_module(cfg.python_module) + class_obj = getattr(module, cfg.python_class) + return class_obj() + + def execute(self, pipeline: PipelineTask) -> DataFrame: + """ + Execute the pipeline task. + + :param pipeline: Pipeline object to execute. + :return: DataFrame resulting from pipeline execution. + """ + logger.info(f"Executing pipeline {pipeline.name} for partition date {pipeline.partition_date}") + + # Load and run the job + job = self._load_module(pipeline.config.module) + metadata_manager, feature_loader = self._init_tools(self.spark, pipeline.config) + df = job.run( + spark=self.spark, + run_date=pipeline.execution_date, + config=pipeline.config, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, + ) + return df diff --git a/rialto/runner/services/result_mapper.py b/rialto/runner/services/result_mapper.py new file mode 100644 index 0000000..a749324 --- /dev/null +++ b/rialto/runner/services/result_mapper.py @@ -0,0 +1,113 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["TaskResultMapper"] + +from datetime import datetime + +from rialto.runner.reporting.record import Record +from rialto.runner.services.task_registry import PipelineTask + + +class TaskResultMapper: + """Maps task execution outcomes to Record objects with consistent schema""" + + @staticmethod + def success( + task: PipelineTask, + run_start: datetime, + records_count: int, + ) -> Record: + """Map successful task execution to Record""" + task.result = str(records_count) + return Record( + job=task.name, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, + records=records_count, + status="Success", + reason="OK", + exception=None, + ) + + @staticmethod + def already_complete(task: PipelineTask, run_start: datetime) -> Record: + """Map skipped (already complete) task to Record""" + task.result = "Skipped" + return Record( + job=task.name, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, + records=0, + status="Skipped", + reason="AlreadyComplete", + exception=None, + ) + + @staticmethod + def dependencies_incomplete( + task: PipelineTask, + run_start: datetime, + failed_deps: list, + ) -> Record: + """Map dependency failure to Record""" + task.result = "Failed" + details = "Dependencies Incomplete: " + ",\n".join(failed_deps) if failed_deps else "Unknown" + return Record( + job=task.name, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, + records=0, + status="Failed", + reason="Dependencies Incomplete", + exception=details, + ) + + @staticmethod + def exception( + task: PipelineTask, + run_start: datetime, + exception_message: str, + traceback_str: str, + ) -> Record: + """Map exception during execution to Record""" + task.result = "Error" + return Record( + job=task.name, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, + records=0, + status="Error", + reason=exception_message, + exception=traceback_str, + ) + + @staticmethod + def interrupted(task: PipelineTask, run_start: datetime) -> Record: + """Map keyboard interrupt to Record""" + task.result = "Interrupt" + return Record( + job=task.name, + target=task.target.get_table_path(), + date=task.partition_date, + time=datetime.now() - run_start, + records=0, + status="Error", + reason="Keyboard Interrupt", + exception=None, + ) diff --git a/rialto/runner/table.py b/rialto/runner/services/table.py similarity index 61% rename from rialto/runner/table.py rename to rialto/runner/services/table.py index 2d44498..dedd072 100644 --- a/rialto/runner/table.py +++ b/rialto/runner/services/table.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,48 @@ __all__ = ["Table"] -from typing import List +from typing import Dict, List from rialto.metadata import class_to_catalog_name +from rialto.runner.services.config_loader import DependencyConfig, PipelineConfig class Table: """Handler for databricks catalog paths""" + @classmethod + def from_target_config(cls, config: PipelineConfig) -> "Table": + """ + Create table object from pipeline config target section + + :param config: Pipeline configuration + + :return: Table object + """ + return cls( + schema_path=config.target.target_schema, + class_name=config.module.python_class, + partition=config.target.target_partition_column, + secondary_partitions=config.target.secondary_partition_columns, + table=config.target.custom_name, + filters=config.target.rerun_filters, + ) + + @classmethod + def from_dependency_config(cls, config: DependencyConfig) -> "Table": + """ + Create table object from pipeline config dependency section + + :param config: Dependency configuration + + :return: Table object + """ + return cls( + table_path=config.table, + partition=config.date_col, + filters=config.filters, + ) + def __init__( self, catalog: str = None, @@ -32,12 +66,14 @@ def __init__( class_name: str = None, partition: str = None, secondary_partitions: List[str] = None, + filters: Dict = None, ): self.catalog = catalog self.schema = schema self.table = table self.partition = partition self.secondary_partitions = secondary_partitions + self.filters = filters if schema_path: schema_path = schema_path.split(".") self.catalog = schema_path[0] @@ -58,7 +94,7 @@ def get_table_path(self) -> str: """Get full table path""" return f"{self.catalog}.{self.schema}.{self.table}" - def get_all_partitions(self) -> List[str]: + def get_all_partition_columns(self) -> List[str]: """Get list of all partitions""" if self.secondary_partitions: return [self.partition] + self.secondary_partitions diff --git a/rialto/runner/services/task_registry.py b/rialto/runner/services/task_registry.py new file mode 100644 index 0000000..2bb7d13 --- /dev/null +++ b/rialto/runner/services/task_registry.py @@ -0,0 +1,109 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["TaskRegistry", "PipelineTask", "PipelineDependency"] + +from dataclasses import dataclass, field +from datetime import date +from typing import Iterator, List + +from loguru import logger +from pyspark.sql import SparkSession + +from rialto.runner.services.config_loader import PipelineConfig +from rialto.runner.services.date_manager import DateManager +from rialto.runner.services.table import Table + + +@dataclass +class PipelineDependency: + """Class representing a pipeline dependency, with associated table and date range for checking completion""" + + table: Table + date_from: date + date_until: date + complete: bool = False + + +@dataclass +class PipelineTask: + """Class representing a pipeline to be executed.""" + + name: str + execution_date: date + partition_date: date + config: PipelineConfig + target: Table + dependencies: List[PipelineDependency] = field(default_factory=list) + completion: bool = False + dependencies_complete: bool = False + precheck_failed: bool = False + error: str | None = None + error_trace: str | None = None + result: str = "" + + +class TaskRegistry: + """Registry for pipeline tasks to be executed""" + + def __init__(self, spark: SparkSession, date_manager: DateManager): + self.spark = spark + self.date_manager = date_manager + self.tasks = [] + + def add_task(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None: + """ + Add task to registry + + :param name: Name of the pipeline + :param execution_date: Date when the pipeline is scheduled to run + :param partition_date: Date for which the pipeline is processing data + :param config: PipelineConfig object with pipeline configuration + + :return: None, adds a Pipeline object to self.tasks + """ + target = Table.from_target_config(config) + new_pipe = PipelineTask( + name=name, execution_date=execution_date, partition_date=partition_date, config=config, target=target + ) + + for dependency_config in config.dependencies: + dependency_table = Table.from_dependency_config(dependency_config) + dependency_from = self.date_manager.date_subtract( + execution_date, dependency_config.interval.units, dependency_config.interval.value + ) + dependency = PipelineDependency( + table=dependency_table, date_from=dependency_from, date_until=execution_date + ) + new_pipe.dependencies.append(dependency) + + self.tasks.append(new_pipe) + + def __iter__(self) -> Iterator[PipelineTask]: + """Allow iteration over tasks in execution plan""" + return iter(self.tasks) + + def log_status(self) -> None: + """Log status of all tasks in registry, showing completion and dependency status""" + check = "\u2714" # ✔ + cross = "\u2718" # ✘ + status = f"\n{'Job Name':<50} {'Partition Date':<15} {'Complete':<8} {'Dependencies':<12} {'Result':<9}\n" + status = status + ("-" * 70 + "\n") + for task in self.tasks: + complete_icon = check if task.completion else cross + deps_icon = check if task.dependencies_complete else cross + status = ( + status + f"{task.name:<50} {str(task.partition_date):<15} {complete_icon:^8} " + f"{deps_icon:^12} {task.result:<9}\n" + ) + logger.info(status) diff --git a/rialto/runner/services/task_status_checker.py b/rialto/runner/services/task_status_checker.py new file mode 100644 index 0000000..10f5c31 --- /dev/null +++ b/rialto/runner/services/task_status_checker.py @@ -0,0 +1,58 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["TaskStatusChecker"] + +from loguru import logger + +from rialto.runner.services.data_checker import DataChecker +from rialto.runner.services.task_registry import PipelineTask + + +class TaskStatusChecker: + """Handles completion and dependency checks for pipeline tasks.""" + + def __init__(self, checker: DataChecker): + self.checker = checker + + def check_completion(self, pipeline: PipelineTask) -> None: + """ + Check if pipeline is complete by checking if target data exists for partition date + + :param pipeline: Pipeline object for which to check completion + + :return: None, updates self.completion attribute + """ + pipeline.completion = self.checker.check_date(pipeline.target, pipeline.partition_date) + logger.info( + f"Job {pipeline.name} completion status for partition date " + f"{pipeline.partition_date}: {pipeline.completion}" + ) + + def check_pipeline_dependencies(self, pipeline: PipelineTask) -> None: + """ + Check if dependencies are complete by checking if data exists for each dependency in date range + + :param pipeline: Pipeline object for which to check dependencies + + :return: None, updates self.dependencies_complete attribute + """ + for dependency in pipeline.dependencies: + dependency.complete = self.checker.check_range( + dependency.table, dependency.date_from, dependency.date_until + ) + logger.info( + f"Dependency {dependency.table.get_table_path()} completion status for date range " + f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" + ) + pipeline.dependencies_complete = all([dependency.complete for dependency in pipeline.dependencies]) diff --git a/rialto/runner/writer.py b/rialto/runner/services/writer.py similarity index 51% rename from rialto/runner/writer.py rename to rialto/runner/services/writer.py index bc147fd..b4c8c91 100644 --- a/rialto/runner/writer.py +++ b/rialto/runner/services/writer.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["Writer"] +__all__ = ["DatabricksWriter", "Writer"] +from abc import ABC, abstractmethod from datetime import date -from typing import List +from typing import Any, List import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession -from rialto.runner.table import Table +from rialto.runner.services.table import Table -class Writer: +class Writer(ABC): """Supporting class for runner""" + @abstractmethod + def write(self, df: DataFrame, info_date: date, table: Table) -> None: + """ + Write dataframe to storage + + :param df: dataframe to write + :param info_date: date to partition + :param table: path to write to + :return: None + """ + pass + + +class DatabricksWriter(Writer): + """Supporting class for runner, Databricks write operations""" + def __init__(self, spark: SparkSession, merge_schema=False): self.spark = spark self.merge_schema = merge_schema @@ -61,50 +78,71 @@ def _align_schema(self, df: DataFrame, existing_columns: List) -> DataFrame: :return: dataframe with aligned schema """ if existing_columns is not None: + missing = [c for c in existing_columns if c not in df.columns] + if missing: + raise ValueError(f"DataFrame is missing columns present in existing table: {missing}") return df.select( - *[F.col(c) for c in existing_columns if c in df.columns], + *[F.col(c) for c in existing_columns], *[F.col(c) for c in df.columns if c not in existing_columns], ) return df - def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame: - df = df.withColumn(table.partition, F.lit(info_date)) + def _process(self, df: DataFrame, partition_date: date, table: Table) -> DataFrame: + df = df.withColumn(table.partition, F.lit(partition_date)) df = self._align_schema(df, self._get_existing_columns(table)) return df - def _get_replace_condition(self, df: DataFrame, partition_cols: List[str]) -> str: - row = df.select(*partition_cols).distinct().collect() - if len(row) > 1: - raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}") - - parts = [] - for c in partition_cols: - val = row[0][c] - if val is None: - parts.append(f"{c} IS NULL") - elif isinstance(val, (int, float)): - parts.append(f"{c} = {val}") - else: - parts.append(f"{c} = '{val}'") - condition = " AND ".join(parts) - return condition - - def write(self, df: DataFrame, info_date: date, table: Table) -> None: + def _get_replace_expression(self, key: str, value: Any) -> str: + if value is None: + return f"{key} IS NULL" + elif isinstance(value, (int, float)): + return f"{key} = {value}" + else: + return f"{key} = '{value}'" + + def _get_replace_condition(self, df: DataFrame, target: Table, partition_date: date) -> str: + partition_cols = target.get_all_partition_columns() + + # only date column + if len(partition_cols) == 1: + return f"{partition_cols[0]} = '{partition_date.strftime('%Y-%m-%d')}'" + + # if target filters present for all partitions + elif target.filters and len(partition_cols) - 1 == len(target.filters): + target.filters[target.partition] = partition_date.strftime("%Y-%m-%d") + parts = [] + for c in partition_cols: + parts.append(self._get_replace_expression(c, target.filters[c])) + condition = " AND ".join(parts) + return condition + # grab from dataframe + else: + row = df.select(*partition_cols).distinct().collect() + if len(row) > 1: + raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}") + + parts = [] + for c in partition_cols: + parts.append(self._get_replace_expression(c, row[0][c])) + condition = " AND ".join(parts) + return condition + + def write(self, df: DataFrame, partition_date: date, table: Table) -> None: """ Write dataframe to storage :param df: dataframe to write - :param info_date: date to partition + :param partition_date: date to partition :param table: path to write to :return: None """ self._create_schema(table) - df = self._process(df, info_date, table) + df = self._process(df, partition_date, table) - replace_where = self._get_replace_condition(df, table.get_all_partitions()) + replace_where = self._get_replace_condition(df, table, partition_date) df.write.format("delta").partitionBy(table.partition).mode("overwrite").option( "mergeSchema", "true" if self.merge_schema else "false" diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py index 5b6f2eb..be5b903 100644 --- a/rialto/runner/transformation.py +++ b/rialto/runner/transformation.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from rialto.common import DataReader from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager -from rialto.runner.config_loader import PipelineConfig +from rialto.runner.services.config_loader import PipelineConfig class Transformation(metaclass=abc.ABCMeta): diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py index 5af1723..7e42542 100644 --- a/rialto/runner/utils.py +++ b/rialto/runner/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,82 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_module", "table_exists", "get_partitions", "init_tools", "find_dependency"] +__all__ = ["find_dependency"] -from datetime import date -from importlib import import_module -from typing import List, Tuple - -from pyspark.sql import SparkSession - -from rialto.common import DataReader -from rialto.loader import PysparkFeatureLoader -from rialto.metadata import MetadataManager -from rialto.runner.config_loader import ModuleConfig, PipelineConfig -from rialto.runner.table import Table -from rialto.runner.transformation import Transformation - - -def load_module(cfg: ModuleConfig) -> Transformation: - """ - Load feature group - - :param cfg: Feature configuration - :return: Transformation object - """ - module = import_module(cfg.python_module) - class_obj = getattr(module, cfg.python_class) - return class_obj() - - -def table_exists(spark: SparkSession, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return spark.catalog.tableExists(table) - - -def get_partitions(reader: DataReader, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - - -def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: - """ - Initialize metadata manager and feature loader - - :param spark: Spark session - :param pipeline: Pipeline configuration - :return: MetadataManager and PysparkFeatureLoader - """ - if pipeline.metadata_manager is not None: - metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema) - else: - metadata_manager = None - - if pipeline.feature_loader is not None: - feature_loader = PysparkFeatureLoader( - spark, - feature_schema=pipeline.feature_loader.feature_schema, - metadata_schema=pipeline.feature_loader.metadata_schema, - ) - else: - feature_loader = None - return metadata_manager, feature_loader +from rialto.runner.services.config_loader import PipelineConfig def find_dependency(config: PipelineConfig, name: str): diff --git a/tests/__init__.py b/tests/__init__.py index 79c3773..94ab807 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/common/conftest.py b/tests/common/conftest.py index 79455ff..198d297 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/common/test_reader.py b/tests/common/test_reader.py index 452ad13..8127154 100644 --- a/tests/common/test_reader.py +++ b/tests/common/test_reader.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index f11344d..fb2a7ba 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/__init__.py b/tests/jobs/__init__.py index 79c3773..94ab807 100644 --- a/tests/jobs/__init__.py +++ b/tests/jobs/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py index dda863d..c00517d 100644 --- a/tests/jobs/conftest.py +++ b/tests/jobs/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py index 3d785c3..ad36d7d 100644 --- a/tests/jobs/resources.py +++ b/tests/jobs/resources.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index b0fb898..c93b90e 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 7069b12..86314e3 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index 7d9b409..55ea96e 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/test_resolver.py b/tests/jobs/test_resolver.py index 443e27b..eca150d 100644 --- a/tests/jobs/test_resolver.py +++ b/tests/jobs/test_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py index 3143210..a3df677 100644 --- a/tests/jobs/test_test_utils.py +++ b/tests/jobs/test_test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/__init__.py b/tests/loader/__init__.py index 79c3773..94ab807 100644 --- a/tests/loader/__init__.py +++ b/tests/loader/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/metadata_config/full_example.yaml b/tests/loader/metadata_config/full_example.yaml index 9ad780c..a8e7d1c 100644 --- a/tests/loader/metadata_config/full_example.yaml +++ b/tests/loader/metadata_config/full_example.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/metadata_config/missing_field_example.yaml b/tests/loader/metadata_config/missing_field_example.yaml index 1caf3b0..7956c08 100644 --- a/tests/loader/metadata_config/missing_field_example.yaml +++ b/tests/loader/metadata_config/missing_field_example.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/metadata_config/missing_value_example.yaml b/tests/loader/metadata_config/missing_value_example.yaml index 844e25f..e7da0c2 100644 --- a/tests/loader/metadata_config/missing_value_example.yaml +++ b/tests/loader/metadata_config/missing_value_example.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/metadata_config/no_map_example.yaml b/tests/loader/metadata_config/no_map_example.yaml index d3679fa..6234e76 100644 --- a/tests/loader/metadata_config/no_map_example.yaml +++ b/tests/loader/metadata_config/no_map_example.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/metadata_config/test_main_config.py b/tests/loader/metadata_config/test_main_config.py index b09f155..c8033eb 100644 --- a/tests/loader/metadata_config/test_main_config.py +++ b/tests/loader/metadata_config/test_main_config.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/pyspark/dataframe_builder.py b/tests/loader/pyspark/dataframe_builder.py index 94a755e..0fa4967 100644 --- a/tests/loader/pyspark/dataframe_builder.py +++ b/tests/loader/pyspark/dataframe_builder.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/pyspark/example_cfg.yaml b/tests/loader/pyspark/example_cfg.yaml index 6b19277..b5fa6c3 100644 --- a/tests/loader/pyspark/example_cfg.yaml +++ b/tests/loader/pyspark/example_cfg.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/pyspark/resources.py b/tests/loader/pyspark/resources.py index 64a8363..1f69226 100644 --- a/tests/loader/pyspark/resources.py +++ b/tests/loader/pyspark/resources.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py index dd2049f..5fc7338 100644 --- a/tests/loader/pyspark/test_from_cfg.py +++ b/tests/loader/pyspark/test_from_cfg.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/__init__.py b/tests/maker/__init__.py index 79c3773..94ab807 100644 --- a/tests/maker/__init__.py +++ b/tests/maker/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/conftest.py b/tests/maker/conftest.py index 79455ff..198d297 100644 --- a/tests/maker/conftest.py +++ b/tests/maker/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_FeatureFunction.py b/tests/maker/test_FeatureFunction.py index 43590ae..d478d76 100644 --- a/tests/maker/test_FeatureFunction.py +++ b/tests/maker/test_FeatureFunction.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_FeatureHolder.py b/tests/maker/test_FeatureHolder.py index 5c00cdb..c02f389 100644 --- a/tests/maker/test_FeatureHolder.py +++ b/tests/maker/test_FeatureHolder.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_FeatureMaker.py b/tests/maker/test_FeatureMaker.py index f3da26d..ed99155 100644 --- a/tests/maker/test_FeatureMaker.py +++ b/tests/maker/test_FeatureMaker.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/__init__.py b/tests/maker/test_features/__init__.py index 79c3773..94ab807 100644 --- a/tests/maker/test_features/__init__.py +++ b/tests/maker/test_features/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/aggregated_num_sum_outbound.py b/tests/maker/test_features/aggregated_num_sum_outbound.py index ce3937b..8d3bb4c 100644 --- a/tests/maker/test_features/aggregated_num_sum_outbound.py +++ b/tests/maker/test_features/aggregated_num_sum_outbound.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/aggregated_num_sum_txn.py b/tests/maker/test_features/aggregated_num_sum_txn.py index 6c807af..a005ef8 100644 --- a/tests/maker/test_features/aggregated_num_sum_txn.py +++ b/tests/maker/test_features/aggregated_num_sum_txn.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/dependent_features_fail.py b/tests/maker/test_features/dependent_features_fail.py index d9a8c7f..a702028 100644 --- a/tests/maker/test_features/dependent_features_fail.py +++ b/tests/maker/test_features/dependent_features_fail.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/dependent_features_fail2.py b/tests/maker/test_features/dependent_features_fail2.py index 4964c8a..96db490 100644 --- a/tests/maker/test_features/dependent_features_fail2.py +++ b/tests/maker/test_features/dependent_features_fail2.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/dependent_features_ok.py b/tests/maker/test_features/dependent_features_ok.py index 232f08b..f96852a 100644 --- a/tests/maker/test_features/dependent_features_ok.py +++ b/tests/maker/test_features/dependent_features_ok.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/sequential_avg_outbound.py b/tests/maker/test_features/sequential_avg_outbound.py index cedad5f..aefe46e 100644 --- a/tests/maker/test_features/sequential_avg_outbound.py +++ b/tests/maker/test_features/sequential_avg_outbound.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/sequential_avg_txn.py b/tests/maker/test_features/sequential_avg_txn.py index 65d1f7f..4e3bad3 100644 --- a/tests/maker/test_features/sequential_avg_txn.py +++ b/tests/maker/test_features/sequential_avg_txn.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/sequential_for_testing.py b/tests/maker/test_features/sequential_for_testing.py index 5a8de84..b9e0dd5 100644 --- a/tests/maker/test_features/sequential_for_testing.py +++ b/tests/maker/test_features/sequential_for_testing.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/sequential_outbound.py b/tests/maker/test_features/sequential_outbound.py index 6b0764e..ffe5961 100644 --- a/tests/maker/test_features/sequential_outbound.py +++ b/tests/maker/test_features/sequential_outbound.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_features/sequential_outbound_with_param.py b/tests/maker/test_features/sequential_outbound_with_param.py index eb50d80..3a003c4 100644 --- a/tests/maker/test_features/sequential_outbound_with_param.py +++ b/tests/maker/test_features/sequential_outbound_with_param.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/maker/test_wrappers.py b/tests/maker/test_wrappers.py index 135b4ad..fc80051 100644 --- a/tests/maker/test_wrappers.py +++ b/tests/maker/test_wrappers.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/metadata/__init__.py b/tests/metadata/__init__.py index 79c3773..94ab807 100644 --- a/tests/metadata/__init__.py +++ b/tests/metadata/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/metadata/conftest.py b/tests/metadata/conftest.py index d500b57..01cfcde 100644 --- a/tests/metadata/conftest.py +++ b/tests/metadata/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/metadata/resources.py b/tests/metadata/resources.py index fb9df8c..5ab4e85 100644 --- a/tests/metadata/resources.py +++ b/tests/metadata/resources.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/metadata/test_metadata_connector.py b/tests/metadata/test_metadata_connector.py index 6594e6c..01e5ade 100644 --- a/tests/metadata/test_metadata_connector.py +++ b/tests/metadata/test_metadata_connector.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/__init__.py b/tests/runner/__init__.py index 79c3773..94ab807 100644 --- a/tests/runner/__init__.py +++ b/tests/runner/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 4e527be..095aa0f 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,6 @@ def spark(request): return spark -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def basic_runner(spark): return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") diff --git a/tests/runner/reporting/__init__.py b/tests/runner/reporting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/runner/reporting/test_record.py b/tests/runner/reporting/test_record.py new file mode 100644 index 0000000..37489f5 --- /dev/null +++ b/tests/runner/reporting/test_record.py @@ -0,0 +1,54 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime, timedelta + +import pytest + +from rialto.runner.reporting.record import Record +from rialto.runner.services.config_loader import RunnerConfig +from rialto.runner.services.date_manager import DateManager + + +@pytest.fixture(scope="module") +def basic_date_manager() -> DateManager: + cfg = RunnerConfig(watched_period_units="months", watched_period_value=4) + return DateManager(cfg) + + +@pytest.fixture(scope="function") +def record(basic_date_manager): + return Record( + "job", + "target", + basic_date_manager.str_to_date("2024-01-01"), + timedelta(days=0, hours=1, minutes=2, seconds=3), + 1, + "status", + "reason", + None, + datetime(2024, 1, 1, 1, 2, 3), + ) + + +def test_record_to_spark(spark, basic_date_manager, record): + row = record.to_spark_row() + assert row.job == "job" + assert row.target == "target" + assert row.date == basic_date_manager.str_to_date("2024-01-01") + assert row.time == "1:02:03" + assert row.records == 1 + assert row.status == "status" + assert row.reason == "reason" + assert row.exception is None + assert row.run_timestamp == datetime(2024, 1, 1, 1, 2, 3) diff --git a/tests/runner/transformations/__init__.py b/tests/runner/resources/__init__.py similarity index 83% rename from tests/runner/transformations/__init__.py rename to tests/runner/resources/__init__.py index eaa15cd..08a0182 100644 --- a/tests/runner/transformations/__init__.py +++ b/tests/runner/resources/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tests.runner.transformations.simple_group import SimpleGroup # noqa +from tests.runner.resources.simple_group import SimpleGroup # noqa diff --git a/tests/runner/transformations/config.yaml b/tests/runner/resources/config.yaml similarity index 98% rename from tests/runner/transformations/config.yaml rename to tests/runner/resources/config.yaml index 11f9fc9..659a893 100644 --- a/tests/runner/transformations/config.yaml +++ b/tests/runner/resources/config.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/resources/config2.yaml similarity index 96% rename from tests/runner/transformations/config2.yaml rename to tests/runner/resources/config2.yaml index f7b9604..31067c1 100644 --- a/tests/runner/transformations/config2.yaml +++ b/tests/runner/resources/config2.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/transformations/config3.yaml b/tests/runner/resources/config3.yaml similarity index 96% rename from tests/runner/transformations/config3.yaml rename to tests/runner/resources/config3.yaml index 72af1da..70e34d3 100644 --- a/tests/runner/transformations/config3.yaml +++ b/tests/runner/resources/config3.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/overrider.yaml b/tests/runner/resources/overrider.yaml similarity index 98% rename from tests/runner/overrider.yaml rename to tests/runner/resources/overrider.yaml index 75257a7..add1fb4 100644 --- a/tests/runner/overrider.yaml +++ b/tests/runner/resources/overrider.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/resources/simple_group.py similarity index 96% rename from tests/runner/transformations/simple_group.py rename to tests/runner/resources/simple_group.py index ec2311c..044c3bf 100644 --- a/tests/runner/transformations/simple_group.py +++ b/tests/runner/resources/simple_group.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/runner/runner_resources.py b/tests/runner/runner_resources.py index 17b5447..a657c0f 100644 --- a/tests/runner/runner_resources.py +++ b/tests/runner/runner_resources.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType -from rialto.runner.date_manager import DateManager +from rialto.runner.services.date_manager import DateManager simple_group_data = [ ("A", DateManager.str_to_date("2023-03-05")), diff --git a/tests/runner/services/__init__.py b/tests/runner/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/runner/test_overrides.py b/tests/runner/services/test_config_overrides.py similarity index 66% rename from tests/runner/test_overrides.py rename to tests/runner/services/test_config_overrides.py index 5c738fb..6be3603 100644 --- a/tests/runner/test_overrides.py +++ b/tests/runner/services/test_config_overrides.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,51 +16,53 @@ from rialto.runner import Runner +CONFIG_PATH = "tests/runner/resources/overrider.yaml" + def test_overrides_simple(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]}, ) - assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + assert runner._services.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] def test_overrides_array_index(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.to[1]": "a@b.c"}, ) - assert runner.config.runner.mail.to == ["developer@testing.org", "a@b.c"] + assert runner._services.config.runner.mail.to == ["developer@testing.org", "a@b.c"] def test_overrides_array_append(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.to[-1]": "test"}, ) - assert runner.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"] + assert runner._services.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"] def test_overrides_array_lookup(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, ) - assert runner.config.pipelines[0].target.target_schema == "new_schema" + assert runner._services.config.pipelines[0].target.target_schema == "new_schema" def test_overrides_combined(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={ "runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"], @@ -68,16 +70,16 @@ def test_overrides_combined(spark): "pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1, }, ) - assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] - assert runner.config.pipelines[0].target.target_schema == "new_schema" - assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1 + assert runner._services.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + assert runner._services.config.pipelines[0].target.target_schema == "new_schema" + assert runner._services.config.pipelines[0].schedule.info_date_shift[0].value == 1 def test_index_out_of_range(spark): with pytest.raises(IndexError) as error: Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.to[8]": "test"}, ) @@ -88,7 +90,7 @@ def test_invalid_index_key(spark): with pytest.raises(ValueError) as error: Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.test[8]": "test"}, ) @@ -99,7 +101,7 @@ def test_invalid_key(spark): with pytest.raises(ValueError) as error: Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.mail.test.param": "test"}, ) @@ -110,7 +112,7 @@ def test_new_key(spark): with pytest.raises(ValidationError): Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={"runner.some_value": 5}, ) @@ -119,7 +121,7 @@ def test_new_key(spark): def test_replace_section(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={ "pipelines[name=SimpleGroup].feature_loader": { @@ -128,13 +130,13 @@ def test_replace_section(spark): } }, ) - assert runner.config.pipelines[0].feature_loader.feature_schema == "catalog.features" + assert runner._services.config.pipelines[0].feature_loader.feature_schema == "catalog.features" def test_add_section(spark): runner = Runner( spark, - config_path="tests/runner/overrider.yaml", + config_path=CONFIG_PATH, run_date="2023-03-31", overrides={ "pipelines[name=OtherGroup].feature_loader": { @@ -143,4 +145,15 @@ def test_add_section(spark): } }, ) - assert runner.config.pipelines[1].feature_loader.feature_schema == "catalog.features" + assert runner._services.config.pipelines[1].feature_loader.feature_schema == "catalog.features" + + +def test_invalid_append_index_for_nested_path(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path=CONFIG_PATH, + run_date="2023-03-31", + overrides={"runner.mail.to[-1].domain": "example.com"}, + ) + assert error.value.args[0] == "Invalid index -1 for key to in path ['to[-1]', 'domain']" diff --git a/tests/runner/services/test_data_checker.py b/tests/runner/services/test_data_checker.py new file mode 100644 index 0000000..a2bf90c --- /dev/null +++ b/tests/runner/services/test_data_checker.py @@ -0,0 +1,230 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import MagicMock + +import pytest +from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType + +from rialto.common import TableReader +from rialto.runner.services.data_checker import DataChecker +from rialto.runner.services.table import Table + + +@pytest.fixture(scope="module") +def simple_dataframe(spark): + df = [ + ("A", date(2023, 3, 5)), + ("B", date(2023, 3, 12)), + ("C", date(2023, 3, 19)), + ] + schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)]) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def partitioned_dataframe(spark): + df = [ + ("W", 1, "A", date(2023, 3, 5)), + ("E", 1, "B", date(2023, 3, 5)), + ("R", 2, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 12)), + ("Y", 2, "A", date(2023, 3, 19)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def new_insert_partitioned_dataframe(spark): + df = [ + ("E", 1, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 5)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 3, 12), True), + (date(2023, 3, 10), False), + (date(2023, 3, 19), True), + (date(2023, 3, 26), False), + ], +) +def test_check_date(mocker, spark, simple_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "start_date, end_date, expected", + [ + (date(2023, 3, 12), date(2023, 4, 12), True), + (date(2023, 3, 10), date(2023, 3, 11), False), + (date(2023, 3, 19), date(2023, 3, 19), True), + (date(2023, 3, 26), date(2023, 3, 29), False), + ], +) +def test_check_range(mocker, spark, simple_dataframe, start_date, end_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_range(table, start_date, end_date) + assert result == expected + + +def test_check_range_no_table( + mocker, + spark, +): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=False) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, date(2023, 3, 12)) + assert result is False + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), True), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_and_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters={"version": 1, "type": "A"}, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), False), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_no_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters=None, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +def test_check_written_with_no_filters_or_secondary_partitions(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE") + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={}, + ) + + +def test_check_written_with_filters(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE", filters={"foo": "bar"}) + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"foo": "bar"}, + ) + + +def test_check_written_with_secondary_partitions(mocker, new_insert_partitioned_dataframe): + # Setup + mock_reader = MagicMock() + mock_df = MagicMock() + mock_df.count.return_value = 7 + mock_reader.get_table.return_value = mock_df + + checker = DataChecker(mock_reader) + table = Table( + table_path="dummy.table.path", partition="DATE", filters=None, secondary_partitions=["VERSION", "TYPE"] + ) + result = checker.check_written(table, date(2023, 3, 5), new_insert_partitioned_dataframe) + assert result == 7 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"VERSION": 1, "TYPE": "B"}, + ) diff --git a/tests/runner/services/test_date_manager.py b/tests/runner/services/test_date_manager.py new file mode 100644 index 0000000..61c9615 --- /dev/null +++ b/tests/runner/services/test_date_manager.py @@ -0,0 +1,227 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date, datetime + +import pytest + +from rialto.runner.services.config_loader import ( + IntervalConfig, + RunnerConfig, + ScheduleConfig, +) +from rialto.runner.services.date_manager import DateManager + + +def test_str_to_date(): + assert DateManager.str_to_date("2023-03-05") == datetime.strptime("2023-03-05", "%Y-%m-%d").date() + + +def test_str_to_date_bad(): + with pytest.raises(ValueError): + DateManager.str_to_date("2023/03/05") + + +def test_invalid_range(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=-3) + with pytest.raises(ValueError): + DateManager(config=runner_cfg, run_date="2023-03-05") + + +@pytest.mark.parametrize( + "units , value, res", + [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")], +) +def test_date_subtract(units, value, res): + rundate = DateManager.str_to_date("2023-03-05") + date_from = DateManager.date_subtract(input_date=rundate, units=units, value=value) + assert date_from == DateManager.str_to_date(res) + + +def test_date_subtract_bad(): + rundate = DateManager.str_to_date("2023-03-05") + with pytest.raises(ValueError) as exception: + DateManager.date_subtract(input_date=rundate, units="random", value=1) + assert str(exception.value) == "Unknown time unit random" + + +def test_all_dates(): + all_dates = DateManager.all_dates( + date_from=DateManager.str_to_date("2023-02-05"), + date_until=DateManager.str_to_date("2023-04-12"), + ) + assert len(all_dates) == 67 + assert all_dates[1] == DateManager.str_to_date("2023-02-06") + + +def test_date_from(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + date_manager = DateManager(config=runner_cfg, run_date="2023-03-05") + assert date_manager.get_date_from() == DateManager.str_to_date("2022-12-05") + + +def test_date_until(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + date_manager = DateManager(config=runner_cfg, run_date="2023-03-05") + assert date_manager.get_date_until() == DateManager.str_to_date("2023-03-05") + + +def test_run_dates_daily_no_shift(): + runner_cfg = RunnerConfig(watched_period_units="weeks", watched_period_value=1) + cfg = ScheduleConfig(frequency="daily") + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 5, 13), + date(2026, 5, 14), + date(2026, 5, 15), + date(2026, 5, 16), + date(2026, 5, 17), + date(2026, 5, 18), + date(2026, 5, 19), + date(2026, 5, 20), + ] + + expected_partition_dates = [ + date(2026, 5, 13), + date(2026, 5, 14), + date(2026, 5, 15), + date(2026, 5, 16), + date(2026, 5, 17), + date(2026, 5, 18), + date(2026, 5, 19), + date(2026, 5, 20), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) + + +def test_run_dates_weekly_backwards_shift(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=1) + cfg = ScheduleConfig(frequency="weekly", day=5, info_date_shift=IntervalConfig(units="days", value=2)) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 4, 24), + date(2026, 5, 1), + date(2026, 5, 8), + date(2026, 5, 15), + ] + + expected_partition_dates = [ + date(2026, 4, 22), + date(2026, 4, 29), + date(2026, 5, 6), + date(2026, 5, 13), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) + + +def test_run_dates_monthly_with_forward_shift(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + cfg = ScheduleConfig(frequency="monthly", day=5, info_date_shift=IntervalConfig(units="days", value=-2)) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 3, 5), + date(2026, 4, 5), + date(2026, 5, 5), + ] + + expected_partition_dates = [ + date(2026, 3, 7), + date(2026, 4, 7), + date(2026, 5, 7), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) + + +def test_run_dates_monthly_with_double_shift(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + cfg = ScheduleConfig( + frequency="monthly", + day=5, + info_date_shift=[IntervalConfig(units="days", value=-2), IntervalConfig(units="months", value=2)], + ) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 3, 5), + date(2026, 4, 5), + date(2026, 5, 5), + ] + + expected_partition_dates = [ + date(2026, 1, 7), + date(2026, 2, 7), + date(2026, 3, 7), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) + + +def test_run_dates_monthly_last(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + cfg = ScheduleConfig(frequency="monthly", day="last") + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 2, 28), + date(2026, 3, 31), + date(2026, 4, 30), + ] + + expected_partition_dates = [ + date(2026, 2, 28), + date(2026, 3, 31), + date(2026, 4, 30), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) + + +def test_invalid_days(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + weekly_cfg = ScheduleConfig(frequency="weekly", day=12) + monthly_cfg = ScheduleConfig(frequency="monthly", day=42) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=weekly_cfg) + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=monthly_cfg) + + +def test_invalid_frequency(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + bad_cfg = ScheduleConfig(frequency="abc", day=42) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=bad_cfg) + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=bad_cfg) diff --git a/tests/runner/services/test_executor.py b/tests/runner/services/test_executor.py new file mode 100644 index 0000000..b7582c1 --- /dev/null +++ b/tests/runner/services/test_executor.py @@ -0,0 +1,160 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date +from unittest.mock import MagicMock, patch + +import pytest + +from rialto.runner.services.executor import PipelineExecutor +from rialto.runner.services.task_registry import PipelineTask + + +def _make_executor(): + return PipelineExecutor( + spark=MagicMock(), + reader=MagicMock(), + checker=MagicMock(), + ) + + +def _make_task(): + config = MagicMock() + config.module = MagicMock() + config.module.python_module = "tests.runner.transformations" + config.module.python_class = "SimpleGroup" + config.metadata_manager = None + config.feature_loader = None + + task = PipelineTask( + name="SimpleGroup", + execution_date=date(2026, 1, 1), + partition_date=date(2026, 1, 1), + config=config, + target=MagicMock(), + ) + return task + + +def test_load_module_imports_and_instantiates_class(): + executor = _make_executor() + cfg = MagicMock() + cfg.python_module = "fake.module" + cfg.python_class = "FakeClass" + + fake_module = MagicMock() + fake_class = MagicMock() + fake_instance = MagicMock() + fake_class.return_value = fake_instance + setattr(fake_module, "FakeClass", fake_class) + + with patch("rialto.runner.services.executor.import_module", return_value=fake_module) as import_module_mock: + result = executor._load_module(cfg) + + import_module_mock.assert_called_once_with("fake.module") + fake_class.assert_called_once_with() + assert result is fake_instance + + +def test_init_tools_returns_none_when_not_configured(): + executor = _make_executor() + + pipeline_cfg = MagicMock() + pipeline_cfg.metadata_manager = None + pipeline_cfg.feature_loader = None + + metadata_manager, feature_loader = executor._init_tools(executor.spark, pipeline_cfg) + + assert metadata_manager is None + assert feature_loader is None + + +def test_init_tools_creates_both_tools_when_configured(): + executor = _make_executor() + + pipeline_cfg = MagicMock() + pipeline_cfg.metadata_manager = MagicMock() + pipeline_cfg.metadata_manager.metadata_schema = "meta_schema" + pipeline_cfg.feature_loader = MagicMock() + pipeline_cfg.feature_loader.feature_schema = "feature_schema" + pipeline_cfg.feature_loader.metadata_schema = "feature_meta_schema" + + with patch("rialto.runner.services.executor.MetadataManager") as metadata_manager_cls, patch( + "rialto.runner.services.executor.PysparkFeatureLoader" + ) as feature_loader_cls: + metadata_manager, feature_loader = executor._init_tools(executor.spark, pipeline_cfg) + + metadata_manager_cls.assert_called_once_with(executor.spark, "meta_schema") + feature_loader_cls.assert_called_once_with( + executor.spark, + feature_schema="feature_schema", + metadata_schema="feature_meta_schema", + ) + assert metadata_manager is metadata_manager_cls.return_value + assert feature_loader is feature_loader_cls.return_value + + +def test_execute_calls_job_run_and_returns_df(): + executor = _make_executor() + task = _make_task() + + mock_job = MagicMock() + mock_df = MagicMock() + mock_job.run.return_value = mock_df + + with patch.object(executor, "_load_module", return_value=mock_job) as load_module_mock, patch.object( + executor, "_init_tools", return_value=(None, None) + ) as init_tools_mock: + result = executor.execute(task) + + load_module_mock.assert_called_once_with(task.config.module) + init_tools_mock.assert_called_once_with(executor.spark, task.config) + mock_job.run.assert_called_once_with( + spark=executor.spark, + run_date=task.execution_date, + config=task.config, + reader=executor.reader, + metadata_manager=None, + feature_loader=None, + ) + assert result is mock_df + + +def test_execute_logs_start_of_execution(): + executor = _make_executor() + task = _make_task() + + mock_job = MagicMock() + mock_job.run.return_value = MagicMock() + + with patch("rialto.runner.services.executor.logger.info") as logger_info_mock, patch.object( + executor, "_load_module", return_value=mock_job + ), patch.object(executor, "_init_tools", return_value=(None, None)): + executor.execute(task) + + logger_info_mock.assert_called_once_with(f"Executing pipeline {task.name} for partition date {task.partition_date}") + + +def test_execute_raises_when_job_run_fails(): + executor = _make_executor() + task = _make_task() + + mock_job = MagicMock() + mock_job.run.side_effect = RuntimeError("job failed") + + with patch.object(executor, "_load_module", return_value=mock_job), patch.object( + executor, "_init_tools", return_value=(None, None) + ): + with pytest.raises(RuntimeError, match="job failed"): + executor.execute(task) diff --git a/tests/runner/services/test_result_mapper.py b/tests/runner/services/test_result_mapper.py new file mode 100644 index 0000000..6a6e8d0 --- /dev/null +++ b/tests/runner/services/test_result_mapper.py @@ -0,0 +1,120 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date, datetime, timedelta +from unittest.mock import Mock + +from rialto.runner.services.result_mapper import TaskResultMapper + + +def _make_task(name="my_pipeline", table_path="catalog.schema.table", partition_date=date(2026, 1, 1)): + task = Mock() + task.name = name + task.partition_date = partition_date + task.target.get_table_path.return_value = table_path + return task + + +def _run_start(): + return datetime.now() - timedelta(seconds=5) + + +# ── success ────────────────────────────────────────────────────────────────── + + +def test_success(): + task = _make_task(name="pipe_a", table_path="cat.sch.tbl", partition_date=date(2026, 5, 1)) + record = TaskResultMapper.success(task, _run_start(), records_count=42) + assert record.status == "Success" + assert record.reason == "OK" + assert record.exception is None + assert record.records == 42 + assert record.job == "pipe_a" + assert record.target == "cat.sch.tbl" + assert record.date == date(2026, 5, 1) + assert isinstance(record.time, timedelta) + assert record.time.total_seconds() > 0 + + +# ── already_complete ────────────────────────────────────────────────────────── + + +def test_already_complete(): + task = _make_task(name="pipe_b", table_path="cat.sch.tbl2", partition_date=date(2026, 3, 15)) + record = TaskResultMapper.already_complete(task, _run_start()) + assert record.status == "Skipped" + assert record.reason == "AlreadyComplete" + assert record.exception is None + assert record.records == 0 + assert record.job == "pipe_b" + assert record.target == "cat.sch.tbl2" + assert record.date == date(2026, 3, 15) + + +# ── dependencies_incomplete ─────────────────────────────────────────────────── + + +def test_dependencies_incomplete_status_and_reason(): + task = _make_task(name="pipe_b", table_path="cat.sch.tbl2", partition_date=date(2026, 3, 15)) + failed = ["cat.s.dep1 from 2026-01-01 until 2026-01-07", "cat.s.dep2 from 2026-01-01 until 2026-01-07"] + record = TaskResultMapper.dependencies_incomplete(task, _run_start(), failed) + assert record.status == "Failed" + assert record.reason == "Dependencies Incomplete" + assert record.records == 0 + assert record.job == "pipe_b" + assert record.target == "cat.sch.tbl2" + assert record.date == date(2026, 3, 15) + for dep in failed: + assert dep in record.exception + + +def test_dependencies_incomplete_lists_failed_deps_in_exception(): + failed = ["cat.s.dep1 from 2026-01-01 until 2026-01-07", "cat.s.dep2 from 2026-01-01 until 2026-01-07"] + record = TaskResultMapper.dependencies_incomplete(_make_task(), _run_start(), failed) + for dep in failed: + assert dep in record.exception + + +def test_dependencies_incomplete_empty_list_falls_back(): + record = TaskResultMapper.dependencies_incomplete(_make_task(), _run_start(), []) + assert record.exception == "Unknown" + + +# ── exception ───────────────────────────────────────────────────────────────── + + +def test_exception(): + task = _make_task(name="pipe_c", table_path="cat.sch.tbl3", partition_date=date(2026, 4, 10)) + record = TaskResultMapper.exception(task, _run_start(), "ValueError", "Traceback...") + assert record.status == "Error" + assert record.reason == "ValueError" + assert record.exception == "Traceback..." + assert record.records == 0 + assert record.job == "pipe_c" + assert record.target == "cat.sch.tbl3" + assert record.date == date(2026, 4, 10) + + +# ── interrupted ─────────────────────────────────────────────────────────────── + + +def test_interrupted(): + task = _make_task(name="pipe_c", table_path="cat.sch.tbl3", partition_date=date(2026, 4, 10)) + record = TaskResultMapper.interrupted(task, _run_start()) + assert record.status == "Error" + assert record.reason == "Keyboard Interrupt" + assert record.exception is None + assert record.records == 0 + assert record.job == "pipe_c" + assert record.target == "cat.sch.tbl3" + assert record.date == date(2026, 4, 10) diff --git a/tests/runner/services/test_table.py b/tests/runner/services/test_table.py new file mode 100644 index 0000000..d34d5c9 --- /dev/null +++ b/tests/runner/services/test_table.py @@ -0,0 +1,107 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from rialto.runner.services.config_loader import TargetConfig +from rialto.runner.services.table import Table + + +def test_table_basic_init(): + t = Table(catalog="cat", schema="sch", table="tab", schema_path=None, table_path=None, class_name=None) + + assert t.get_table_path() == "cat.sch.tab" + assert t.get_schema_path() == "cat.sch" + + +def test_table_classname_init(): + t = Table(catalog=None, schema=None, table=None, schema_path="cat.sch", table_path=None, class_name="ClaSs") + + assert t.get_table_path() == "cat.sch.cla_ss" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "cla_ss" + + +def test_table_path_init(): + t = Table(catalog=None, schema=None, table=None, schema_path=None, table_path="cat.sch.tab", class_name=None) + + assert t.get_table_path() == "cat.sch.tab" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "tab" + + +def test_table_secondary_partitions(): + t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"]) + + assert t.get_all_partition_columns() == ["part", "sec1", "sec2"] + + +def test_table_get_partitions_only_main(): + t = Table(catalog="cat", schema="sch", table="tab", partition="part") + + assert t.get_all_partition_columns() == ["part"] + + +def test_table_prioritize_table_name(): + t = Table(catalog=None, schema=None, table="custom", schema_path="cat.sch", table_path=None, class_name="ClaSs") + + assert t.get_table_path() == "cat.sch.custom" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "custom" + + +def test_from_target_config(): + tconfig = TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + secondary_partition_columns=["sec1", "sec2"], + custom_name=None, + rerun_filters={"col": "value"}, + ) + + pipeline_cfg = Mock() + pipeline_cfg.module.python_class = "TestClass" + pipeline_cfg.target = tconfig + + t = Table.from_target_config(pipeline_cfg) + + assert t.get_table_path() == "cat.sch.test_class" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "test_class" + assert t.get_all_partition_columns() == ["part", "sec1", "sec2"] + assert t.filters == {"col": "value"} + + +def test_from_dependency_config(): + dconfig = Mock() + dconfig.table = "cat.sch.tab" + dconfig.date_col = "date" + dconfig.filters = {"col": "value"} + + t = Table.from_dependency_config(dconfig) + + assert t.get_table_path() == "cat.sch.tab" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "tab" + assert t.partition == "date" + assert t.filters == {"col": "value"} diff --git a/tests/runner/services/test_task_registry.py b/tests/runner/services/test_task_registry.py new file mode 100644 index 0000000..b13a579 --- /dev/null +++ b/tests/runner/services/test_task_registry.py @@ -0,0 +1,177 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import patch + +import pytest + +from rialto.runner.services.config_loader import ( + DependencyConfig, + IntervalConfig, + ModuleConfig, + PipelineConfig, + RunnerConfig, + ScheduleConfig, + TargetConfig, +) +from rialto.runner.services.date_manager import DateManager +from rialto.runner.services.task_registry import ( + PipelineDependency, + PipelineTask, + TaskRegistry, +) + + +@pytest.fixture(scope="module") +def date_manager(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + return DateManager(runner_cfg, "2020-01-01") + + +@pytest.fixture +def pipeline_config_no_deps(): + return PipelineConfig( + name="test_pipeline", + module=ModuleConfig(python_module="some.module", python_class="TestClass"), + schedule=ScheduleConfig(frequency="monthly", day=1), + target=TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + secondary_partition_columns=["sec1", "sec2"], + custom_name=None, + rerun_filters={"col": "value"}, + ), + ) + + +@pytest.fixture +def pipeline_config_with_deps(): + return PipelineConfig( + name="test_pipeline_with_deps", + module=ModuleConfig(python_module="some.module", python_class="TestClass"), + schedule=ScheduleConfig(frequency="monthly", day=1), + target=TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + ), + dependencies=[ + DependencyConfig( + table="cat.sch.dep_table", + date_col="part", + interval=IntervalConfig(units="months", value=1), + ) + ], + ) + + +def test_registry_initializes_empty(spark, date_manager): + registry = TaskRegistry(spark, date_manager) + assert list(registry) == [] + + +def test_add_task_no_dependencies(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_no_deps, + ) + + tasks = list(registry) + assert len(tasks) == 1 + + task = tasks[0] + assert isinstance(task, PipelineTask) + assert task.name == "test_pipeline" + assert task.execution_date == date(2020, 1, 1) + assert task.partition_date == date(2019, 12, 31) + assert task.config is pipeline_config_no_deps + assert task.dependencies == [] + assert task.completion is False + assert task.dependencies_complete is False + + +def test_add_task_target_table(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_no_deps, + ) + + task = list(registry)[0] + assert task.target.catalog == "cat" + assert task.target.schema == "sch" + assert task.target.partition == "part" + assert task.target.secondary_partitions == ["sec1", "sec2"] + assert task.target.filters == {"col": "value"} + + +def test_add_task_with_dependencies(spark, date_manager, pipeline_config_with_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline_with_deps", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_with_deps, + ) + + task = list(registry)[0] + assert len(task.dependencies) == 1 + + dep = task.dependencies[0] + assert isinstance(dep, PipelineDependency) + assert dep.table.get_table_path() == "cat.sch.dep_table" + assert dep.date_until == date(2020, 1, 1) + assert dep.date_from == date(2019, 12, 1) # 1 month subtracted + assert dep.complete is False + + +def test_add_multiple_tasks(spark, date_manager, pipeline_config_no_deps, pipeline_config_with_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("pipeline_a", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.add_task("pipeline_b", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_with_deps) + + tasks = list(registry) + assert len(tasks) == 2 + assert tasks[0].name == "pipeline_a" + assert tasks[1].name == "pipeline_b" + + +def test_iteration(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("p1", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.add_task("p2", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + + names = [task.name for task in registry] + assert names == ["p1", "p2"] + + +def test_log_status_contains_task_info(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("test_pipeline", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.tasks[0].completion = True + + with patch("rialto.runner.services.task_registry.logger") as mock_logger: + registry.log_status() + + mock_logger.info.assert_called_once() + logged_output = mock_logger.info.call_args[0][0] + + assert "test_pipeline" in logged_output + assert "2019-12-31" in logged_output + assert "✔" in logged_output + assert "✘" in logged_output diff --git a/tests/runner/services/test_task_status_checker.py b/tests/runner/services/test_task_status_checker.py new file mode 100644 index 0000000..55ad406 --- /dev/null +++ b/tests/runner/services/test_task_status_checker.py @@ -0,0 +1,132 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date +from unittest.mock import Mock, patch + +import pytest + +from rialto.runner.services.table import Table +from rialto.runner.services.task_registry import PipelineDependency +from rialto.runner.services.task_status_checker import TaskStatusChecker + + +def make_task(name="my_pipeline", partition_date=date(2020, 1, 1), dependencies=None): + task = Mock() + task.name = name + task.partition_date = partition_date + task.target = Table(schema_path="cat.sch", class_name="TestClass", partition="part") + task.dependencies = dependencies or [] + task.completion = False + task.dependencies_complete = False + return task + + +def make_dependency(table_path="cat.sch.dep_table", date_from=date(2019, 12, 1), date_until=date(2020, 1, 1)): + table = Table(table_path=table_path, partition="part") + return PipelineDependency(table=table, date_from=date_from, date_until=date_until) + + +@pytest.fixture +def mock_data_checker(): + return Mock() # Mock of DataChecker + + +@pytest.fixture +def status_checker(mock_data_checker): + return TaskStatusChecker(checker=mock_data_checker) + + +def test_check_completion_sets_true_when_data_exists(status_checker, mock_data_checker): + mock_data_checker.check_date.return_value = True + task = make_task() + + status_checker.check_completion(task) + + assert task.completion is True + mock_data_checker.check_date.assert_called_once_with(task.target, task.partition_date) + + +def test_check_completion_sets_false_when_no_data(status_checker, mock_data_checker): + mock_data_checker.check_date.return_value = False + task = make_task() + + status_checker.check_completion(task) + + assert task.completion is False + + +# --- check_pipeline_dependencies --- + + +def test_check_pipeline_dependencies_no_deps(status_checker, mock_data_checker): + task = make_task(dependencies=[]) + + status_checker.check_pipeline_dependencies(task) + + assert task.dependencies_complete is True # all([]) == True + mock_data_checker.check_range.assert_not_called() + + +def test_check_pipeline_dependencies_all_complete(status_checker, mock_data_checker): + mock_data_checker.check_range.return_value = True + dep1 = make_dependency("cat.sch.table_a") + dep2 = make_dependency("cat.sch.table_b") + task = make_task(dependencies=[dep1, dep2]) + + status_checker.check_pipeline_dependencies(task) + + assert dep1.complete is True + assert dep2.complete is True + assert task.dependencies_complete is True + assert mock_data_checker.check_range.call_count == 2 + + +def test_check_pipeline_dependencies_one_incomplete(status_checker, mock_data_checker): + mock_data_checker.check_range.side_effect = [True, False] + dep1 = make_dependency("cat.sch.table_a") + dep2 = make_dependency("cat.sch.table_b") + task = make_task(dependencies=[dep1, dep2]) + + status_checker.check_pipeline_dependencies(task) + + assert dep1.complete is True + assert dep2.complete is False + assert task.dependencies_complete is False + + +def test_check_pipeline_dependencies_passes_correct_dates(status_checker, mock_data_checker): + mock_data_checker.check_range.return_value = True + dep = make_dependency(date_from=date(2019, 10, 1), date_until=date(2020, 1, 1)) + task = make_task(dependencies=[dep]) + + status_checker.check_pipeline_dependencies(task) + + mock_data_checker.check_range.assert_called_once_with(dep.table, date(2019, 10, 1), date(2020, 1, 1)) + + +def test_check_completion_logs_status(): + checker = Mock() + checker.check_date.return_value = True + task = make_task(name="logged_pipeline", partition_date=date(2020, 3, 1)) + status_checker = TaskStatusChecker(checker=checker) + + with patch("rialto.runner.services.task_status_checker.logger") as mock_logger: + status_checker.check_completion(task) + + mock_logger.info.assert_called_once() + log_msg = mock_logger.info.call_args[0][0] + assert "logged_pipeline" in log_msg + assert "2020-03-01" in log_msg + assert "True" in log_msg diff --git a/tests/runner/services/test_writer.py b/tests/runner/services/test_writer.py new file mode 100644 index 0000000..818d9fc --- /dev/null +++ b/tests/runner/services/test_writer.py @@ -0,0 +1,244 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date, datetime +from unittest.mock import MagicMock, Mock, patch + +import pytest +from pyspark.sql import Row +from pyspark.sql.types import StringType, StructField, StructType + +from rialto.runner.services.table import Table +from rialto.runner.services.writer import DatabricksWriter + + +@pytest.fixture +def writer(spark): + return DatabricksWriter(spark, merge_schema=False) + + +@pytest.fixture +def writer_merge(spark): + return DatabricksWriter(spark, merge_schema=True) + + +@pytest.fixture +def simple_table(): + return Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date") + + +@pytest.fixture +def table_with_secondary(simple_table): + simple_table.secondary_partitions = ["region"] + return simple_table + + +# --- _align_schema --- + + +def test_align_schema_no_existing_columns_returns_df_unchanged(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2)]) + result = writer._align_schema(df, None) + assert result.columns == df.columns + + +def test_align_schema_reorders_to_existing_columns(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2, c=3)]) + result = writer._align_schema(df, ["c", "a", "b"]) + assert result.columns == ["c", "a", "b"] + + +def test_align_schema_new_columns_appended_after_existing(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2, new_col=3)]) + result = writer._align_schema(df, ["a", "b"]) + assert result.columns == ["a", "b", "new_col"] + + +def test_align_schema_missing_existing_columns_raises_value_error(spark, writer): + # column "gone" is in existing but not in df — should raise ValueError + df = spark.createDataFrame([Row(a=1, b=2)]) + with pytest.raises(ValueError): + writer._align_schema(df, ["gone", "a", "b"]) + + +# --- _get_replace_condition --- + + +def test_get_replace_condition_string_value(spark, writer): + df = spark.createDataFrame([Row(info_date=date(2020, 1, 1))]) + target = Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date") + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "info_date = '2020-01-01'" + + +def test_get_replace_condition_second_value_null(spark, writer): + df = spark.createDataFrame( + [Row(information_date="2020-01-01", region=None)], + schema=StructType( + [ + StructField("information_date", StringType(), nullable=False), + StructField("region", StringType(), nullable=True), + ] + ), + ) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + ) + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "information_date = '2020-01-01' AND region IS NULL" + + +def test_get_replace_condition_second_value_no_filters(spark, writer): + df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)]) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + ) + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "information_date = '2020-01-01' AND region = 1" + + +def test_get_replace_condition_second_value_with_filters(spark, writer): + df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)]) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + filters={"region": 1}, + ) + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "information_date = '2020-01-01' AND region = 1" + + +def test_get_replace_condition_raises_on_multiple_distinct_values(spark, writer): + df = spark.createDataFrame( + [Row(information_date="2020-01-01", region=1), Row(information_date="2020-01-01", region=2)] + ) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + ) + with pytest.raises(ValueError, match="more than 1 distinct value"): + writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + + +# --- _process --- + + +def test_process_adds_partition_column(spark, writer, simple_table): + df = spark.createDataFrame([Row(a=1)]) + with patch.object(writer, "_get_existing_columns", return_value=None): + result = writer._process(df, date(2020, 1, 1), simple_table) + assert "info_date" in result.columns + assert result.collect()[0]["info_date"] == date(2020, 1, 1) + + +def test_get_existing_columns_returns_columns(): + # Arrange + spark = Mock() + table = Mock() + table.get_table_path.return_value = "catalog.schema.tbl" + + spark_df = Mock() + spark_df.columns = ["id", "name"] + spark.table.return_value = spark_df + + writer = DatabricksWriter(spark=spark) + result = writer._get_existing_columns(table) + + assert result == ["id", "name"] + table.get_table_path.assert_called_once_with() + spark.table.assert_called_once_with("catalog.schema.tbl") + + +def test_get_existing_columns_returns_none_on_exception(): + # Arrange + spark = Mock() + table = Mock() + table.get_table_path.return_value = "catalog.schema.tbl" + spark.table.side_effect = Exception("table not found") + + writer = DatabricksWriter(spark=spark) + + # Patch the module-level logger used in writer.py + with patch("rialto.runner.services.writer.logger.warning") as warning_mock: + # Act + result = writer._get_existing_columns(table) + + # Assert + assert result is None + table.get_table_path.assert_called() # called at least once (try + except log message) + spark.table.assert_called_once_with("catalog.schema.tbl") + warning_mock.assert_called_once() + + +# --- write (integration of internal steps) --- + + +def test_write_calls_create_schema(spark, writer, simple_table): + with patch.object(writer, "_create_schema") as mock_create, patch.object(writer, "_process"), patch.object( + writer, "_get_replace_condition" + ): + df = Mock() + df.write = Mock() + writer.write(df, Mock(), simple_table) + + mock_create.assert_called_once_with(simple_table) + + +def test_create_schema_uses_table_schema_path(): + # Arrange + spark = Mock() + table = Mock() + table.get_schema_path.return_value = "my_catalog.my_schema" + + writer = DatabricksWriter(spark=spark) + writer._create_schema(table) + + spark.sql.assert_called_once_with("CREATE SCHEMA IF NOT EXISTS my_catalog.my_schema") + table.get_schema_path.assert_called_once_with() + + +def test_write_merge_schema_option(spark, writer_merge, simple_table): + df = MagicMock() + df.write = MagicMock() + + with patch.object(writer_merge, "_create_schema"), patch.object( + writer_merge, "_process", return_value=df + ), patch.object(writer_merge, "_get_replace_condition"): + writer_merge.write(df, date(2020, 1, 1), simple_table) + + option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list + assert any(call.args == ("mergeSchema", "true") for call in option_calls) + + +def test_write_not_merge_schema_option(spark, writer, simple_table): + df = MagicMock() + df.write = MagicMock() + + with patch.object(writer, "_create_schema"), patch.object(writer, "_process", return_value=df), patch.object( + writer, "_get_replace_condition" + ): + writer.write(df, date(2020, 1, 1), simple_table) + + option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list + assert any(call.args == ("mergeSchema", "false") for call in option_calls) diff --git a/tests/runner/test_bookkeeping.py b/tests/runner/test_bookkeeping.py deleted file mode 100644 index 06507d5..0000000 --- a/tests/runner/test_bookkeeping.py +++ /dev/null @@ -1,29 +0,0 @@ -from datetime import datetime, timedelta - -from rialto.runner.date_manager import DateManager -from rialto.runner.reporting.record import Record - -record = Record( - "job", - "target", - DateManager.str_to_date("2024-01-01"), - timedelta(days=0, hours=1, minutes=2, seconds=3), - 1, - "status", - "reason", - None, - datetime(2024, 1, 1, 1, 2, 3), -) - - -def test_record_to_spark(spark): - row = record.to_spark_row() - assert row.job == "job" - assert row.target == "target" - assert row.date == DateManager.str_to_date("2024-01-01") - assert row.time == "1:02:03" - assert row.records == 1 - assert row.status == "status" - assert row.reason == "reason" - assert row.exception is None - assert row.run_timestamp == datetime(2024, 1, 1, 1, 2, 3) diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py deleted file mode 100644 index 73b61b8..0000000 --- a/tests/runner/test_date_manager.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import datetime - -import pytest - -from rialto.runner.config_loader import IntervalConfig, ScheduleConfig -from rialto.runner.date_manager import DateManager - - -def test_str_to_date(): - assert DateManager.str_to_date("2023-03-05") == datetime.strptime("2023-03-05", "%Y-%m-%d").date() - - -@pytest.mark.parametrize( - "units , value, res", - [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")], -) -def test_date_from(units, value, res): - rundate = DateManager.str_to_date("2023-03-05") - date_from = DateManager.date_subtract(run_date=rundate, units=units, value=value) - assert date_from == DateManager.str_to_date(res) - - -def test_date_from_bad(): - rundate = DateManager.str_to_date("2023-03-05") - with pytest.raises(ValueError) as exception: - DateManager.date_subtract(run_date=rundate, units="random", value=1) - assert str(exception.value) == "Unknown time unit random" - - -def test_all_dates(): - all_dates = DateManager.all_dates( - date_from=DateManager.str_to_date("2023-02-05"), - date_to=DateManager.str_to_date("2023-04-12"), - ) - assert len(all_dates) == 67 - assert all_dates[1] == DateManager.str_to_date("2023-02-06") - - -def test_all_dates_reversed(): - all_dates = DateManager.all_dates( - date_from=DateManager.str_to_date("2023-04-12"), - date_to=DateManager.str_to_date("2023-02-05"), - ) - assert len(all_dates) == 67 - assert all_dates[1] == DateManager.str_to_date("2023-02-06") - - -def test_run_dates_weekly(): - cfg = ScheduleConfig(frequency="weekly", day=5) - - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2023-02-05"), - date_to=DateManager.str_to_date("2023-04-07"), - schedule=cfg, - ) - - expected = [ - "2023-02-10", - "2023-02-17", - "2023-02-24", - "2023-03-03", - "2023-03-10", - "2023-03-17", - "2023-03-24", - "2023-03-31", - "2023-04-07", - ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected - - -def test_run_dates_monthly(): - cfg = ScheduleConfig(frequency="monthly", day=5) - - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2022-08-05"), - date_to=DateManager.str_to_date("2023-04-07"), - schedule=cfg, - ) - - expected = [ - "2022-08-05", - "2022-09-05", - "2022-10-05", - "2022-11-05", - "2022-12-05", - "2023-01-05", - "2023-02-05", - "2023-03-05", - "2023-04-05", - ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected - - -def test_run_dates_daily(): - cfg = ScheduleConfig(frequency="daily") - - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2023-03-28"), - date_to=DateManager.str_to_date("2023-04-03"), - schedule=cfg, - ) - - expected = [ - "2023-03-28", - "2023-03-29", - "2023-03-30", - "2023-03-31", - "2023-04-01", - "2023-04-02", - "2023-04-03", - ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected - - -def test_run_dates_invalid(): - cfg = ScheduleConfig(frequency="random") - with pytest.raises(ValueError) as exception: - DateManager.run_dates( - date_from=DateManager.str_to_date("2023-03-28"), - date_to=DateManager.str_to_date("2023-04-03"), - schedule=cfg, - ) - assert str(exception.value) == "Unknown frequency random" - - -@pytest.mark.parametrize( - "shift, res", - [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")], -) -def test_to_info_date(shift, res): - cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)]) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date(res) == info - - -@pytest.mark.parametrize( - "unit, result", - [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")], -) -def test_info_date_shift_units(unit, result): - cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)]) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date(result) == info - - -def test_info_date_shift_combined(): - cfg = ScheduleConfig( - frequency="daily", - info_date_shift=[IntervalConfig(units="months", value=3), IntervalConfig(units="days", value=4)], - ) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date("2022-12-01") == info diff --git a/tests/runner/test_engine.py b/tests/runner/test_engine.py new file mode 100644 index 0000000..db74f8e --- /dev/null +++ b/tests/runner/test_engine.py @@ -0,0 +1,450 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import MagicMock, call, patch + +import pytest + +from rialto.runner.engine import RunnerEngine +from rialto.runner.services.config_loader import ( + ModuleConfig, + PipelineConfig, + ScheduleConfig, +) + + +def _pipeline(name="p1", schedule="weekly"): + sch_cfg = ScheduleConfig(frequency=schedule, day=1) + return PipelineConfig(name=name, schedule=sch_cfg, module=ModuleConfig(python_module="mod", python_class="Class")) + + +def _dependency(complete=True, table_path="src.sch.dep", date_from=date(2026, 1, 1), date_until=date(2026, 1, 7)): + dep = MagicMock() + dep.complete = complete + dep.date_from = date_from + dep.date_until = date_until + dep.table.get_table_path.return_value = table_path + return dep + + +def _task( + name="p1", + completion=False, + dependencies_complete=True, + precheck_failed=False, + partition_date=date(2026, 1, 8), + execution_date=date(2026, 1, 8), + deps=None, +): + t = MagicMock() + t.name = name + t.completion = completion + t.dependencies_complete = dependencies_complete + t.precheck_failed = precheck_failed + t.partition_date = partition_date + t.execution_date = execution_date + t.dependencies = deps if deps is not None else [] + t.target = MagicMock() + return t + + +def _services(): + s = MagicMock() + s.config = MagicMock() + s.config.pipelines = [_pipeline("p1"), _pipeline("p2")] + s.date_manager = MagicMock() + s.registry = MagicMock() + s.registry.tasks = [] + s.task_checker = MagicMock() + s.executor = MagicMock() + s.writer = MagicMock() + s.data_checker = MagicMock() + s.tracker = MagicMock() + return s + + +# ---- select_pipelines ------------------------------------------------------ + + +def test_select_pipelines_returns_all_when_op_none(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + result = engine.select_pipelines(None) + + assert result == services.config.pipelines + + +def test_select_pipelines_returns_matching_op(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + result = engine.select_pipelines("p2") + + assert len(result) == 1 + assert result[0].name == "p2" + + +def test_select_pipelines_raises_for_unknown_op(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with pytest.raises(ValueError, match="Unknown operation selected: nope"): + engine.select_pipelines("nope") + + +# ---- register_tasks -------------------------------------------------------- + + +def test_register_tasks_adds_task_for_each_execution_partition_pair(): + services = _services() + services.date_manager.get_execution_and_partition_dates.return_value = [ + (date(2026, 1, 10), date(2026, 1, 8)), + (date(2026, 1, 17), date(2026, 1, 15)), + ] + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + pipelines = [_pipeline("p1", "weekly")] + + engine.register_tasks(pipelines) + + assert services.registry.add_task.call_count == 2 + services.registry.add_task.assert_has_calls( + [ + call( + name="p1", + execution_date=date(2026, 1, 10), + partition_date=date(2026, 1, 8), + config=pipelines[0], + ), + call( + name="p1", + execution_date=date(2026, 1, 17), + partition_date=date(2026, 1, 15), + config=pipelines[0], + ), + ] + ) + + +# ---- check_tasks ----------------------------------------------------------- + + +def test_check_tasks_calls_both_checks_by_default(): + services = _services() + t1 = _task(name="p1") + t2 = _task(name="p2") + services.registry.tasks = [t1, t2] + + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + engine.check_tasks() + + services.task_checker.check_completion.assert_has_calls([call(t1), call(t2)]) + services.task_checker.check_pipeline_dependencies.assert_has_calls([call(t1), call(t2)]) + + +def test_check_tasks_skips_completion_when_rerun_true(): + services = _services() + t = _task() + services.registry.tasks = [t] + + engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False) + engine.check_tasks() + + services.task_checker.check_completion.assert_not_called() + services.task_checker.check_pipeline_dependencies.assert_called_once_with(t) + + +def test_check_tasks_skips_dependencies_when_skip_dependencies_true(): + services = _services() + t = _task() + services.registry.tasks = [t] + + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=True) + engine.check_tasks() + + services.task_checker.check_completion.assert_called_once_with(t) + services.task_checker.check_pipeline_dependencies.assert_not_called() + + +def test_check_completion_records_exception_and_sets_precheck_failed(): + services = _services() + t = _task() + services.registry.tasks = [t] + services.task_checker.check_completion.side_effect = RuntimeError("boom") + + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + engine.check_tasks() + + services.task_checker.check_completion.assert_called_once_with(t) + assert t.precheck_failed is True + assert t.error == "boom" + assert "RuntimeError: boom" in t.error_trace + + +def test_check_dependencies_records_exception_and_sets_precheck_failed(): + services = _services() + t = _task() + services.registry.tasks = [t] + services.task_checker.check_pipeline_dependencies.side_effect = RuntimeError("boom") + engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False) + engine.check_tasks() + + services.task_checker.check_pipeline_dependencies.assert_called_once_with(t) + assert t.precheck_failed is True + assert t.error == "boom" + assert "RuntimeError: boom" in t.error_trace + + +# ---- run_tasks ------------------------------------------------------------- + + +def test_run_tasks_calls_execute_with_tracking_for_each_task(): + services = _services() + t1 = _task(name="p1") + t2 = _task(name="p2") + services.registry.tasks = [t1, t2] + + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch.object(engine, "_execute_task_with_tracking") as exec_track, patch( + "rialto.runner.engine.logger.info" + ) as log_info: + engine.run_tasks() + + exec_track.assert_has_calls([call(t1), call(t2)]) + assert exec_track.call_count == 2 + assert log_info.call_count == 2 + + +# ---- _execute_task_with_tracking branches ---------------------------------- + + +def test_execute_task_with_tracking_records_precheck_failure_and_skips_execution(): + services = _services() + task = _task(precheck_failed=True) + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.exception", return_value="rec") as mapper: + engine._execute_task_with_tracking(task) + + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + services.executor.execute.assert_not_called() + + +def test_execute_task_with_tracking_skips_already_complete(): + services = _services() + task = _task(completion=True, dependencies_complete=True, precheck_failed=False) + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.already_complete", return_value="rec") as mapper, patch( + "rialto.runner.engine.logger.info" + ): + engine._execute_task_with_tracking(task) + + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + services.executor.execute.assert_not_called() + + +def test_execute_task_with_tracking_skips_incomplete_dependencies(): + services = _services() + deps = [_dependency(complete=False, table_path="src.sch.dep1")] + task = _task(completion=False, dependencies_complete=False, deps=deps) + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.dependencies_incomplete", return_value="rec") as mapper, patch( + "rialto.runner.engine.logger.info" + ): + engine._execute_task_with_tracking(task) + + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + services.executor.execute.assert_not_called() + + +def test_execute_task_with_tracking_success_path(): + services = _services() + task = _task(completion=False, dependencies_complete=True) + services.executor.execute.return_value = "df" + services.data_checker.check_written.return_value = 123 + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as mapper, patch( + "rialto.runner.engine.logger.info" + ): + engine._execute_task_with_tracking(task) + + services.executor.execute.assert_called_once_with(task) + services.writer.write.assert_called_once_with("df", task.partition_date, task.target) + services.data_checker.check_written.assert_called_once_with(task.target, task.partition_date, "df") + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + + +def test_execute_task_with_tracking_exception_path(): + services = _services() + task = _task(completion=False, dependencies_complete=True) + services.executor.execute.side_effect = RuntimeError("boom") + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with ( + patch("rialto.runner.engine.TaskResultMapper.exception", return_value="rec") as mapper, + patch("rialto.runner.engine.logger.exception") as log_exc, + ): + engine._execute_task_with_tracking(task) + + log_exc.assert_called_once() + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + services.writer.write.assert_not_called() + + +def test_execute_task_with_tracking_keyboard_interrupt_records_and_reraises(): + services = _services() + task = _task(completion=False, dependencies_complete=True) + services.executor.execute.side_effect = KeyboardInterrupt() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.interrupted", return_value="rec") as mapper: + with pytest.raises(KeyboardInterrupt): + engine._execute_task_with_tracking(task) + + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + + +def test_execute_task_with_tracking_runs_when_skip_dependencies_true_even_if_incomplete(): + services = _services() + deps = [_dependency(complete=False)] + task = _task(completion=False, dependencies_complete=False, deps=deps) + services.executor.execute.return_value = "df" + services.data_checker.check_written.return_value = 1 + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=True) + + with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as success_mapper, patch( + "rialto.runner.engine.TaskResultMapper.dependencies_incomplete" + ) as dep_mapper, patch("rialto.runner.engine.logger.info"): + engine._execute_task_with_tracking(task) + + dep_mapper.assert_not_called() + success_mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + + +def test_execute_task_with_tracking_runs_when_rerun_true_even_if_completion_true(): + services = _services() + task = _task(completion=True, dependencies_complete=True) + services.executor.execute.return_value = "df" + services.data_checker.check_written.return_value = 10 + engine = RunnerEngine(services=services, rerun=True, skip_dependencies=False) + + with patch("rialto.runner.engine.TaskResultMapper.success", return_value="rec") as mapper, patch( + "rialto.runner.engine.logger.info" + ): + engine._execute_task_with_tracking(task) + + mapper.assert_called_once() + services.tracker.add.assert_called_once_with("rec") + + +# ---- wrappers -------------------------------------------------------------- + + +def test_finalize_calls_tracker_report_by_mail_and_log(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch.object(engine, "log_task_status") as log_task_status: + engine.finalize() + + services.tracker.report_by_mail.assert_called_once_with() + log_task_status.assert_called_once_with() + + +def test_run_calls_full_flow(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes, patch.object( + engine, "register_tasks" + ) as register_tasks, patch.object(engine, "check_tasks") as check_tasks, patch.object( + engine, "log_task_status" + ) as log_status, patch.object( + engine, "run_tasks" + ) as run_tasks, patch.object( + engine, "finalize" + ) as finalize: + engine.run("my_op") + + select_pipes.assert_called_once_with("my_op") + register_tasks.assert_called_once_with(["pipes"]) + check_tasks.assert_called_once_with() + log_status.assert_called_once_with() + run_tasks.assert_called_once_with() + finalize.assert_called_once_with() + + +def test_dry_run_execution_calls_expected_flow_without_run_or_finalize(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes, patch.object( + engine, "register_tasks" + ) as register_tasks, patch.object(engine, "check_tasks") as check_tasks, patch.object( + engine, "log_task_status" + ) as log_status, patch.object( + engine, "run_tasks" + ) as run_tasks, patch.object( + engine, "finalize" + ) as finalize: + engine.dry_run_execution("my_op") + + select_pipes.assert_called_once_with("my_op") + register_tasks.assert_called_once_with(["pipes"]) + check_tasks.assert_called_once_with() + log_status.assert_called_once_with() + run_tasks.assert_not_called() + finalize.assert_not_called() + + +def test_debug_first_task_registers_and_executes_first_task(): + services = _services() + t1 = _task(name="first") + t2 = _task(name="second") + services.registry.tasks = [t1, t2] + services.executor.execute.return_value = "df_debug" + + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + with ( + patch.object(engine, "select_pipelines", return_value=["pipes"]) as select_pipes, + patch.object(engine, "register_tasks") as register_tasks, + ): + result = engine.debug_first_task("my_op") + + select_pipes.assert_called_once_with("my_op") + register_tasks.assert_called_once_with(["pipes"]) + services.executor.execute.assert_called_once_with(t1) + assert result == "df_debug" + + +# ---- logging --------------------------------------------------------------- +def test_log_task_status_calls_registry_log_status(): + services = _services() + engine = RunnerEngine(services=services, rerun=False, skip_dependencies=False) + + engine.log_task_status() + + services.registry.log_status.assert_called_once_with() diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index cbb4b7b..b07d9d9 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -1,4 +1,4 @@ -# Copyright 2022 ABSA Group Limited +# Copyright 2022-2026 ABSA Group Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,320 +11,117 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from pyspark.sql import DataFrame -import rialto.runner.utils as utils -from rialto.common.table_reader import TableReader -from rialto.runner.runner import DateManager, Runner -from rialto.runner.table import Table -from tests.runner.runner_resources import ( - dep1_data, - dep2_data, - general_schema, - multi_part_data, - multi_schema, - simple_group_data, -) -from tests.runner.transformations.simple_group import SimpleGroup +from unittest.mock import MagicMock, patch +from rialto.runner.runner import Runner -class MockReader(TableReader): - def __init__(self, spark): - self.spark = spark - def _get_raw_data(self, table: str) -> DataFrame: - if table == "catalog.schema.simple_group": - return self.spark.createDataFrame(simple_group_data, general_schema) - if table == "source.schema.dep1": - return self.spark.createDataFrame(dep1_data, general_schema) - if table == "source.schema.dep2": - return self.spark.createDataFrame(dep2_data, general_schema) - if table == "source.schema.multi_part_data": - return self.spark.createDataFrame(multi_part_data, multi_schema) +def test_runner_init_builds_default_services_when_not_provided(): + spark = MagicMock() + built_services = MagicMock() + with ( + patch("rialto.runner.runner.DefaultRunnerServices.build", return_value=built_services) as build_mock, + patch("rialto.runner.runner.RunnerEngine") as engine_cls, + ): + runner = Runner( + spark=spark, + config_path="tests/runner/transformations/config.yaml", + run_date="2026-01-01", + rerun=True, + op="SimpleGroup", + skip_dependencies=True, + overrides={"runner": {"watched_period_units": "days", "watched_period_value": 7}}, + merge_schema=True, + ) -def test_table_exists(spark, mocker): - mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True) - utils.table_exists(spark, "abc") - mock.assert_called_once_with("abc") - - -def test_load_module(spark, basic_runner): - module = utils.load_module(basic_runner.config.pipelines[0].module) - assert isinstance(module, SimpleGroup) - - -def test_generate(spark, mocker, basic_runner): - run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") - group = SimpleGroup() - config = basic_runner.config.pipelines[0] - basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config) - - run.assert_called_once_with( - reader=basic_runner.reader, - run_date=DateManager.str_to_date("2023-01-31"), + build_mock.assert_called_once_with( spark=spark, - config=config, - metadata_manager=None, - feature_loader=None, - ) - - -def test_generate_w_dep(spark, mocker, basic_runner): - run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") - group = SimpleGroup() - basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) - run.assert_called_once_with( - reader=basic_runner.reader, - run_date=DateManager.str_to_date("2023-01-31"), - spark=spark, - config=basic_runner.config.pipelines[2], - metadata_manager=None, - feature_loader=None, - ) - - -def test_init_dates(spark): - runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") - assert runner.date_from == DateManager.str_to_date("2023-01-31") - assert runner.date_until == DateManager.str_to_date("2023-03-31") - - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", - overrides={"runner.watched_period_units": "weeks", "runner.watched_period_value": 2}, - ) - assert runner.date_from == DateManager.str_to_date("2023-03-17") - assert runner.date_until == DateManager.str_to_date("2023-03-31") - - runner = Runner( - spark, - config_path="tests/runner/transformations/config2.yaml", - run_date="2023-03-31", - ) - assert runner.date_from == DateManager.str_to_date("2023-02-24") - assert runner.date_until == DateManager.str_to_date("2023-03-31") - - -def test_completion(spark, mocker, basic_runner): - mocker.patch("rialto.runner.utils.table_exists", return_value=True) - - basic_runner.reader = MockReader(spark) - - dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - dates = [DateManager.str_to_date(d) for d in dates] - - comp = basic_runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates) - expected = [False, True, True, True, False] - assert comp == expected - - -def test_completion_rerun(spark, mocker, basic_runner): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") - runner.reader = MockReader(spark) - - dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - dates = [DateManager.str_to_date(d) for d in dates] - - comp = runner._get_completion(Table(table_path="catalog.schema.simple_group", partition="DATE"), dates) - expected = [False, True, True, True, False] - assert comp == expected - - -def test_completion_secondary_partitions(spark, mocker, basic_runner): - mocker.patch("rialto.runner.utils.table_exists", return_value=True) - - basic_runner.reader = MockReader(spark) - - dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - dates = [DateManager.str_to_date(d) for d in dates] - filters = {"version": 1, "type": "A"} - - comp = basic_runner._get_completion( - Table(table_path="source.schema.multi_part_data", partition="DATE"), dates, filters - ) - expected = [False, True, False, False, False] - assert comp == expected - - -def test_completion_secondary_partitions_no_filter(spark, mocker, basic_runner): - mocker.patch("rialto.runner.utils.table_exists", return_value=True) - - basic_runner.reader = MockReader(spark) - - dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - dates = [DateManager.str_to_date(d) for d in dates] - - comp = basic_runner._get_completion( - Table(table_path="source.schema.multi_part_data", partition="DATE", secondary_partitions=["VERSION"]), dates - ) - expected = [False, False, False, False, False] - assert comp == expected - - -def test_check_dates_have_partition(spark, mocker): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", - ) - runner.reader = MockReader(spark) - dates = ["2023-03-04", "2023-03-05", "2023-03-06"] - dates = [DateManager.str_to_date(d) for d in dates] - res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep1", partition="DATE"), dates) - expected = [False, True, False] - assert res == expected - - -def test_check_dates_have_partition_no_table(spark, mocker): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=False) - - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", - ) - dates = ["2023-03-04", "2023-03-05", "2023-03-06"] - dates = [DateManager.str_to_date(d) for d in dates] - res = runner.check_dates_have_data(Table(schema_path="source.schema", table="dep66", partition="DATE"), dates) - expected = [False, False, False] - assert res == expected - - -@pytest.mark.parametrize( - "r_date, expected", - [("2023-02-26", False), ("2023-03-05", True)], -) -def test_check_dependencies(spark, mocker, r_date, expected): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", - ) - runner.reader = MockReader(spark) - res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date)) - assert res == expected - - -@pytest.mark.parametrize( - "r_date, expected", - [("2023-03-19", True), ("2023-03-18", False)], -) -def test_check_dependencies_filter(spark, mocker, r_date, expected): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner( - spark, - config_path="tests/runner/transformations/config3.yaml", - run_date="2023-03-19", - ) - runner.reader = MockReader(spark) - res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date)) - assert res == expected - - -def test_check_no_dependencies(spark, mocker): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", + run_date="2026-01-01", + merge_schema=True, + overrides={"runner": {"watched_period_units": "days", "watched_period_value": 7}}, ) - runner.reader = MockReader(spark) - res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05")) - assert res is True - - -def test_select_dates(spark, mocker): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) - - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-31", - overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 1}, + engine_cls.assert_called_once_with( + services=built_services, + rerun=True, + skip_dependencies=True, ) - runner.reader = MockReader(spark) - - r, i = runner._select_run_dates( - runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE") + assert runner.op == "SimpleGroup" + + +def test_runner_init_uses_injected_services_without_build(): + spark = MagicMock() + injected_services = MagicMock() + + with patch("rialto.runner.runner.DefaultRunnerServices.build") as build_mock, patch( + "rialto.runner.runner.RunnerEngine" + ) as engine_cls: + Runner( + spark=spark, + config_path="tests/runner/transformations/config.yaml", + services=injected_services, + ) + + build_mock.assert_not_called() + engine_cls.assert_called_once_with( + services=injected_services, + rerun=False, + skip_dependencies=False, ) - expected_run = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - expected_run = [DateManager.str_to_date(d) for d in expected_run] - expected_info = ["2023-03-02", "2023-03-09", "2023-03-16", "2023-03-23"] - expected_info = [DateManager.str_to_date(d) for d in expected_info] - assert r == expected_run - assert i == expected_info -def test_select_dates_all_done(spark, mocker): - mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) +def test_runner_call_delegates_to_engine_run(): + spark = MagicMock() + injected_services = MagicMock() + engine = MagicMock() - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - run_date="2023-03-02", - overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 0}, - ) - runner.reader = MockReader(spark) - - r, i = runner._select_run_dates( - runner.config.pipelines[0], Table(table_path="catalog.schema.simple_group", partition="DATE") - ) - expected_run = [] - expected_run = [DateManager.str_to_date(d) for d in expected_run] - expected_info = [] - expected_info = [DateManager.str_to_date(d) for d in expected_info] - assert r == expected_run - assert i == expected_info - - -def test_op_selected(spark, mocker): - mocker.patch("rialto.runner.reporting.tracker.Tracker.report_by_mail") - run = mocker.patch("rialto.runner.runner.Runner._run_pipeline") - - runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="SimpleGroup") + with patch("rialto.runner.runner.RunnerEngine", return_value=engine): + runner = Runner( + spark=spark, + config_path="tests/runner/transformations/config.yaml", + op="SimpleGroup", + services=injected_services, + ) runner() - run.assert_called_once() - - -def test_op_bad(spark, mocker): - mocker.patch("rialto.runner.reporting.tracker.Tracker.report_by_mail") - mocker.patch("rialto.runner.runner.Runner._run_pipeline") - - runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="BadOp") + engine.run.assert_called_once_with("SimpleGroup") - with pytest.raises(ValueError) as exception: - runner() - assert str(exception.value) == "Unknown operation selected: BadOp" +def test_runner_dry_run_delegates_to_engine_dry_run_execution(): + spark = MagicMock() + injected_services = MagicMock() + engine = MagicMock() -def test_bookkeeping_active(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._run_pipeline") + with patch("rialto.runner.runner.RunnerEngine", return_value=engine): + runner = Runner( + spark=spark, + config_path="tests/runner/transformations/config.yaml", + op="SimpleGroup", + services=injected_services, + ) - runner = Runner(spark, config_path="tests/runner/transformations/config.yaml") - assert runner.config.runner.bookkeeping == "some.test.location" + runner.dry_run() + engine.dry_run_execution.assert_called_once_with("SimpleGroup") -def test_bookkeeping_inactive(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._run_pipeline") +def test_runner_debug_delegates_to_engine_and_returns_dataframe(): + spark = MagicMock() + injected_services = MagicMock() + engine = MagicMock() + debug_df = MagicMock() + engine.debug_first_task.return_value = debug_df - runner = Runner(spark, config_path="tests/runner/transformations/config2.yaml") - assert runner.config.runner.bookkeeping is None + with patch("rialto.runner.runner.RunnerEngine", return_value=engine): + runner = Runner( + spark=spark, + config_path="tests/runner/transformations/config.yaml", + op="SimpleGroup", + services=injected_services, + ) + result = runner._debug() -def test_config_multi_partition(spark, mocker): - runner = Runner(spark, config_path="tests/runner/transformations/config3.yaml") - assert runner.config.pipelines[0].target.secondary_partition_columns == ["VERSION", "ENV"] - assert runner.config.pipelines[0].dependencies[0].filters == {"VERSION": 2, "TYPE": "A"} - assert runner.config.pipelines[0].target.rerun_filters == {"version": 3, "env": "dev"} + engine.debug_first_task.assert_called_once_with("SimpleGroup") + assert result is debug_df diff --git a/tests/runner/test_services.py b/tests/runner/test_services.py new file mode 100644 index 0000000..5ab63ac --- /dev/null +++ b/tests/runner/test_services.py @@ -0,0 +1,129 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# tests/runner/test_services.py +from datetime import date +from unittest.mock import Mock, patch + +from rialto.runner.reporting.tracker import Tracker +from rialto.runner.runner_services import DefaultRunnerServices, RunnerServices +from rialto.runner.services.data_checker import DataChecker +from rialto.runner.services.date_manager import DateManager +from rialto.runner.services.executor import PipelineExecutor +from rialto.runner.services.task_registry import TaskRegistry +from rialto.runner.services.task_status_checker import TaskStatusChecker +from rialto.runner.services.writer import DatabricksWriter + +CONFIG_PATH = "tests/runner/resources/config.yaml" + + +# ── RunnerServices dataclass ────────────────────────────────────────────────── + + +def test_runner_services_stores_all_fields(): + """RunnerServices holds all injected collaborators as attributes""" + config = Mock() + date_manager = Mock() + writer = Mock() + data_checker = Mock() + task_checker = Mock() + registry = Mock() + executor = Mock() + tracker = Mock() + + services = RunnerServices( + config=config, + date_manager=date_manager, + writer=writer, + data_checker=data_checker, + task_checker=task_checker, + registry=registry, + executor=executor, + tracker=tracker, + ) + + assert services.config is config + assert services.date_manager is date_manager + assert services.writer is writer + assert services.data_checker is data_checker + assert services.task_checker is task_checker + assert services.registry is registry + assert services.executor is executor + assert services.tracker is tracker + + +# ── DefaultRunnerServices.build ─────────────────────────────────────────────── + + +def test_build_returns_runner_services_instance(spark): + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + assert isinstance(services, RunnerServices) + + +def test_build_creates_correct_types(spark): + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + assert isinstance(services.date_manager, DateManager) + assert isinstance(services.writer, DatabricksWriter) + assert isinstance(services.data_checker, DataChecker) + assert isinstance(services.task_checker, TaskStatusChecker) + assert isinstance(services.registry, TaskRegistry) + assert isinstance(services.executor, PipelineExecutor) + assert isinstance(services.tracker, Tracker) + + +def test_build_passes_merge_schema_to_writer(spark): + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH, merge_schema=True) + assert services.writer.merge_schema is True + + +def test_build_merge_schema_default_is_false(spark): + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + assert services.writer.merge_schema is False + + +def test_build_shares_data_checker_between_task_checker_and_executor(spark): + """data_checker should be the same instance in task_checker and executor""" + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + assert services.executor.checker is services.data_checker + + +def test_build_config_loads_pipelines(spark): + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + pipeline_names = [p.name for p in services.config.pipelines] + assert "SimpleGroup" in pipeline_names + + +def test_build_with_run_date_passed_to_date_manager(spark): + """run_date is passed to DateManager — verify it doesn't raise and config is loaded""" + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH, run_date="2023-03-31") + assert isinstance(services.date_manager, DateManager) + assert services.date_manager.date_until == date(2023, 3, 31) + + +def test_build_tracker_has_bookkeeper_when_config_has_bookkeeping(spark): + """config.yaml sets bookkeeping: some.test.location — tracker should have a bookkeeper""" + services = DefaultRunnerServices.build(spark=spark, config_path=CONFIG_PATH) + assert services.tracker.bookkeeper is not None + + +def test_build_calls_config_loader_with_overrides(spark): + overrides = {"runner": {"watched_period_units": "days", "watched_period_value": 7}} + with patch("rialto.runner.runner_services.ConfigLoader") as config_loader_cls, patch( + "rialto.runner.runner_services.DateManager", Mock() + ): + DefaultRunnerServices.build( + spark=spark, + config_path=CONFIG_PATH, + overrides=overrides, + ) + config_loader_cls.load_yaml.assert_called_once_with(CONFIG_PATH, overrides) diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py deleted file mode 100644 index f1e4ead..0000000 --- a/tests/runner/test_table.py +++ /dev/null @@ -1,50 +0,0 @@ -from rialto.runner.table import Table - - -def test_table_basic_init(): - t = Table(catalog="cat", schema="sch", table="tab", schema_path=None, table_path=None, class_name=None) - - assert t.get_table_path() == "cat.sch.tab" - assert t.get_schema_path() == "cat.sch" - - -def test_table_classname_init(): - t = Table(catalog=None, schema=None, table=None, schema_path="cat.sch", table_path=None, class_name="ClaSs") - - assert t.get_table_path() == "cat.sch.cla_ss" - assert t.get_schema_path() == "cat.sch" - assert t.catalog == "cat" - assert t.schema == "sch" - assert t.table == "cla_ss" - - -def test_table_path_init(): - t = Table(catalog=None, schema=None, table=None, schema_path=None, table_path="cat.sch.tab", class_name=None) - - assert t.get_table_path() == "cat.sch.tab" - assert t.get_schema_path() == "cat.sch" - assert t.catalog == "cat" - assert t.schema == "sch" - assert t.table == "tab" - - -def test_table_secondary_partitions(): - t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"]) - - assert t.get_all_partitions() == ["part", "sec1", "sec2"] - - -def test_table_get_partitions_only_main(): - t = Table(catalog="cat", schema="sch", table="tab", partition="part") - - assert t.get_all_partitions() == ["part"] - - -def test_table_prioritize_table_name(): - t = Table(catalog=None, schema=None, table="custom", schema_path="cat.sch", table_path=None, class_name="ClaSs") - - assert t.get_table_path() == "cat.sch.custom" - assert t.get_schema_path() == "cat.sch" - assert t.catalog == "cat" - assert t.schema == "sch" - assert t.table == "custom" diff --git a/tests/runner/test_utils.py b/tests/runner/test_utils.py new file mode 100644 index 0000000..4cfd626 --- /dev/null +++ b/tests/runner/test_utils.py @@ -0,0 +1,44 @@ +# Copyright 2022-2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from rialto.runner.utils import find_dependency + + +def _config_with_dependencies(dep_names): + config = MagicMock() + deps = [] + for n in dep_names: + d = MagicMock() + d.name = n + deps.append(d) + config.dependencies = deps + return config + + +def test_find_dependency_returns_matching_dependency(): + config = _config_with_dependencies(["dep_a", "dep_b", "dep_c"]) + + result = find_dependency(config, "dep_b") + + assert result is config.dependencies[1] + + +def test_find_dependency_returns_none_when_not_found(): + config = _config_with_dependencies(["dep_a", "dep_b"]) + + result = find_dependency(config, "missing_dep") + + assert result is None diff --git a/tests/runner/test_writer.py b/tests/runner/test_writer.py deleted file mode 100644 index 5ec9ec6..0000000 --- a/tests/runner/test_writer.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import date - -import pytest - -from rialto.runner.writer import Writer - - -@pytest.fixture -def sample_multi_partition(spark): - df = spark.createDataFrame( - [ - ("REGION_A", 3, date(2023, 1, 1), 100), - ("REGION_A", 3, date(2023, 1, 1), 300), - ], - schema="region string, version int, info_date date, value int", - ) - return df - - -@pytest.fixture -def sample_multi_partition_non_unique(spark): - df = spark.createDataFrame( - [ - ("REGION_A", 1, date(2023, 1, 1), 100), - ("REGION_A", 2, date(2023, 1, 1), 300), - ], - schema="region string, version int, info_date date, value int", - ) - return df - - -def test_replace_condition(sample_multi_partition): - writer = Writer(spark=None) - condition = writer._get_replace_condition(sample_multi_partition, partition_cols=["region", "version", "info_date"]) - expected_condition = "region = 'REGION_A' AND version = 3 AND info_date = '2023-01-01'" - assert condition == expected_condition - - -def test_replace_condition_non_unique(sample_multi_partition_non_unique): - writer = Writer(spark=None) - with pytest.raises(ValueError): - writer._get_replace_condition( - sample_multi_partition_non_unique, partition_cols=["region", "version", "info_date"] - ) diff --git a/uv.lock b/uv.lock index d693ef0..d2512f2 100644 --- a/uv.lock +++ b/uv.lock @@ -1078,7 +1078,7 @@ wheels = [ [[package]] name = "rialto" -version = "2.2.0" +version = "2.2.1" source = { editable = "." } dependencies = [ { name = "delta-spark" },
{record.exception}