| #include "deep_wide_pt.h" |
| |
| #include <torch/csrc/jit/serialization/import_source.h> |
| #include <torch/script.h> |
| |
| namespace { |
| // No ReplaceNaN (this removes the constant in the model) |
| const std::string deep_wide_pt = R"JIT( |
| class DeepAndWide(Module): |
| __parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ] |
| __buffers__ = [] |
| _mu : Tensor |
| _sigma : Tensor |
| _fc_w : Tensor |
| _fc_b : Tensor |
| training : bool |
| def forward(self: __torch__.DeepAndWide, |
| ad_emb_packed: Tensor, |
| user_emb: Tensor, |
| wide: Tensor) -> Tuple[Tensor]: |
| _0 = self._fc_b |
| _1 = self._fc_w |
| _2 = self._sigma |
| wide_offset = torch.add(wide, self._mu, alpha=1) |
| wide_normalized = torch.mul(wide_offset, _2) |
| wide_preproc = torch.clamp(wide_normalized, 0., 10.) |
| user_emb_t = torch.transpose(user_emb, 1, 2) |
| dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t) |
| dp = torch.flatten(dp_unflatten, 1, -1) |
| input = torch.cat([dp, wide_preproc], 1) |
| fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1) |
| return (torch.sigmoid(fc1),) |
| )JIT"; |
| |
| const std::string trivial_model_1 = R"JIT( |
| def forward(self, a, b, c): |
| s = torch.tensor([[3, 3], [3, 3]]) |
| return a + b * c + s |
| )JIT"; |
| |
| const std::string leaky_relu_model_const = R"JIT( |
| def forward(self, input): |
| x = torch.leaky_relu(input, 0.1) |
| x = torch.leaky_relu(x, 0.1) |
| x = torch.leaky_relu(x, 0.1) |
| x = torch.leaky_relu(x, 0.1) |
| return torch.leaky_relu(x, 0.1) |
| )JIT"; |
| |
| const std::string leaky_relu_model = R"JIT( |
| def forward(self, input, neg_slope): |
| x = torch.leaky_relu(input, neg_slope) |
| x = torch.leaky_relu(x, neg_slope) |
| x = torch.leaky_relu(x, neg_slope) |
| x = torch.leaky_relu(x, neg_slope) |
| return torch.leaky_relu(x, neg_slope) |
| )JIT"; |
| |
| void import_libs( |
| std::shared_ptr<at::CompilationUnit> cu, |
| const std::string& class_name, |
| const std::shared_ptr<torch::jit::Source>& src, |
| const std::vector<at::IValue>& tensor_table) { |
| torch::jit::SourceImporter si( |
| cu, |
| &tensor_table, |
| [&](const std::string& /* unused */) |
| -> std::shared_ptr<torch::jit::Source> { return src; }, |
| /*version=*/2); |
| si.loadType(c10::QualifiedName(class_name)); |
| } |
| } // namespace |
| |
| torch::jit::Module getDeepAndWideSciptModel(int num_features) { |
| auto cu = std::make_shared<at::CompilationUnit>(); |
| std::vector<at::IValue> constantTable; |
| import_libs( |
| cu, |
| "__torch__.DeepAndWide", |
| std::make_shared<torch::jit::Source>(deep_wide_pt), |
| constantTable); |
| c10::QualifiedName base("__torch__"); |
| auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide")); |
| |
| torch::jit::Module mod(cu, clstype); |
| |
| mod.register_parameter("_mu", torch::randn({1, num_features}), false); |
| mod.register_parameter("_sigma", torch::randn({1, num_features}), false); |
| mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false); |
| mod.register_parameter("_fc_b", torch::randn({1}), false); |
| |
| // mod.dump(true, true, true); |
| return mod; |
| } |
| |
| torch::jit::Module getTrivialScriptModel() { |
| torch::jit::Module module("m"); |
| module.define(trivial_model_1); |
| return module; |
| } |
| |
| torch::jit::Module getLeakyReLUScriptModel() { |
| torch::jit::Module module("leaky_relu"); |
| module.define(leaky_relu_model); |
| return module; |
| } |
| |
| torch::jit::Module getLeakyReLUConstScriptModel() { |
| torch::jit::Module module("leaky_relu_const"); |
| module.define(leaky_relu_model_const); |
| return module; |
| } |
| |
| const std::string long_model = R"JIT( |
| def forward(self, a, b, c): |
| d = torch.relu(a * b) |
| e = torch.relu(a * c) |
| f = torch.relu(e * d) |
| g = torch.relu(f * f) |
| h = torch.relu(g * c) |
| return h |
| )JIT"; |
| |
| torch::jit::Module getLongScriptModel() { |
| torch::jit::Module module("m"); |
| module.define(long_model); |
| return module; |
| } |
| |
| const std::string signed_log1p_model = R"JIT( |
| def forward(self, a): |
| b = torch.abs(a) |
| c = torch.log1p(b) |
| d = torch.sign(a) |
| e = d * c |
| return e |
| )JIT"; |
| |
| torch::jit::Module getSignedLog1pModel() { |
| torch::jit::Module module("signed_log1p"); |
| module.define(signed_log1p_model); |
| return module; |
| } |