Simplify ExecutionCallback

This CL simplifies ExecutionCallback by (1) merging the functionality of
CallbackBase into ExecutionCallback, (2) restricts permission to methods
that are only used internally in ExecutionCallback, and (3) simplifies
some of the documentation.

Bug: 118624080
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: I8e0cb6cec0e71350a742be0ec05fb95a288dd78d
Merged-In: I8e0cb6cec0e71350a742be0ec05fb95a288dd78d
(cherry picked from commit a55536bc9c00e8b4a85b721a001e19072e22a65d)
diff --git a/runtime/Callbacks.cpp b/runtime/Callbacks.cpp
index 0c58a46..660def7 100644
--- a/runtime/Callbacks.cpp
+++ b/runtime/Callbacks.cpp
@@ -15,80 +15,17 @@
  */
 
 #include "Callbacks.h"
+
 #include <android-base/logging.h>
 
-namespace android {
-namespace hardware {
-namespace neuralnetworks {
-namespace V1_2 {
-namespace implementation {
+#include <limits>
 
-CallbackBase::CallbackBase() : mNotified(false) {}
+namespace android::hardware::neuralnetworks::V1_2::implementation {
 
-CallbackBase::~CallbackBase() {
-    // Note that we cannot call CallbackBase::join_thread from here:
-    // CallbackBase is intended to be reference counted, and it is possible that
-    // the reference count drops to zero in the bound thread, causing the
-    // bound thread to call this destructor. If a thread tries to join
-    // itself, it throws an exception, producing a message like the
-    // following:
-    //
-    //     terminating with uncaught exception of type std::__1::system_error:
-    //     thread::join failed: Resource deadlock would occur
-}
+constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
+                              .timeInDriver = std::numeric_limits<uint64_t>::max()};
 
-void CallbackBase::wait() {
-    std::unique_lock<std::mutex> lock(mMutex);
-    mCondition.wait(lock, [this]{return mNotified;});
-    join_thread_locked();
-}
-
-bool CallbackBase::on_finish(std::function<void()> post_work) {
-    std::lock_guard<std::mutex> lock(mMutex);
-    if (mPostWork != nullptr) {
-        LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
-                   "this callback object";
-        return false;
-    }
-    if (post_work == nullptr) {
-        LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
-        return false;
-    }
-    mPostWork = std::move(post_work);
-    return true;
-}
-
-bool CallbackBase::bind_thread(std::thread&& asyncThread) {
-    std::lock_guard<std::mutex> lock(mMutex);
-    if (mThread.joinable()) {
-        LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
-                   "callback object";
-        return false;
-    }
-    if (!asyncThread.joinable()) {
-        LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
-        return false;
-    }
-    mThread = std::move(asyncThread);
-    return true;
-}
-
-void CallbackBase::notify() {
-    {
-        std::lock_guard<std::mutex> lock(mMutex);
-        mNotified = true;
-        if (mPostWork) {
-            mPostWork();
-        }
-    }
-    mCondition.notify_all();
-}
-
-void CallbackBase::join_thread_locked() {
-    if (mThread.joinable()) {
-        mThread.join();
-    }
-}
+// PreparedModelCallback methods begin here
 
 Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
                                            const sp<V1_0::IPreparedModel>& preparedModel) {
@@ -130,73 +67,141 @@
     return mPreparedModel;
 }
 
-ExecutionCallback::ExecutionCallback()
-    : mErrorStatus(ErrorStatus::GENERAL_FAILURE), mOnFinish(nullptr) {
-    on_finish([this] {
-        if (mOnFinish != nullptr) {
-            ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
-            if (status != ErrorStatus::NONE) {
-                mErrorStatus = status;
-            }
-        }
-    });
-}
-
-ExecutionCallback::~ExecutionCallback() {}
+// ExecutionCallback methods begin here
 
 Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
-    mErrorStatus = errorStatus;
-    mOutputShapes = {};
-    mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
-    CallbackBase::notify();
+    notifyInternal(errorStatus, {}, kNoTiming);
     return Void();
 }
 
 Return<void> ExecutionCallback::notify_1_2(ErrorStatus errorStatus,
                                            const hidl_vec<OutputShape>& outputShapes,
                                            const Timing& timing) {
-    mErrorStatus = errorStatus;
-    mOutputShapes = outputShapes;
-    mTiming = timing;
-    if (mErrorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
-        // mOutputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
-        if (mOutputShapes.size() == 0) {
+    if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
+        // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
+        if (outputShapes.size() == 0) {
             LOG(ERROR) << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
-            mErrorStatus = ErrorStatus::GENERAL_FAILURE;
-            mOutputShapes = {};
-            mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
+            notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+            return Void();
         }
-    } else if (mErrorStatus != ErrorStatus::NONE) {
-        // mOutputShapes must be empty if mErrorStatus is neither NONE nor OUTPUT_INSUFFICIENT_SIZE.
-        if (mOutputShapes.size() != 0) {
+    } else if (errorStatus != ErrorStatus::NONE) {
+        // outputShapes must be empty if errorStatus is neither NONE nor OUTPUT_INSUFFICIENT_SIZE.
+        if (outputShapes.size() != 0) {
             LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
                           "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
-            mErrorStatus = ErrorStatus::GENERAL_FAILURE;
-            mOutputShapes = {};
-            mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
+            notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+            return Void();
         }
     }
-    CallbackBase::notify();
+    notifyInternal(errorStatus, outputShapes, timing);
     return Void();
 }
 
-ErrorStatus ExecutionCallback::getStatus() {
+void ExecutionCallback::wait() const {
+    std::unique_lock<std::mutex> lock(mMutex);
+    mCondition.wait(lock, [this] { return mNotified; });
+
+    /*
+     * Note that we cannot call std::thread::join from ExecutionCallback's
+     * destructor: ExecutionCallback is intended to be reference counted, and it
+     * is possible that the reference count drops to zero in the bound thread,
+     * causing the bound thread to call this destructor. If a thread tries to
+     * join itself, it throws an exception, producing a message like the
+     * following:
+     *
+     *     terminating with uncaught exception of type std::__1::system_error:
+     *     thread::join failed: Resource deadlock would occur
+     */
+    if (mThread.joinable()) {
+        mThread.join();
+    }
+}
+
+ErrorStatus ExecutionCallback::getStatus() const {
     wait();
     return mErrorStatus;
 }
 
-const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() {
+const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
     wait();
     return mOutputShapes;
 }
 
-Timing ExecutionCallback::getTiming() {
+Timing ExecutionCallback::getTiming() const {
     wait();
     return mTiming;
 }
 
-}  // namespace implementation
-}  // namespace V1_2
-}  // namespace neuralnetworks
-}  // namespace hardware
-}  // namespace android
+bool ExecutionCallback::bindThread(std::thread asyncThread) {
+    std::lock_guard<std::mutex> lock(mMutex);
+
+    // Ensure ExecutionCallback object does not already have a thread bound
+    if (mThread.joinable()) {
+        LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
+                      "callback object";
+        return false;
+    }
+
+    // Ensure the new thread is valid
+    if (!asyncThread.joinable()) {
+        LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
+        return false;
+    }
+
+    mThread = std::move(asyncThread);
+    return true;
+}
+
+void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
+    std::lock_guard<std::mutex> hold(mMutex);
+
+    // Ensure ExecutionCallback object does not already have a "finish" callback
+    if (mOnFinish != nullptr) {
+        LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
+        return;
+    }
+
+    // Ensure new "finish" callback is valid
+    if (finish == nullptr) {
+        LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
+        return;
+    }
+
+    // Essure ExecutionCallback object has not already been notified
+    if (mNotified) {
+        LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
+                      "notified with results";
+        return;
+    }
+
+    mOnFinish = finish;
+}
+
+void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
+                                       const hidl_vec<OutputShape>& outputShapes,
+                                       const Timing& timing) {
+    {
+        std::lock_guard<std::mutex> hold(mMutex);
+
+        // quick-return if object has already been notified
+        if (mNotified) {
+            return;
+        }
+
+        mErrorStatus = errorStatus;
+        mOutputShapes = outputShapes;
+        mTiming = timing;
+        mNotified = true;
+
+        if (mOnFinish != nullptr) {
+            ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
+            mOnFinish = nullptr;
+            if (status != ErrorStatus::NONE) {
+                mErrorStatus = status;
+            }
+        }
+    }
+    mCondition.notify_all();
+}
+
+}  // namespace android::hardware::neuralnetworks::V1_2::implementation
diff --git a/runtime/Callbacks.h b/runtime/Callbacks.h
index 96c21a2..227e9ea 100644
--- a/runtime/Callbacks.h
+++ b/runtime/Callbacks.h
@@ -14,8 +14,8 @@
  * limitations under the License.
  */
 
