| // This file defines OptionalArrayRef<T>, a class that has almost the same |
| // exact functionality as c10::optional<ArrayRef<T>>, except that its |
| // converting constructor fixes a dangling pointer issue. |
| // |
| // The implicit converting constructor of both c10::optional<ArrayRef<T>> and |
| // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store |
| // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping |
| // a c10::optional<ArrayRef<T>> and fixing the constructor implementation. |
| // |
| // See https://github.com/pytorch/pytorch/issues/63645 for more on this. |
| |
| #pragma once |
| |
| #include <c10/util/ArrayRef.h> |
| #include <c10/util/Optional.h> |
| #include <cstdint> |
| #include <initializer_list> |
| #include <type_traits> |
| #include <utility> |
| |
| namespace c10 { |
| |
| template <typename T> |
| class OptionalArrayRef final { |
| public: |
| // Constructors |
| |
| constexpr OptionalArrayRef() noexcept = default; |
| |
| constexpr OptionalArrayRef(nullopt_t) noexcept {} |
| |
| OptionalArrayRef(const OptionalArrayRef& other) = default; |
| |
| OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; |
| |
| constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept |
| : wrapped_opt_array_ref(other) {} |
| |
| constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept |
| : wrapped_opt_array_ref(std::move(other)) {} |
| |
| constexpr OptionalArrayRef(const T& value) noexcept |
| : wrapped_opt_array_ref(value) {} |
| |
| template < |
| typename U = ArrayRef<T>, |
| std::enable_if_t< |
| !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && |
| !std::is_same_v<std::decay_t<U>, std::in_place_t> && |
| std::is_constructible_v<ArrayRef<T>, U&&> && |
| std::is_convertible_v<U&&, ArrayRef<T>> && |
| !std::is_convertible_v<U&&, T>, |
| bool> = false> |
| constexpr OptionalArrayRef(U&& value) noexcept( |
| std::is_nothrow_constructible_v<ArrayRef<T>, U&&>) |
| : wrapped_opt_array_ref(std::forward<U>(value)) {} |
| |
| template < |
| typename U = ArrayRef<T>, |
| std::enable_if_t< |
| !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && |
| !std::is_same_v<std::decay_t<U>, std::in_place_t> && |
| std::is_constructible_v<ArrayRef<T>, U&&> && |
| !std::is_convertible_v<U&&, ArrayRef<T>>, |
| bool> = false> |
| constexpr explicit OptionalArrayRef(U&& value) noexcept( |
| std::is_nothrow_constructible_v<ArrayRef<T>, U&&>) |
| : wrapped_opt_array_ref(std::forward<U>(value)) {} |
| |
| template <typename... Args> |
| constexpr explicit OptionalArrayRef( |
| std::in_place_t ip, |
| Args&&... args) noexcept |
| : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {} |
| |
| template <typename U, typename... Args> |
| constexpr explicit OptionalArrayRef( |
| std::in_place_t ip, |
| std::initializer_list<U> il, |
| Args&&... args) |
| : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {} |
| |
| constexpr OptionalArrayRef(const std::initializer_list<T>& Vec) |
| : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {} |
| |
| // Destructor |
| |
| ~OptionalArrayRef() = default; |
| |
| // Assignment |
| |
| constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { |
| wrapped_opt_array_ref = c10::nullopt; |
| return *this; |
| } |
| |
| OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; |
| |
| OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; |
| |
| constexpr OptionalArrayRef& operator=( |
| const optional<ArrayRef<T>>& other) noexcept { |
| wrapped_opt_array_ref = other; |
| return *this; |
| } |
| |
| constexpr OptionalArrayRef& operator=( |
| optional<ArrayRef<T>>&& other) noexcept { |
| wrapped_opt_array_ref = std::move(other); |
| return *this; |
| } |
| |
| template < |
| typename U = ArrayRef<T>, |
| typename = std::enable_if_t< |
| !std::is_same_v<std::decay_t<U>, OptionalArrayRef> && |
| std::is_constructible_v<ArrayRef<T>, U&&> && |
| std::is_assignable_v<ArrayRef<T>&, U&&>>> |
| constexpr OptionalArrayRef& operator=(U&& value) noexcept( |
| std::is_nothrow_constructible_v<ArrayRef<T>, U&&> && |
| std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) { |
| wrapped_opt_array_ref = std::forward<U>(value); |
| return *this; |
| } |
| |
| // Observers |
| |
| constexpr ArrayRef<T>* operator->() noexcept { |
| return &wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr const ArrayRef<T>* operator->() const noexcept { |
| return &wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr ArrayRef<T>& operator*() & noexcept { |
| return wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr const ArrayRef<T>& operator*() const& noexcept { |
| return wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr ArrayRef<T>&& operator*() && noexcept { |
| return std::move(wrapped_opt_array_ref.value()); |
| } |
| |
| constexpr const ArrayRef<T>&& operator*() const&& noexcept { |
| return std::move(wrapped_opt_array_ref.value()); |
| } |
| |
| constexpr explicit operator bool() const noexcept { |
| return wrapped_opt_array_ref.has_value(); |
| } |
| |
| constexpr bool has_value() const noexcept { |
| return wrapped_opt_array_ref.has_value(); |
| } |
| |
| constexpr ArrayRef<T>& value() & { |
| return wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr const ArrayRef<T>& value() const& { |
| return wrapped_opt_array_ref.value(); |
| } |
| |
| constexpr ArrayRef<T>&& value() && { |
| return std::move(wrapped_opt_array_ref.value()); |
| } |
| |
| constexpr const ArrayRef<T>&& value() const&& { |
| return std::move(wrapped_opt_array_ref.value()); |
| } |
| |
| template <typename U> |
| constexpr std:: |
| enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>> |
| value_or(U&& default_value) const& { |
| return wrapped_opt_array_ref.value_or(std::forward<U>(default_value)); |
| } |
| |
| template <typename U> |
| constexpr std:: |
| enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>> |
| value_or(U&& default_value) && { |
| return wrapped_opt_array_ref.value_or(std::forward<U>(default_value)); |
| } |
| |
| // Modifiers |
| |
| constexpr void swap(OptionalArrayRef& other) noexcept { |
| std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); |
| } |
| |
| constexpr void reset() noexcept { |
| wrapped_opt_array_ref.reset(); |
| } |
| |
| template <typename... Args> |
| constexpr std:: |
| enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&> |
| emplace(Args&&... args) noexcept( |
| std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) { |
| return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...); |
| } |
| |
| template <typename U, typename... Args> |
| constexpr ArrayRef<T>& emplace( |
| std::initializer_list<U> il, |
| Args&&... args) noexcept { |
| return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...); |
| } |
| |
| private: |
| optional<ArrayRef<T>> wrapped_opt_array_ref; |
| }; |
| |
| using OptionalIntArrayRef = OptionalArrayRef<int64_t>; |
| |
| inline bool operator==( |
| const OptionalIntArrayRef& a1, |
| const IntArrayRef& other) { |
| if (!a1.has_value()) { |
| return false; |
| } |
| return a1.value() == other; |
| } |
| |
| inline bool operator==( |
| const c10::IntArrayRef& a1, |
| const c10::OptionalIntArrayRef& a2) { |
| return a2 == a1; |
| } |
| |
| } // namespace c10 |