Skip to content
Snippets Groups Projects
Commit ca0e54fd authored by Nastassya Horlava's avatar Nastassya Horlava
Browse files

fixed

parent 52c9da4a
No related branches found
No related tags found
1 merge request!4Docs tensorflow
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
[tool.ruff]
line-length = 88
[tool.ruff.lint.pycodestyle]
max-doc-length = 88
max-line-length = 88
[tool.ruff.lint]
extend-select = ["I", "W505"]
......@@ -5,19 +5,14 @@ import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from time import perf_counter
from typing import Any, Dict, Optional, Union
import click
import mlflow
from mlflow_utils import (
MLflowMetricsCallback,
MlflowTimingCallback,
TimingCallback,
mlflow_log_sbatch_logs,
mlflow_log_sbatch_scripts,
)
import pandas as pd
import tensorflow as tf
logger = logging.getLogger(__name__)
......@@ -106,11 +101,11 @@ class TimingCallback(tf.keras.callbacks.Callback):
)
def set_seed(seed: int = 5):
import random
import numpy as np
if not isinstance(seed, int):
raise ValueError(
"Expected `seed` argument to be an integer. "
......@@ -122,7 +117,6 @@ def set_seed(seed: int = 5):
tf.random.set_seed(seed)
class NullStrategy:
@staticmethod
def scope():
......@@ -153,11 +147,11 @@ class TrainingStrategy:
communication_type: Union[
str, tf.distribute.experimental.CommunicationImplementation
] = field(init=False)
cross_device_communication_type: Union[
str, tf.distribute.CrossDeviceOps
] = field(init=False)
cross_device_communication_type: Union[str, tf.distribute.CrossDeviceOps] = field(
init=False
)
communication_options: Optional[tf.distribute.experimental.CommunicationOptions] = (
field(default=None, init=False)
)
......@@ -224,9 +218,12 @@ class TrainingStrategy:
def _use_single_node_multi_gpu_strategy(self) -> None:
self.strategy_type = "MirroredStrategy"
self.communication_type = NullCommunication()
self.cross_device_communication_type = self._get_cross_device_ops_implementation(self.device_type)
self.strategy = tf.distribute.MirroredStrategy(cross_device_ops = self.cross_device_communication_type)
self.cross_device_communication_type = (
self._get_cross_device_ops_implementation(self.device_type)
)
self.strategy = tf.distribute.MirroredStrategy(
cross_device_ops=self.cross_device_communication_type
)
def _use_multi_node_strategy(self) -> None:
self.cross_device_communication_type = NullCommunication()
......@@ -242,7 +239,7 @@ class TrainingStrategy:
self.strategy = tf.distribute.MultiWorkerMirroredStrategy(
communication_options=self.communication_options
)
def _get_cross_device_ops_implementation(self, device_type: str):
"""Map device type to appropriate communication implementation."""
if device_type == "NVIDIA":
......@@ -266,12 +263,13 @@ class TrainingStrategy:
)
def _log_strategy_params(self) -> None:
"""Log key strategy configuration to MLflow."""
"""Log key strategy configuration"""
logger.info(f"num_replicas_in_sync = {self.strategy.num_replicas_in_sync}")
logger.info(f"strategy_type = {self.strategy_type}")
logger.info(f"communication_type = {self.communication_type.name}")
logger.info(f"cross_device_communication_type = {type(self.cross_device_communication_type)}")
logger.info(
f"cross_device_communication_type = {type(self.cross_device_communication_type)}"
)
@dataclass
......@@ -345,7 +343,6 @@ class SYNTH_classifier:
return model
def prepare_dataset(self):
@tf.function
def gen_fn(_):
image = tf.random.uniform([224, 224, 3])
......@@ -360,10 +357,9 @@ class SYNTH_classifier:
return dataset
def train(self):
self.train_dataset = self.prepare_dataset()
logger.info(f"train_dataset: {type(self.train_dataset)}")
# Define distributed strategy
if self.opts.distributed:
self.train_dataset = (
......@@ -372,7 +368,8 @@ class SYNTH_classifier:
)
)
# Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of distributed training, or just NullStrategy instead.
# Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of
# distributed training, or just NullStrategy instead.
with self.opts.training_strategy.strategy.scope():
self.model = self.get_compiled_model()
......@@ -390,7 +387,7 @@ class SYNTH_classifier:
TimingCallback(
batch_size=self.opts.global_batch_size,
log_freq=self.opts.timing_log_freq,
rank=int(os.environ["RANK"]),
rank=int(os.environ.get("RANK", 0)),
num_warmup_batches=self.opts.timing_warmup_batches,
),
]
......@@ -401,23 +398,30 @@ class SYNTH_classifier:
test_loss, test_acc = self.model.evaluate(test_dataset, verbose=0)
return test_loss, test_acc
@click.group()
def cli():
pass
@click.command(no_args_is_help=True)
@click.option("--run_cfg", type=click.Path(exists=True))
@click.option("--batch_size_per_device", type=int, default=None)
@cli.command(no_args_is_help=False)
@click.option("--batch_size_per_device", type=int, default=256)
@click.option("--run_cfg", type=click.Path(exists=True), default=None)
def train(
run_cfg,
batch_size_per_device,
run_cfg,
):
training_options = TrainingOptions.from_yaml(
cfg_path=run_cfg,
cli_kwargs=dict(
if run_cfg is not None:
training_options = TrainingOptions.from_yaml(
cfg_path=run_cfg,
cli_kwargs=dict(
batch_size_per_device=batch_size_per_device,
),
)
else:
training_options = TrainingOptions(
batch_size_per_device=batch_size_per_device,
),
)
)
set_seed(training_options.seed)
mnist_classifier = SYNTH_classifier(opts=training_options)
......@@ -432,5 +436,5 @@ if __name__ == "__main__":
urllib3_logger.setLevel(logging.WARNING)
simple_parsing_logger = logging.getLogger("simple_parsing")
simple_parsing_logger.setLevel(logging.INFO)
cli()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment