blob: 60a7cc81c06fdaa2febfd5e0a65d4cc7f289fbc4 [file] [log] [blame]
#!/usr/bin/env python3
import importlib
import logging
import os
import re
import subprocess
import sys
import warnings
try:
from .common import BenchmarkRunner, download_retry_decorator, main
except ImportError:
from common import BenchmarkRunner, download_retry_decorator, main
import torch
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
torch._inductor.config.fx_graph_cache = True
def pip_install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
importlib.import_module("timm")
except ModuleNotFoundError:
print("Installing PyTorch Image Models...")
pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
from timm import __version__ as timmversion
from timm.data import resolve_data_config
from timm.models import create_model
TIMM_MODELS = dict()
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)
# TODO - Figure out the reason of cold start memory spike
BATCH_SIZE_DIVISORS = {
"beit_base_patch16_224": 2,
"convit_base": 2,
"convmixer_768_32": 2,
"convnext_base": 2,
"cspdarknet53": 2,
"deit_base_distilled_patch16_224": 2,
"gluon_xception65": 2,
"mobilevit_s": 2,
"pnasnet5large": 2,
"poolformer_m36": 2,
"resnest101e": 2,
"swin_base_patch4_window7_224": 2,
"swsl_resnext101_32x16d": 2,
"vit_base_patch16_224": 2,
"volo_d1_224": 2,
"jx_nest_base": 4,
}
REQUIRE_HIGHER_TOLERANCE = {
"fbnetv3_b",
"gmixer_24_224",
"hrnet_w18",
"inception_v3",
"mixer_b16_224",
"mobilenetv3_large_100",
"sebotnet33ts_256",
"selecsls42b",
"cspdarknet53",
}
REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
"adv_inception_v3",
"botnet26t_256",
"gluon_inception_v3",
"selecsls42b",
"swsl_resnext101_32x16d",
}
SCALED_COMPUTE_LOSS = {
"ese_vovnet19b_dw",
"fbnetc_100",
"mnasnet_100",
"mobilevit_s",
"sebotnet33ts_256",
}
FORCE_AMP_FOR_FP16_BF16_MODELS = {
"convit_base",
"xcit_large_24_p8_224",
}
SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
"xcit_large_24_p8_224",
}
def refresh_model_names():
import glob
from timm.models import list_models
def read_models_from_docs():
models = set()
# TODO - set the path to pytorch-image-models repo
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
with open(fn) as f:
while True:
line = f.readline()
if not line:
break
if not line.startswith("model = timm.create_model("):
continue
model = line.split("'")[1]
# print(model)
models.add(model)
return models
def get_family_name(name):
known_families = [
"darknet",
"densenet",
"dla",
"dpn",
"ecaresnet",
"halo",
"regnet",
"efficientnet",
"deit",
"mobilevit",
"mnasnet",
"convnext",
"resnet",
"resnest",
"resnext",
"selecsls",
"vgg",
"xception",
]
for known_family in known_families:
if known_family in name:
return known_family
if name.startswith("gluon_"):
return "gluon_" + name.split("_")[1]
return name.split("_")[0]
def populate_family(models):
family = dict()
for model_name in models:
family_name = get_family_name(model_name)
if family_name not in family:
family[family_name] = []
family[family_name].append(model_name)
return family
docs_models = read_models_from_docs()
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
all_models_family = populate_family(all_models)
docs_models_family = populate_family(docs_models)
for key in docs_models_family:
del all_models_family[key]
chosen_models = set()
chosen_models.update(value[0] for value in docs_models_family.values())
chosen_models.update(value[0] for key, value in all_models_family.items())
filename = "timm_models_list.txt"
if os.path.exists("benchmarks"):
filename = "benchmarks/" + filename
with open(filename, "w") as fw:
for model_name in sorted(chosen_models):
fw.write(model_name + "\n")
class TimmRunner(BenchmarkRunner):
def __init__(self):
super().__init__()
self.suite_name = "timm_models"
@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS
@property
def force_fp16_for_bf16_models(self):
return set()
@property
def get_output_amp_train_process_func(self):
return {}
@property
def skip_accuracy_check_as_eager_non_deterministic(self):
if self.args.accuracy and self.args.training:
return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
return set()
@property
def guard_on_nn_module_models(self):
return {
"convit_base",
}
@download_retry_decorator
def _download_model(self, model_name):
model = create_model(
model_name,
in_chans=3,
scriptable=False,
num_classes=None,
drop_rate=0.0,
drop_path_rate=None,
drop_block_rate=None,
pretrained=True,
)
return model
def load_model(
self,
device,
model_name,
batch_size=None,
extra_args=None,
):
if self.args.enable_activation_checkpointing:
raise NotImplementedError(
"Activation checkpointing not implemented for Timm models"
)
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
channels_last = self._args.channels_last
model = self._download_model(model_name)
if model is None:
raise RuntimeError(f"Failed to load model '{model_name}'")
model.to(
device=device,
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
vars(self._args) if timmversion >= "0.8.0" else self._args,
model=model,
use_test_size=not is_training,
)
input_size = data_config["input_size"]
recorded_batch_size = TIMM_MODELS[model_name]
if model_name in BATCH_SIZE_DIVISORS:
recorded_batch_size = max(
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
)
batch_size = batch_size or recorded_batch_size
torch.manual_seed(1337)
input_tensor = torch.randint(
256, size=(batch_size,) + input_size, device=device
).to(dtype=torch.float32)
mean = torch.mean(input_tensor)
std_dev = torch.std(input_tensor)
example_inputs = (input_tensor - mean) / std_dev
if channels_last:
example_inputs = example_inputs.contiguous(
memory_format=torch.channels_last
)
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
if model_name in SCALED_COMPUTE_LOSS:
self.compute_loss = self.scaled_compute_loss
if is_training and not use_eval_mode:
model.train()
else:
model.eval()
self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size
def iter_model_names(self, args):
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
model_names = sorted(TIMM_MODELS.keys())
start, end = self.get_benchmark_indices(len(model_names))
for index, model_name in enumerate(model_names):
if index < start or index >= end:
continue
if (
not re.search("|".join(args.filter), model_name, re.I)
or re.search("|".join(args.exclude), model_name, re.I)
or model_name in args.exclude_exact
or model_name in self.skip_models
):
continue
yield model_name
def pick_grad(self, name, is_training):
if is_training:
return torch.enable_grad()
else:
return torch.no_grad()
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
tolerance = 1e-3
if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
# the conv-batchnorm fusion used under freezing may cause relatively
# large numerical difference. We need are larger tolerance.
# Check https://github.com/pytorch/pytorch/issues/120545 for context
tolerance = 8 * 1e-2
if is_training:
if name in ["levit_128"]:
tolerance = 8 * 1e-2
elif name in REQUIRE_HIGHER_TOLERANCE:
tolerance = 4 * 1e-2
else:
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.
return reduce_to_scalar_loss(pred)
def scaled_compute_loss(self, pred):
# Loss values need zoom out further.
return reduce_to_scalar_loss(pred) / 1000.0
def forward_pass(self, mod, inputs, collect_outputs=True):
with self.autocast(**self.autocast_arg):
return mod(*inputs)
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
with self.autocast(**self.autocast_arg):
pred = mod(*cloned_inputs)
if isinstance(pred, tuple):
pred = pred[0]
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
def timm_main():
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings("ignore")
main(TimmRunner())
if __name__ == "__main__":
timm_main()