diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 6582fdb2..a83f33a8 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -46,15 +46,12 @@ services: - ${CUSTOM_SCRIPT:-./toktagger/api/run.py}:/app/run.py - ~/.sal/:/root/.sal environment: - MONGO_URL: "mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongo:27017" - UDA_HOST: "uda2.mast.l" - UDA_META_PLUGINNAME: "MASTU_DB" - UDA_METANEW_PLUGINNAME: "MAST_DB" - SAL_HOST: "https://sal.jetdata.eu" - MODEL_STORAGE: "/app/data/models" - API_URL: "http://api_app:8002" + DATABASE_MONGO_URL: "mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongo:27017" + MODELS_CACHE_DIR: "/app/data/models" + SERVER_HOST: api_app + SERVER_PORT: 8002 + SERVER_RELOAD: "true" CUSTOM_SCRIPT: ${CUSTOM_SCRIPT} - RELOAD: "true" working_dir: /app command: ["python", "run.py"] networks: diff --git a/docker-compose.yml b/docker-compose.yml index ac750f45..a2dfde29 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,15 +46,12 @@ services: - ${CUSTOM_SCRIPT:-./toktagger/api/run.py}:/app/run.py - ~/.sal/:/root/.sal environment: - MONGO_URL: "mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongo:27017" - UDA_HOST: "uda2.mast.l" - UDA_META_PLUGINNAME: "MASTU_DB" - UDA_METANEW_PLUGINNAME: "MAST_DB" - MODEL_STORAGE: "/app/data/models" - SAL_HOST: "https://sal.jetdata.eu" - API_URL: "http://api_app:8002" + DATABASE_MONGO_URL: "mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongo:27017" + MODELS_CACHE_DIR: "/app/data/models" + SERVER_HOST: api_app + SERVER_PORT: 8002 + SERVER_RELOAD: "false" CUSTOM_SCRIPT: ${CUSTOM_SCRIPT} - RELOAD: "false" working_dir: /app command: ["python", "run.py"] networks: diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 00000000..76a85e6c --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,43 @@ +# Configuration Options +The following options can be configured within TokTagger to improve your experience. They can either be set via a `toktagger.toml` configuration file in your working directory, or via environment variables. Environment variables will take precidence over settings within the TOML file. + +## Server settings +These settings should be defined under the `[server]` heading in the TOML file: + +| Setting | Environment Variable | Type | Default | Description | +|-----------------|-------------------------|--------------|-----------------------------------------|--------------------------------------------------------------------------| +| host | SERVER_HOST | str | localhost | Address of the host to launch TokTagger on. | +| port | SERVER_PORT | int | 8002 | The port to use for the TokTagger Rest API. | +| reload | SERVER_RELOAD | bool | False | Whether to hot reload the TokTagger server on changes to files. | +| cache_dir | SERVER_CACHE_DIR | pathlib.Path | ~/.cache/toktagger | The directory to use for storing entries in the Mongita database. | + +## Database Settings +These settings should be defined under the `[database]` heading in the TOML file: + +| Setting | Environment Variable | Type | Default | Description | +|-----------------|-------------------------|--------------|-----------------------------------------|---------------------------------------------------------------------------------------------| +| mongo_url | DATABASE_MONGO_URL | str | ./toktagger_db | URL of the MongoDB server to connect to as a backend, by default uses local Mongita client. | + +## Models Settings +These settings should be defined under the `[models]` heading in the TOML file: + +| Setting | Environment Variable | Type | Default | Description | +|-----------------|-------------------------|--------------|-----------------------------------------|--------------------------------------------------------------------------| +| cache_dir | MODELS_CACHE_DIR | pathlib.Path | ~/.cache/toktagger/models | The directory to use for storing ML model weights. | +| max_actors | MODELS_MAX_ACTORS | int | 5 | The maximum number of ML models which can be loaded concurrently. | + +## UDA Connection Settings +These settings should be defined under the `[uda]` heading in the TOML file: + +| Setting | Environment Variable | Type | Default | Description | +|--------------------|-------------------------|--------------|-----------------------------------------|--------------------------------------------------------------------------| +| host | UDA_HOST | str | uda2.mast.l | Host name for the UDA server to connect to for MAST data loaders. | +| meta_pluginname | UDA_META_PLUGINNAME | str | MASTU_DB | ??? | +| metanew_pluginname | UDA_METANEW_PLUGINNAME | str | MAST_DB | ??? | + +## SAL Connection Settings +These settings should be defined under the `[sal]` heading in the TOML file: + +| Setting | Environment Variable | Type | Default | Description | +|-----------------|-------------------------|--------------|-----------------------------------------|--------------------------------------------------------------------------| +| host | SAL_HOST | str | https://sal.jetdata.eu | URL for the SAL server to connect to for JET data loaders. | \ No newline at end of file diff --git a/docs/custom_dataloaders.md b/docs/custom_dataloaders.md index d5af4e73..d450d7cc 100644 --- a/docs/custom_dataloaders.md +++ b/docs/custom_dataloaders.md @@ -294,6 +294,31 @@ server.run() Here's an example of loading data from a SQL database: +### Update Config Settings +If your data loader requires configuration inputs from the user, then the `config.Settings` object should be updated to accept this. This takes the form of a [Pydantic Settings object](https://pydantic.dev/docs/validation/latest/concepts/pydantic_settings/#usage), where nested `BaseModels` represent sections inside the `toktagger.toml` configuration file. For example, we can make create a new Settings object which inherits from the one in `toktagger.api.config.py`, and we can add a new `SQL` section where we need the database URL to connect to with our dataloader: +```python +from toktagger.api.config import Settings + +class SQL(pydantic.BaseModel): + url: str | None = pydantic.Field( + None, + description="URL of the SQL database to connect to", + ) +class UpdatedSettings(Settings): + sql: SQL = pydantic.Field(SQL) +``` +Note that this will load settings from the following sources, in the following order: +1. Any values which the `UpdatedSettings` class is initialized with +2. Environment variables, case insensitive, named using the nested model names. Eg for the above setting, it would be `SQL_URL`. +3. Values in the `toktagger.toml` configuration file, with section titles according to nesting. Eg: +```toml +[sql] +url = "sqlite:///./test.db" +``` +4. Environment variables provided in a .env file + +### Create the DataLoader +We can then create our dataloader, accessing the setting we defined above: ```python import sqlalchemy as sa from typing import Type @@ -302,7 +327,7 @@ import pydantic from toktagger.api.core.data_loaders import DataLoader, LoaderRegistry from toktagger.api.schemas.data import MultiVariateTimeSeriesData, TimeSeriesData, DataParams from toktagger.api.schemas.samples import ShotData - +import toktagger.api.config as config @LoaderRegistry.register("sql_database") class SQLDatabaseLoader(DataLoader): @@ -310,10 +335,8 @@ class SQLDatabaseLoader(DataLoader): def __init__(self): # Initialize database connection - # Connection string should be in environment variable - import os - connection_string = os.environ.get("DATABASE_URL") - self.engine = sa.create_engine(connection_string) + # Connection string should be in the settings object + self.engine = sa.create_engine(config.settings.sql.url) @classmethod def sample_data_type(cls) -> Type[ShotData]: @@ -353,6 +376,20 @@ class SQLDatabaseLoader(DataLoader): return MultiVariateTimeSeriesData(values=results) ``` +### Launch the Server +To run the server with our custom Settings object and DataLoader, we should create a run script as follows: +```python title="run.py" +from settings import UpdatedSettings +from loader import CSVTimeSeriesLoader +from toktagger.api.main import Server +import toktagger.api.config as config + +# Update config.settings to use our new object +config.settings = UpdatedSettings() + +server = Server() +server.run() +``` ## Using Docker If you are using the docker compose option to run the server, you can provide a custom script similar to the one above to add your own data loaders. To do this, create a file similar to the one above, but making sure to pass the following arguments into `server.run()`: @@ -364,10 +401,10 @@ server.run( ) ``` -You can then provide the path to your script when running docker compose. For example, say we have the above script in a file called `custom_toktagger.py` - We simply need to add `CUSTOM_SCRIPT=./custom_toktagger.py` before the docker compose command! +You can then provide the path to your script when running docker compose. For example, say we have the above script in a file called `custom_toktagger.py` - We simply need to add `CUSTOM_SCRIPT=./custom_toktagger.py` before the docker compose command, and a SQL URL as an environment variable: ```sh -CUSTOM_SCRIPT=./custom_toktagger.py docker compose --env-file .env.dev -f docker-compose.dev.yml up --build +CUSTOM_SCRIPT=./custom_toktagger.py SQL_URL= docker compose --env-file .env.dev -f docker-compose.dev.yml up --build ``` !!! tip diff --git a/docs/index.md b/docs/index.md index 94be7ac2..81f15d8a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,9 @@ toktagger This will start a local instance of the application running at `http://localhost:8002`. +## Configuration +There are a series of additional options which you can configure to customise the functionality of TokTagger - [find details about these here.](./configuration.md) + ## Project Links - [Git Repo](https://github.com/ukaea/toktagger) diff --git a/pyproject.toml b/pyproject.toml index 6e353aad..7a2d82a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "sal-xarray>=0.2.1", "bump-my-version>=1.2.7", "platformdirs>=4.4.0", + "pydantic-settings>=2.11.0", ] [project.optional-dependencies] models = [ diff --git a/tests/api/core/test_config.py b/tests/api/core/test_config.py new file mode 100644 index 00000000..777da23e --- /dev/null +++ b/tests/api/core/test_config.py @@ -0,0 +1,242 @@ +# tests/test_settings.py + +import pathlib + +import pydantic +import pytest +from pydantic_settings import SettingsConfigDict +import tempfile +from toktagger.api.config import Settings + + +ENV_VARS = [ + "SERVER_HOST", + "SERVER_PORT", + "SERVER_RELOAD", + "SERVER_CACHE_DIR", + "DATABASE_MONGO_URL", + "UDA_HOST", + "UDA_META_PLUGINNAME", + "UDA_METANEW_PLUGINNAME", + "SAL_HOST", + "MODELS_CACHE_DIR", + "MODELS_MAX_ACTORS", +] + + +@pytest.fixture +def setup_test_settings(monkeypatch): + """ + A Settings subclass that reads TOML from a temp file instead of the real + project working directory. + """ + for name in ENV_VARS: + monkeypatch.delenv(name, raising=False) + with tempfile.NamedTemporaryFile(mode="w", prefix=".toml") as tempf: + + class TestSettings(Settings): + model_config = SettingsConfigDict( + toml_file=tempf.name, + env_nested_delimiter="_", + ) + + yield TestSettings, tempf + + +def test_default_settings(setup_test_settings): + TestSettings, _ = setup_test_settings + + settings = TestSettings() + + assert settings.server.host == "localhost" + assert settings.server.port == 8002 + assert settings.server.reload is False + assert isinstance(settings.server.cache_dir, pathlib.Path) + + assert settings.database.mongo_url == "./toktagger_db" + + assert settings.uda.host == "uda2.mast.l" + assert settings.uda.meta_pluginname == "MASTU_DB" + assert settings.uda.metanew_pluginname == "MAST_DB" + + assert settings.sal.host == "https://sal.jetdata.eu" + + assert isinstance(settings.models.cache_dir, pathlib.Path) + assert settings.models.max_actors == 5 + + +def test_env_overrides_simple_nested_fields(monkeypatch, setup_test_settings): + TestSettings, _ = setup_test_settings + + monkeypatch.setenv("SERVER_HOST", "0.0.0.0") + monkeypatch.setenv("SERVER_PORT", "9000") + monkeypatch.setenv("SERVER_RELOAD", "true") + monkeypatch.setenv("UDA_HOST", "uda-test-host") + monkeypatch.setenv("SAL_HOST", "https://sal.example.com") + + settings = TestSettings() + + assert settings.server.host == "0.0.0.0" + assert settings.server.port == 9000 + assert settings.server.reload is True + assert settings.uda.host == "uda-test-host" + assert settings.sal.host == "https://sal.example.com" + + +def test_env_overrides_fields_with_underscores(monkeypatch, setup_test_settings): + TestSettings, _ = setup_test_settings + + monkeypatch.setenv("DATABASE_MONGO_URL", "mongodb://user:pass@mongo:27017") + monkeypatch.setenv("UDA_META_PLUGINNAME", "TEST_META") + monkeypatch.setenv("UDA_METANEW_PLUGINNAME", "TEST_METANEW") + monkeypatch.setenv("MODELS_MAX_ACTORS", "10") + + settings = TestSettings() + + assert settings.database.mongo_url == "mongodb://user:pass@mongo:27017" + assert settings.uda.meta_pluginname == "TEST_META" + assert settings.uda.metanew_pluginname == "TEST_METANEW" + assert settings.models.max_actors == 10 + + +def test_toml_loading(setup_test_settings): + TestSettings, toml_file = setup_test_settings + + toml_file.write( + """ + [server] + host = "127.0.0.1" + port = 9999 + reload = true + cache_dir = "/tmp/toktagger-cache" + + [database] + mongo_url = "mongodb://mongo:27017" + + [uda] + host = "uda.example.com" + meta_pluginname = "CUSTOM_META" + metanew_pluginname = "CUSTOM_METANEW" + + [sal] + host = "https://sal.example.com" + + [models] + cache_dir = "/tmp/toktagger-models" + max_actors = 3 + """ + ) + toml_file.flush() + + settings = TestSettings() + + assert settings.server.host == "127.0.0.1" + assert settings.server.port == 9999 + assert settings.server.reload is True + assert settings.server.cache_dir == pathlib.Path("/tmp/toktagger-cache") + + assert settings.database.mongo_url == "mongodb://mongo:27017" + + assert settings.uda.host == "uda.example.com" + assert settings.uda.meta_pluginname == "CUSTOM_META" + assert settings.uda.metanew_pluginname == "CUSTOM_METANEW" + + assert settings.sal.host == "https://sal.example.com" + + assert settings.models.cache_dir == pathlib.Path("/tmp/toktagger-models") + assert settings.models.max_actors == 3 + + +def test_env_takes_precedence_over_toml(monkeypatch, setup_test_settings): + TestSettings, toml_file = setup_test_settings + + toml_file.write( + """ + [server] + host = "toml-host" + port = 1111 + """ + ) + toml_file.flush() + + monkeypatch.setenv("SERVER_HOST", "env-host") + monkeypatch.setenv("SERVER_PORT", "2222") + + settings = TestSettings() + + assert settings.server.host == "env-host" + assert settings.server.port == 2222 + + +def test_env_and_toml_applied(monkeypatch, setup_test_settings): + TestSettings, toml_file = setup_test_settings + + toml_file.write( + """ + [server] + host = "toml-host" + port = 1111 + """ + ) + toml_file.flush() + + monkeypatch.setenv("SERVER_PORT", "2222") + + settings = TestSettings() + + assert settings.server.host == "toml-host" + assert settings.server.port == 2222 + + +def test_init_kwargs_take_precedence_over_env_and_toml( + monkeypatch, setup_test_settings +): + TestSettings, toml_file = setup_test_settings + + toml_file.write( + """ + [server] + host = "toml-host" + port = 1111 + """ + ) + toml_file.flush() + + monkeypatch.setenv("SERVER_HOST", "env-host") + monkeypatch.setenv("SERVER_PORT", "2222") + + settings = TestSettings( + server={ + "host": "init-host", + "port": 3333, + } + ) + + assert settings.server.host == "init-host" + assert settings.server.port == 3333 + + +def test_invalid_models_max_actors_rejected(setup_test_settings): + TestSettings, _ = setup_test_settings + + with pytest.raises(pydantic.ValidationError): + TestSettings(models={"max_actors": 0}) + + +def test_invalid_server_port_rejected(setup_test_settings): + TestSettings, _ = setup_test_settings + + with pytest.raises(pydantic.ValidationError): + TestSettings(server={"port": "not-a-port"}) + + +def test_path_env_vars_are_converted_to_paths(monkeypatch, setup_test_settings): + TestSettings, _ = setup_test_settings + + monkeypatch.setenv("SERVER_CACHE_DIR", "/tmp/server-cache") + monkeypatch.setenv("MODELS_CACHE_DIR", "/tmp/models-cache") + + settings = TestSettings() + + assert settings.server.cache_dir == pathlib.Path("/tmp/server-cache") + assert settings.models.cache_dir == pathlib.Path("/tmp/models-cache") diff --git a/tests/api/routers/test_models.py b/tests/api/routers/test_models.py index 27f07bde..9adf0be1 100644 --- a/tests/api/routers/test_models.py +++ b/tests/api/routers/test_models.py @@ -1,5 +1,4 @@ import pytest -import pathlib from toktagger.api.schemas.models import ModelUpdate from toktagger.api.models.base import ActorRegistry from toktagger.api.core.sender import ( @@ -10,7 +9,6 @@ import ray from unittest.mock import patch from bson import ObjectId -import os import time @@ -145,11 +143,11 @@ async def test_model_batch_predict_version(api_client, db_client, setup_model_db @pytest.mark.asyncio -async def test_model_predict_missing_weights(api_client, db_client, setup_model_db): +async def test_model_predict_missing_weights( + api_client, db_client, setup_model_db, settings +): # Delete weights - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath( - f"{setup_model_db['model_id_1']}.model" - ).unlink() + settings.models.cache_dir.joinpath(f"{setup_model_db['model_id_1']}.model").unlink() response = await api_client.post( f"/projects/{setup_model_db['project_id']}/models/mock_disruption_cnn/predict?num_predictions=5&version=1" ) @@ -219,7 +217,7 @@ async def test_model_get_sample_prediction(api_client, db_client, setup_model_db # Poll the endpoint until results arrive t = 0 - while t < 10: + while t < 20: get_response = await api_client.get( f"/projects/{setup_model_db['project_id']}/samples/{setup_model_db['sample_ids'][-1]}/models/mock_disruption_cnn/predict/{task_id}" ) @@ -264,7 +262,7 @@ async def test_model_get_sample_prediction_wrong_sample( # Ask for predictions from this task for a sample which we did not predict on t = 0 - while t < 10: + while t < 20: get_response = await api_client.get( f"/projects/{setup_model_db['project_id']}/samples/{setup_model_db['sample_ids'][-2]}/models/mock_disruption_cnn/predict/{task_id}" ) @@ -327,7 +325,9 @@ async def test_model_update(api_client, db_client, setup_model_db): @pytest.mark.asyncio -async def test_model_start_training_no_params(api_client, db_client, setup_model_db): +async def test_model_start_training_no_params( + api_client, db_client, setup_model_db, settings +): response = await api_client.put( f"/projects/{setup_model_db['project_id']}/models/mock_disruption_cnn/train" ) @@ -345,9 +345,7 @@ async def test_model_start_training_no_params(api_client, db_client, setup_model assert model["score"] == 60 # value returned by train method # Check model has been saved after completion - assert ( - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model_id}.model").exists() - ) + assert settings.models.cache_dir.joinpath(f"{model_id}.model").exists() @pytest.mark.asyncio @@ -400,7 +398,9 @@ async def test_model_missing_params(api_client, db_client, setup_model_db, metho @pytest.mark.asyncio -async def test_model_start_training_params(api_client, db_client, setup_model_db): +async def test_model_start_training_params( + api_client, db_client, setup_model_db, settings +): response = await api_client.put( f"/projects/{setup_model_db['project_id']}/models/mock_params_timeseries_cnn/train", json={ @@ -426,14 +426,12 @@ async def test_model_start_training_params(api_client, db_client, setup_model_db assert model["score"] == 50 # value returned from params # Check model has been saved after completion - assert ( - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model_id}.model").exists() - ) + assert settings.models.cache_dir.joinpath(f"{model_id}.model").exists() # Test delete model @pytest.mark.asyncio -async def test_model_delete_type(api_client, db_client, setup_db): +async def test_model_delete_type(api_client, db_client, setup_db, settings): response = await api_client.delete( f"/projects/{setup_db['project_id_1']}/models/mock_disruption_cnn" ) @@ -447,26 +445,20 @@ async def test_model_delete_type(api_client, db_client, setup_db): assert models[0]["type"] != "mock_disruption_cnn" # Check for models 1 and 2, their file no longer exists - assert ( - not pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_1']}.model") - .exists() - ) - assert ( - not pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_2']}.model") - .exists() - ) + assert not settings.models.cache_dir.joinpath( + f"{setup_db['model_id_1']}.model" + ).exists() + assert not settings.models.cache_dir.joinpath( + f"{setup_db['model_id_2']}.model" + ).exists() # And for model 3 it does still exist - assert ( - pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_3']}.model") - .exists() - ) + assert settings.models.cache_dir.joinpath( + f"{setup_db['model_id_3']}.model" + ).exists() @pytest.mark.asyncio -async def test_model_delete_type_version(api_client, db_client, setup_db): +async def test_model_delete_type_version(api_client, db_client, setup_db, settings): response = await api_client.delete( f"/projects/{setup_db['project_id_1']}/models/mock_disruption_cnn?version=2" ) @@ -481,22 +473,16 @@ async def test_model_delete_type_version(api_client, db_client, setup_db): assert models[1]["type"] == "disruption_cnn" # Check for model 2, their file no longer exists - assert ( - not pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_2']}.model") - .exists() - ) + assert not settings.models.cache_dir.joinpath( + f"{setup_db['model_id_2']}.model" + ).exists() # And for models 1 and 3 it does still exist - assert ( - pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_1']}.model") - .exists() - ) - assert ( - pathlib.Path(os.environ["MODEL_STORAGE"]) - .joinpath(f"{setup_db['model_id_3']}.model") - .exists() - ) + assert settings.models.cache_dir.joinpath( + f"{setup_db['model_id_1']}.model" + ).exists() + assert settings.models.cache_dir.joinpath( + f"{setup_db['model_id_3']}.model" + ).exists() @pytest.mark.asyncio diff --git a/tests/conftest.py b/tests/conftest.py index 7c02fffa..9736f9a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,9 @@ import time from pymongo import MongoClient import tempfile -import pathlib import ray import random +import toktagger.api.config as config @pytest.fixture(scope="session") @@ -45,34 +45,46 @@ def mongo_container(): yield mongo.get_connection_url() -@pytest.fixture(scope="package") -def ray_session(): +@pytest.fixture(scope="session") +def settings(mongo_container): with tempfile.TemporaryDirectory(suffix="toktagger_") as tempd: - os.environ["MODEL_STORAGE"] = tempd - # Ray copies the value of the API_URL env var if already set in this local env - # We want it to be blank inside the ray worker nodes, so that it doesn't try to send stuff to API - # Cannot explicitly pass a None, it requires a str:str dict in env_vars - # So will pop the env varvalue, init ray, then restore it - if (api_url := os.environ.get("API_URL")) is not None: - api_url = os.environ.pop("API_URL") - ray.init( - ignore_reinit_error=True, - include_dashboard=False, - runtime_env={"env_vars": {"MODEL_STORAGE": tempd}}, + settings = config.Settings( + server=config.Server(), + models=config.Models(cache_dir=tempd, max_actors=1), + database=config.Database(mongo_url=mongo_container), + uda=config.UDA(), + sal=config.SAL(), ) - if api_url is not None: - os.environ["API_URL"] = api_url + config.settings = settings + yield settings - # Create a ray actor for use as a model registry - WorkerRegistry.options(name="WorkerModelRegistry", lifetime="detached").remote( - ModelRegistry._registry - ) - # And one for use as a dataloader registry - WorkerRegistry.options(name="WorkerLoaderRegistry", lifetime="detached").remote( - LoaderRegistry._registry - ) - yield - ray.shutdown() + +@pytest.fixture(scope="package") +def ray_session(settings): + # Ray copies the value of the API_URL env var if already set in this local env + # We want it to be blank inside the ray worker nodes, so that it doesn't try to send stuff to API + # Cannot explicitly pass a None, it requires a str:str dict in env_vars + # So will pop the env var value, init ray, then restore it + if (api_url := os.environ.get("API_URL")) is not None: + os.environ.pop("API_URL") + ray.init( + ignore_reinit_error=True, + include_dashboard=False, + runtime_env={"env_vars": {"MODEL_STORAGE": str(settings.models.cache_dir)}}, + ) + if api_url is not None: + os.environ["API_URL"] = api_url + + # Create a ray actor for use as a model registry + WorkerRegistry.options(name="WorkerModelRegistry", lifetime="detached").remote( + ModelRegistry._registry + ) + # And one for use as a dataloader registry + WorkerRegistry.options(name="WorkerLoaderRegistry", lifetime="detached").remote( + LoaderRegistry._registry + ) + yield + ray.shutdown() @pytest_asyncio.fixture(scope="function") @@ -95,8 +107,7 @@ async def task_actor(): @pytest_asyncio.fixture(scope="function") -async def api_client(task_actor, mongo_container): - os.environ["MONGO_URL"] = mongo_container +async def api_client(task_actor, settings): # Have hit various issues getting this setup # Using fastAPI TestClient() doesn't play well with async pymongo as it tries to do stuff in different event loops # So have to use this AsyncClient from httpx, but this no longer just accepts an app @@ -104,8 +115,8 @@ async def api_client(task_actor, mongo_container): # So have to run this manually, however trying to run the close after the yield to close the db connection gives errors # So am just going to leave it open, since the db container will be deleted after anyway # Any alternative solution ideas are welcome..... + os.environ["API_URL"] = "http://test" server = Server() - os.environ["API_URL"] = "" server._setup_app() app = server.app lifespan_ctx = app.router.lifespan_context(app) @@ -179,21 +190,21 @@ async def setup_db(db_client): db_definitions.MODEL_1, ids={"project_id": ObjectId(project_id_1)}, ) - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model_id_1}.model").touch() + config.settings.models.cache_dir.joinpath(f"{model_id_1}.model").touch() await asyncio.sleep(0.01) model_id_2 = await db_client.insert( "models", db_definitions.MODEL_2, ids={"project_id": ObjectId(project_id_1)}, ) - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model_id_2}.model").touch() + config.settings.models.cache_dir.joinpath(f"{model_id_2}.model").touch() await asyncio.sleep(0.01) model_id_3 = await db_client.insert( "models", db_definitions.MODEL_3, ids={"project_id": ObjectId(project_id_1)}, ) - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model_id_3}.model").touch() + config.settings.models.cache_dir.joinpath(f"{model_id_3}.model").touch() yield { "project_id_1": project_id_1, "project_id_2": project_id_2, @@ -301,7 +312,7 @@ async def setup_model_db(setup_model_samples, ray_session, db_client): # Create temp files for each for _id in (model_id_1, model_id_2, model_id_4): - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{_id}.model").touch() + config.settings.models.cache_dir.joinpath(f"{_id}.model").touch() yield { "project_id": project_id, @@ -311,17 +322,15 @@ async def setup_model_db(setup_model_samples, ray_session, db_client): } -def run_server(): - # Import to register mock model - +def run_server(settings): + config.settings = settings server = Server() server.run() @pytest.fixture(scope="package") -def start_server(mongo_container): - os.environ["MONGO_URL"] = mongo_container - proc = multiprocessing.Process(target=run_server) +def start_server(settings): + proc = multiprocessing.Process(target=run_server, args=(settings,)) proc.start() # Wait for server to start server_up = False @@ -348,6 +357,6 @@ def start_server(mongo_container): @pytest.fixture(scope="function") def server_setup(start_server): yield - client = MongoClient(os.environ["MONGO_URL"]) + client = MongoClient(config.settings.database.mongo_url) client.drop_database("annotate_db") client.close() diff --git a/toktagger/api/cli.py b/toktagger/api/cli.py index bf901921..cf1b6681 100644 --- a/toktagger/api/cli.py +++ b/toktagger/api/cli.py @@ -1,11 +1,11 @@ import webbrowser import argparse from toktagger.api.main import Server +import toktagger.api.config as config from toktagger.api.models import models_dependencies_installed import uvicorn import time import threading -import os # Need to point to app as a module level string if we want reload option @@ -33,29 +33,38 @@ def main(): """) argparser = argparse.ArgumentParser(description="Run the FastAPI application") - argparser.add_argument("--host", default="localhost", help="Host to run the app on") argparser.add_argument( - "--port", default=8002, type=int, help="Port to run the app on" + "--host", help="Host to run the app on, by default localhost" + ) + argparser.add_argument( + "--port", type=int, help="Port to run the app on, by default 8002" ) argparser.add_argument( "--no-browser", action="store_true", help="Don't open a browser" ) argparser.add_argument( - "--reload", action="store_true", help="Reload the API on changes" + "--reload", + action="store_true", + help="Reload the API on changes, by default False", ) args = argparser.parse_args() open_browser = not args.no_browser if open_browser: threading.Thread(target=do_open_browser, args=(args.host, args.port)).start() - os.environ["API_URL"] = f"http://{args.host}:{args.port}" + if args.host: + config.settings.server.host = args.host + if args.port: + config.settings.server.port = args.port + if args.reload: + config.settings.server.reload = args.reload uvicorn.run( "toktagger.api.cli:create_app", factory=True, - host=args.host, - port=args.port, - reload=args.reload, + host=config.settings.server.host, + port=config.settings.server.port, + reload=config.settings.server.reload, ) diff --git a/toktagger/api/config.py b/toktagger/api/config.py new file mode 100644 index 00000000..ae357fc9 --- /dev/null +++ b/toktagger/api/config.py @@ -0,0 +1,108 @@ +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + TomlConfigSettingsSource, +) +import pydantic +import typing +import pathlib +from platformdirs import user_cache_dir + + +class UDA(pydantic.BaseModel): + host: str = pydantic.Field( + "uda2.mast.l", + description="Host name for the UDA server to connect to for MAST data loaders.", + ) + meta_pluginname: str = pydantic.Field( + "MASTU_DB", + description="???", # TODO whats this? + ) + metanew_pluginname: str = pydantic.Field( + "MAST_DB", + description="???", # TODO whats this? + ) + + +class SAL(pydantic.BaseModel): + host: str = pydantic.Field( + "https://sal.jetdata.eu", + description="URL for the SAL server to connect to for JET data loaders.", + ) + + +class Database(pydantic.BaseModel): + mongo_url: str = pydantic.Field( + "./toktagger_db", + description="URL of the MongoDB server to connect to as a backend. If not set, uses a local mongita client.", + ) + + +class Server(pydantic.BaseModel): + host: str = pydantic.Field( + "localhost", + description="Address of the host to launch TokTagger on.", + ) + port: int = pydantic.Field( + 8002, + description="The port to use for the TokTagger Rest API.", + ) + reload: bool = pydantic.Field( + False, + description="Whether to hot reload the TokTagger server on changes to files.", + ) + cache_dir: pathlib.Path = pydantic.Field( + user_cache_dir("toktagger", "ukaea"), + description="The directory to use for storing entries in the Mongita database, if used.", + validate_default=True, + ) + + +class Models(pydantic.BaseModel): + cache_dir: pathlib.Path = pydantic.Field( + pathlib.Path(user_cache_dir("toktagger", "ukaea")).joinpath("models"), + description="The directory to use for storing ML model weights.", + validate_default=True, + ) + max_actors: typing.Annotated[ + int, + pydantic.Field( + default=5, + description="The maximum number of ML models which can be loaded concurrently.", + gt=0, + ), + ] + + +class Settings(BaseSettings): + server: Server = pydantic.Field(default_factory=Server) + database: Database = pydantic.Field(default_factory=Database) + uda: UDA = pydantic.Field(default_factory=UDA) + sal: SAL = pydantic.Field(default_factory=SAL) + models: Models = pydantic.Field(default_factory=Models) + + model_config = SettingsConfigDict( + toml_file="toktagger.toml", + env_nested_delimiter="_", + env_nested_max_split=1, + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + env_settings, + TomlConfigSettingsSource(settings_cls), + dotenv_settings, + ) + + +settings = Settings() diff --git a/toktagger/api/core/data_loaders.py b/toktagger/api/core/data_loaders.py index e3e44736..7f2a264e 100644 --- a/toktagger/api/core/data_loaders.py +++ b/toktagger/api/core/data_loaders.py @@ -31,18 +31,21 @@ ImageArrayFileData, DataTypes, ) +import toktagger.api.config as config # Set up UDA environment variables with defaults if not already set. This is required for # the pyuda client to work correctly outside of Freia. -os.environ["UDA_HOST"] = os.environ.get("UDA_HOST", "uda2.mast.l") -os.environ["UDA_META_PLUGINNAME"] = os.environ.get("UDA_META_PLUGINNAME", "MASTU_DB") +os.environ["UDA_HOST"] = os.environ.get("UDA_HOST", config.settings.uda.host) +os.environ["UDA_META_PLUGINNAME"] = os.environ.get( + "UDA_META_PLUGINNAME", config.settings.uda.meta_pluginname +) os.environ["UDA_METANEW_PLUGINNAME"] = os.environ.get( - "UDA_METANEW_PLUGINNAME", "MAST_DB" + "UDA_METANEW_PLUGINNAME", config.settings.uda.metanew_pluginname ) # Setup SAL environment variables with defaults if not already set. This is required for # the SAL client to work correctly. -os.environ["SAL_HOST"] = os.environ.get("SAL_HOST", "https://sal.jetdata.eu") +os.environ["SAL_HOST"] = os.environ.get("SAL_HOST", config.settings.sal.host) class DataLoaderError(Exception): diff --git a/toktagger/api/core/sender.py b/toktagger/api/core/sender.py index d8297d88..21e264e0 100644 --- a/toktagger/api/core/sender.py +++ b/toktagger/api/core/sender.py @@ -47,6 +47,7 @@ def send_model_updates( updates : ModelUpdate Updates about the model to be sent - parameters which are unset or None will be ignored """ + # Note that we use env vars here since it is inside a worker node... if (api_url := os.environ.get("API_URL")) is not None: url = f"{api_url}/projects/{project_id}/models/{model_id}" diff --git a/toktagger/api/main.py b/toktagger/api/main.py index 734f18ed..5e2018b0 100644 --- a/toktagger/api/main.py +++ b/toktagger/api/main.py @@ -1,10 +1,10 @@ -import os import pathlib from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import uvicorn +import warnings from toktagger.api.routers.annotations import router as annotations_router from toktagger.api.routers.annotators import router as annotators_router from toktagger.api.routers.data import router as data_router @@ -17,6 +17,7 @@ from toktagger.api.core.data_loaders import LoaderRegistry from toktagger.api.crud.db import MongoDBClient from toktagger.api.models import models_dependencies_installed +import toktagger.api.config as config # Only import large packages if models dependencies installed if models_dependencies_installed(): @@ -30,10 +31,11 @@ @asynccontextmanager async def lifespan(app: FastAPI): - mongo_url = os.environ.get("MONGO_URL", "./toktagger_db") db_name = "annotate_db" - app.state.db_client = MongoDBClient(mongo_url, db_name) + app.state.db_client = MongoDBClient( + str(config.settings.database.mongo_url), db_name + ) app.state.project = None yield @@ -45,14 +47,12 @@ def __init__(self): self.frontend_path = pathlib.Path(__file__).parent / "static" def _setup_ray(self): - if (api_url := os.environ.get("API_URL")) is None: - raise ValueError("API URL must be set!") if not ray.is_initialized(): ray.init( runtime_env={ "env_vars": { - "API_URL": api_url, - "MODEL_STORAGE": os.environ.get("MODEL_STORAGE"), + "API_URL": f"http://{config.settings.server.host}:{config.settings.server.port}", + "MODEL_STORAGE": str(config.settings.models.cache_dir), } }, ) @@ -67,7 +67,7 @@ def _setup_ray(self): # Create a task registry self.app.state.task_registry = ActorRegistry( - max_actors=os.environ.get("MAX_ACTORS", 5) + max_actors=config.settings.models.max_actors ) def _setup_app(self): @@ -104,14 +104,41 @@ def _setup_app(self): self.app.include_router(meta_router) self.app.include_router(base_router) - def run( - self, - host: str = "localhost", - port: int = 8002, - ): - os.environ["API_URL"] = f"http://{host}:{port}" + def run(self, host: str | None = None, port: int | None = None): + """ + Launch the TokTagger server. + + Parameters + ---------- + host : str + DEPRECATED - use config file or environment variables instead. + The host to launch the server on, by default 'localhost' + port : int + DEPRECATED - use config file or environment variables instead. + The port to launch the server on, by default 8002 + """ + # Provide deprecation warning + if host or port: + warnings.warn( + """ + Specifying host and port within Server.run() is deprecated and will be removed in a future version. + Please provide these arguments via configuration file or environment variable instead. + See https://ukaea.github.io/toktagger/configuration for details. + """, + DeprecationWarning, + stacklevel=2, + ) + if host: + config.settings.server.host = host + if port: + config.settings.server.port = port + self._setup_app() # Setup ray if required if models_dependencies_installed(): self._setup_ray() - uvicorn.run(self.app, host=host, port=port) + uvicorn.run( + self.app, + host=config.settings.server.host, + port=config.settings.server.port, + ) diff --git a/toktagger/api/routers/models.py b/toktagger/api/routers/models.py index 6921bb22..903de81b 100644 --- a/toktagger/api/routers/models.py +++ b/toktagger/api/routers/models.py @@ -1,7 +1,5 @@ from fastapi import APIRouter, Request, Depends, Path, Query, Body, HTTPException from fastapi.responses import JSONResponse -import pathlib -import os import random from bson.objectid import ObjectId from toktagger.api.crud import utils @@ -11,6 +9,7 @@ from toktagger.api.models.base import ModelRegistry from pydantic import ValidationError from collections import defaultdict +import toktagger.api.config as config # Only import large packages if models dependencies installed if models_dependencies_installed(): @@ -141,7 +140,7 @@ async def delete_models( ) # And delete file from storage (if it exists - may not if the job failed) - pathlib.Path(os.environ["MODEL_STORAGE"]).joinpath(f"{model.id}.model").unlink( + config.settings.models.cache_dir.joinpath(f"{model.id}.model").unlink( missing_ok=True ) diff --git a/toktagger/api/run.py b/toktagger/api/run.py index 31fa0ae6..ff7a08db 100644 --- a/toktagger/api/run.py +++ b/toktagger/api/run.py @@ -1,13 +1,11 @@ import uvicorn -import os +import toktagger.api.config as config if __name__ == "__main__": - os.environ["API_URL"] = "http://0.0.0.0:8002" - uvicorn.run( "toktagger.api.cli:create_app", factory=True, - host="0.0.0.0", - port=8002, - reload=True if os.environ.get("RELOAD") == "true" else False, + host=config.settings.server.host, + port=config.settings.server.port, + reload=config.settings.server.reload, ) diff --git a/toktagger/api/worker.py b/toktagger/api/worker.py index 95e29e4e..cc91f113 100644 --- a/toktagger/api/worker.py +++ b/toktagger/api/worker.py @@ -27,6 +27,7 @@ models_dir_default.mkdir(parents=True, exist_ok=True) # Set model storage to default path if not already set +# Note that we still use env vars here since this is inside a worker node... os.environ["MODEL_STORAGE"] = os.environ.get("MODEL_STORAGE", str(models_dir_default)) diff --git a/zensical.toml b/zensical.toml index 5c31ccbc..e4f7282e 100644 --- a/zensical.toml +++ b/zensical.toml @@ -50,6 +50,7 @@ Annotators = "annotators.md" Advanced = [ {"Custom Data Loaders" = "custom_dataloaders.md"}, {"Custom Models" = "custom_models.md"}, + {"Configuration Options" = "configuration.md"} ] [[project.nav]]