blob: 5020e5432c2a536c61f02f3b016966fa06f913bb [file] [log] [blame]
## @package dot_product
# Module caffe2.python.layers.dot_product
from caffe2.python import schema
from caffe2.python.layers.layers import (
ModelLayer,
)
class PairwiseSimilarity(ModelLayer):
def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
name='pairwise_similarity', **kwargs):
super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
assert isinstance(input_record, schema.Struct), (
"Incorrect input type. Expected Struct, but received: {0}".
format(input_record))
assert (
('all_embeddings' in input_record) ^
('x_embeddings' in input_record and 'y_embeddings' in input_record)
), (
"either (all_embeddings) xor (x_embeddings and y_embeddings) " +
"should be given."
)
self.pairwise_similarity_func = pairwise_similarity_func
if 'all_embeddings' in input_record:
x_embeddings = input_record['all_embeddings']
y_embeddings = input_record['all_embeddings']
else:
x_embeddings = input_record['x_embeddings']
y_embeddings = input_record['y_embeddings']
assert isinstance(x_embeddings, schema.Scalar), (
"Incorrect input type for x. Expected Scalar, " +
"but received: {0}".format(x_embeddings))
assert isinstance(y_embeddings, schema.Scalar), (
"Incorrect input type for y. Expected Scalar, " +
"but received: {0}".format(y_embeddings)
)
if 'indices_to_gather' in input_record:
indices_to_gather = input_record['indices_to_gather']
assert isinstance(indices_to_gather, schema.Scalar), (
"Incorrect type of indices_to_gather. "
"Expected Scalar, but received: {0}".format(indices_to_gather)
)
self.indices_to_gather = indices_to_gather
else:
self.indices_to_gather = None
self.x_embeddings = x_embeddings
self.y_embeddings = y_embeddings
dtype = x_embeddings.field_types()[0].base
self.output_schema = schema.Scalar(
(dtype, (output_dim,)),
self.get_next_blob_reference('output')
)
def add_ops(self, net):
if self.pairwise_similarity_func == "cosine_similarity":
x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1)
y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1)
Y = net.BatchMatMul(
[x_embeddings_norm, y_embeddings_norm],
[self.get_next_blob_reference(x_embeddings_norm + '_matmul')],
trans_b=1,
)
elif self.pairwise_similarity_func == "dot":
Y = net.BatchMatMul(
[self.x_embeddings(), self.y_embeddings()],
[self.get_next_blob_reference(self.x_embeddings() + '_matmul')],
trans_b=1,
)
else:
raise NotImplementedError(
"pairwise_similarity_func={} is not valid".format(
self.pairwise_similarity_func
)
)
if self.indices_to_gather:
flattened = net.Flatten(
Y, Y + '_flatten',
)
net.BatchGather(
[flattened, self.indices_to_gather()],
self.output_schema(),
)
else:
net.Flatten(Y, self.output_schema())