blob: 8886553c973602d733ca52bb9eb11ae7082134e4 [file] [log] [blame]
import argparse
import inspect
import os
import sys
import time
from datetime import timedelta
import torch
import torch._dynamo
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
torch.backends.cuda.matmul.allow_tf32 = True
# You will download around 84G dataset if you run this end to end training/evaluation example.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def data_processing(num_samples, batch_size):
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
return train_dataloader, eval_dataloader
def training_iter_fn(batch, model, optimizer):
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def model_training_evaluation(
backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
):
model.to(device)
model.train()
loss_history = []
if not backend:
# Run with native Pytorch
opt_training_iter_fn = training_iter_fn
else:
# Support backends: eager, aot_eager, aot_nvfuser and inductor
opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
for epoch in range(num_epochs):
running_loss = 0.0
for i, batch in enumerate(train_dataloader, 0):
batch = {k: v.to(device) for k, v in batch.items()}
loss = opt_training_iter_fn(batch, model, optimizer)
running_loss += loss.item()
if i % 100 == 99:
loss_history.append(running_loss / 100)
running_loss = 0.0
if evaluation:
metric = load_metric("accuracy")
model.eval()
if not backend:
opt_model = model
else:
opt_model = torch._dynamo.optimize(backend)(model)
for batch in eval_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = opt_model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
return loss_history, metric.compute()
else:
return loss_history, None
def check_loss(ref_loss, res_loss):
assert len(ref_loss) == len(res_loss)
length = len(ref_loss)
x = min(length, 10)
if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
return True
else:
return False
def parse_args():
parser = argparse.ArgumentParser(
description="TorchDynamo end to end training/evaluation benchmark"
)
parser.add_argument(
"--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
)
parser.add_argument(
"--num-samples",
type=int,
default=1000,
help="number of samples to train/eval (default: 1000)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="input batch size for training (default: 8)",
)
parser.add_argument(
"--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
)
parser.add_argument(
"--backend",
choices=torch._dynamo.list_backends(exclude_tags=None),
default="inductor",
help="train/evaluate model with a given backend (default: inductor)",
)
parser.add_argument(
"--optimizer",
default="Adam",
help="train model using a given optimizer (default: Adam)",
)
parser.add_argument(
"--evaluation",
action="store_true",
help="running evaluation after model training",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
train_dataloader, eval_dataloader = data_processing(
args.num_samples, args.batch_size
)
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", num_labels=5
)
optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
else:
optimizer = optimizer_cls(model.parameters(), lr=args.lr)
native_start = time.time()
ref_loss, accuracy = model_training_evaluation(
None,
train_dataloader,
eval_dataloader,
model,
optimizer,
args.epochs,
args.evaluation,
)
native_end = time.time()
res_loss, accuracy = model_training_evaluation(
args.backend,
train_dataloader,
eval_dataloader,
model,
optimizer,
args.epochs,
args.evaluation,
)
dynamo_end = time.time()
if check_loss(ref_loss, res_loss):
print(
"[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
)
else:
print(
"[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
)
if args.evaluation:
print(f"Model accuracy: {accuracy}")
native_elapsed = native_end - native_start
dynamo_elapsed = dynamo_end - native_end
print(
f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
)
print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
print(
f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
)
if __name__ == "__main__":
main()