| ## @package sparse_lookup |
| # Module caffe2.python.layers.sparse_lookup |
| |
| |
| |
| |
| |
| from caffe2.python.optimizer import FP16_ENGINES, Optimizer |
| from caffe2.python.helpers.arg_scope import get_current_scope |
| from caffe2.python import schema |
| from caffe2.python.layers.layers import ( |
| get_categorical_limit, |
| get_key, |
| IdList, |
| IdScoreList, |
| IdListWithEvicted, |
| IdScoreListWithEvicted, |
| LayerPsParam, |
| ModelLayer, |
| almost_equal_schemas, |
| ) |
| import collections |
| import functools |
| import logging |
| import math |
| import numpy as np |
| import operator |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| def get_trainer_version_based_on_optim(optim_def): |
| if isinstance(optim_def, Optimizer) and hasattr(optim_def, "engine"): |
| logger.info( |
| "Attempting to set trainer version for engine {}".format(optim_def.engine) |
| ) |
| if optim_def.engine in FP16_ENGINES: |
| logger.info("Setting FP16 trainer for engine {}".format(optim_def.engine)) |
| return "fp16" |
| else: |
| logger.info("Setting FP32 trainer for engine {}".format(optim_def.engine)) |
| return "fp32" |
| else: |
| return "fp32" |
| |
| |
| def get_sparse_lookup_predictor_version( |
| version, |
| blob_size=None, |
| min_blob_size_4bits=None, |
| embedding_dim=None, |
| sparse_feature_name=None, |
| ): |
| assert version in { |
| 'fp32', 'fp16', 'uint8rowwise', 'fused_uint8rowwise', 'fused_uint4rowwise' |
| }, "Unexpected version of sparse_lookup layer {0}".format(version) |
| if version == 'fused_uint4rowwise': |
| if ( |
| blob_size is not None |
| and min_blob_size_4bits is not None |
| and embedding_dim is not None |
| ): |
| if blob_size < min_blob_size_4bits: |
| logger.info( |
| "{} fall back to uint8 because lookup table size {} < min_blob_size_4bits {}".format( |
| sparse_feature_name, |
| blob_size, |
| min_blob_size_4bits, |
| ) |
| ) |
| version = 'fused_uint8rowwise' |
| |
| if embedding_dim % 2 == 1: |
| logger.info( |
| "{} fall back to uint8 because lookup table dimension {} is not divisible by 2".format( |
| sparse_feature_name, embedding_dim |
| ) |
| ) |
| version = 'fused_uint8rowwise' |
| else: |
| raise ValueError( |
| ( |
| "When 4 bit quantization is enabled for {}, " |
| "(i.e., Sparse lookup predictor version:{}), " |
| "requires arguments blob_size:{}, " |
| "min_blob_size_4bits:{}, embedding_dim:{}" |
| ).format( |
| sparse_feature_name, |
| version, |
| blob_size, |
| min_blob_size_4bits, |
| embedding_dim |
| ) |
| ) |
| return version |
| |
| |
| def get_sparse_lookup_trainer_version(version): |
| assert version in {'fp32', 'fp16'},\ |
| "Unexpected version of sparse_lookup layer {0}".format(version) |
| return version |
| |
| def _is_id_list(input_record): |
| return almost_equal_schemas(input_record, IdList) |
| |
| |
| def _is_id_score_list(input_record): |
| return almost_equal_schemas(input_record, |
| IdScoreList, |
| check_field_types=False) |
| |
| |
| class SparseLookup(ModelLayer): |
| _id_list_supported_reducers = [ |
| 'LogMeanExp', 'LogSumExp', 'Max', 'Mean', 'Sum', |
| 'WeightedSum', 'WeightedMean', 'Sqrt', 'None'] |
| |
| _id_score_list_supported_reducers = [ |
| 'PositionWeighted', 'RecencyWeighted', 'Mean', 'Sum', 'WeightedSum', |
| 'WeightedMean', 'None' |
| ] |
| |
| _fp16_compatible_init_op_types = [ |
| 'Float16UniformFill' |
| ] |
| |
| _fp16_compatible_reducers = [ |
| 'Sum', 'Mean', 'Sqrt', 'PositionWeighted', 'RecencyWeighted', |
| ] |
| |
| def __init__(self, model, input_record, inner_shape, reducer, |
| weight_init=None, weight_optim=None, |
| name='sparse_lookup', regularizer=None, use_external_weights=False, |
| uniform_weight_init_scale_numerator=1.0, **kwargs): |
| |
| super(SparseLookup, self).__init__(model, name, input_record, **kwargs) |
| |
| self.sparse_key = get_key(self.input_record)() |
| logger.info("Setup the sparse lookup layer for " + self.sparse_key) |
| |
| # TODO Add some asserts about input type |
| if isinstance(inner_shape, int): |
| inner_shape = [inner_shape] |
| assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\ |
| "Unexpected type for inner_shape, expected list or tuple, got {0} for {1}".\ |
| format(type(inner_shape), self.sparse_key) |
| |
| if reducer == "PositionWeighted": |
| assert _is_id_score_list(self.input_record), ( |
| "PositionWeighted only support IdScoreList, but got {} for {}" |
| + "please use PositionWeighted layer to convert IdList " |
| + "to IdScoreList" |
| ).format(repr(self.input_record), self.sparse_key) |
| self.external_weights = self.input_record.values() |
| |
| elif reducer == "RecencyWeighted": |
| assert _is_id_score_list(self.input_record), ( |
| "RecencyWeighted only supports IdScoreList, " |
| "while the sparse feature {} is not.".format(self.sparse_key) |
| ) |
| self.external_weights = self.input_record.values() |
| # TODO: create a new type of reducer with external weights to wrap |
| # this and the above two cases since essentially their input formats |
| # are the same. |
| elif use_external_weights: |
| assert _is_id_score_list(self.input_record), ( |
| "Use_external_weights only supports IdScoreList, " |
| "while the sparse feature {} is not.".format(self.sparse_key) |
| ) |
| assert reducer in ["Sum", "WeightedSum"], ( |
| "Use_external_weights only supports Sum reducer, " |
| "while the reducer is {}.".format(reducer) |
| ) |
| self.external_weights = self.input_record.values() |
| self.reducer = reducer |
| self.use_external_weights = use_external_weights |
| |
| input_dim = get_categorical_limit(self.input_record) |
| assert input_dim > 0, "{} should have categorical limit > 0, but got {}".format( |
| self.sparse_key, input_dim |
| ) |
| |
| self.input_dim = input_dim |
| self.shape = [input_dim] + inner_shape |
| |
| self.trainer_version = get_trainer_version_based_on_optim( |
| weight_optim |
| ) |
| |
| self.uniform_weight_init_scale_numerator = uniform_weight_init_scale_numerator |
| default_init_op = self._get_default_init_op() |
| |
| self.weight_init = weight_init or default_init_op |
| |
| self.evicted_values = None |
| if schema.equal_schemas( |
| self.input_record, IdListWithEvicted |
| ) or schema.equal_schemas( |
| self.input_record, IdScoreListWithEvicted, check_field_types=False |
| ): |
| self.evicted_values = self.input_record._evicted_values |
| |
| # If fp16 is used, make sure fp16 init op is used |
| if self.trainer_version == "fp16": |
| assert self.reducer in self._fp16_compatible_reducers or use_external_weights, ( |
| "Fp16 training is enabled. The reducer specified is not supported. " |
| "Got {}. Supported reducers: {}. Right now, in general, sum, mean, " |
| "positional pooling are supported. Attention is not. Please check " |
| "if there is fp16 trained sparse features using advanced pooling.".format( |
| self.reducer, self._fp16_compatible_reducers) |
| ) |
| |
| # if init op is UniformFill, we replace it directly |
| if self.weight_init[0] == "UniformFill": |
| self.weight_init = ("Float16UniformFill", self.weight_init[1]) |
| assert self.weight_init[0] in self._fp16_compatible_init_op_types, ( |
| "Fp16 training is enabled. Init op for weight parameter must be fp16 " |
| "compatibale. Got {}. Supported ops: {}".format( |
| self.weight_init[0], |
| self._fp16_compatible_init_op_types) |
| ) |
| |
| assert regularizer is None, "Regularizer is not compatible with fp16" |
| |
| if self.input_record.lengths.metadata: |
| avg_length = self.input_record.lengths.metadata.expected_value |
| else: |
| avg_length = None |
| |
| self.w = self.create_param( |
| param_name='w', |
| shape=self.shape, |
| initializer=self.weight_init, |
| optimizer=weight_optim, |
| ps_param=LayerPsParam( |
| sparse_key=self.sparse_key, |
| average_length=avg_length), |
| regularizer=regularizer |
| ) |
| if self.evicted_values: |
| self.reinit_vec = self.create_param( |
| param_name="reinit_vec", |
| shape=inner_shape, |
| initializer=self.weight_init, |
| optimizer=model.NoOptim, |
| regularizer=None, |
| ) |
| |
| self.scale_bias_init = ('ConstantFill', {'value': 0.0}) |
| |
| self.scale_bias = self.create_param( |
| param_name='scale_bias', |
| shape=[], |
| initializer=self.scale_bias_init, |
| optimizer=model.NoOptim, |
| ) |
| |
| self.output_schema = schema.Scalar( |
| (np.float32, inner_shape), |
| self.get_next_blob_reference('output'), |
| ) |
| |
| def get_memory_usage(self): |
| return functools.reduce(operator.mul, self.shape) * 4 |
| |
| def get_fp16_compatible_parameters(self): |
| return [self.w] |
| |
| def support_8bit(self): |
| # Rowwise quantization makes sense only if shape it's 2D matrix with |
| # second dimension >= 8 |
| if len(self.shape) != 2 or self.shape[1] < 8: |
| return False |
| return True |
| |
| def get_8bits_compatible_parameters(self, fused=True): |
| if not self.support_8bit(): |
| return [] |
| if fused: |
| RowwiseQuantized8BitsWeight = collections.namedtuple( |
| 'RowwiseQuantized8BitsWeight', 'w' |
| ) |
| return [RowwiseQuantized8BitsWeight(self.w)] |
| else: |
| RowwiseQuantized8BitsWeight = collections.namedtuple( |
| 'RowwiseQuantized8BitsWeight', 'w, scale_bias' |
| ) |
| return [RowwiseQuantized8BitsWeight(self.w, self.scale_bias)] |
| |
| def _get_default_init_op(self): |
| scale = math.sqrt(self.uniform_weight_init_scale_numerator / self.input_dim) |
| |
| if self.trainer_version == 'fp32': |
| default_weight_init = ('UniformFill', {'min': -scale, 'max': scale}) |
| elif self.trainer_version == 'fp16': |
| default_weight_init = ("Float16UniformFill", {'min': -scale, 'max': scale}) |
| else: |
| raise NotImplementedError( |
| "Train version {} is not currently supported for sparse feature {}".format( |
| trainer_version, self.sparse_key |
| ) |
| ) |
| |
| return default_weight_init |
| |
| def _gather_wrapper(self, net, version, in_indices, out): |
| # Gather can work on all kinds of input data types, and output |
| # data with the same type. Convert the output of Gather to float, |
| # because the follow-up Ops expect fp32. |
| if version == 'fp32': |
| return net.Gather([self.w, in_indices], out) |
| elif version == 'fp16': |
| gathered_w = net.Gather([self.w, in_indices], 'gathered_w') |
| return net.HalfToFloat(gathered_w, out) |
| elif version == 'uint8rowwise': |
| gathered_w = net.Gather([self.w, in_indices], 'gathered_w') |
| gathered_scale_bias = net.Gather( |
| [self.scale_bias, in_indices], |
| 'gathered_scale_bias' |
| ) |
| |
| return net.Rowwise8BitQuantizedToFloat( |
| [gathered_w, gathered_scale_bias], out) |
| elif version == 'fused_uint8rowwise': |
| gathered_w = net.Gather([self.w, in_indices], 'gathered_w') |
| return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out) |
| elif version == 'fused_uint4rowwise': |
| gathered_w = net.Gather([self.w, in_indices], 'gathered_w') |
| return net.Fused4BitRowwiseQuantizedToFloat(gathered_w, out) |
| |
| else: |
| raise "Unsupported version of operators in SparseLookup " +\ |
| "layer: {0} for sparse feature {1}".format( |
| version, self.sparse_key |
| ) |
| |
| def _sparse_lengths_weighted_reducer( |
| self, |
| in_indices, |
| weights, |
| reducer, |
| net, |
| version, |
| grad_on_weights=0, |
| ): |
| op_input = [ |
| self.w, |
| weights, |
| in_indices, |
| self.input_record.lengths(), |
| ] |
| layer_name = 'SparseLengths' + reducer |
| |
| if version in ['fp32', 'fp16']: |
| # SparseLengths* Ops will accept either fp16 or fp32 embedding |
| # matrix and output fp32 pooled embedding |
| # A special case here is that we need FP16 engine for |
| # SparseLengthsWeightedSum when FP16 embeedings are used for |
| # correct backward updates |
| if reducer == "WeightedSum" and version == "fp16": |
| net.SparseLengthsWeightedSum( |
| op_input, |
| self.output_schema.field_blobs(), |
| grad_on_weights=grad_on_weights, |
| engine='FP16', |
| ) |
| else: |
| net.__getattr__(layer_name)( |
| op_input, |
| self.output_schema.field_blobs(), |
| grad_on_weights=grad_on_weights, |
| ) |
| elif version == 'uint8rowwise': |
| op_input.insert(len(op_input), self.scale_bias) |
| net.__getattr__(layer_name + '8BitsRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint8rowwise': |
| net.__getattr__(layer_name + 'Fused8BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint4rowwise': |
| net.__getattr__(layer_name + 'Fused4BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| else: |
| raise "Unsupported version of operator in SparseLookUp " +\ |
| "layer: {0} for sparse feature {1}".format( |
| version, self.sparse_key |
| ) |
| |
| # deal with sparse features of id_list type |
| def _add_ops_id_list(self, net, version): |
| assert self.reducer in self._id_list_supported_reducers, ( |
| "Unsupported reducer: {} for ID_LIST {}".format( |
| self.reducer, self.sparse_key |
| ) |
| ) |
| if self.reducer in ['Sum', 'Mean', 'WeightedSum', 'WeightedMean']: |
| op_input = [self.w, |
| self.input_record.items(), |
| self.input_record.lengths()] |
| |
| # For id list features, the behaviors of 'Sum' and |
| # 'WeightedSum' are identical, since we can regard the weight on each |
| # id as 1. Similarly, for 'Mean' and 'WeightedMean'. |
| if self.reducer == 'WeightedSum': |
| self.reducer = 'Sum' |
| elif self.reducer == 'WeightedMean': |
| self.reducer = 'Mean' |
| |
| layer_name = 'SparseLengths' + self.reducer |
| if version in ['fp32', 'fp16']: |
| # SparseLengths* Ops will accept either fp16 or fp32 embedding |
| # matrix and output fp32 pooled embedding |
| net.__getattr__(layer_name)( |
| op_input, |
| self.output_schema.field_blobs(), |
| ) |
| elif version == 'uint8rowwise': |
| op_input.insert(len(op_input), self.scale_bias) |
| net.__getattr__(layer_name + '8BitsRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint8rowwise': |
| net.__getattr__(layer_name + 'Fused8BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint4rowwise': |
| net.__getattr__(layer_name + 'Fused4BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| else: |
| raise "Unsupported version of operator in SparseLookUp " +\ |
| "layer: {0} for sparse feature {1}".format( |
| version, self.sparse_key |
| ) |
| |
| elif self.reducer == 'Sqrt': |
| sqrt_weight = net.LengthsToWeights( |
| [self.input_record.lengths()], |
| [net.NextScopedBlob('lengths_sqrt')], |
| power=0.5, |
| ) |
| self._sparse_lengths_weighted_reducer( |
| self.input_record.items(), |
| sqrt_weight, |
| 'WeightedSum', net, version) |
| |
| elif self.reducer == 'None': |
| # Gather operator will gather the embedding for each id of |
| # each IdList. |
| self._gather_wrapper(net, version, self.input_record.items(), |
| self.output_schema.field_blobs()) |
| |
| else: |
| table_rows = self._gather_wrapper( |
| net, version, self.input_record.items(), 'table_rows') |
| |
| segment_ids = net.LengthsToSegmentIds( |
| self.input_record.lengths(), |
| net.NextScopedBlob(self.input_record.lengths() + '_sid')) |
| net.__getattr__('SortedSegmentRange' + self.reducer)( |
| [table_rows, segment_ids], |
| self.output_schema.field_blobs(), |
| ) |
| |
| # deal with sparse features of id_score_list type |
| def _add_ops_id_score_list(self, net, version): |
| assert self.reducer in self._id_score_list_supported_reducers, ( |
| "Unsupported reducer: {} for ID_SCORE_LIST {}".format( |
| self.reducer, self.sparse_key |
| ) |
| ) |
| if self.reducer in ['WeightedSum', 'WeightedMean']: |
| self._sparse_lengths_weighted_reducer( |
| self.input_record.keys(), |
| self.input_record.values(), |
| self.reducer, net, version) |
| |
| elif self.reducer in ['PositionWeighted', 'RecencyWeighted'] or self.use_external_weights: |
| self._sparse_lengths_weighted_reducer( |
| self.input_record.keys(), |
| self.external_weights, |
| 'WeightedSum', net, version, grad_on_weights=1) |
| |
| elif self.reducer in ['Sum', 'Mean']: |
| op_input = [self.w, |
| self.input_record.keys(), |
| self.input_record.lengths()] |
| |
| layer_name = 'SparseLengths' + self.reducer |
| |
| if version in ['fp32', 'fp16']: |
| net.__getattr__(layer_name)( |
| op_input, |
| self.output_schema.field_blobs(), |
| ) |
| elif version == 'uint8rowwise': |
| net.__getattr__(layer_name + '8BitsRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint8rowwise': |
| net.__getattr__(layer_name + 'Fused8BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| elif version == 'fused_uint4rowwise': |
| net.__getattr__(layer_name + 'Fused4BitRowwise')( |
| op_input, self.output_schema.field_blobs()) |
| else: |
| raise "Unsupported version of operator in SparseLookUp " +\ |
| "layer: {0} for sparse feature {1}".format( |
| version, self.sparse_key |
| ) |
| |
| elif self.reducer == 'None': |
| # Gather operator will gather the embedding for each id of |
| # each IdList. |
| self._gather_wrapper(net, version, self.input_record.keys(), |
| self.output_schema.field_blobs()) |
| else: |
| raise "Only Sum, Mean, None are supported for IdScoreList input." +\ |
| "Trying to create with {} for sparse feature {}".format( |
| self.reducer, self.sparse_key |
| ) |
| |
| def _add_ops(self, net, version='fp32', is_train=True): |
| if self.evicted_values and is_train: |
| net.CopyRowsToTensor( |
| [self.w, self.evicted_values.get(), self.reinit_vec], [self.w]) |
| if _is_id_list(self.input_record): |
| self._add_ops_id_list(net, version=version) |
| elif _is_id_score_list(self.input_record): |
| self._add_ops_id_score_list(net, version=version) |
| else: |
| raise "Unsupported input type {0}".format(self.input_record) |
| |
| def add_train_ops(self, net): |
| self._add_ops(net, self.trainer_version, is_train=True) |
| |
| def add_ops(self, net): |
| version_info = get_current_scope().get( |
| get_sparse_lookup_predictor_version.__name__, {'version': 'fp32'} |
| ) |
| lookup_table_blob_size = self.shape[0] * self.shape[1] |
| version = get_sparse_lookup_predictor_version( |
| version_info['version'], |
| blob_size=lookup_table_blob_size, |
| min_blob_size_4bits=( |
| version_info['min_blob_size_4bits'] |
| if 'min_blob_size_4bits' in version_info |
| else None |
| ), |
| embedding_dim=self.shape[1], |
| sparse_feature_name=self.sparse_key, |
| ) |
| |
| # TODO(amalevich): Layer should not be responsible for decision about |
| # quantization. |
| if not self.support_8bit() and version in {'uint8rowwise', |
| 'fused_uint8rowwise', |
| 'fused_uint4rowwise'}: |
| version = 'fp16' |
| |
| self._add_ops(net, version, is_train=False) |