| from torchgen.api.lazy import LazyArgument, LazyIrSchema |
| from torchgen.api.types import OptionalCType |
| |
| |
| def ts_lowering_body(schema: LazyIrSchema) -> str: |
| # for now, we just want one IR class decl and soon after also the method defs |
| # and we use the functional version not out/inplace. |
| emplace_arguments = [] |
| |
| def get_value(arg: LazyArgument) -> str: |
| if isinstance(arg.lazy_type, OptionalCType): |
| return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" |
| return "loctx->GetOutputOp(operand(i++))" |
| |
| for arg in schema.positional_args: |
| if arg.is_lazy_value: |
| emplace_arguments.append(get_value(arg)) |
| continue |
| emplace_arguments.append(f'"{arg.name}", {arg.name}') |
| |
| emplace_arguments_str = "\n ".join( |
| [f"arguments.emplace_back({a});" for a in emplace_arguments] |
| ) |
| emplace_kwarg_values = [ |
| f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values |
| ] |
| emplace_kwarg_scalars = [ |
| f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars |
| ] |
| emplace_kwarguments = "\n ".join( |
| [ |
| f"kwarguments.emplace_back({a});" |
| for a in emplace_kwarg_values + emplace_kwarg_scalars |
| ] |
| ) |
| return f"""\ |
| std::vector<torch::jit::NamedValue> arguments; |
| std::vector<torch::jit::NamedValue> kwarguments; |
| arguments.reserve({len(emplace_arguments)}); |
| kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); |
| size_t i = 0; |
| {emplace_arguments_str} |
| {emplace_kwarguments} |
| torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); |
| TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); |
| |
| return {schema.aten_name}_out; |
| """ |