diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52d9ab137634ffb42324f41514882527736f3223 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.11.2 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b0078edcbd7495eb00660d99c3fbdd4b8bd566ea --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[tool.ruff] +line-length = 88 + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 +max-line-length = 88 + + +[tool.ruff.lint] +extend-select = ["I", "W505"] diff --git a/tensorflow/src/train_synthetic.py b/tensorflow/src/train_synthetic.py index ce1b075f1a64e3dda2ea0ba7e55ccb0ca71b5df0..f725e060ee81be6fe587070ff3a3a4532292864e 100644 --- a/tensorflow/src/train_synthetic.py +++ b/tensorflow/src/train_synthetic.py @@ -5,19 +5,14 @@ import os from contextlib import nullcontext from dataclasses import dataclass, field from pathlib import Path +from time import perf_counter from typing import Any, Dict, Optional, Union import click -import mlflow -from mlflow_utils import ( - MLflowMetricsCallback, - MlflowTimingCallback, - TimingCallback, - mlflow_log_sbatch_logs, - mlflow_log_sbatch_scripts, -) +import pandas as pd import tensorflow as tf + logger = logging.getLogger(__name__) @@ -106,11 +101,11 @@ class TimingCallback(tf.keras.callbacks.Callback): ) - def set_seed(seed: int = 5): import random + import numpy as np - + if not isinstance(seed, int): raise ValueError( "Expected `seed` argument to be an integer. " @@ -122,7 +117,6 @@ def set_seed(seed: int = 5): tf.random.set_seed(seed) - class NullStrategy: @staticmethod def scope(): @@ -153,11 +147,11 @@ class TrainingStrategy: communication_type: Union[ str, tf.distribute.experimental.CommunicationImplementation ] = field(init=False) - - cross_device_communication_type: Union[ - str, tf.distribute.CrossDeviceOps - ] = field(init=False) - + + cross_device_communication_type: Union[str, tf.distribute.CrossDeviceOps] = field( + init=False + ) + communication_options: Optional[tf.distribute.experimental.CommunicationOptions] = ( field(default=None, init=False) ) @@ -224,9 +218,12 @@ class TrainingStrategy: def _use_single_node_multi_gpu_strategy(self) -> None: self.strategy_type = "MirroredStrategy" self.communication_type = NullCommunication() - self.cross_device_communication_type = self._get_cross_device_ops_implementation(self.device_type) - self.strategy = tf.distribute.MirroredStrategy(cross_device_ops = self.cross_device_communication_type) - + self.cross_device_communication_type = ( + self._get_cross_device_ops_implementation(self.device_type) + ) + self.strategy = tf.distribute.MirroredStrategy( + cross_device_ops=self.cross_device_communication_type + ) def _use_multi_node_strategy(self) -> None: self.cross_device_communication_type = NullCommunication() @@ -242,7 +239,7 @@ class TrainingStrategy: self.strategy = tf.distribute.MultiWorkerMirroredStrategy( communication_options=self.communication_options ) - + def _get_cross_device_ops_implementation(self, device_type: str): """Map device type to appropriate communication implementation.""" if device_type == "NVIDIA": @@ -266,12 +263,13 @@ class TrainingStrategy: ) def _log_strategy_params(self) -> None: - """Log key strategy configuration to MLflow.""" + """Log key strategy configuration""" logger.info(f"num_replicas_in_sync = {self.strategy.num_replicas_in_sync}") logger.info(f"strategy_type = {self.strategy_type}") logger.info(f"communication_type = {self.communication_type.name}") - logger.info(f"cross_device_communication_type = {type(self.cross_device_communication_type)}") - + logger.info( + f"cross_device_communication_type = {type(self.cross_device_communication_type)}" + ) @dataclass @@ -345,7 +343,6 @@ class SYNTH_classifier: return model def prepare_dataset(self): - @tf.function def gen_fn(_): image = tf.random.uniform([224, 224, 3]) @@ -360,10 +357,9 @@ class SYNTH_classifier: return dataset def train(self): - self.train_dataset = self.prepare_dataset() logger.info(f"train_dataset: {type(self.train_dataset)}") - + # Define distributed strategy if self.opts.distributed: self.train_dataset = ( @@ -372,7 +368,8 @@ class SYNTH_classifier: ) ) - # Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of distributed training, or just NullStrategy instead. + # Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of + # distributed training, or just NullStrategy instead. with self.opts.training_strategy.strategy.scope(): self.model = self.get_compiled_model() @@ -390,7 +387,7 @@ class SYNTH_classifier: TimingCallback( batch_size=self.opts.global_batch_size, log_freq=self.opts.timing_log_freq, - rank=int(os.environ["RANK"]), + rank=int(os.environ.get("RANK", 0)), num_warmup_batches=self.opts.timing_warmup_batches, ), ] @@ -401,23 +398,30 @@ class SYNTH_classifier: test_loss, test_acc = self.model.evaluate(test_dataset, verbose=0) return test_loss, test_acc + @click.group() def cli(): pass -@click.command(no_args_is_help=True) -@click.option("--run_cfg", type=click.Path(exists=True)) -@click.option("--batch_size_per_device", type=int, default=None) + +@cli.command(no_args_is_help=False) +@click.option("--batch_size_per_device", type=int, default=256) +@click.option("--run_cfg", type=click.Path(exists=True), default=None) def train( - run_cfg, batch_size_per_device, + run_cfg, ): - training_options = TrainingOptions.from_yaml( - cfg_path=run_cfg, - cli_kwargs=dict( + if run_cfg is not None: + training_options = TrainingOptions.from_yaml( + cfg_path=run_cfg, + cli_kwargs=dict( + batch_size_per_device=batch_size_per_device, + ), + ) + else: + training_options = TrainingOptions( batch_size_per_device=batch_size_per_device, - ), - ) + ) set_seed(training_options.seed) mnist_classifier = SYNTH_classifier(opts=training_options) @@ -432,5 +436,5 @@ if __name__ == "__main__": urllib3_logger.setLevel(logging.WARNING) simple_parsing_logger = logging.getLogger("simple_parsing") simple_parsing_logger.setLevel(logging.INFO) - + cli()