| |
| |
| |
| |
| |
| import argparse |
| import json |
| import os |
| |
| import caffe2.contrib.playground.AnyExp as AnyExp |
| import caffe2.contrib.playground.checkpoint as checkpoint |
| |
| import logging |
| logging.basicConfig() |
| log = logging.getLogger("AnyExpOnTerm") |
| log.setLevel(logging.DEBUG) |
| |
| |
| def runShardedTrainLoop(opts, myTrainFun): |
| start_epoch = 0 |
| pretrained_model = opts['model_param']['pretrained_model'] |
| if pretrained_model != '' and os.path.exists(pretrained_model): |
| # Only want to get start_epoch. |
| start_epoch, prev_checkpointed_lr, best_metric = \ |
| checkpoint.initialize_params_from_file( |
| model=None, |
| weights_file=pretrained_model, |
| num_xpus=1, |
| opts=opts, |
| broadcast_computed_param=True, |
| reset_epoch=opts['model_param']['reset_epoch'], |
| ) |
| log.info('start epoch: {}'.format(start_epoch)) |
| pretrained_model = None if pretrained_model == '' else pretrained_model |
| ret = None |
| |
| pretrained_model = "" |
| shard_results = [] |
| |
| for epoch in range(start_epoch, |
| opts['epoch_iter']['num_epochs'], |
| opts['epoch_iter']['num_epochs_per_flow_schedule']): |
| # must support checkpoint or the multiple schedule will always |
| # start from initial state |
| checkpoint_model = None if epoch == start_epoch else ret['model'] |
| pretrained_model = None if epoch > start_epoch else pretrained_model |
| shard_results = [] |
| # with LexicalContext('epoch{}_gang'.format(epoch),gang_schedule=False): |
| for shard_id in range(opts['distributed']['num_shards']): |
| opts['temp_var']['shard_id'] = shard_id |
| opts['temp_var']['pretrained_model'] = pretrained_model |
| opts['temp_var']['checkpoint_model'] = checkpoint_model |
| opts['temp_var']['epoch'] = epoch |
| opts['temp_var']['start_epoch'] = start_epoch |
| shard_ret = myTrainFun(opts) |
| shard_results.append(shard_ret) |
| |
| ret = None |
| # always only take shard_0 return |
| for shard_ret in shard_results: |
| if shard_ret is not None: |
| ret = shard_ret |
| opts['temp_var']['metrics_output'] = ret['metrics'] |
| break |
| log.info('ret is: {}'.format(str(ret))) |
| |
| return ret |
| |
| |
| def trainFun(): |
| def simpleTrainFun(opts): |
| trainerClass = AnyExp.createTrainerClass(opts) |
| trainerClass = AnyExp.overrideAdditionalMethods(trainerClass, opts) |
| trainer = trainerClass(opts) |
| return trainer.buildModelAndTrain(opts) |
| return simpleTrainFun |
| |
| |
| if __name__ == '__main__': |
| |
| parser = argparse.ArgumentParser(description='Any Experiment training.') |
| parser.add_argument("--parameters-json", type=json.loads, |
| help='model options in json format', dest="params") |
| |
| args = parser.parse_args() |
| opts = args.params['opts'] |
| opts = AnyExp.initOpts(opts) |
| log.info('opts is: {}'.format(str(opts))) |
| |
| AnyExp.initDefaultModuleMap() |
| |
| opts['input']['datasets'] = AnyExp.aquireDatasets(opts) |
| |
| # defined this way so that AnyExp.trainFun(opts) can be replaced with |
| # some other custermized training function. |
| ret = runShardedTrainLoop(opts, trainFun()) |
| |
| log.info('ret is: {}'.format(str(ret))) |