Skip to content
Snippets Groups Projects
Commit 126e399d authored by Piero Coronica's avatar Piero Coronica
Browse files

STY: similar train_fn

parent 95cdb30d
No related branches found
No related tags found
No related merge requests found
......@@ -7,18 +7,20 @@ from utils import train_epoch, eval_model, get_FashionMNIST
from argparse import ArgumentParser
def train(args):
print(f"Hyper-parameters: {vars(args)}")
def train_fn(args):
# HyperParameters
lr = args.lr
hidden = args.hidden
bs = args.bs
epochs = args.epochs
hidden = args.hidden
# Model and Training configuration
device = torch.device("cuda")
model = MLP(hidden=hidden).to(device)
optimizer = optim.SGD(params=model.parameters(), lr=lr)
loss = F.nll_loss
# Dataset
train_loader, valid_loader = get_FashionMNIST(
batch_size=bs,
device=device
......
......@@ -5,22 +5,27 @@ import torch.optim as optim
from models import MLP
from utils import train_epoch, eval_model, get_FashionMNIST
from argparse import ArgumentParser
import optuna
def train_fn(trial):
# HyperParameters
lr = trial.suggest_float('lr', 1e-3, 1, log=True)
bs = 64
hidden = trial.suggest_int('hidden', 64, 512, 64)
hidden = 128
bs = 64
epochs = 10
device = torch.device("cuda")
train_loader, valid_loader = get_FashionMNIST(batch_size=bs, device=device)
# Model and Training configuration
device = torch.device("cuda")
model = MLP(hidden=hidden).to(device)
optimizer = optim.SGD(params=model.parameters(), lr=lr)
loss = F.nll_loss
# Dataset
train_loader, valid_loader = get_FashionMNIST(
batch_size=bs,
device=device
)
# Training loop
for epoch in range(1, epochs + 1):
train_epoch(model, device, train_loader, optimizer,
......
......@@ -11,24 +11,33 @@ import ray
from ray import train, tune
def train_fn(config, data_dir):
lr = config['lr']
bs = 64
# HyperParameters
lr = config['lr']
hidden = config['hidden']
bs = 64
epochs = 10
# Model and Training configuration
device = torch.device("cuda")
train_loader, valid_loader = get_FashionMNIST(batch_size=bs,
device=device,
path=data_dir)
model = MLP(hidden=hidden).to(device)
optimizer = optim.SGD(params=model.parameters(), lr=lr)
loss = F.nll_loss
# Dataset
train_loader, valid_loader = get_FashionMNIST(
batch_size=bs,
device=device,
path=data_dir
)
# Training loop
for epoch in range(1, epochs + 1):
train_epoch(model, device, train_loader, optimizer, loss, epoch, verbose=False)
_, acc = eval_model(model, device, valid_loader, loss, verbose=False)
train.report({'accuracy': acc})
train_epoch(model, device, train_loader, optimizer,
loss, epoch, verbose=False)
_, valid_acc = eval_model(model, device, valid_loader,
loss, verbose=False)
train.report({'accuracy': valid_acc})
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment