Follow-up fixes to MusicRecognitionManagerService

Added a check to only allow the default MusicRecognitionService to be called, unless the caller is the same app that hosts the MusicRecognitionService, as is the case in tests.

Refactored MusicRecognitionManagerPerUserService to not hold the client's callback as a field to prevent leaks and needing to null it out (then requiring null checks).  Instead, it's always a local variable or a final field in ephemeral objects.

Bug: 174744852
Bug: 168696277
Test: atest CtsMusicRecognitionTestCases
Change-Id: I70ff3c38da7096c60f82fe75b47674b4b233e672
diff --git a/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerPerUserService.java b/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerPerUserService.java
index 3531512..0cb729d 100644
--- a/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerPerUserService.java
+++ b/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerPerUserService.java
@@ -16,6 +16,7 @@
 
 package com.android.server.musicrecognition;
 
+import static android.media.musicrecognition.MusicRecognitionManager.RECOGNITION_FAILED_AUDIO_UNAVAILABLE;
 import static android.media.musicrecognition.MusicRecognitionManager.RECOGNITION_FAILED_SERVICE_KILLED;
 import static android.media.musicrecognition.MusicRecognitionManager.RECOGNITION_FAILED_SERVICE_UNAVAILABLE;
 import static android.media.musicrecognition.MusicRecognitionManager.RecognitionFailureCode;
@@ -64,10 +65,6 @@
     @GuardedBy("mLock")
     private RemoteMusicRecognitionService mRemoteService;
 