-#ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
-#define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
+#ifndef ANDROID_ML_NN_RUNTIME_CALLBACKS_H
+#define ANDROID_ML_NN_RUNTIME_CALLBACKS_H
 
 #include <android-base/thread_annotations.h>
 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
@@ -29,121 +29,27 @@
 #include <mutex>
 #include <thread>
 
-namespace android {
-namespace hardware {
-namespace neuralnetworks {
-namespace V1_2 {
-namespace implementation {
-
-using V1_0::ErrorStatus;
-
-/**
- * The CallbackBase class is used internally by the NeuralNetworks runtime to
+/*
+ * The Callback classes are used internally by the NeuralNetworks runtime to
  * synchronize between different threads. An asynchronous task is launched
  * paired with a callback object. When a client thread requires the output being
  * generated by the asynchronous task, the client thread can wait for the result
  * and be blocked until it has completed. Any wait may safely be called
  * concurrently, even on the same callback object. When the asynchronous task
- * has finished its workload, it must immediately call "notify". If the
+ * has finished its workload, it must immediately call "notify*". If the
  * asynchronous task has failed to launch, the function that tried to launch the
- * asynchronous task must immediately call "notify". This "notify" call awakens
- * any client threads waiting on the callback object.
+ * asynchronous task must immediately call "notify*". This "notify*" call
+ * awakens any client threads waiting on the callback object.
  *
- * The CallbackBase class implements some of the base synchronization common to
- * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
- * callback class must inherit from CallbackBase as well as the HIDL callback
- * interface it implements.
- *
- * This class exists to enable synchronization across HIDL. When synchronization
- * is only required in the same process, consider using std::future, std::mutex,
- * std::condition_variable, or std::experimental::latch instead.
+ * These classes exist to enable synchronization across HIDL. When
+ * synchronization is only required in the same process, consider using
+ * std::future, std::mutex, std::condition_variable, or std::experimental::latch
+ * instead.
  */
-class CallbackBase {
-   public:
-    CallbackBase();
-    ~CallbackBase();
 
-    /**
-     * CallbackBase::wait blocks until notify has been called on the callback
-     * object.
-     */
-    void wait();
+namespace android::hardware::neuralnetworks::V1_2::implementation {
 
-    /**
-     * CallbackBase::bind_thread binds a thread to the callback object for later
-     * use by CallbackBase::join_thread_locked.
-     *
-     * The thread must be passed using std::move.
-     *
-     * Once a thread is bound with CallbackBase::bind_thread, the client code
-     * should ensure that CallbackBase::wait has been called before the callback
-     * object is destroyed.
-     *
-     * The bound thread shall not call any CallbackBase method with the
-     * exception of CallbackBase::notify, which it must call when the thread has
-     * finished its computation.
-     *
-     * CallbackBase::bind_thread can be called at most once on a given callback
-     * object.
-     *
-     * @param asyncThread Thread to be bound to the callback object. The thread
-     *     object must represent a thread of execution -- i.e.,
-     *     asyncThread.joinable() must be true.
-     * @return bool True if successful, false if thread was not properly bound.
-     */
-    bool bind_thread(std::thread&& asyncThread);
-
-   protected:
-    /**
-     * CallbackBase::notify enables all prior and future wait calls on the
-     * callback object to proceed. The call to CallbackBase::notify happens
-     * before any wait calls on this callback object return. The asynchronous
-     * call the callback object is paired with must ensure that any update to
-     * state that should be visible to the caller of wait happens before the
-     * call to CallbackBase::notify.
-     *
-     * CallbackBase::notify must be called exactly once on a given callback
-     * object.
-     */
-    void notify();
-
-    /**
-     * CallbackBase::on_finish binds a function to the callback object. This
-     * bound function will be executed when CallbackBase::notify is called,
-     * before any calls to wait return.
-     *
-     * The bound function must not synchronize with or otherwise access the
-     * callback object it is bound to, as this could cause a deadlock.
-     *
-     * CallbackBase::on_finish can be called at most once on a given callback
-     * object, and the call to CallbackBase::on_finish must finish before
-     * CallbackBase::notify is called.
-     *
-     * @param post_work Function to be invoked the first time
-     *     CallbackBase::notify is called. Must have a target -- i.e., must not
-     *     compare equal to nullptr.
-     * @return bool True if the function was successfully bound, false if
-     *     unsuccessful.
-     */
-    bool on_finish(std::function<void()> post_work);
-
-   private:
-    /**
-     * CallbackBase::join_thread_locked ensures that the thread (if any) bound
-     * to this callback object with CallbackBase::bind_thread has fully finished
-     * and cleaned its resources.
-     *
-     * CallbackBase::join_thread_locked can be called multiple times. When
-     * called, it must be called while the object's mutex is locked.
-     */
-    void join_thread_locked();
-
-    bool mNotified;
-    std::mutex mMutex;
-    std::condition_variable mCondition;
-    std::function<void()> mPostWork;
-    std::thread mThread;
-};
+using V1_0::ErrorStatus;
 
 /**
  * The PreparedModelCallback class is used to receive the error status of
@@ -154,7 +60,7 @@
  * until the asynchronous task has either called notify or notify_1_2.
  *
  * If the callback object is notified more than once, only the results of the
- * first call to notify are used, and the results from subsequent calls are
+ * first call to notify* are used, and the results from subsequent calls are
  * discarded.
  *
  * This callback object is passed as an argument to IDevice::prepareModel*.
@@ -172,7 +78,7 @@
      * PreparedModelCallback object.
      *
      * If the callback object is notified more than once, only the results of
-     * the first call to notify is used, and the results from subsequent calls
+     * the first call to notify* are used, and the results from subsequent calls
      * are discarded.
      *
      * @param status Error status returned from asynchronously preparing the
@@ -197,7 +103,7 @@
      * PreparedModelCallback object.
      *
      * If the callback object is notified more than once, only the results of
-     * the first call to notify are used, and the results from subsequent calls
+     * the first call to notify* are used, and the results from subsequent calls
      * are discarded.
      *
      * @param status Error status returned from asynchronously preparing the
@@ -254,36 +160,36 @@
 };
 
 /**
- * The ExecutionCallback class is used to receive the error status of the
- * execution from a task executing asynchronously with respect to the runtime.
- * If a calling thread calls wait or get* on a PreparedModelCallback object and
- * the corresponding asynchronous task has not finished the execution, the
- * calling thread will block until the asynchronous task has either called notify
- * or notify_1_2. For more information on the synchronization behavior, refer to
- * the CallbackBase class.
+ * The ExecutionCallback class is used to receive the results of the execution
+ * from a task executing asynchronously with respect to the runtime. If a
+ * calling thread calls wait or get* on a ExecutionCallback object and the
+ * corresponding asynchronous task has not finished the execution, the calling
+ * thread will block until the asynchronous task has either called notify or
+ * notify_1_2.
  *
- * This class inherits the basic blocking and signaling calls from
- * CallbackBase, and implements the HIDL notify and notify_1_2 calls from
- * IExecutionCallback. This callback object is passed as an argument to
- * IPreparedModel::execute.
+ * If the callback object is notified more than once, only the results of the
+ * first call to notify* are used, and the results from subsequent calls are
+ * discarded.
+ *
+ * This callback object is passed as an argument to IPreparedModel::execute*.
  */
-class ExecutionCallback : public CallbackBase, public IExecutionCallback {
+class ExecutionCallback : public IExecutionCallback {
     using ExecutionFinish =
             std::function<ErrorStatus(ErrorStatus, const std::vector<OutputShape>&)>;
 
    public:
-    ExecutionCallback();
-    ~ExecutionCallback() override;
-
     /**
-     * IExecutionCallback::notify and IExecutionCallback::notify_1_2 mark the
-     * callback object with the return status of the asynchronous execution that
-     * held this callback and enable all prior and future wait calls on the
-     * ExecutionCallback object to proceed. For more information on the
-     * synchronization behavior, refer to the CallbackBase class.
+     * IExecutionCallback::notify marks the callback object with the return
+     * status of the asynchronous execution that held this callback and enables
+     * all prior and future wait calls on the ExecutionCallback object to
+     * proceed.
      *
      * Either IExecutionCallback::notify or IExecutionCallback::notify_1_2 must
-     * be called exactly once on a given ExecutionCallback object.
+     * be called on a given ExecutionCallback object.
+     *
+     * If the callback object is notified more than once, only the results of
+     * the first call to notify* are used, and the results from subsequent calls
+     * are discarded.
      *
      * @param status Error status returned from launching the asynchronous task
      *     (if the launch fails) or from the asynchronous task itself (if the
@@ -298,8 +204,17 @@
     Return<void> notify(ErrorStatus status) override;
 
     /**
-     * Similar to IExecutionCallback::notify, but for V1_2::IPreparedModel to
-     * also notify output shapes along with error status.
+     * IExecutionCallback::notify_1_2 marks the callback object with the results
+     * (error status, dynamic output shapes, and timing information) of the
+     * asynchronous execution that held this callback and enables all prior and
+     * future wait calls on the ExecutionCallback object to proceed.
+     *
+     * Either IExecutionCallback::notify or IExecutionCallback::notify_1_2 must
+     * be called on a given ExecutionCallback object.
+     *
+     * If the callback object is notified more than once, only the results of
+     * the first call to notify* are used, and the results from subsequent calls
+     * are discarded.
      *
      * @param status Error status returned from launching the asynchronous task
      *     (if the launch fails) or from the asynchronous task itself (if the
@@ -316,11 +231,10 @@
      *     The index into "outputShapes" corresponds to the index of the output
      *     operand in the Request outputs vector. outputShapes must be empty
      *     unless the status is either NONE or OUTPUT_INSUFFICIENT_SIZE.
-     * @return Timing Duration of execution. Unless MeasureTiming::YES was
-     *     passed when launching the execution and status is NONE, all times
-     *     must be reported as UINT64_MAX. A driver may choose to report any
-     *     time as UINT64_MAX, indicating that particular measurement is not
-     *     available.
+     * @param Timing Duration of execution. Unless MeasureTiming::YES was passed
+     *     when launching the execution and status is NONE, all times must be
+     *     reported as UINT64_MAX. A driver may choose to report any time as
+     *     UINT64_MAX, indicating that particular measurement is not available.
      */
     Return<void> notify_1_2(ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
                             const Timing& timing) override;
@@ -332,6 +246,12 @@
     }
 
     /**
+     * ExecutionCallback::wait blocks until notify* has been called on the
+     * callback object.
+     */
+    void wait() const;
+
+    /**
      * Retrieves the error status returned from the asynchronous task launched
      * by either IPreparedModel::execute or IPreparedModel::execute_1_2. If
      * IPreparedModel::execute or IPreparedModel::execute_1_2 has not finished
@@ -350,7 +270,7 @@
      *     - INVALID_ARGUMENT if one of the input arguments to prepareModel is
      *         invalid
      */
-    ErrorStatus getStatus();
+    ErrorStatus getStatus() const;
 
     /**
      * Retrieves the output shapes returned from the asynchronous task launched
@@ -372,7 +292,7 @@
      *     OUTPUT_INSUFFICIENT_SIZE, or if the status is NONE and the model has
      *     at least one output operand that is not fully-specified.
      */
-    const std::vector<OutputShape>& getOutputShapes();
+    const std::vector<OutputShape>& getOutputShapes() const;
 
     /**
      * Retrieves the duration of execution of the asynchronous task launched by
@@ -386,23 +306,75 @@
      * @return timing Duration of the execution. Every time must be UINT64_MAX
      *     unless the status is NONE.
      */
-    Timing getTiming();
+    Timing getTiming() const;
 
-    // The callback will invoke finish(mErrorStatus) on finish.
-    void setOnFinish(const ExecutionFinish& finish) { mOnFinish = finish; }
+    /**
+     * ExecutionCallback::bindThread binds a thread to the ExecutionCallback
+     * object. The bound thread is later joined by ExecutionCallback::wait or
+     * ExecutionCallback::get*.
+     *
+     * Once a thread is bound with ExecutionCallback::bindThread, the client
+     * code must ensure that ExecutionCallback::wait or ExecutionCallback::get*
+     * has been called before the ExecutionCallback object is destroyed.
+     *
+     * The bound thread must not call any ExecutionCallback method with the
+     * exception of ExecutionCallback::notify*, which it must call when the
+     * thread has finished its computation.
+     *
+     * ExecutionCallback::bindThread can be called at most once on a given
+     * callback object.
+     *
+     * @param asyncThread Thread to be bound to the callback object. The thread
+     *     object must represent a thread of execution -- i.e.,
+     *     std::thread::joinable() must be true.
+     * @return bool True if successful, false if thread was not properly bound.
+     */
+    bool bindThread(std::thread asyncThread);
+
+    /**
+     * ExecutionCallback::setOnFinish binds a callback to the ExecutionCallback
+     * object that will be executed during one of the ExecutionCallback::notify*
+     * calls but before any calls to wait or get* return. This provided callback
+     * is provided with both the ErrorStatus and the output shapes from
+     * ExecutionCallback::notify*.
+     *
+     * The bound function must not synchronize with or otherwise access the
+     * callback object it is bound to, as this could cause a deadlock.
+     *
+     * This call will not bind the provided callback if any of the following
+     * occur:
+     * (1) the provided callback is invalid (i.e., "(bool) finish" is false)
+     * (2) ExecutionCallback already contains a bound callback
+     * (3) ExecutionCallback has already been notified with results
+     *
+     * @param finish Callback to be executed when ExecutionCallback is notified
+     *     with results.
+     */
+    void setOnFinish(const ExecutionFinish& finish);
 
    private:
+    /*
+     * ExecutionCallback::notifyInternal stores the results of the execution
+     * (status, output shapes, and timing information) in the ExecutionCallback
+     * object and invokes the bound callback function "mOnFinish" (if present)
+     * before any call to wait or get* return. It then enables all prior and
+     * future wait calls on the ExecutionCallback object to proceed.
+     */
+    void notifyInternal(ErrorStatus errorStatus, const hidl_vec<OutputShape>& outputShapes,
+                        const Timing& timing);
+
+    // members
+    mutable std::mutex mMutex;
+    mutable std::condition_variable mCondition;
+    mutable std::thread mThread GUARDED_BY(mMutex);
+    ExecutionFinish mOnFinish GUARDED_BY(mMutex);
+    bool mNotified GUARDED_BY(mMutex) = false;
     ErrorStatus mErrorStatus = ErrorStatus::GENERAL_FAILURE;
     std::vector<OutputShape> mOutputShapes = {};
     Timing mTiming = {};
-    ExecutionFinish mOnFinish;
 };
 
-}  // namespace implementation
-}  // namespace V1_2
-}  // namespace neuralnetworks
-}  // namespace hardware
-}  // namespace android
+}  // namespace android::hardware::neuralnetworks::V1_2::implementation
 
 namespace android::nn {
 
@@ -411,4 +383,4 @@
 
 }  // namespace android::nn
 
-#endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
+#endif  // ANDROID_ML_NN_RUNTIME_CALLBACKS_H
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index e3e0c9f..8866a5a 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -595,7 +595,7 @@
             VLOG(EXECUTION) << "ExecutionBuilder::compute (asynchronous API)";
             std::thread thread(asyncStartComputePartitioned, this, mPlan, controller, allowFallback,
                                executionCallback);
-            executionCallback->bind_thread(std::move(thread));
+            executionCallback->bindThread(std::move(thread));
         }
         *synchronizationCallback = executionCallback;
         return ANEURALNETWORKS_NO_ERROR;
@@ -1043,7 +1043,7 @@
         // TODO: should model be moved with a std::cref?
         std::thread thread(computeOnCpu, model, std::move(request), std::move(modelPoolInfos),
                            std::move(requestPoolInfos), executionCallback);
-        executionCallback->bind_thread(std::move(thread));
+        executionCallback->bindThread(std::move(thread));
     }
 
     *synchronizationCallback = executionCallback;