Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AI Containers
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
dataanalytics-public
AI Containers
Commits
ca0e54fd
Commit
ca0e54fd
authored
1 month ago
by
Nastassya Horlava
Browse files
Options
Downloads
Patches
Plain Diff
fixed
parent
52c9da4a
No related branches found
No related tags found
1 merge request
!4
Docs tensorflow
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
.pre-commit-config.yaml
+16
-0
16 additions, 0 deletions
.pre-commit-config.yaml
pyproject.toml
+10
-0
10 additions, 0 deletions
pyproject.toml
tensorflow/src/train_synthetic.py
+42
-38
42 additions, 38 deletions
tensorflow/src/train_synthetic.py
with
68 additions
and
38 deletions
.pre-commit-config.yaml
0 → 100644
+
16
−
0
View file @
ca0e54fd
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.6.0
hooks
:
-
id
:
check-yaml
-
id
:
end-of-file-fixer
-
id
:
trailing-whitespace
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev
:
v0.11.2
hooks
:
# Run the linter.
-
id
:
ruff
args
:
[
--fix
]
# Run the formatter.
-
id
:
ruff-format
This diff is collapsed.
Click to expand it.
pyproject.toml
0 → 100644
+
10
−
0
View file @
ca0e54fd
[tool.ruff]
line-length
=
88
[tool.ruff.lint.pycodestyle]
max-doc-length
=
88
max-line-length
=
88
[tool.ruff.lint]
extend-select
=
[
"I"
,
"W505"
]
This diff is collapsed.
Click to expand it.
tensorflow/src/train_synthetic.py
+
42
−
38
View file @
ca0e54fd
...
...
@@ -5,19 +5,14 @@ import os
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
,
field
from
pathlib
import
Path
from
time
import
perf_counter
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
click
import
mlflow
from
mlflow_utils
import
(
MLflowMetricsCallback
,
MlflowTimingCallback
,
TimingCallback
,
mlflow_log_sbatch_logs
,
mlflow_log_sbatch_scripts
,
)
import
pandas
as
pd
import
tensorflow
as
tf
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -106,11 +101,11 @@ class TimingCallback(tf.keras.callbacks.Callback):
)
def
set_seed
(
seed
:
int
=
5
):
import
random
import
numpy
as
np
if
not
isinstance
(
seed
,
int
):
raise
ValueError
(
"
Expected `seed` argument to be an integer.
"
...
...
@@ -122,7 +117,6 @@ def set_seed(seed: int = 5):
tf
.
random
.
set_seed
(
seed
)
class
NullStrategy
:
@staticmethod
def
scope
():
...
...
@@ -153,11 +147,11 @@ class TrainingStrategy:
communication_type
:
Union
[
str
,
tf
.
distribute
.
experimental
.
CommunicationImplementation
]
=
field
(
init
=
False
)
cross_device_communication_type
:
Union
[
str
,
tf
.
distribute
.
CrossDeviceOps
]
=
field
(
init
=
False
)
cross_device_communication_type
:
Union
[
str
,
tf
.
distribute
.
CrossDeviceOps
]
=
field
(
init
=
False
)
communication_options
:
Optional
[
tf
.
distribute
.
experimental
.
CommunicationOptions
]
=
(
field
(
default
=
None
,
init
=
False
)
)
...
...
@@ -224,9 +218,12 @@ class TrainingStrategy:
def
_use_single_node_multi_gpu_strategy
(
self
)
->
None
:
self
.
strategy_type
=
"
MirroredStrategy
"
self
.
communication_type
=
NullCommunication
()
self
.
cross_device_communication_type
=
self
.
_get_cross_device_ops_implementation
(
self
.
device_type
)
self
.
strategy
=
tf
.
distribute
.
MirroredStrategy
(
cross_device_ops
=
self
.
cross_device_communication_type
)
self
.
cross_device_communication_type
=
(
self
.
_get_cross_device_ops_implementation
(
self
.
device_type
)
)
self
.
strategy
=
tf
.
distribute
.
MirroredStrategy
(
cross_device_ops
=
self
.
cross_device_communication_type
)
def
_use_multi_node_strategy
(
self
)
->
None
:
self
.
cross_device_communication_type
=
NullCommunication
()
...
...
@@ -242,7 +239,7 @@ class TrainingStrategy:
self
.
strategy
=
tf
.
distribute
.
MultiWorkerMirroredStrategy
(
communication_options
=
self
.
communication_options
)
def
_get_cross_device_ops_implementation
(
self
,
device_type
:
str
):
"""
Map device type to appropriate communication implementation.
"""
if
device_type
==
"
NVIDIA
"
:
...
...
@@ -266,12 +263,13 @@ class TrainingStrategy:
)
def
_log_strategy_params
(
self
)
->
None
:
"""
Log key strategy configuration
to MLflow.
"""
"""
Log key strategy configuration
"""
logger
.
info
(
f
"
num_replicas_in_sync =
{
self
.
strategy
.
num_replicas_in_sync
}
"
)
logger
.
info
(
f
"
strategy_type =
{
self
.
strategy_type
}
"
)
logger
.
info
(
f
"
communication_type =
{
self
.
communication_type
.
name
}
"
)
logger
.
info
(
f
"
cross_device_communication_type =
{
type
(
self
.
cross_device_communication_type
)
}
"
)
logger
.
info
(
f
"
cross_device_communication_type =
{
type
(
self
.
cross_device_communication_type
)
}
"
)
@dataclass
...
...
@@ -345,7 +343,6 @@ class SYNTH_classifier:
return
model
def
prepare_dataset
(
self
):
@tf.function
def
gen_fn
(
_
):
image
=
tf
.
random
.
uniform
([
224
,
224
,
3
])
...
...
@@ -360,10 +357,9 @@ class SYNTH_classifier:
return
dataset
def
train
(
self
):
self
.
train_dataset
=
self
.
prepare_dataset
()
logger
.
info
(
f
"
train_dataset:
{
type
(
self
.
train_dataset
)
}
"
)
# Define distributed strategy
if
self
.
opts
.
distributed
:
self
.
train_dataset
=
(
...
...
@@ -372,7 +368,8 @@ class SYNTH_classifier:
)
)
# Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of distributed training, or just NullStrategy instead.
# Create a MirroredStrategy or MultiWorkerMirroredStrategy in case of
# distributed training, or just NullStrategy instead.
with
self
.
opts
.
training_strategy
.
strategy
.
scope
():
self
.
model
=
self
.
get_compiled_model
()
...
...
@@ -390,7 +387,7 @@ class SYNTH_classifier:
TimingCallback
(
batch_size
=
self
.
opts
.
global_batch_size
,
log_freq
=
self
.
opts
.
timing_log_freq
,
rank
=
int
(
os
.
environ
[
"
RANK
"
]
),
rank
=
int
(
os
.
environ
.
get
(
"
RANK
"
,
0
)
),
num_warmup_batches
=
self
.
opts
.
timing_warmup_batches
,
),
]
...
...
@@ -401,23 +398,30 @@ class SYNTH_classifier:
test_loss
,
test_acc
=
self
.
model
.
evaluate
(
test_dataset
,
verbose
=
0
)
return
test_loss
,
test_acc
@click.group
()
def
cli
():
pass
@click.command
(
no_args_is_help
=
True
)
@click.option
(
"
--run_cfg
"
,
type
=
click
.
Path
(
exists
=
True
))
@click.option
(
"
--batch_size_per_device
"
,
type
=
int
,
default
=
None
)
@cli.command
(
no_args_is_help
=
False
)
@click.option
(
"
--batch_size_per_device
"
,
type
=
int
,
default
=
256
)
@click.option
(
"
--run_cfg
"
,
type
=
click
.
Path
(
exists
=
True
),
default
=
None
)
def
train
(
run_cfg
,
batch_size_per_device
,
run_cfg
,
):
training_options
=
TrainingOptions
.
from_yaml
(
cfg_path
=
run_cfg
,
cli_kwargs
=
dict
(
if
run_cfg
is
not
None
:
training_options
=
TrainingOptions
.
from_yaml
(
cfg_path
=
run_cfg
,
cli_kwargs
=
dict
(
batch_size_per_device
=
batch_size_per_device
,
),
)
else
:
training_options
=
TrainingOptions
(
batch_size_per_device
=
batch_size_per_device
,
),
)
)
set_seed
(
training_options
.
seed
)
mnist_classifier
=
SYNTH_classifier
(
opts
=
training_options
)
...
...
@@ -432,5 +436,5 @@ if __name__ == "__main__":
urllib3_logger
.
setLevel
(
logging
.
WARNING
)
simple_parsing_logger
=
logging
.
getLogger
(
"
simple_parsing
"
)
simple_parsing_logger
.
setLevel
(
logging
.
INFO
)
cli
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment