diff --git a/README.md b/README.md index 5e5433cc182cf0dbf053da2a283f58f4f2b7a021..e4c6b8582616068e92950011a86a94bc11b68e59 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ import slurm_sweeps as ss def train(cfg: dict): for epoch in range(cfg["epochs"]): sleep(0.5) - loss = (cfg["parameter"] - 1) ** 2 * epoch + loss = (cfg["parameter"] - 1) ** 2 / (epoch + 1) # log your metrics ss.log({"loss": loss}, epoch) @@ -66,10 +66,10 @@ experiment = ss.Experiment( # Run your experiment -dataframe = experiment.run(n_trials=1000) +result = experiment.run(n_trials=1000) -# Your results are stored in a pandas DataFrame -print(f"\nBest trial:\n{dataframe.sort_values('loss').iloc[0]}") +# Show the best performing trial +print(result.best_trial()) ``` Or submit it to a SLURM cluster. @@ -251,6 +251,110 @@ A configuration class for the SlurmBackend. - `ntasks` - How many tasks do you request for your srun? - `args` - Additional command line arguments for srun, formatted as a string. +### CLASS `slurm_sweeps.Result` + +```python +class Result( + experiment: str, + local_dir: Union[str, Path] = "./slurm-sweeps", +) +``` + +The result of an experiment. + +**Arguments**: + +- `experiment` - The name of the experiment. +- `local_dir` - The directory where we find the `slurm-sweeps.db` database. + +#### `Result.experiment` + +```python +@property +def experiment() -> str +``` + +The name of the experiment. + +#### `Result.trials` + +```python +@property +def trials() -> List[Trial] +``` + +A list of the trials of the experiment. + +#### `Result.best_trial` + +```python +def best_trial( + metric: Optional[str] = None, + mode: Optional[str] = None +) -> Trial +``` + +Get the best performing trial of the experiment. + +**Arguments**: + +- `metric` - The metric. By default, we take the one defined by ASHA. +- `mode` - The mode of the metric, either 'min' or 'max'. By default, we take the one defined by ASHA. + +**Returns**: + + The best trial. + +### CLASS `slurm_sweeps.trial.Trial` + +```python +@dataclass +class Trial: + cfg: Dict + process: Optional[subprocess.Popen] = None + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + status: Optional[Union[str, Status]] = None + metrics: Optional[Dict[str, Dict[int, Union[int, float]]]] = None +``` + +A trial of an experiment. + +**Arguments**: + +- `cfg` - The config of the trial. +- `process` - The subprocess that runs the trial. +- `start_time` - The start time of the trial. +- `end_time` - The end time of the trial. +- `status` - Status of the trial. If `process` is not None, we will always query the process for the status. +- `metrics` - Logged metrics of the trial. + +#### `Trial.trial_id` + +```python +@property +def trial_id() -> str +``` + +The trial ID is a 6-digit hash from the config. + +#### `Trial.runtime` + +```python +@property +def runtime() -> Optional[timedelta] +``` + +The runtime of the trial. + +#### `Trial.is_terminated` + +```python +def is_terminated() -> bool +``` + +Return True, if the trial has been completed or pruned. + ### FUNCTION `slurm_sweeps.log` ```python diff --git a/src/slurm_sweeps/__init__.py b/src/slurm_sweeps/__init__.py index e239ed91f24c230b12fa695004db9d03fee9fc7c..dd095d4da1f556df7f679b721202ad1c9d36da94 100644 --- a/src/slurm_sweeps/__init__.py +++ b/src/slurm_sweeps/__init__.py @@ -2,7 +2,7 @@ import logging as _logging from .asha import ASHA from .backends import SlurmCfg -from .experiment import Experiment +from .experiment import Experiment, Result from .logger import log from .sampler import Choice, Grid, LogUniform, Uniform diff --git a/src/slurm_sweeps/backends.py b/src/slurm_sweeps/backends.py index 706a40816398b09ae8557420b35dfa42e716e8c7..83cd6eb473ece879aa2407a22ecc91772dce284a 100644 --- a/src/slurm_sweeps/backends.py +++ b/src/slurm_sweeps/backends.py @@ -9,7 +9,7 @@ from typing import Optional import yaml from .constants import CFG_YML, DB_TRAIN, TRAIN_PY -from .database import SqlDatabase +from .database import Database from .trial import Trial @@ -38,7 +38,7 @@ class Backend: database: The database of the experiment. """ - def __init__(self, execution_dir: Path, database: SqlDatabase): + def __init__(self, execution_dir: Path, database: Database): self._execution_dir = execution_dir self._database = database @@ -78,10 +78,10 @@ class Backend: """The python script that executes the train function.""" template = dedent( f"""\ - from slurm_sweeps.database import SqlDatabase + from slurm_sweeps.database import Database from slurm_sweeps.logger import Logger - database = SqlDatabase("{self._database.experiment}", "{self._database.path}") + database = Database("{self._database.experiment}", "{self._database.path}") train = database.load("{DB_TRAIN}") trial = database.read_trials(trial_id="{trial_id}")[0] @@ -110,7 +110,7 @@ class SlurmBackend(Backend): """ def __init__( - self, execution_dir: Path, database: SqlDatabase, cfg: Optional[SlurmCfg] = None + self, execution_dir: Path, database: Database, cfg: Optional[SlurmCfg] = None ): super().__init__(execution_dir=execution_dir, database=database) diff --git a/src/slurm_sweeps/constants.py b/src/slurm_sweeps/constants.py index 1ead9958d22048487ce241e42541c43b6f53b5d1..b7958f369e879217ba4abf1b07cc1fd096fd8d80 100644 --- a/src/slurm_sweeps/constants.py +++ b/src/slurm_sweeps/constants.py @@ -1,8 +1,3 @@ -# env variables -DB_PATH = "SLURMSWEEPS_DB_PATH" -EXPERIMENT_NAME = "SLURMSWEEPS_EXPERIMENT_NAME" -TRIAL_ID = "SLURMSWEEPS_TRIAL_ID" - # DB tables DB_TRIALS = "_trials" DB_METRICS = "_metrics" diff --git a/src/slurm_sweeps/database.py b/src/slurm_sweeps/database.py index 3f0ebd04783b42e7a4d894cd7cdff516b00b6317..de58f5a0fd34875c3d116c0b8e1ebdd851ba1123 100644 --- a/src/slurm_sweeps/database.py +++ b/src/slurm_sweeps/database.py @@ -28,7 +28,7 @@ from .constants import ( from .trial import Trial -class SqlDatabase: +class Database: """An SQLite database that stores the trials of an experiment and their metrics. It also serves as a storage for pickled objects. @@ -38,7 +38,7 @@ class SqlDatabase: path: The path to the database file. """ - def __init__(self, experiment: str, path: Union[str, Path] = "./slurm_sweeps.db"): + def __init__(self, experiment: str, path: Union[str, Path] = "./slurm-sweeps.db"): self._experiment = experiment self._path = Path(path).resolve() @@ -275,6 +275,21 @@ class SqlDatabase: df = df.replace([None], float("nan")) return df + def get_logged_metrics(self) -> List[str]: + """Returns the names of the logged metrics.""" + with self._connection() as con: + response = con.execute( + f"pragma table_info({self.experiment}{DB_METRICS})" + ).fetchall() + + metrics = [ + col[1].replace(DB_METRIC, "", 1) + for col in response + if (col[1].startswith(DB_METRIC) and not col[1].endswith(DB_LOGGED)) + ] + + return metrics + def dump(self, data: Dict[str, Any]): """Pickles and dumps the data to the storage table. @@ -307,6 +322,8 @@ class SqlDatabase: f"where {DB_EXPERIMENT}='{self.experiment}' and {DB_OBJECT_NAME}='{name}'" ).fetchone() + if response is None: + return None return cloudpickle.loads(response[0]) diff --git a/src/slurm_sweeps/experiment.py b/src/slurm_sweeps/experiment.py index 70c4e8bbe256a59c338d2657e953270fde38ca1f..ad295407c7015ea1d0c8d7932c725f59f506586d 100644 --- a/src/slurm_sweeps/experiment.py +++ b/src/slurm_sweeps/experiment.py @@ -7,6 +7,7 @@ from pathlib import Path from textwrap import dedent from typing import Callable, Dict, List, Optional, Union +import numpy as np import pandas as pd from .asha import ASHA @@ -20,13 +21,107 @@ from .constants import ( DB_TRIAL_ID, WAITING_TIME_IN_SEC, ) -from .database import ExperimentExistsError, ExperimentNotFoundError, SqlDatabase +from .database import Database, ExperimentExistsError, ExperimentNotFoundError from .sampler import Sampler from .trial import Status, Trial _logger = logging.getLogger(__name__) +class Result: + """The result of an experiment. + + Args: + experiment: The name of the experiment. + local_dir: The directory where we find the `slurm-sweeps.db` database. + """ + + def __init__( + self, + experiment: str, + local_dir: Union[str, Path] = "./slurm-sweeps", + ): + self._experiment = experiment + self._database = Database(experiment, Path(local_dir) / "slurm-sweeps.db") + if not self._database.exists(): + raise ExperimentNotFoundError(experiment) + + @property + def experiment(self) -> str: + """The name of the experiment.""" + return self._experiment + + @property + def trials(self) -> List[Trial]: + """A list of the trials of the experiment.""" + trials = self._database.read_trials() + + metrics = self._database.get_logged_metrics() + for metric in metrics: + metric_df = self._database.read_metrics(metric) + metric_by_trial = metric_df.groupby(DB_TRIAL_ID) + for trial in trials: + df = metric_by_trial.get_group(trial.trial_id).sort_values(DB_ITERATION) + trial_metrics = { + metric: { + row[1][DB_ITERATION]: row[1][f"{DB_METRIC}{metric}"] + for row in df.iterrows() + } + } + if trial.metrics is None: + trial.metrics = trial_metrics + else: + trial.metrics.update(trial_metrics) + + return trials + + def best_trial( + self, metric: Optional[str] = None, mode: Optional[str] = None + ) -> Trial: + """Get the best performing trial of the experiment. + + Args: + metric: The metric. By default, we take the one defined by ASHA. + mode: The mode of the metric, either 'min' or 'max'. By default, we take the one defined by ASHA. + + Returns: + The best trial. + """ + asha: Optional[ASHA] = None + if metric is None or mode is None: + asha = self._database.load(DB_ASHA) + + if (metric is None or mode is None) and asha is None: + raise ValueError( + "ASHA was not defined, and you did not specify a `metric` or `mode`. " + "Please specify both of them." + ) + + metric = metric or asha.metric + mode = mode or asha.mode + if mode not in ("min", "max"): + raise ValueError(f"`mode` has to be either 'min' or 'max', but is '{mode}'") + + trials = self.trials + + trial, best_metric = None, np.inf if mode == "min" else -np.inf + for trial in trials: + metrics = np.array(list(trial.metrics.get(metric, {}).values())) + if metrics.size == 0: + continue + if mode == "min" and metrics.min() < best_metric: + best_metric, best_trial = metrics.min(), trial + elif mode == "max" and metrics.max() > best_metric: + best_metric, best_trial = metrics.max(), trial + + if trial is None: + raise ValueError( + f"None of the trials contain the metric '{metric}', cannot determine best trial." + ) + + return trial + + class Experiment: """Set up an HPO experiment. @@ -61,7 +156,7 @@ class Experiment: self._create_experiment_dir(self._local_dir / name, restore, overwrite) - self._database = SqlDatabase(self.name, self.local_dir / "slurm_sweeps.db") + self._database = Database(self.name, self.local_dir / "slurm-sweeps.db") if not restore: self._database.create(overwrite=overwrite) elif not self._database.exists(): @@ -116,7 +211,7 @@ class Experiment: summary_interval_in_sec: float = 5.0, nr_of_rows_in_summary: int = 10, summarize_cfg_and_metrics: Union[bool, List[str]] = True, - ) -> pd.DataFrame: + ) -> Result: """Run the experiment. Args: @@ -172,7 +267,9 @@ class Experiment: # print current summary if (time.time() - time_of_last_summary) > summary_interval_in_sec: self._print_summary( - running_trials + scheduled_trials + terminated_trials, + list( + reversed(scheduled_trials + terminated_trials + running_trials) + ), n_rows=nr_of_rows_in_summary, summarize_cfg_and_metrics=summarize_cfg_and_metrics, ) @@ -183,14 +280,14 @@ class Experiment: time.sleep(WAITING_TIME_IN_SEC) # print final summary - summary_df = self._print_summary( + self._print_summary( terminated_trials, n_rows=None, summarize_cfg_and_metrics=summarize_cfg_and_metrics, sort_by="RUNTIME", ) - return summary_df + return Result(self.name, self.local_dir) def _run_trial(self, trials: List[Trial], trial_nr: int) -> Trial: trial = trials[trial_nr] diff --git a/src/slurm_sweeps/logger.py b/src/slurm_sweeps/logger.py index 2df813739291f9f16d3a924d6ff80e722439097c..5a6f8205dfd2346f7df9d15a37c676d5fcd06296 100644 --- a/src/slurm_sweeps/logger.py +++ b/src/slurm_sweeps/logger.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, Union from .asha import ASHA from .constants import DB_ASHA -from .database import SqlDatabase +from .database import Database class Logger: @@ -17,7 +17,7 @@ class Logger: instance: Optional["Logger"] = None - def __init__(self, trial_id: str, database: SqlDatabase): + def __init__(self, trial_id: str, database: Database): self._trial_id = trial_id self._database = database diff --git a/src/slurm_sweeps/trial.py b/src/slurm_sweeps/trial.py index 0686ba7fd4b63467c8d10f13ff9c95bfcedb488c..272fda8894afe9fbff94883d8844a260ebc0cc10 100644 --- a/src/slurm_sweeps/trial.py +++ b/src/slurm_sweeps/trial.py @@ -33,6 +33,7 @@ class Trial: start_time: The start time of the trial. end_time: The end time of the trial. status: Status of the trial. If `process` is not None, we will always query the process for the status. + metrics: Logged metrics of the trial. """ cfg: Dict @@ -41,6 +42,7 @@ class Trial: end_time: Optional[datetime] = None status: Optional[Union[str, Status]] = None _status: Optional[Status] = field(init=False, repr=False) + metrics: Optional[Dict[str, Dict[int, Union[int, float]]]] = None @property def trial_id(self) -> str: @@ -72,7 +74,7 @@ class Trial: @property def runtime(self) -> Optional[timedelta]: """The runtime of the trial.""" - if self.end_time is not None: + if self.end_time is not None and self.start_time is not None: return self.end_time - self.start_time if self.start_time is not None: return datetime.now() - self.start_time diff --git a/tests/test_database.py b/tests/test_database.py index 48d963eeff1cb6b37d401c0601ef7f6d6d18e94a..73481358d7e3b77ea06dba7a9cfa3c91c72f0d97 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -28,9 +28,9 @@ from slurm_sweeps.constants import ( DB_TRIALS, ) from slurm_sweeps.database import ( + Database, ExperimentExistsError, ExperimentNotFoundError, - SqlDatabase, ) from slurm_sweeps.trial import Trial @@ -45,9 +45,9 @@ def connection(path: Path): @pytest.fixture -def database(tmp_path) -> SqlDatabase: +def database(tmp_path) -> Database: db_path = tmp_path / "slurm_sweeps.db" - return SqlDatabase(experiment="test_experiment", path=db_path) + return Database(experiment="test_experiment", path=db_path) def test_init(database): @@ -79,6 +79,7 @@ def test_dump_load(database): assert response[0] == 2 assert database.load("a") == {} assert database.load("b") is None + assert database.load("c") is None def test_exists(database): @@ -248,7 +249,7 @@ def test_write_read_metrics(database): ) -def read_or_write(mode: str, database: SqlDatabase): +def read_or_write(mode: str, database: Database): if mode == "w": database.write_metrics(trial_id="test", iteration=0, metrics={"loss": 0.9}) else: @@ -284,8 +285,9 @@ def test_nan_values(database): assert np.isnan(df[f"{DB_METRIC}loss"].iloc[0]) -@pytest.mark.skip("Only for speed comparisons") +@pytest.mark.skip("Only for speed comparisons (OUTDATED!)") def test_speed(monkeypatch, database): + # TODO: Update! from slurm_sweeps.constants import DB_PATH, EXPERIMENT_NAME from slurm_sweeps.logger import Logger diff --git a/tests/test_readme.py b/tests/test_readme.py index 2a1741aef3a359a6efc830a2697481dc36132fa6..760b3b307a82d9694d51b89e2c01faadbd72cc3e 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -3,8 +3,7 @@ from time import sleep import pytest -from slurm_sweeps.constants import DB_ITERATION -from slurm_sweeps.database import SqlDatabase +from slurm_sweeps import Result def is_slurm_available() -> bool: @@ -25,7 +24,7 @@ def test_readme_example_on_local(tmp_path): def train(cfg: dict): for epoch in range(cfg["epochs"]): sleep(0.5) - loss = (cfg["parameter"] - 1) ** 2 * epoch + loss = (cfg["parameter"] - 1) ** 2 / (epoch + 1) # log your metrics ss.log({"loss": loss}, epoch) @@ -41,13 +40,13 @@ def test_readme_example_on_local(tmp_path): ) # Run your experiment - dataframe = experiment.run(n_trials=10) + result = experiment.run(n_trials=10) - # Your results are stored in a pandas DataFrame - print(f"\nBest trial:\n{dataframe.sort_values('loss').iloc[0]}") + assert len(result.trials) == 10 - assert len(dataframe) == 10 - assert dataframe["ITERATION"].sort_values().iloc[-1] == 9 + trial = result.best_trial() + + assert trial.metrics["loss"][9] < 0.05 @pytest.mark.skipif(not is_slurm_available(), reason="requires a SLURM cluster") @@ -104,11 +103,16 @@ python train.py # check output job_out = subprocess.check_output(["cat", "slurm-3.out"], cwd=tmp_path) - dataframe = SqlDatabase("MySweep", local_dir / "slurm_sweeps.db").read_metrics() + result = Result("MySweep", local_dir) # Relax until issue with slurm GitHub action is fixed: https://github.com/koesterlab/setup-slurm-action/issues/4 assert ("max number of concurrent trials: 2" in job_out.decode()[50:]) or ( "max number of concurrent trials: 4" in job_out.decode()[50:] ) - assert len(dataframe) > 10 - assert dataframe[DB_ITERATION].sort_values().iloc[-1] == 9 + assert len(result.trials) == 10 + + trial = result.best_trial() + print(result.trials) + print(trial) + + assert trial.metrics["loss"][9] < 0.05