Revert "Revert "Make outout of cppbor::parse safely mutable"" am: 1398a70364 am: 5a54201c94 am: e95b86d4c7

Original change: https://android-review.googlesource.com/c/platform/system/libcppbor/+/2474924

Change-Id: Ib7c703327cf7808ef3c790aaab3dffd11bcb9203
Signed-off-by: Automerger Merge Worker <[email protected]>
diff --git a/src/cppbor_parse.cpp b/src/cppbor_parse.cpp
index a221cf4..a9a9b7d 100644
--- a/src/cppbor_parse.cpp
+++ b/src/cppbor_parse.cpp
@@ -16,8 +16,11 @@
 
 #include "cppbor_parse.h"
 
+#include <memory>
 #include <sstream>
 #include <stack>
+#include <type_traits>
+#include "cppbor.h"
 
 #ifndef __TRUSTY__
 #include <android-base/logging.h>
@@ -110,8 +113,11 @@
 
 class IncompleteItem {
   public:
+    static IncompleteItem* cast(Item* item);
+
     virtual ~IncompleteItem() {}
     virtual void add(std::unique_ptr<Item> item) = 0;
+    virtual std::unique_ptr<Item> finalize() && = 0;
 };
 
 class IncompleteArray : public Array, public IncompleteItem {
@@ -126,6 +132,12 @@
         mEntries.push_back(std::move(item));
     }
 
+    virtual std::unique_ptr<Item> finalize() && override {
+        // Use Array explicitly so the compiler picks the correct ctor overload
+        Array* thisArray = this;
+        return std::make_unique<Array>(std::move(*thisArray));
+    }
+
   private:
     size_t mSize;
 };
@@ -146,6 +158,10 @@
         }
     }
 
+    virtual std::unique_ptr<Item> finalize() && override {
+        return std::make_unique<Map>(std::move(*this));
+    }
+
   private:
     std::unique_ptr<Item> mKeyHeldForAdding;
     size_t mSize;
@@ -159,8 +175,36 @@
     size_t size() const override { return 1; }
 
     void add(std::unique_ptr<Item> item) override { mTaggedItem = std::move(item); }
+
+    virtual std::unique_ptr<Item> finalize() && override {
+        return std::make_unique<SemanticTag>(std::move(*this));
+    }
 };
 
+IncompleteItem* IncompleteItem::cast(Item* item) {
+    CHECK(item->isCompound());
+    // Semantic tag must be check first, because SemanticTag::type returns the wrapped item's type.
+    if (item->asSemanticTag()) {
+#if __has_feature(cxx_rtti)
+        CHECK(dynamic_cast<IncompleteSemanticTag*>(item));
+#endif
+        return static_cast<IncompleteSemanticTag*>(item);
+    } else if (item->type() == ARRAY) {
+#if __has_feature(cxx_rtti)
+        CHECK(dynamic_cast<IncompleteArray*>(item));
+#endif
+        return static_cast<IncompleteArray*>(item);
+    } else if (item->type() == MAP) {
+#if __has_feature(cxx_rtti)
+        CHECK(dynamic_cast<IncompleteMap*>(item));
+#endif
+        return static_cast<IncompleteMap*>(item);
+    } else {
+        CHECK(false);  // Impossible to get here.
+    }
+    return nullptr;
+}
+
 std::tuple<const uint8_t*, ParseClient*> handleEntries(size_t entryCount, const uint8_t* hdrBegin,
                                                        const uint8_t* pos, const uint8_t* end,
                                                        const std::string& typeName,
@@ -320,13 +364,15 @@
                                  const uint8_t* end) override {
         CHECK(item->isCompound() && item.get() == mParentStack.top());
         mParentStack.pop();
+        IncompleteItem* incompleteItem = IncompleteItem::cast(item.get());
+        std::unique_ptr<Item> finalizedItem = std::move(*incompleteItem).finalize();
 
         if (mParentStack.empty()) {
-            mTheItem = std::move(item);
+            mTheItem = std::move(finalizedItem);
             mPosition = end;
             return nullptr;  // We're done
         } else {
-            appendToLastParent(std::move(item));
+            appendToLastParent(std::move(finalizedItem));
             return this;
         }
     }
