diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52d9ab137634ffb42324f41514882527736f3223
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,16 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+  rev: v4.6.0
+  hooks:
+    - id: check-yaml
+    - id: end-of-file-fixer
+    - id: trailing-whitespace
+- repo: https://github.com/astral-sh/ruff-pre-commit
+  # Ruff version.
+  rev: v0.11.2
+  hooks:
+    # Run the linter.
+    - id: ruff
+      args: [ --fix ]
+    # Run the formatter.
+    - id: ruff-format
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b0078edcbd7495eb00660d99c3fbdd4b8bd566ea
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,10 @@
+[tool.ruff]
+line-length = 88
+
+[tool.ruff.lint.pycodestyle]
+max-doc-length = 88
+max-line-length = 88
+
+
+[tool.ruff.lint]
+extend-select = ["I", "W505"]
diff --git a/tensorflow/src/train_synthetic.py b/tensorflow/src/train_synthetic.py
index ce1b075f1a64e3dda2ea0ba7e55ccb0ca71b5df0..f725e060ee81be6fe587070ff3a3a4532292864e 100644
--- a/tensorflow/src/train_synthetic.py
+++ b/tensorflow/src/train_synthetic.py
@@ -5,19 +5,14 @@ import os
 from contextlib import nullcontext
 from dataclasses import dataclass, field
 from pathlib import Path
+from time import perf_counter
 from typing import Any, Dict, Optional, Union
 
 import click
-import mlflow
-from mlflow_utils import (
-    MLflowMetricsCallback,
-    MlflowTimingCallback,
-    TimingCallback,
-    mlflow_log_sbatch_logs,
-    mlflow_log_sbatch_scripts,
-)
+import pandas as pd
 
 import tensorflow as tf
+
 logger = logging.getLogger(__name__)
 
 
@@ -106,11 +101,11 @@ class TimingCallback(tf.keras.callbacks.Callback):
             )
 
 
-
 def set_seed(seed: int = 5):
     import random
+
     import numpy as np
-    
+
     if not isinstance(seed, int):
         raise ValueError(
             "Expected `seed` argument to be an integer. "
@@ -122,7 +117,6 @@ def set_seed(seed: int = 5):
     tf.random.set_seed(seed)
 
 
-
 class NullStrategy:
     @staticmethod
     def scope():
@@ -153,11 +147,11 @@ class TrainingStrategy:
     communication_type: Union[
         str, tf.distribute.experimental.CommunicationImplementation
     ] = field(init=False)
-    
-    cross_device_communication_type: Union[
-        str, tf.distribute.CrossDeviceOps
-    ] = field(init=False)
-    
+
+    cross_device_communication_type: Union[str, tf.distribute.CrossDeviceOps] = field(
+        init=False
+    )
+
     communication_options: Optional[tf.distribute.experimental.CommunicationOptions] = (
         field(default=None, init=False)
     )
@@ -224,9 +218,12 @@ class TrainingStrategy:
     def _use_single_node_multi_gpu_strategy(self) -> None:
         self.strategy_type = "MirroredStrategy"
         self.communication_type = NullCommunication()
-        self.cross_device_communication_type =  self._get_cross_device_ops_implementation(self.device_type)
-        self.strategy = tf.distribute.MirroredStrategy(cross_device_ops = self.cross_device_communication_type)
-        
+        self.cross_device_communication_type = (
+            self._get_cross_device_ops_implementation(self.device_type)
+        )
+        self.strategy = tf.distribute.MirroredStrategy(
+            cross_device_ops=self.cross_device_communication_type
+        )
 
     def _use_multi_node_strategy(self) -> None:
         self.cross_device_communication_type = NullCommunication()
@@ -242,7 +239,7 @@ class TrainingStrategy:
         self.strategy = tf.distribute.MultiWorkerMirroredStrategy(
             communication_options=self.communication_options
         )
-    
+
     def _get_cross_device_ops_implementation(self, device_type: str):
         """Map device type to appropriate communication implementation."""
         if device_type == "NVIDIA":
@@ -266,12 +263,13 @@ class TrainingStrategy:
             )
 
     def _log_strategy_params(self) -> None:
-        """Log key strategy configuration to MLflow."""
+        """Log key strategy configuration"""
         logger.info(f"num_replicas_in_sync = {self.strategy.num_replicas_in_sync}")
         logger.info(f"strategy_type = {self.strategy_type}")
         logger.info(f"communication_type = {self.communication_type.name}")
-        logger.info(f"cross_device_communication_type = {type(self.cross_device_communication_type)}")
-        
+        logger.info(
+            f"cross_device_communication_type = {type(self.cross_device_communication_type)}"
+        )
 
 
 @dataclass
@@ -345,7 +343,6 @@ class SYNTH_classifier:
         return model
 
     def prepare_dataset(self):
-        
         @tf.function
         def gen_fn(_):
             image = tf.random.uniform([224, 224, 3])
@@ -360,10 +357,9 @@ class SYNTH_classifier:
         return dataset
 
     def train(self):
-        
         self.train_dataset = self.prepare_dataset()
         logger.info(f"train_dataset: {type(self.train_dataset)}")
-        
+
         # Define distributed strategy
         if self.opts.distributed:
             self.train_dataset = (
@@ -372,7 +368,8 @@ class SYNTH_classifier:
                 )
             )
 
-        # Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of distributed training, or just NullStrategy instead.
+        # Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of
+        # distributed training, or just NullStrategy instead.
         with self.opts.training_strategy.strategy.scope():
             self.model = self.get_compiled_model()
 
@@ -390,7 +387,7 @@ class SYNTH_classifier:
             TimingCallback(
                 batch_size=self.opts.global_batch_size,
                 log_freq=self.opts.timing_log_freq,
-                rank=int(os.environ["RANK"]),
+                rank=int(os.environ.get("RANK", 0)),
                 num_warmup_batches=self.opts.timing_warmup_batches,
             ),
         ]
@@ -401,23 +398,30 @@ class SYNTH_classifier:
         test_loss, test_acc = self.model.evaluate(test_dataset, verbose=0)
         return test_loss, test_acc
 
+
 @click.group()
 def cli():
     pass
 
-@click.command(no_args_is_help=True)
-@click.option("--run_cfg", type=click.Path(exists=True))
-@click.option("--batch_size_per_device", type=int, default=None)
+
+@cli.command(no_args_is_help=False)
+@click.option("--batch_size_per_device", type=int, default=256)
+@click.option("--run_cfg", type=click.Path(exists=True), default=None)
 def train(
-    run_cfg,
     batch_size_per_device,
+    run_cfg,
 ):
-    training_options = TrainingOptions.from_yaml(
-        cfg_path=run_cfg,
-        cli_kwargs=dict(
+    if run_cfg is not None:
+        training_options = TrainingOptions.from_yaml(
+            cfg_path=run_cfg,
+            cli_kwargs=dict(
+                batch_size_per_device=batch_size_per_device,
+            ),
+        )
+    else:
+        training_options = TrainingOptions(
             batch_size_per_device=batch_size_per_device,
-        ),
-    )
+        )
 
     set_seed(training_options.seed)
     mnist_classifier = SYNTH_classifier(opts=training_options)
@@ -432,5 +436,5 @@ if __name__ == "__main__":
     urllib3_logger.setLevel(logging.WARNING)
     simple_parsing_logger = logging.getLogger("simple_parsing")
     simple_parsing_logger.setLevel(logging.INFO)
-    
+
     cli()