Skip to content
Snippets Groups Projects
Commit ca0e54fd authored by Nastassya Horlava's avatar Nastassya Horlava
Browse files

fixed

parent 52c9da4a
No related branches found
No related tags found
1 merge request!4Docs tensorflow
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
[tool.ruff]
line-length = 88
[tool.ruff.lint.pycodestyle]
max-doc-length = 88
max-line-length = 88
[tool.ruff.lint]
extend-select = ["I", "W505"]
...@@ -5,19 +5,14 @@ import os ...@@ -5,19 +5,14 @@ import os
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from time import perf_counter
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import click import click
import mlflow import pandas as pd
from mlflow_utils import (
MLflowMetricsCallback,
MlflowTimingCallback,
TimingCallback,
mlflow_log_sbatch_logs,
mlflow_log_sbatch_scripts,
)
import tensorflow as tf import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -106,9 +101,9 @@ class TimingCallback(tf.keras.callbacks.Callback): ...@@ -106,9 +101,9 @@ class TimingCallback(tf.keras.callbacks.Callback):
) )
def set_seed(seed: int = 5): def set_seed(seed: int = 5):
import random import random
import numpy as np import numpy as np
if not isinstance(seed, int): if not isinstance(seed, int):
...@@ -122,7 +117,6 @@ def set_seed(seed: int = 5): ...@@ -122,7 +117,6 @@ def set_seed(seed: int = 5):
tf.random.set_seed(seed) tf.random.set_seed(seed)
class NullStrategy: class NullStrategy:
@staticmethod @staticmethod
def scope(): def scope():
...@@ -154,9 +148,9 @@ class TrainingStrategy: ...@@ -154,9 +148,9 @@ class TrainingStrategy:
str, tf.distribute.experimental.CommunicationImplementation str, tf.distribute.experimental.CommunicationImplementation
] = field(init=False) ] = field(init=False)
cross_device_communication_type: Union[ cross_device_communication_type: Union[str, tf.distribute.CrossDeviceOps] = field(
str, tf.distribute.CrossDeviceOps init=False
] = field(init=False) )
communication_options: Optional[tf.distribute.experimental.CommunicationOptions] = ( communication_options: Optional[tf.distribute.experimental.CommunicationOptions] = (
field(default=None, init=False) field(default=None, init=False)
...@@ -224,9 +218,12 @@ class TrainingStrategy: ...@@ -224,9 +218,12 @@ class TrainingStrategy:
def _use_single_node_multi_gpu_strategy(self) -> None: def _use_single_node_multi_gpu_strategy(self) -> None:
self.strategy_type = "MirroredStrategy" self.strategy_type = "MirroredStrategy"
self.communication_type = NullCommunication() self.communication_type = NullCommunication()
self.cross_device_communication_type = self._get_cross_device_ops_implementation(self.device_type) self.cross_device_communication_type = (
self.strategy = tf.distribute.MirroredStrategy(cross_device_ops = self.cross_device_communication_type) self._get_cross_device_ops_implementation(self.device_type)
)
self.strategy = tf.distribute.MirroredStrategy(
cross_device_ops=self.cross_device_communication_type
)
def _use_multi_node_strategy(self) -> None: def _use_multi_node_strategy(self) -> None:
self.cross_device_communication_type = NullCommunication() self.cross_device_communication_type = NullCommunication()
...@@ -266,12 +263,13 @@ class TrainingStrategy: ...@@ -266,12 +263,13 @@ class TrainingStrategy:
) )
def _log_strategy_params(self) -> None: def _log_strategy_params(self) -> None:
"""Log key strategy configuration to MLflow.""" """Log key strategy configuration"""
logger.info(f"num_replicas_in_sync = {self.strategy.num_replicas_in_sync}") logger.info(f"num_replicas_in_sync = {self.strategy.num_replicas_in_sync}")
logger.info(f"strategy_type = {self.strategy_type}") logger.info(f"strategy_type = {self.strategy_type}")
logger.info(f"communication_type = {self.communication_type.name}") logger.info(f"communication_type = {self.communication_type.name}")
logger.info(f"cross_device_communication_type = {type(self.cross_device_communication_type)}") logger.info(
f"cross_device_communication_type = {type(self.cross_device_communication_type)}"
)
@dataclass @dataclass
...@@ -345,7 +343,6 @@ class SYNTH_classifier: ...@@ -345,7 +343,6 @@ class SYNTH_classifier:
return model return model
def prepare_dataset(self): def prepare_dataset(self):
@tf.function @tf.function
def gen_fn(_): def gen_fn(_):
image = tf.random.uniform([224, 224, 3]) image = tf.random.uniform([224, 224, 3])
...@@ -360,7 +357,6 @@ class SYNTH_classifier: ...@@ -360,7 +357,6 @@ class SYNTH_classifier:
return dataset return dataset
def train(self): def train(self):
self.train_dataset = self.prepare_dataset() self.train_dataset = self.prepare_dataset()
logger.info(f"train_dataset: {type(self.train_dataset)}") logger.info(f"train_dataset: {type(self.train_dataset)}")
...@@ -372,7 +368,8 @@ class SYNTH_classifier: ...@@ -372,7 +368,8 @@ class SYNTH_classifier:
) )
) )
# Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of distributed training, or just NullStrategy instead. # Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of
# distributed training, or just NullStrategy instead.
with self.opts.training_strategy.strategy.scope(): with self.opts.training_strategy.strategy.scope():
self.model = self.get_compiled_model() self.model = self.get_compiled_model()
...@@ -390,7 +387,7 @@ class SYNTH_classifier: ...@@ -390,7 +387,7 @@ class SYNTH_classifier:
TimingCallback( TimingCallback(
batch_size=self.opts.global_batch_size, batch_size=self.opts.global_batch_size,
log_freq=self.opts.timing_log_freq, log_freq=self.opts.timing_log_freq,
rank=int(os.environ["RANK"]), rank=int(os.environ.get("RANK", 0)),
num_warmup_batches=self.opts.timing_warmup_batches, num_warmup_batches=self.opts.timing_warmup_batches,
), ),
] ]
...@@ -401,23 +398,30 @@ class SYNTH_classifier: ...@@ -401,23 +398,30 @@ class SYNTH_classifier:
test_loss, test_acc = self.model.evaluate(test_dataset, verbose=0) test_loss, test_acc = self.model.evaluate(test_dataset, verbose=0)
return test_loss, test_acc return test_loss, test_acc
@click.group() @click.group()
def cli(): def cli():
pass pass
@click.command(no_args_is_help=True)
@click.option("--run_cfg", type=click.Path(exists=True)) @cli.command(no_args_is_help=False)
@click.option("--batch_size_per_device", type=int, default=None) @click.option("--batch_size_per_device", type=int, default=256)
@click.option("--run_cfg", type=click.Path(exists=True), default=None)
def train( def train(
run_cfg,
batch_size_per_device, batch_size_per_device,
run_cfg,
): ):
if run_cfg is not None:
training_options = TrainingOptions.from_yaml( training_options = TrainingOptions.from_yaml(
cfg_path=run_cfg, cfg_path=run_cfg,
cli_kwargs=dict( cli_kwargs=dict(
batch_size_per_device=batch_size_per_device, batch_size_per_device=batch_size_per_device,
), ),
) )
else:
training_options = TrainingOptions(
batch_size_per_device=batch_size_per_device,
)
set_seed(training_options.seed) set_seed(training_options.seed)
mnist_classifier = SYNTH_classifier(opts=training_options) mnist_classifier = SYNTH_classifier(opts=training_options)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment