blob: 6e2f1f320688856bdc03e1dc41f4eb28a0eff848 [file] [log] [blame]
#!/usr/bin/env python3
import importlib
import logging
import os
import re
import subprocess
import sys
import time
import warnings
import torch
from common import BenchmarkRunner, main
from torch._dynamo.testing import collect_results
from torch._dynamo.utils import clone_inputs
def pip_install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
importlib.import_module("timm")
except ModuleNotFoundError:
print("Installing Pytorch Image Models...")
pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
from timm import __version__ as timmversion
from timm.data import resolve_data_config
from timm.models import create_model
TIMM_MODELS = dict()
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename, "r") as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)
# TODO - Figure out the reason of cold start memory spike
BATCH_SIZE_DIVISORS = {
"beit_base_patch16_224": 2,
"cait_m36_384": 2,
"convit_base": 2,
"convmixer_768_32": 2,
"convnext_base": 2,
"cspdarknet53": 2,
"deit_base_distilled_patch16_224": 2,
"dpn107": 2,
"gluon_xception65": 2,
"mobilevit_s": 2,
"pit_b_224": 2,
"pnasnet5large": 2,
"poolformer_m36": 2,
"res2net101_26w_4s": 2,
"resnest101e": 2,
"sebotnet33ts_256": 2,
"swin_base_patch4_window7_224": 2,
"swsl_resnext101_32x16d": 2,
"twins_pcpvt_base": 2,
"vit_base_patch16_224": 2,
"volo_d1_224": 2,
"jx_nest_base": 4,
"xcit_large_24_p8_224": 4,
}
REQUIRE_HIGHER_TOLERANCE = set("botnet26t_256")
SKIP = {
# Unusual training setup
"levit_128",
}
MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = {
"cait_m36_384": 4,
}
def refresh_model_names():
import glob
from timm.models import list_models
def read_models_from_docs():
models = set()
# TODO - set the path to pytorch-image-models repo
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
with open(fn, "r") as f:
while True:
line = f.readline()
if not line:
break
if not line.startswith("model = timm.create_model("):
continue
model = line.split("'")[1]
# print(model)
models.add(model)
return models
def get_family_name(name):
known_families = [
"darknet",
"densenet",
"dla",
"dpn",
"ecaresnet",
"halo",
"regnet",
"efficientnet",
"deit",
"mobilevit",
"mnasnet",
"convnext",
"resnet",
"resnest",
"resnext",
"selecsls",
"vgg",
"xception",
]
for known_family in known_families:
if known_family in name:
return known_family
if name.startswith("gluon_"):
return "gluon_" + name.split("_")[1]
return name.split("_")[0]
def populate_family(models):
family = dict()
for model_name in models:
family_name = get_family_name(model_name)
if family_name not in family:
family[family_name] = []
family[family_name].append(model_name)
return family
docs_models = read_models_from_docs()
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
all_models_family = populate_family(all_models)
docs_models_family = populate_family(docs_models)
# print(docs_models_family.keys())
for key in docs_models_family:
del all_models_family[key]
chosen_models = set()
for value in docs_models_family.values():
chosen_models.add(value[0])
for key, value in all_models_family.items():
chosen_models.add(value[0])
filename = "timm_models_list.txt"
if os.path.exists("benchmarks"):
filename = "benchmarks/" + filename
with open(filename, "w") as fw:
for model_name in sorted(chosen_models):
fw.write(model_name + "\n")
class TimmRunnner(BenchmarkRunner):
def __init__(self):
super(TimmRunnner, self).__init__()
self.suite_name = "timm_models"
def load_model(
self,
device,
model_name,
batch_size=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
# _, model_dtype, data_dtype = self.resolve_precision()
channels_last = self._args.channels_last
retries = 1
success = False
model = None
while not success and retries < 6:
try:
model = create_model(
model_name,
in_chans=3,
scriptable=False,
num_classes=None,
drop_rate=0.0,
drop_path_rate=None,
drop_block_rate=None,
pretrained=True,
# global_pool=kwargs.pop('gp', 'fast'),
# num_classes=kwargs.pop('num_classes', None),
# drop_rate=kwargs.pop('drop', 0.),
# drop_path_rate=kwargs.pop('drop_path', None),
# drop_block_rate=kwargs.pop('drop_block', None),
)
success = True
except Exception:
wait = retries * 30
time.sleep(wait)
retries += 1
if model is None:
raise RuntimeError(f"Failed to load model '{model_name}'")
model.to(
device=device,
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
vars(self._args) if timmversion >= "0.8.0" else self._args,
model=model,
use_test_size=not is_training,
)
input_size = data_config["input_size"]
recorded_batch_size = TIMM_MODELS[model_name]
if model_name in BATCH_SIZE_DIVISORS:
recorded_batch_size = max(
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
)
batch_size = batch_size or recorded_batch_size
# Control the memory footprint for few models
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
# example_inputs = torch.randn(
# (batch_size,) + input_size, device=device, dtype=data_dtype
# )
torch.manual_seed(1337)
input_tensor = torch.randint(
256, size=(batch_size,) + input_size, device=device
).to(dtype=torch.float32)
mean = torch.mean(input_tensor)
std_dev = torch.std(input_tensor)
example_inputs = (input_tensor - mean) / std_dev
if channels_last:
example_inputs = example_inputs.contiguous(
memory_format=torch.channels_last
)
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
if is_training and not use_eval_mode:
model.train()
else:
model.eval()
self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size
def iter_model_names(self, args):
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
model_names = sorted(TIMM_MODELS.keys())
start, end = self.get_benchmark_indices(len(model_names))
for index, model_name in enumerate(model_names):
if index < start or index >= end:
continue
if (
not re.search("|".join(args.filter), model_name, re.I)
or re.search("|".join(args.exclude), model_name, re.I)
or model_name in self.skip_models
):
continue
yield model_name
def pick_grad(self, name, is_training):
if is_training:
return torch.enable_grad()
else:
return torch.no_grad()
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
tolerance = 1e-3
if is_training:
if REQUIRE_HIGHER_TOLERANCE:
tolerance = 2 * 1e-2
else:
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
# return torch.ones((batch_size,) + (), device=device, dtype=torch.long)
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.
return self.loss(pred, self.target) / 10.0
def forward_pass(self, mod, inputs, collect_outputs=True):
return mod(*inputs)
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
with self.autocast():
pred = mod(*cloned_inputs)
if isinstance(pred, tuple):
pred = pred[0]
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings("ignore")
main(TimmRunnner())