@@ -346,21 +392,7 @@
   private:
     void appendToLastParent(std::unique_ptr<Item> item) {
         auto parent = mParentStack.top();
-#if __has_feature(cxx_rtti)
-        assert(dynamic_cast<IncompleteItem*>(parent));
-#endif
-
-        IncompleteItem* parentItem{};
-        if (parent->type() == ARRAY) {
-            parentItem = static_cast<IncompleteArray*>(parent);
-        } else if (parent->type() == MAP) {
-            parentItem = static_cast<IncompleteMap*>(parent);
-        } else if (parent->asSemanticTag()) {
-            parentItem = static_cast<IncompleteSemanticTag*>(parent);
-        } else {
-            CHECK(false);  // Impossible to get here.
-        }
-        parentItem->add(std::move(item));
+        IncompleteItem::cast(parent)->add(std::move(item));
     }
 
     std::unique_ptr<Item> mTheItem;
diff --git a/tests/cppbor_test.cpp b/tests/cppbor_test.cpp
index b9a2f35..3c4d97f 100644
--- a/tests/cppbor_test.cpp
+++ b/tests/cppbor_test.cpp
@@ -1642,6 +1642,39 @@
     EXPECT_EQ(arr[0]->asTstr()->value(), "hello");
 }
 
+TEST(FullParserTest, MutableOutput) {
+    Array nestedArray("pizza", 31415);
+    Map nestedMap("array", std::move(nestedArray));
+    Array input(std::move(nestedMap));
+
+    auto [updatedItem, ignoredPos, ignoredMessage] = parse(input.encode());
+    updatedItem->asArray()->add(42);
+
+    // add some stuff to the map in our array
+    Map* parsedNestedMap = updatedItem->asArray()->get(0)->asMap();
+    ASSERT_NE(nullptr, parsedNestedMap);
+    parsedNestedMap->add("number", 10);
+    EXPECT_THAT(updatedItem->asArray()->get(0)->asMap()->get("number"), MatchesItem(Uint(10)));
+    parsedNestedMap->add(42, "the answer");
+    EXPECT_THAT(updatedItem->asArray()->get(0)->asMap()->get(42), MatchesItem(Tstr("the answer")));
+
+    // add some stuff to the array in the map that's in our array
+    Array* parsedNestedArray = parsedNestedMap->get("array")->asArray();
+    ASSERT_NE(nullptr, parsedNestedArray);
+    parsedNestedArray->add("pie");
+    EXPECT_THAT(
+        updatedItem->asArray()->get(0)->asMap()->get("array")->asArray()->get(2),
+        MatchesItem(Tstr("pie")));
+
+    // encode the mutated item, then ensure the CBOR is valid
+    const auto encodedUpdatedItem = updatedItem->encode();
+    auto [parsedUpdatedItem, pos, message] = parse(encodedUpdatedItem);
+    EXPECT_EQ("", message);
+    EXPECT_EQ(pos, encodedUpdatedItem.data() + encodedUpdatedItem.size());
+    ASSERT_NE(nullptr, parsedUpdatedItem);
+    EXPECT_THAT(parsedUpdatedItem, MatchesItem(ByRef(*updatedItem)));
+}
+
 TEST(FullParserTest, Map) {
     Map val("hello", -4, 3, Bstr("hi"));
 
@@ -1663,6 +1696,20 @@
     EXPECT_THAT(item, MatchesItem(ByRef(val)));
 }
 
+TEST(FullParserTest, TaggedArray) {
+    SemanticTag val(10, Array().add(42));
+
+    auto [item, pos, message] = parse(val.encode());
+    EXPECT_THAT(item, MatchesItem(ByRef(val)));
+}
+
+TEST(FullParserTest, TaggedMap) {
+    SemanticTag val(100, Map().add("foo", "bar"));
+
+    auto [item, pos, message] = parse(val.encode());
+    EXPECT_THAT(item, MatchesItem(ByRef(val)));
+}
+
 TEST(FullParserTest, Complex) {
     vector<uint8_t> vec = {0x01, 0x02, 0x08, 0x03};
     Map val("Outer1",