| #include <torch/csrc/jit/ir/irparser.h> |
| |
| #include <ATen/EmptyTensor.h> |
| #include <torch/csrc/jit/frontend/lexer.h> |
| #include <torch/csrc/jit/frontend/parse_string_literal.h> |
| #include <torch/csrc/jit/frontend/schema_type_parser.h> |
| #include <torch/csrc/jit/ir/ir.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #else |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/empty_strided.h> |
| #endif |
| |
| #include <string> |
| #include <vector> |
| |
| namespace torch::jit { |
| |
| struct VarWithType; |
| struct ParsedLiteral; |
| |
| class IRParser { |
| friend void parseIR( |
| const std::string& str, |
| torch::jit::Graph* graph, |
| std::unordered_map<std::string, Value*>& vmap, |
| bool parse_tensor_constants); |
| IRParser( |
| const std::string& str, |
| torch::jit::Graph* graph, |
| std::unordered_map<std::string, Value*>& vmap, |
| bool parse_tensor_constants) |
| : L(std::make_shared<Source>(str)), |
| g(graph), |
| vmap(vmap), |
| type_parser(L, /*parse_complete_tensor_types*/ true), |
| parse_tensor_constants_(parse_tensor_constants) {} |
| |
| std::string parseVar(); |
| VarWithType parseVarWithType(bool allow_optional = false); |
| ParsedLiteral parseScalarLiteral(Node* n); |
| |
| void parse(); |
| void parseGraphInputs(); |
| void parseReturnOperator(); |
| |
| void parseBlocks(Node* parentNode); |
| void parseBlock(Node* parentNode); |
| void parseBlockInputs(Block* b); |
| void parseBlockOutputs(Block* b); |
| |
| void parseOperatorsList(Block* b); |
| void parseOperator(Block* b); |
| void parseOperatorOutputs(std::vector<VarWithType>* outs); |
| std::string parseOperatorName(); |
| void parseOperatorInputs(Node* n); |
| void parseAttrs(Node* n); |
| void parseAttr(Node* n); |
| |
| void parseList( |
| int begin, |
| int sep, |
| int end, |
| const std::function<void()>& callback); |
| |
| void bypassTypeAnnotationList(); |
| |
| Value* findValueInVMap(const std::string& name); |
| |
| torch::jit::Lexer L; |
| torch::jit::Graph* g = nullptr; |
| std::unordered_map<std::string, Value*>& vmap; |
| SchemaTypeParser type_parser; |
| bool parse_tensor_constants_; |
| std::vector<Node*> deferred_tensor_value_initializations_; |
| std::vector<Node*> deferred_empty_container_initializations_; |
| }; |
| |
| struct ParsedLiteral { |
| ParsedLiteral() = default; |
| |
| AttributeKind k = AttributeKind::t; |
| |
| int64_t i = 0; |
| std::string s = ""; |
| double f = 0.0; |
| c10::complex<double> c = c10::complex<double>(0, 0); |
| TypePtr ty; |
| std::vector<int64_t> is; |
| std::vector<std::string> ss; |
| std::vector<double> fs; |
| std::vector<c10::complex<double>> cs; |
| std::vector<TypePtr> tys; |
| }; |
| |
| struct VarWithType { |
| VarWithType() = default; |
| std::string name; |
| TypePtr type; |
| }; |
| |
| void parseIR( |
| const std::string& str, |
| torch::jit::Graph* graph, |
| std::unordered_map<std::string, Value*>& vmap, |
| bool parse_tensor_constants) { |
| torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants); |
| p.parse(); |
| } |
| |
| void parseIR( |
| const std::string& str, |
| torch::jit::Graph* graph, |
| bool parse_tensor_constants) { |
| std::unordered_map<std::string, Value*> vmap; |
| parseIR(str, graph, vmap, parse_tensor_constants); |
| } |
| |
| VarWithType IRParser::parseVarWithType(bool allow_optional) { |
| VarWithType r; |
| r.name = parseVar(); |
| if (allow_optional) { |
| r.type = nullptr; |
| } else { |
| r.type = TensorType::get(); |
| } |
| if (L.nextIf(':')) { |
| auto type_alias = type_parser.parseType(); |
| AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled"); |
| r.type = type_alias.first; |
| } |
| return r; |
| } |
| |
| std::string IRParser::parseVar() { |
| L.expect('%'); |
| std::string name; |
| bool continue_parsing; |
| do { |
| if (L.cur().kind == TK_IDENT) { |
| name += L.expect(TK_IDENT).text(); |
| } else { |
| name += L.expect(TK_NUMBER).text(); |
| } |
| continue_parsing = false; |
| if (L.nextIf('.')) { |
| continue_parsing = true; |
| name += '.'; |
| } else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') { |
| continue_parsing = true; |
| } |
| } while (continue_parsing); |
| return name; |
| } |
| |
| void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) { |
| if (L.cur().kind != '%') { |
| return; |
| } |
| parseList(TK_NOTHING, ',', TK_NOTHING, [&] { |
| outs->push_back(parseVarWithType(true)); |
| }); |
| L.expect('='); |
| } |
| |
| // Parse string or numeric literal and return it along with its type. |
| ParsedLiteral IRParser::parseScalarLiteral(Node* n) { |
| auto token = L.cur(); |
| std::string str; |
| std::pair<TypePtr, c10::optional<c10::AliasInfo>> type_alias; |
| ParsedLiteral r; |
| switch (token.kind) { |
| case TK_STRINGLITERAL: |
| r.k = AttributeKind::s; |
| r.s = parseStringLiteral(token.range, token.text()); |
| L.next(); |
| return r; |
| case '-': |
| str = "-"; |
| L.next(); |
| if (L.cur().kind != TK_NUMBER) { |
| throw ErrorReport(token.range) |
| << "Expected a number after '-' but got:" << token.text(); |
| } |
| // Fallthrough |
| case TK_NUMBER: |
| str += L.cur().text(); |
| if (str.find('j') != std::string::npos) { |
| r.k = AttributeKind::c; |
| double imag = 0.0f; |
| try { |
| imag = c10::stod(str.substr(0, str.size() - 1)); |
| } catch (const std::invalid_argument& e) { |
| throw ErrorReport(token.range) |
| << "Number cannot be converted to double"; |
| } catch (const std::out_of_range& e) { |
| throw ErrorReport(token.range) |
| << "Number is too long to be represented in type double"; |
| } |
| r.c = c10::complex<double>(0, imag); |
| } else if ( |
| str.find('.') != std::string::npos || |
| str.find('e') != std::string::npos) { |
| r.k = AttributeKind::f; |
| try { |
| r.f = c10::stod(str); |
| } catch (const std::invalid_argument& e) { |
| throw ErrorReport(token.range) |
| << "Number cannot be converted to double"; |
| } catch (const std::out_of_range& e) { |
| throw ErrorReport(token.range) |
| << "Number is too long to be represented in type double"; |
| } |
| } else { |
| r.k = AttributeKind::i; |
| try { |
| r.i = c10::stoll(str); |
| } catch (const std::invalid_argument& e) { |
| throw ErrorReport(token.range) |
| << "Number cannot be converted to integer"; |
| } catch (const std::out_of_range& e) { |
| throw ErrorReport(token.range) << "Number is too big"; |
| } |
| } |
| L.next(); |
| return r; |
| case TK_IDENT: |
| // Type literal |
| r.k = AttributeKind::ty; |
| type_alias = type_parser.parseType(); |
| AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled"); |
| r.ty = type_alias.first; |
| return r; |
| case '<': { |
| L.next(); |
| auto text = L.expect(TK_IDENT); |
| if (text.text() != "Tensor") { |
| throw ErrorReport(token.range) |
| << "Could not parse literal" << token.text(); |
| } |
| if (!parse_tensor_constants_) { |
| throw ErrorReport(token.range) |
| << "Tensor constant encountered but `parse_tensor_constants` set to false" |
| << token.text(); |
| } |
| L.expect('>'); |
| // these values will be set with randomly initialized data in |
| // a post processing pass; |
| deferred_tensor_value_initializations_.push_back(n); |
| r.k = AttributeKind::t; |
| return r; |
| } |
| case '{': { |
| L.next(); |
| if (L.cur().kind == '-') { |
| L.next(); |
| } |
| auto text = L.expect(TK_NUMBER); |
| if (!parse_tensor_constants_) { |
| throw ErrorReport(token.range) |
| << "Single-element tensor constant encountered but " |
| << "`parse_tensor_constants` is set to false " << token.text(); |
| } |
| L.expect('}'); |
| deferred_tensor_value_initializations_.push_back(n); |
| r.k = AttributeKind::t; |
| return r; |
| } |
| default: |
| throw ErrorReport(token.range) |
| << "Could not parse literal" << token.text(); |
| } |
| } |
| |
| void IRParser::bypassTypeAnnotationList() { |
| int depth = 0; |
| bool bypassed_list = false; |
| while (depth != 0 || !bypassed_list) { |
| if (L.cur().kind == '[') { |
| bypassed_list = true; |
| depth++; |
| } else if (L.cur().kind == ']') { |
| depth--; |
| } |
| L.next(); |
| } |
| } |
| |
| /** \brief Parse attribute and add it to the node N. |
| * |
| * The function determines the attribute type (string, int, float, complex, list |
| * of strings, list of ints, list of floats, list of complex, and a list of |
| * tensors (currently only for empty lists)). An attribute looks like the |
| * following: AttrName=AttrValue Where AttrValue can be a list or a scalar |
| * literal, e.g.: size = 27 name = "Bob" coefs = [1.2, 3.4, 0.6] |
| */ |
| void IRParser::parseAttr(Node* n) { |
| std::string attrname = L.expect(TK_IDENT).text(); |
| L.expect('='); |
| if (L.cur().kind == '[') { |
| // list |
| AttributeKind k = AttributeKind::ts; |
| c10::List<int64_t> is; |
| c10::List<std::string> ss; |
| c10::List<double> fs; |
| c10::List<c10::complex<double>> cs; |
| std::vector<TypePtr> tys; |
| int elem_num = 0; |
| parseList('[', ',', ']', [&] { |
| ParsedLiteral r = parseScalarLiteral(n); |
| switch (r.k) { |
| case AttributeKind::s: |
| ss.push_back(r.s); |
| AT_ASSERT(!elem_num++ || k == AttributeKind::ss); |
| k = AttributeKind::ss; |
| break; |
| case AttributeKind::i: |
| is.push_back(r.i); |
| AT_ASSERT(!elem_num++ || k == AttributeKind::is); |
| k = AttributeKind::is; |
| break; |
| case AttributeKind::f: |
| fs.push_back(r.f); |
| AT_ASSERT(!elem_num++ || k == AttributeKind::fs); |
| k = AttributeKind::fs; |
| break; |
| case AttributeKind::c: |
| cs.push_back(r.c); |
| AT_ASSERT(!elem_num++ || k == AttributeKind::cs); |
| k = AttributeKind::cs; |
| break; |
| case AttributeKind::ty: |
| tys.push_back(r.ty); |
| AT_ASSERT(!elem_num++ || k == AttributeKind::tys); |
| k = AttributeKind::tys; |
| break; |
| default: |
| throw ErrorReport(L.cur().range) << "Unexpected attr type"; |
| } |
| }); |
| switch (k) { |
| case AttributeKind::ts: |
| n->ival_(Symbol::attr(attrname), IValue()); |
| break; |
| case AttributeKind::ss: |
| n->ival_(Symbol::attr(attrname), IValue(ss)); |
| break; |
| case AttributeKind::fs: |
| n->ival_(Symbol::attr(attrname), IValue(fs)); |
| break; |
| case AttributeKind::cs: |
| n->ival_(Symbol::attr(attrname), IValue(cs)); |
| break; |
| case AttributeKind::is: |
| n->ival_(Symbol::attr(attrname), IValue(is)); |
| break; |
| case AttributeKind::tys: |
| n->tys_(Symbol::attr(attrname), tys); |
| break; |
| default: |
| throw ErrorReport(L.cur().range) << "Unexpected attr type"; |
| } |
| } else if (L.cur().text() == "annotate") { |
| L.next(); |
| L.expect('('); |
| auto type = L.cur().text(); |
| if (type != "List" && type != "Dict") { |
| throw ErrorReport(L.cur().range) |
| << "Unexpected annotation (only List and Dict can be parsed)"; |
| } |
| L.next(); |
| // ignore the annotations on the IValue constants, and instead recover |
| // type from the Node output |
| // Note: we could also use script_type_parser |
| bypassTypeAnnotationList(); |
| L.expect(','); |
| // expect an empty definition (note - this isn't always true) |
| if (type == "Dict") { |
| L.expect('{'); |
| L.expect('}'); |
| } else if (type == "List") { |
| L.expect('['); |
| L.expect(']'); |
| } |
| L.expect(')'); |
| deferred_empty_container_initializations_.push_back(n); |
| } else { |
| // scalar |
| ParsedLiteral r = parseScalarLiteral(n); |
| switch (r.k) { |
| case AttributeKind::s: |
| n->s_(Symbol::attr(attrname), r.s); |
| break; |
| case AttributeKind::i: |
| n->i_(Symbol::attr(attrname), r.i); |
| break; |
| case AttributeKind::f: |
| n->f_(Symbol::attr(attrname), r.f); |
| break; |
| case AttributeKind::c: |
| n->c_(Symbol::attr(attrname), r.c); |
| break; |
| case AttributeKind::ty: |
| n->ty_(Symbol::attr(attrname), r.ty); |
| break; |
| case AttributeKind::t: |
| // initialized with random data later |
| break; |
| default: |
| throw ErrorReport(L.cur().range) << "Unexpected attr type"; |
| } |
| return; |
| } |
| } |
| |
| void IRParser::parseAttrs(Node* n) { |
| parseList('[', ',', ']', [&] { parseAttr(n); }); |
| } |
| |
| void IRParser::parseOperatorInputs(Node* n) { |
| if (L.cur().kind == '[') { |
| parseAttrs(n); |
| } |
| parseList('(', ',', ')', [&] { |
| std::string var_name = parseVar(); |
| n->addInput(findValueInVMap(var_name)); |
| }); |
| } |
| |
| void IRParser::parseBlocks(Node* parentNode) { |
| L.expect(TK_INDENT); |
| while (L.cur().kind != TK_DEDENT) { |
| parseBlock(parentNode); |
| } |
| L.expect(TK_DEDENT); |
| } |
| |
| void IRParser::parseBlockInputs(Block* b) { |
| parseList('(', ',', ')', [&] { |
| VarWithType v = parseVarWithType(); |
| // If the name isn't valid, don't use it |
| std::string uniq_name = Value::isValidName(v.name) ? v.name : ""; |
| vmap[v.name] = b->addInput(uniq_name); |
| vmap[v.name]->setType(v.type); |
| }); |
| } |
| |
| void IRParser::parseBlockOutputs(Block* b) { |
| L.expect(TK_ARROW); |
| parseList('(', ',', ')', [&] { |
| std::string var_name = parseVar(); |
| b->registerOutput(findValueInVMap(var_name)); |
| }); |
| L.expect(TK_NEWLINE); |
| L.expect(TK_DEDENT); |
| } |
| |
| /** \brief Parse a block. |
| * |
| * It should look like the following: |
| * blockName(input1, input2, input3, ...): |
| * op1 |
| * op2 |
| * ... |
| * opN |
| * -> (output1, output2, output3, ...) |
| */ |
| void IRParser::parseBlock(Node* parentNode) { |
| Block* b = parentNode->addBlock(); |
| L.expect(TK_IDENT).text(); // Block name is not used anywhere. |
| parseBlockInputs(b); |
| L.expect(':'); |
| parseOperatorsList(b); |
| parseBlockOutputs(b); |
| } |
| |
| /** \brief Parse a list of statements. |
| * |
| * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or |
| * TK_ARROW. |
| */ |
| void IRParser::parseOperatorsList(Block* b) { |
| L.expect(TK_INDENT); |
| while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) { |
| parseOperator(b); |
| } |
| } |
| |
| std::string IRParser::parseOperatorName() { |
| std::string name = L.expect(TK_IDENT).text(); |
| L.expect(':'); |
| L.expect(':'); |
| name += "::" + L.expect(TK_IDENT).text(); |
| return name; |
| } |
| |
| /** \brief Parse a statement. |
| * |
| * It should look like the following: |
| * <outputs> = NodeName[<attributes>](<inputs>) |
| * <blocks> |
| * Outputs, blocks and attributes are optional. |
| */ |
| void IRParser::parseOperator(Block* b) { |
| // Parse lefthand side. |
| std::vector<VarWithType> outs; |
| parseOperatorOutputs(&outs); |
| |
| // Parse the name and create the corresponding node in the graph. |
| auto source_range = L.cur().range; |
| std::string name = parseOperatorName(); |
| Node* n = g->create(Symbol::fromQualString(name), {}, outs.size()) |
| ->setSourceRange(source_range); |
| |
| // Parse attributes and inputs. |
| parseOperatorInputs(n); |
| |
| const FunctionSchema* schema = n->maybeSchema(); |
| |
| // Register outputs. |
| unsigned idx = 0; |
| for (const VarWithType& v : outs) { |
| vmap[v.name] = n->outputs()[idx]; |
| if (schema && !schema->is_varret()) { |
| TORCH_CHECK( |
| schema->returns().size() > idx, |
| "Operator parsing error: out of bounds access at ", |
| idx, |
| " to schema->returns() which size is ", |
| schema->returns().size(), |
| " in size"); |
| auto schema_return_type = schema->returns().at(idx).type(); |
| if (!v.type) { |
| vmap[v.name]->setType(schema_return_type); |
| } else { |
| // Don't currently support checking against type variables |
| // TODO: support? |
| if (!schema_return_type->hasFreeVariables() && |
| !v.type->isSubtypeOf(*schema_return_type)) { |
| throw ErrorReport(source_range) |
| << "Annotated type " << v.type->repr_str() |
| << " does not match schema type " |
| << schema_return_type->repr_str() << " for operator " << *schema; |
| } |
| vmap[v.name]->setType(v.type); |
| } |
| } else { |
| vmap[v.name]->setType(v.type ? v.type : TensorType::get()); |
| } |
| idx++; |
| } |
| |
| // Insert the new node into block B. |
| b->appendNode(n); |
| |
| // If the statement has nested blocks, parse them: |
| if (L.cur().kind == TK_INDENT) { |
| parseBlocks(n); |
| } |
| L.nextIf(TK_NEWLINE); |
| } |
| |
| void IRParser::parseGraphInputs() { |
| parseList('(', ',', ')', [&] { |
| VarWithType v = parseVarWithType(); |
| // If the name isn't valid, don't use it |
| std::string uniq_name = Value::isValidName(v.name) ? v.name : ""; |
| vmap[v.name] = g->addInput(uniq_name); |
| vmap[v.name]->setType(v.type); |
| }); |
| } |
| |
| /** \brief Parse return statement. |
| * |
| * It should look like the following: |
| * return (x : TypeX, y : TypeY, z, ...) |
| */ |
| void IRParser::parseReturnOperator() { |
| L.expect(TK_RETURN); |
| |
| // Parse output names and types |
| parseList('(', ',', ')', [&] { |
| std::string var_name = parseVar(); |
| g->registerOutput(findValueInVMap(var_name)); |
| }); |
| |
| // Consume ending tokens |
| if (L.cur().kind != TK_EOF) { |
| L.expect(TK_NEWLINE); |
| L.expect(TK_DEDENT); |
| } |
| } |
| |
| /** \brief Parse entire graph. |
| * |
| * It should look like the following: |
| * graphName (input1, input2, ... inputN): |
| * op1 |
| * op2 |
| * ... |
| * opN |
| * return (output1, output2, ... outputN) |
| */ |
| void IRParser::parse() { |
| // Parse graph definition, it should look like the following: |
| // graphName (input1, input2, ... inputN): |
| std::string graphName = L.expect(TK_IDENT).text(); |
| parseGraphInputs(); |
| L.expect(':'); |
| |
| // After the definition we should have a list of statements, parse it: |
| parseOperatorsList(g->block()); |
| |
| // The last statement should be return, which specifies graph outputs |
| parseReturnOperator(); |
| |
| for (Node* n : deferred_tensor_value_initializations_) { |
| auto type = n->output()->type()->expect<TensorType>(); |
| auto tt = n->output()->type()->cast<TensorType>(); |
| TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n); |
| auto sizes = tt->sizes().concrete_sizes(); |
| TORCH_INTERNAL_ASSERT(sizes); |
| auto strides = tt->strides().concrete_sizes(); |
| TORCH_INTERNAL_ASSERT(strides); |
| auto device = tt->device(); |
| TORCH_INTERNAL_ASSERT(device); |
| auto dtype = tt->scalarType(); |
| TORCH_INTERNAL_ASSERT(dtype); |
| auto options = at::TensorOptions(*device).dtype(*dtype); |
| auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options)); |
| (void)t; |
| } |
| |
| for (Node* n : deferred_empty_container_initializations_) { |
| auto type = n->output()->type(); |
| IValue val; |
| if (type->kind() == TypeKind::ListType) { |
| val = c10::impl::GenericList(type->containedType(0)); |
| } else if (type->kind() == TypeKind::DictType) { |
| val = c10::impl::GenericDict( |
| type->containedType(0), type->containedType(1)); |
| } |
| n->ival_(attr::value, val); |
| } |
| } |
| |
| void IRParser::parseList( |
| int begin, |
| int sep, |
| int end, |
| const std::function<void()>& callback) { |
| if (begin != TK_NOTHING) { |
| L.expect(begin); |
| } |
| if (L.cur().kind != end) { |
| do { |
| callback(); |
| } while (L.nextIf(sep)); |
| } |
| if (end != TK_NOTHING) { |
| L.expect(end); |
| } |
| } |
| |
| Value* IRParser::findValueInVMap(const std::string& name) { |
| if (!vmap.count(name)) { |
| throw ErrorReport(L.cur().range) |
| << "Cannot find a variable with name '" << name << "'"; |
| } |
| return vmap.at(name); |
| } |
| |
| } // namespace torch::jit |