| /* |
| * Copyright 2019 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "cppbor.h" |
| |
| #include <inttypes.h> |
| #include <openssl/sha.h> |
| #include <cstdint> |
| |
| #include "cppbor_parse.h" |
| |
| using std::string; |
| using std::vector; |
| |
| #ifndef __TRUSTY__ |
| #include <android-base/logging.h> |
| #define LOG_TAG "CppBor" |
| #else |
| #define CHECK(x) (void)(x) |
| #endif |
| |
| namespace cppbor { |
| |
| namespace { |
| |
| template <typename T, typename Iterator, typename = std::enable_if<std::is_unsigned<T>::value>> |
| Iterator writeBigEndian(T value, Iterator pos) { |
| for (unsigned i = 0; i < sizeof(value); ++i) { |
| *pos++ = static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1))); |
| value = static_cast<T>(value << 8); |
| } |
| return pos; |
| } |
| |
| template <typename T, typename = std::enable_if<std::is_unsigned<T>::value>> |
| void writeBigEndian(T value, std::function<void(uint8_t)>& cb) { |
| for (unsigned i = 0; i < sizeof(value); ++i) { |
| cb(static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1)))); |
| value = static_cast<T>(value << 8); |
| } |
| } |
| |
| bool cborAreAllElementsNonCompound(const Item* compoundItem) { |
| if (compoundItem->type() == ARRAY) { |
| const Array* array = compoundItem->asArray(); |
| for (size_t n = 0; n < array->size(); n++) { |
| const Item* entry = (*array)[n].get(); |
| switch (entry->type()) { |
| case ARRAY: |
| case MAP: |
| return false; |
| default: |
| break; |
| } |
| } |
| } else { |
| const Map* map = compoundItem->asMap(); |
| for (auto& [keyEntry, valueEntry] : *map) { |
| switch (keyEntry->type()) { |
| case ARRAY: |
| case MAP: |
| return false; |
| default: |
| break; |
| } |
| switch (valueEntry->type()) { |
| case ARRAY: |
| case MAP: |
| return false; |
| default: |
| break; |
| } |
| } |
| } |
| return true; |
| } |
| |
| bool prettyPrintInternal(const Item* item, string& out, size_t indent, size_t maxBStrSize, |
| const vector<string>& mapKeysToNotPrint) { |
| if (!item) { |
| out.append("<NULL>"); |
| return false; |
| } |
| |
| char buf[80]; |
| |
| string indentString(indent, ' '); |
| |
| size_t tagCount = item->semanticTagCount(); |
| while (tagCount > 0) { |
| --tagCount; |
| snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount)); |
| out.append(buf); |
| } |
| |
| switch (item->type()) { |
| case SEMANTIC: |
| // Handled above. |
| break; |
| |
| case UINT: |
| snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue()); |
| out.append(buf); |
| break; |
| |
| case NINT: |
| snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value()); |
| out.append(buf); |
| break; |
| |
| case BSTR: { |
| const uint8_t* valueData; |
| size_t valueSize; |
| const Bstr* bstr = item->asBstr(); |
| if (bstr != nullptr) { |
| const vector<uint8_t>& value = bstr->value(); |
| valueData = value.data(); |
| valueSize = value.size(); |
| } else { |
| const ViewBstr* viewBstr = item->asViewBstr(); |
| assert(viewBstr != nullptr); |
| |
| valueData = viewBstr->view().data(); |
| valueSize = viewBstr->view().size(); |
| } |
| |
| if (valueSize > maxBStrSize) { |
| unsigned char digest[SHA_DIGEST_LENGTH]; |
| SHA_CTX ctx; |
| SHA1_Init(&ctx); |
| SHA1_Update(&ctx, valueData, valueSize); |
| SHA1_Final(digest, &ctx); |
| char buf2[SHA_DIGEST_LENGTH * 2 + 1]; |
| for (size_t n = 0; n < SHA_DIGEST_LENGTH; n++) { |
| snprintf(buf2 + n * 2, 3, "%02x", digest[n]); |
| } |
| snprintf(buf, sizeof(buf), "<bstr size=%zd sha1=%s>", valueSize, buf2); |
| out.append(buf); |
| } else { |
| out.append("{"); |
| for (size_t n = 0; n < valueSize; n++) { |
| if (n > 0) { |
| out.append(", "); |
| } |
| snprintf(buf, sizeof(buf), "0x%02x", valueData[n]); |
| out.append(buf); |
| } |
| out.append("}"); |
| } |
| } break; |
| |
| case TSTR: |
| out.append("'"); |
| { |
| // TODO: escape "'" characters |
| if (item->asTstr() != nullptr) { |
| out.append(item->asTstr()->value().c_str()); |
| } else { |
| const ViewTstr* viewTstr = item->asViewTstr(); |
| assert(viewTstr != nullptr); |
| out.append(viewTstr->view()); |
| } |
| } |
| out.append("'"); |
| break; |
| |
| case ARRAY: { |
| const Array* array = item->asArray(); |
| if (array->size() == 0) { |
| out.append("[]"); |
| } else if (cborAreAllElementsNonCompound(array)) { |
| out.append("["); |
| for (size_t n = 0; n < array->size(); n++) { |
| if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize, |
| mapKeysToNotPrint)) { |
| return false; |
| } |
| out.append(", "); |
| } |
| out.append("]"); |
| } else { |
| out.append("[\n" + indentString); |
| for (size_t n = 0; n < array->size(); n++) { |
| out.append(" "); |
| if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize, |
| mapKeysToNotPrint)) { |
| return false; |
| } |
| out.append(",\n" + indentString); |
| } |
| out.append("]"); |
| } |
| } break; |
| |
| case MAP: { |
| const Map* map = item->asMap(); |
| |
| if (map->size() == 0) { |
| out.append("{}"); |
| } else { |
| out.append("{\n" + indentString); |
| for (auto& [map_key, map_value] : *map) { |
| out.append(" "); |
| |
| if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize, |
| mapKeysToNotPrint)) { |
| return false; |
| } |
| out.append(" : "); |
| if (map_key->type() == TSTR && |
| std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(), |
| map_key->asTstr()->value()) != mapKeysToNotPrint.end()) { |
| out.append("<not printed>"); |
| } else { |
| if (!prettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize, |
| mapKeysToNotPrint)) { |
| return false; |
| } |
| } |
| out.append(",\n" + indentString); |
| } |
| out.append("}"); |
| } |
| } break; |
| |
| case SIMPLE: |
| const Bool* asBool = item->asSimple()->asBool(); |
| const Null* asNull = item->asSimple()->asNull(); |
| if (asBool != nullptr) { |
| out.append(asBool->value() ? "true" : "false"); |
| } else if (asNull != nullptr) { |
| out.append("null"); |
| } else { |
| #ifndef __TRUSTY__ |
| LOG(ERROR) << "Only boolean/null is implemented for SIMPLE"; |
| #endif // __TRUSTY__ |
| return false; |
| } |
| break; |
| } |
| |
| return true; |
| } |
| |
| } // namespace |
| |
| size_t headerSize(uint64_t addlInfo) { |
| if (addlInfo < ONE_BYTE_LENGTH) return 1; |
| if (addlInfo <= std::numeric_limits<uint8_t>::max()) return 2; |
| if (addlInfo <= std::numeric_limits<uint16_t>::max()) return 3; |
| if (addlInfo <= std::numeric_limits<uint32_t>::max()) return 5; |
| return 9; |
| } |
| |
| uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) { |
| size_t sz = headerSize(addlInfo); |
| if (end - pos < static_cast<ssize_t>(sz)) return nullptr; |
| switch (sz) { |
| case 1: |
| *pos++ = type | static_cast<uint8_t>(addlInfo); |
| return pos; |
| case 2: |
| *pos++ = type | ONE_BYTE_LENGTH; |
| *pos++ = static_cast<uint8_t>(addlInfo); |
| return pos; |
| case 3: |
| *pos++ = type | TWO_BYTE_LENGTH; |
| return writeBigEndian(static_cast<uint16_t>(addlInfo), pos); |
| case 5: |
| *pos++ = type | FOUR_BYTE_LENGTH; |
| return writeBigEndian(static_cast<uint32_t>(addlInfo), pos); |
| case 9: |
| *pos++ = type | EIGHT_BYTE_LENGTH; |
| return writeBigEndian(addlInfo, pos); |
| default: |
| CHECK(false); // Impossible to get here. |
| return nullptr; |
| } |
| } |
| |
| void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) { |
| size_t sz = headerSize(addlInfo); |
| switch (sz) { |
| case 1: |
| encodeCallback(type | static_cast<uint8_t>(addlInfo)); |
| break; |
| case 2: |
| encodeCallback(type | ONE_BYTE_LENGTH); |
| encodeCallback(static_cast<uint8_t>(addlInfo)); |
| break; |
| case 3: |
| encodeCallback(type | TWO_BYTE_LENGTH); |
| writeBigEndian(static_cast<uint16_t>(addlInfo), encodeCallback); |
| break; |
| case 5: |
| encodeCallback(type | FOUR_BYTE_LENGTH); |
| writeBigEndian(static_cast<uint32_t>(addlInfo), encodeCallback); |
| break; |
| case 9: |
| encodeCallback(type | EIGHT_BYTE_LENGTH); |
| writeBigEndian(addlInfo, encodeCallback); |
| break; |
| default: |
| CHECK(false); // Impossible to get here. |
| } |
| } |
| |
| bool Item::operator==(const Item& other) const& { |
| if (type() != other.type()) return false; |
| switch (type()) { |
| case UINT: |
| return *asUint() == *(other.asUint()); |
| case NINT: |
| return *asNint() == *(other.asNint()); |
| case BSTR: |
| if (asBstr() != nullptr && other.asBstr() != nullptr) { |
| return *asBstr() == *(other.asBstr()); |
| } |
| if (asViewBstr() != nullptr && other.asViewBstr() != nullptr) { |
| return *asViewBstr() == *(other.asViewBstr()); |
| } |
| // Interesting corner case: comparing a Bstr and ViewBstr with |
| // identical contents. The function currently returns false for |
| // this case. |
| // TODO: if it should return true, this needs a deep comparison |
| return false; |
| case TSTR: |
| if (asTstr() != nullptr && other.asTstr() != nullptr) { |
| return *asTstr() == *(other.asTstr()); |
| } |
| if (asViewTstr() != nullptr && other.asViewTstr() != nullptr) { |
| return *asViewTstr() == *(other.asViewTstr()); |
| } |
| // Same corner case as Bstr |
| return false; |
| case ARRAY: |
| return *asArray() == *(other.asArray()); |
| case MAP: |
| return *asMap() == *(other.asMap()); |
| case SIMPLE: |
| return *asSimple() == *(other.asSimple()); |
| case SEMANTIC: |
| return *asSemanticTag() == *(other.asSemanticTag()); |
| default: |
| CHECK(false); // Impossible to get here. |
| return false; |
| } |
| } |
| |
| Nint::Nint(int64_t v) : mValue(v) { |
| CHECK(v < 0); |
| } |
| |
| bool Simple::operator==(const Simple& other) const& { |
| if (simpleType() != other.simpleType()) return false; |
| |
| switch (simpleType()) { |
| case BOOLEAN: |
| return *asBool() == *(other.asBool()); |
| case NULL_T: |
| return true; |
| default: |
| CHECK(false); // Impossible to get here. |
| return false; |
| } |
| } |
| |
| uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(mValue.size(), pos, end); |
| if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr; |
| return std::copy(mValue.begin(), mValue.end(), pos); |
| } |
| |
| void Bstr::encodeValue(EncodeCallback encodeCallback) const { |
| for (auto c : mValue) { |
| encodeCallback(c); |
| } |
| } |
| |
| uint8_t* ViewBstr::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(mView.size(), pos, end); |
| if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr; |
| return std::copy(mView.begin(), mView.end(), pos); |
| } |
| |
| void ViewBstr::encodeValue(EncodeCallback encodeCallback) const { |
| for (auto c : mView) { |
| encodeCallback(static_cast<uint8_t>(c)); |
| } |
| } |
| |
| uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(mValue.size(), pos, end); |
| if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr; |
| return std::copy(mValue.begin(), mValue.end(), pos); |
| } |
| |
| void Tstr::encodeValue(EncodeCallback encodeCallback) const { |
| for (auto c : mValue) { |
| encodeCallback(static_cast<uint8_t>(c)); |
| } |
| } |
| |
| uint8_t* ViewTstr::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(mView.size(), pos, end); |
| if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr; |
| return std::copy(mView.begin(), mView.end(), pos); |
| } |
| |
| void ViewTstr::encodeValue(EncodeCallback encodeCallback) const { |
| for (auto c : mView) { |
| encodeCallback(static_cast<uint8_t>(c)); |
| } |
| } |
| |
| bool Array::operator==(const Array& other) const& { |
| return size() == other.size() |
| // Can't use vector::operator== because the contents are pointers. std::equal lets us |
| // provide a predicate that does the dereferencing. |
| && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(), |
| [](auto& a, auto& b) -> bool { return *a == *b; }); |
| } |
| |
| uint8_t* Array::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(size(), pos, end); |
| if (!pos) return nullptr; |
| for (auto& entry : mEntries) { |
| pos = entry->encode(pos, end); |
| if (!pos) return nullptr; |
| } |
| return pos; |
| } |
| |
| void Array::encode(EncodeCallback encodeCallback) const { |
| encodeHeader(size(), encodeCallback); |
| for (auto& entry : mEntries) { |
| entry->encode(encodeCallback); |
| } |
| } |
| |
| std::unique_ptr<Item> Array::clone() const { |
| auto res = std::make_unique<Array>(); |
| for (size_t i = 0; i < mEntries.size(); i++) { |
| res->add(mEntries[i]->clone()); |
| } |
| return res; |
| } |
| |
| bool Map::operator==(const Map& other) const& { |
| return size() == other.size() |
| // Can't use vector::operator== because the contents are pairs of pointers. std::equal |
| // lets us provide a predicate that does the dereferencing. |
| && std::equal(begin(), end(), other.begin(), [](auto& a, auto& b) { |
| return *a.first == *b.first && *a.second == *b.second; |
| }); |
| } |
| |
| uint8_t* Map::encode(uint8_t* pos, const uint8_t* end) const { |
| pos = encodeHeader(size(), pos, end); |
| if (!pos) return nullptr; |
| for (auto& entry : mEntries) { |
| pos = entry.first->encode(pos, end); |
| if (!pos) return nullptr; |
| pos = entry.second->encode(pos, end); |
| if (!pos) return nullptr; |
| } |
| return pos; |
| } |
| |
| void Map::encode(EncodeCallback encodeCallback) const { |
| encodeHeader(size(), encodeCallback); |
| for (auto& entry : mEntries) { |
| entry.first->encode(encodeCallback); |
| entry.second->encode(encodeCallback); |
| } |
| } |
| |
| bool Map::keyLess(const Item* a, const Item* b) { |
| // CBOR map canonicalization rules are: |
| |
| // 1. If two keys have different lengths, the shorter one sorts earlier. |
| if (a->encodedSize() < b->encodedSize()) return true; |
| if (a->encodedSize() > b->encodedSize()) return false; |
| |
| // 2. If two keys have the same length, the one with the lower value in (byte-wise) lexical |
| // order sorts earlier. This requires encoding both items. |
| auto encodedA = a->encode(); |
| auto encodedB = b->encode(); |
| |
| return std::lexicographical_compare(encodedA.begin(), encodedA.end(), // |
| encodedB.begin(), encodedB.end()); |
| } |
| |
| void recursivelyCanonicalize(std::unique_ptr<Item>& item) { |
| switch (item->type()) { |
| case UINT: |
| case NINT: |
| case BSTR: |
| case TSTR: |
| case SIMPLE: |
| return; |
| |
| case ARRAY: |
| std::for_each(item->asArray()->begin(), item->asArray()->end(), |
| recursivelyCanonicalize); |
| return; |
| |
| case MAP: |
| item->asMap()->canonicalize(true /* recurse */); |
| return; |
| |
| case SEMANTIC: |
| // This can't happen. SemanticTags delegate their type() method to the contained Item's |
| // type. |
| assert(false); |
| return; |
| } |
| } |
| |
| Map& Map::canonicalize(bool recurse) & { |
| if (recurse) { |
| for (auto& entry : mEntries) { |
| recursivelyCanonicalize(entry.first); |
| recursivelyCanonicalize(entry.second); |
| } |
| } |
| |
| if (size() < 2 || mCanonicalized) { |
| // Trivially or already canonical; do nothing. |
| return *this; |
| } |
| |
| std::sort(begin(), end(), |
| [](auto& a, auto& b) { return keyLess(a.first.get(), b.first.get()); }); |
| mCanonicalized = true; |
| return *this; |
| } |
| |
| std::unique_ptr<Item> Map::clone() const { |
| auto res = std::make_unique<Map>(); |
| for (auto& [key, value] : *this) { |
| res->add(key->clone(), value->clone()); |
| } |
| res->mCanonicalized = mCanonicalized; |
| return res; |
| } |
| |
| std::unique_ptr<Item> SemanticTag::clone() const { |
| return std::make_unique<SemanticTag>(mValue, mTaggedItem->clone()); |
| } |
| |
| uint8_t* SemanticTag::encode(uint8_t* pos, const uint8_t* end) const { |
| // Can't use the encodeHeader() method that calls type() to get the major type, since that will |
| // return the tagged Item's type. |
| pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end); |
| if (!pos) return nullptr; |
| return mTaggedItem->encode(pos, end); |
| } |
| |
| void SemanticTag::encode(EncodeCallback encodeCallback) const { |
| // Can't use the encodeHeader() method that calls type() to get the major type, since that will |
| // return the tagged Item's type. |
| ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback); |
| mTaggedItem->encode(encodeCallback); |
| } |
| |
| size_t SemanticTag::semanticTagCount() const { |
| size_t levelCount = 1; // Count this level. |
| const SemanticTag* cur = this; |
| while (cur->mTaggedItem && (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) ++levelCount; |
| return levelCount; |
| } |
| |
| uint64_t SemanticTag::semanticTag(size_t nesting) const { |
| // Getting the value of a specific nested tag is a bit tricky, because we start with the outer |
| // tag and don't know how many are inside. We count the number of nesting levels to find out |
| // how many there are in total, then to get the one we want we have to walk down levelCount - |
| // nesting steps. |
| size_t levelCount = semanticTagCount(); |
| if (nesting >= levelCount) return 0; |
| |
| levelCount -= nesting; |
| const SemanticTag* cur = this; |
| while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag(); |
| |
| return cur->mValue; |
| } |
| |
| string prettyPrint(const Item* item, size_t maxBStrSize, const vector<string>& mapKeysToNotPrint) { |
| string out; |
| prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint); |
| return out; |
| } |
| string prettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize, |
| const vector<string>& mapKeysToNotPrint) { |
| auto [item, _, message] = parse(encodedCbor); |
| if (item == nullptr) { |
| #ifndef __TRUSTY__ |
| LOG(ERROR) << "Data to pretty print is not valid CBOR: " << message; |
| #endif // __TRUSTY__ |
| return ""; |
| } |
| |
| return prettyPrint(item.get(), maxBStrSize, mapKeysToNotPrint); |
| } |
| |
| } // namespace cppbor |