| import torch.fx as fx |
| |
| def set_trace(gm: fx.GraphModule) -> fx.GraphModule: |
| """ |
| Sets a breakpoint in `gm`'s generated python code. It drops into pdb when |
| `gm` gets run. |
| |
| Args: |
| gm: graph module to insert breakpoint. It is then recompiled for it to |
| take effect. |
| |
| Returns: |
| the `gm` with breakpoint inserted. |
| """ |
| def insert_pdb(body): |
| return ["import pdb; pdb.set_trace()\n", *body] |
| |
| with gm.graph.on_generate_code( |
| make_transformer=lambda cur_transform: ( |
| # new code transformer to register |
| lambda body: ( |
| insert_pdb( |
| cur_transform(body) if cur_transform |
| else body |
| ) |
| ) |
| ) |
| ): |
| gm.recompile() |
| |
| return gm |