blob: 6699b395d04583873fb14f0bfee61e94a453814b [file] [log] [blame]
#include "deep_wide_pt.h"
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/script.h>
namespace {
// No ReplaceNaN (this removes the constant in the model)
const std::string deep_wide_pt = R"JIT(
class DeepAndWide(Module):
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
__buffers__ = []
_mu : Tensor
_sigma : Tensor
_fc_w : Tensor
_fc_b : Tensor
training : bool
def forward(self: __torch__.DeepAndWide,
ad_emb_packed: Tensor,
user_emb: Tensor,
wide: Tensor) -> Tuple[Tensor]:
_0 = self._fc_b
_1 = self._fc_w
_2 = self._sigma
wide_offset = torch.add(wide, self._mu, alpha=1)
wide_normalized = torch.mul(wide_offset, _2)
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
user_emb_t = torch.transpose(user_emb, 1, 2)
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
dp = torch.flatten(dp_unflatten, 1, -1)
input = torch.cat([dp, wide_preproc], 1)
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
return (torch.sigmoid(fc1),)
)JIT";
const std::string trivial_model_1 = R"JIT(
def forward(self, a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
)JIT";
const std::string leaky_relu_model_const = R"JIT(
def forward(self, input):
x = torch.leaky_relu(input, 0.1)
x = torch.leaky_relu(x, 0.1)
x = torch.leaky_relu(x, 0.1)
x = torch.leaky_relu(x, 0.1)
return torch.leaky_relu(x, 0.1)
)JIT";
const std::string leaky_relu_model = R"JIT(
def forward(self, input, neg_slope):
x = torch.leaky_relu(input, neg_slope)
x = torch.leaky_relu(x, neg_slope)
x = torch.leaky_relu(x, neg_slope)
x = torch.leaky_relu(x, neg_slope)
return torch.leaky_relu(x, neg_slope)
)JIT";
void import_libs(
std::shared_ptr<at::CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<torch::jit::Source>& src,
const std::vector<at::IValue>& tensor_table) {
torch::jit::SourceImporter si(
cu,
&tensor_table,
[&](const std::string& /* unused */)
-> std::shared_ptr<torch::jit::Source> { return src; },
/*version=*/2);
si.loadType(c10::QualifiedName(class_name));
}
} // namespace
torch::jit::Module getDeepAndWideSciptModel(int num_features) {
auto cu = std::make_shared<at::CompilationUnit>();
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.DeepAndWide",
std::make_shared<torch::jit::Source>(deep_wide_pt),
constantTable);
c10::QualifiedName base("__torch__");
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
torch::jit::Module mod(cu, clstype);
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
mod.register_parameter("_fc_b", torch::randn({1}), false);
// mod.dump(true, true, true);
return mod;
}
torch::jit::Module getTrivialScriptModel() {
torch::jit::Module module("m");
module.define(trivial_model_1);
return module;
}
torch::jit::Module getLeakyReLUScriptModel() {
torch::jit::Module module("leaky_relu");
module.define(leaky_relu_model);
return module;
}
torch::jit::Module getLeakyReLUConstScriptModel() {
torch::jit::Module module("leaky_relu_const");
module.define(leaky_relu_model_const);
return module;
}
const std::string long_model = R"JIT(
def forward(self, a, b, c):
d = torch.relu(a * b)
e = torch.relu(a * c)
f = torch.relu(e * d)
g = torch.relu(f * f)
h = torch.relu(g * c)
return h
)JIT";
torch::jit::Module getLongScriptModel() {
torch::jit::Module module("m");
module.define(long_model);
return module;
}
const std::string signed_log1p_model = R"JIT(
def forward(self, a):
b = torch.abs(a)
c = torch.log1p(b)
d = torch.sign(a)
e = d * c
return e
)JIT";
torch::jit::Module getSignedLog1pModel() {
torch::jit::Module module("signed_log1p");
module.define(signed_log1p_model);
return module;
}