-    private MusicRecognitionServiceCallback mRemoteServiceCallback =
-            new MusicRecognitionServiceCallback();
-    private IMusicRecognitionManagerCallback mCallback;
-
     MusicRecognitionManagerPerUserService(
             @NonNull MusicRecognitionManagerService primary,
             @NonNull Object lock, int userId) {
@@ -100,7 +97,8 @@
 
     @GuardedBy("mLock")
     @Nullable
-    private RemoteMusicRecognitionService ensureRemoteServiceLocked() {
+    private RemoteMusicRecognitionService ensureRemoteServiceLocked(
+            IMusicRecognitionManagerCallback clientCallback) {
         if (mRemoteService == null) {
             final String serviceName = getComponentNameLocked();
             if (serviceName == null) {
@@ -113,7 +111,8 @@
 
             mRemoteService = new RemoteMusicRecognitionService(getContext(),
                     serviceComponent, mUserId, this,
-                    mRemoteServiceCallback, mMaster.isBindInstantServiceAllowed(),
+                    new MusicRecognitionServiceCallback(clientCallback),
+                    mMaster.isBindInstantServiceAllowed(),
                     mMaster.verbose);
         }
 
@@ -130,13 +129,14 @@
             @NonNull IBinder callback) {
         int maxAudioLengthSeconds = Math.min(recognitionRequest.getMaxAudioLengthSeconds(),
                 MAX_STREAMING_SECONDS);
-        mCallback = IMusicRecognitionManagerCallback.Stub.asInterface(callback);
+        IMusicRecognitionManagerCallback clientCallback =
+                IMusicRecognitionManagerCallback.Stub.asInterface(callback);
         AudioRecord audioRecord = createAudioRecord(recognitionRequest, maxAudioLengthSeconds);
 
-        mRemoteService = ensureRemoteServiceLocked();
+        mRemoteService = ensureRemoteServiceLocked(clientCallback);
         if (mRemoteService == null) {
             try {
-                mCallback.onRecognitionFailed(
+                clientCallback.onRecognitionFailed(
                         RECOGNITION_FAILED_SERVICE_UNAVAILABLE);
             } catch (RemoteException e) {
                 // Ignored.
@@ -147,7 +147,8 @@
         Pair<ParcelFileDescriptor, ParcelFileDescriptor> clientPipe = createPipe();
         if (clientPipe == null) {
             try {
-                mCallback.onAudioStreamClosed();
+                clientCallback.onRecognitionFailed(
+                        RECOGNITION_FAILED_AUDIO_UNAVAILABLE);
             } catch (RemoteException ignored) {
                 // Ignored.
             }
@@ -192,11 +193,10 @@
             } finally {
                 audioRecord.release();
                 try {
-                    mCallback.onAudioStreamClosed();
+                    clientCallback.onAudioStreamClosed();
                 } catch (RemoteException ignored) {
                     // Ignored.
                 }
-
             }
         });
         // Send the pipe down to the lookup service while we write to it asynchronously.
@@ -207,13 +207,20 @@
      * Callback invoked by {@link android.service.musicrecognition.MusicRecognitionService} to pass
      * back the music search result.
      */
-    private final class MusicRecognitionServiceCallback extends
+    final class MusicRecognitionServiceCallback extends
             IMusicRecognitionServiceCallback.Stub {
+
+        private final IMusicRecognitionManagerCallback mClientCallback;
+
+        private MusicRecognitionServiceCallback(IMusicRecognitionManagerCallback clientCallback) {
+            mClientCallback = clientCallback;
+        }
+
         @Override
         public void onRecognitionSucceeded(MediaMetadata result, Bundle extras) {
             try {
                 sanitizeBundle(extras);
-                mCallback.onRecognitionSucceeded(result, extras);
+                mClientCallback.onRecognitionSucceeded(result, extras);
             } catch (RemoteException ignored) {
                 // Ignored.
             }
@@ -223,18 +230,23 @@
         @Override
         public void onRecognitionFailed(@RecognitionFailureCode int failureCode) {
             try {
-                mCallback.onRecognitionFailed(failureCode);
+                mClientCallback.onRecognitionFailed(failureCode);
             } catch (RemoteException ignored) {
                 // Ignored.
             }
             destroyService();
         }
+
+        private IMusicRecognitionManagerCallback getClientCallback() {
+            return mClientCallback;
+        }
     }
 
     @Override
     public void onServiceDied(@NonNull RemoteMusicRecognitionService service) {
         try {
-            mCallback.onRecognitionFailed(RECOGNITION_FAILED_SERVICE_KILLED);
+            service.getServerCallback().getClientCallback().onRecognitionFailed(
+                    RECOGNITION_FAILED_SERVICE_KILLED);
         } catch (RemoteException e) {
             // Ignored.
         }
diff --git a/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerService.java b/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerService.java
index 9123daf..38f43138 100644
--- a/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerService.java
+++ b/services/musicrecognition/java/com/android/server/musicrecognition/MusicRecognitionManagerService.java
@@ -22,7 +22,9 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.content.ComponentName;
 import android.content.Context;
+import android.content.pm.PackageManager;
 import android.media.musicrecognition.IMusicRecognitionManager;
 import android.media.musicrecognition.IMusicRecognitionManagerCallback;
 import android.media.musicrecognition.RecognitionRequest;
@@ -32,7 +34,9 @@
 import android.os.ResultReceiver;
 import android.os.ShellCallback;
 import android.os.UserHandle;
+import android.util.Slog;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.server.infra.AbstractMasterSystemService;
 import com.android.server.infra.FrameworkResourcesServiceNameResolver;
 
@@ -113,9 +117,11 @@
             enforceCaller("beginRecognition");
 
             synchronized (mLock) {
+                int userId = UserHandle.getCallingUserId();
                 final MusicRecognitionManagerPerUserService service = getServiceForUserLocked(
-                        UserHandle.getCallingUserId());
-                if (service != null) {
+                        userId);
+                if (service != null && (isDefaultServiceLocked(userId)
+                        || isCalledByServiceAppLocked("beginRecognition"))) {
                     service.beginRecognitionLocked(recognitionRequest, callback);
                 } else {
                     try {
@@ -139,5 +145,55 @@
                     MusicRecognitionManagerService.this).exec(this, in, out, err, args, callback,
                     resultReceiver);
         }
+
+        /** True if the currently set handler service is not overridden by the shell. */
+        @GuardedBy("mLock")
+        private boolean isDefaultServiceLocked(int userId) {
+            final String defaultServiceName = mServiceNameResolver.getDefaultServiceName(userId);
+            if (defaultServiceName == null) {
+                return false;
+            }
+
+            final String currentServiceName = mServiceNameResolver.getServiceName(userId);
+            return defaultServiceName.equals(currentServiceName);
+        }
+
+        /** True if the caller of the api is the same app which hosts the default service. */
+        @GuardedBy("mLock")
+        private boolean isCalledByServiceAppLocked(@NonNull String methodName) {
+            final int userId = UserHandle.getCallingUserId();
+            final int callingUid = Binder.getCallingUid();
+            final String serviceName = mServiceNameResolver.getServiceName(userId);
+            if (serviceName == null) {
+                Slog.e(TAG, methodName + ": called by UID " + callingUid
+                        + ", but there's no service set for user " + userId);
+                return false;
+            }
+
+            final ComponentName serviceComponent = ComponentName.unflattenFromString(serviceName);
+            if (serviceComponent == null) {
+                Slog.w(TAG, methodName + ": invalid service name: " + serviceName);
+                return false;
+            }
+
+            final String servicePackageName = serviceComponent.getPackageName();
+
+            final PackageManager pm = getContext().getPackageManager();
+            final int serviceUid;
+            try {
+                serviceUid = pm.getPackageUidAsUser(servicePackageName,
+                        UserHandle.getCallingUserId());
+            } catch (PackageManager.NameNotFoundException e) {
+                Slog.w(TAG, methodName + ": could not verify UID for " + serviceName);
+                return false;
+            }
+            if (callingUid != serviceUid) {
+                Slog.e(TAG, methodName + ": called by UID " + callingUid + ", but service UID is "
+                        + serviceUid);
+                return false;
+            }
+
+            return true;
+        }
     }
 }
diff --git a/services/musicrecognition/java/com/android/server/musicrecognition/RemoteMusicRecognitionService.java b/services/musicrecognition/java/com/android/server/musicrecognition/RemoteMusicRecognitionService.java
index 4814a82..6c7d673 100644
--- a/services/musicrecognition/java/com/android/server/musicrecognition/RemoteMusicRecognitionService.java
+++ b/services/musicrecognition/java/com/android/server/musicrecognition/RemoteMusicRecognitionService.java
@@ -21,13 +21,13 @@
 import android.content.Context;
 import android.media.AudioFormat;
 import android.media.musicrecognition.IMusicRecognitionService;
-import android.media.musicrecognition.IMusicRecognitionServiceCallback;
 import android.media.musicrecognition.MusicRecognitionService;
 import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
 import android.text.format.DateUtils;
 
 import com.android.internal.infra.AbstractMultiplePendingRequestsRemoteService;
+import com.android.server.musicrecognition.MusicRecognitionManagerPerUserService.MusicRecognitionServiceCallback;
 
 /** Remote connection to an instance of {@link MusicRecognitionService}. */
 public class RemoteMusicRecognitionService extends
@@ -39,11 +39,12 @@
     private static final long TIMEOUT_IDLE_BIND_MILLIS = 40 * DateUtils.SECOND_IN_MILLIS;
 
     // Allows the remote service to send back a result.
-    private final IMusicRecognitionServiceCallback mServerCallback;
+    private final MusicRecognitionServiceCallback
+            mServerCallback;
 
     public RemoteMusicRecognitionService(Context context, ComponentName serviceName,
             int userId, MusicRecognitionManagerPerUserService perUserService,
-            IMusicRecognitionServiceCallback callback,
+            MusicRecognitionServiceCallback callback,
             boolean bindInstantServiceAllowed, boolean verbose) {
         super(context, MusicRecognitionService.ACTION_MUSIC_SEARCH_LOOKUP, serviceName, userId,
                 perUserService,
@@ -66,6 +67,10 @@
         return TIMEOUT_IDLE_BIND_MILLIS;
     }
 
+    MusicRecognitionServiceCallback getServerCallback() {
+        return mServerCallback;
+    }
+
     /**
      * Required, but empty since we don't need to notify the callback implementation of the request
      * results.