blob: 5ea3d2a9035c512ee19a4696da73deaf81abb1d1 [file] [log] [blame]
import numpy as np
import pickle
from collections import OrderedDict
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, core, scope
import logging
logging.basicConfig()
log = logging.getLogger("AnyExpOnTerm")
log.setLevel(logging.DEBUG)
def initialize_params_from_file(
model, weights_file, num_xpus, opts,
broadcast_computed_param=False, reset_epoch=False):
start_epoch, lr, best_metric = initialize_master_xpu_model_params(
model, weights_file, opts, reset_epoch)
broadcast_parameters(opts, model, num_xpus, broadcast_computed_param)
return start_epoch, lr, best_metric
def initialize_master_xpu_model_params(model, weights_file, opts, reset_epoch):
log.info("Initializing model params from file: {}".format(weights_file))
with open(weights_file, 'r') as fopen:
blobs = pickle.load(fopen)
if 'blobs' in blobs:
blobs = blobs['blobs']
start_epoch = 0
best_metric = float('-inf')
if 'epoch' in blobs:
log.info('epoch {} is found in model file'.format(blobs['epoch']))
if not reset_epoch:
start_epoch = blobs['epoch']
else:
log.info('Reset epoch')
else:
log.info('no epoch is found in model file')
lr = opts['model_param']['base_learning_rate']
if 'lr' in blobs:
lr = blobs['lr']
if 'best_metric' in blobs and not reset_epoch:
best_metric = blobs['best_metric']
if model is not None:
log.info('initialize model parameters using weights file: {}'.format(
weights_file
))
ws_blobs = workspace.Blobs()
unscoped_blob_names = OrderedDict()
for blob in model.GetAllParams():
unscoped_blob_names[unscope_name(str(blob))] = True
root_xpu_id = opts['distributed']['first_xpu_id']
device = opts['distributed']['device']
caffe2_pb2_DEVICE =\
caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
else caffe2_pb2.CPU
with core.NameScope('{}_{}'.format(device, root_xpu_id)):
with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, 0)):
for unscoped_blob_name in unscoped_blob_names.keys():
scoped_blob_name = scoped_name(unscoped_blob_name)
if unscoped_blob_name not in blobs:
log.info('{:s} not found'.format(unscoped_blob_name))
continue
log.info(
'{:s} loaded from weights file into: {:s}'.format(
unscoped_blob_name, scoped_blob_name
)
)
if scoped_blob_name in ws_blobs:
ws_blob = workspace.FetchBlob(scoped_blob_name)
if not ws_blob.shape == blobs[unscoped_blob_name].shape:
log.info(
('Workspace blob {} with shape {} does '
'not match weights file shape {}').format(
unscoped_blob_name, ws_blob.shape,
blobs[unscoped_blob_name].shape)
)
else:
workspace.FeedBlob(
scoped_blob_name,
blobs[unscoped_blob_name].astype(
np.float32, copy=False))
else:
log.info('Skip initializing model parameters from file: {}'.format(
weights_file
))
log.info('Complete initialize_master_xpu_model_params')
return start_epoch, lr, best_metric
def broadcast_parameters(opts, model, num_xpus, broadcast_computed_param=False):
if num_xpus == 1:
log.info("only 1 device. Skip parameter broadcast")
return
all_params = [model.GetParams()]
if broadcast_computed_param:
all_params.append(model.GetComputedParams())
caffe2_pb2_DEVICE =\
caffe2_pb2.CUDA if opts['distributed']['device'] == 'gpu'\
else caffe2_pb2.CPU
for params in all_params:
assert len(params) % num_xpus == 0, \
"Current model doesn't match device number when loading checkpoint"
params_per_xpu = int(len(params) / num_xpus)
for idx in range(params_per_xpu):
blobs = [param for param in params[idx::params_per_xpu]]
data = workspace.FetchBlob(blobs[0])
log.info('Broadcasting {} to'.format(str(blobs[0])))
for i, p in enumerate(blobs[1:]):
log.info(' |-> {}'.format(str(p)))
with core.DeviceScope(core.DeviceOption(caffe2_pb2_DEVICE, i+1)):
workspace.FeedBlob(p, data)
log.info("Complete parameter broadcast")
def save_model_params(is_checkpoint, model, checkpoint_path, epoch, opts, best_metric):
# best_metric=float('-inf')
if checkpoint_path is None:
return None
try:
save_model_params_blob(
model, checkpoint_path, epoch, opts, best_metric
)
except Exception as e:
log.warning('Exception from save_model_params {}'.format(str(e)))
return checkpoint_path
def save_model_params_blob(model, params_file, epoch, opts, best_metric):
# best_metric=float('-inf')
log.info("Saving model params...")
root_xpu_id = opts['distributed']['first_xpu_id']
device = opts['distributed']['device']
save_params = [str(param) for param in
model.GetParams('{}_{}'.format(device, root_xpu_id))]
save_computed_params = [str(param) for param in
model.GetComputedParams('{}_{}'
.format(device, root_xpu_id))]
save_blobs = {}
save_blobs['epoch'] = epoch
save_blobs['best_metric'] = best_metric
save_blobs['lr'] = \
workspace.FetchBlob('{}_{}/lr'.format(device, root_xpu_id))
for param in save_params + save_computed_params:
scoped_blob_name = str(param)
unscoped_blob_name = unscope_name(scoped_blob_name)
if unscoped_blob_name not in save_blobs:
save_blobs[unscoped_blob_name] = workspace.FetchBlob(
scoped_blob_name)
log.debug(
'{:s} -> {:s}'.format(scoped_blob_name, unscoped_blob_name))
log.info('to weights file {}'.format(params_file))
try:
with open(params_file, 'w') as fwrite:
pickle.dump(dict(blobs=save_blobs), fwrite, pickle.HIGHEST_PROTOCOL)
except IOError as e:
log.error('I/O error({0}): {1}'.format(e.errno, e.strerror))
def unscope_name(blob_name):
return blob_name[blob_name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
def scoped_name(blob_name):
return scope.CurrentNameScope() + blob_name