Merge "Rate limit calls to setWidgetPreview/removeWidgetPreview" into main
diff --git a/core/api/current.txt b/core/api/current.txt
index cd9c3ad..b9ad92b 100644
--- a/core/api/current.txt
+++ b/core/api/current.txt
@@ -9608,7 +9608,7 @@
     method public void partiallyUpdateAppWidget(int, android.widget.RemoteViews);
     method @FlaggedApi("android.appwidget.flags.generated_previews") public void removeWidgetPreview(@NonNull android.content.ComponentName, int);
     method public boolean requestPinAppWidget(@NonNull android.content.ComponentName, @Nullable android.os.Bundle, @Nullable android.app.PendingIntent);
-    method @FlaggedApi("android.appwidget.flags.generated_previews") public void setWidgetPreview(@NonNull android.content.ComponentName, int, @NonNull android.widget.RemoteViews);
+    method @FlaggedApi("android.appwidget.flags.generated_previews") public boolean setWidgetPreview(@NonNull android.content.ComponentName, int, @NonNull android.widget.RemoteViews);
     method public void updateAppWidget(int[], android.widget.RemoteViews);
     method public void updateAppWidget(int, android.widget.RemoteViews);
     method public void updateAppWidget(android.content.ComponentName, android.widget.RemoteViews);
diff --git a/core/java/android/appwidget/AppWidgetManager.java b/core/java/android/appwidget/AppWidgetManager.java
index eb82e1f..cda4d89 100644
--- a/core/java/android/appwidget/AppWidgetManager.java
+++ b/core/java/android/appwidget/AppWidgetManager.java
@@ -1417,13 +1417,15 @@
      * @see AppWidgetProviderInfo#WIDGET_CATEGORY_HOME_SCREEN
      * @see AppWidgetProviderInfo#WIDGET_CATEGORY_KEYGUARD
      * @see AppWidgetProviderInfo#WIDGET_CATEGORY_SEARCHBOX
+     *
+     * @return true if the call was successful, false if it was rate-limited.
      */
     @FlaggedApi(Flags.FLAG_GENERATED_PREVIEWS)
-    public void setWidgetPreview(@NonNull ComponentName provider,
+    public boolean setWidgetPreview(@NonNull ComponentName provider,
             @AppWidgetProviderInfo.CategoryFlags int widgetCategories,
             @NonNull RemoteViews preview) {
         try {
-            mService.setWidgetPreview(provider, widgetCategories, preview);
+            return mService.setWidgetPreview(provider, widgetCategories, preview);
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
diff --git a/core/java/com/android/internal/appwidget/IAppWidgetService.aidl b/core/java/com/android/internal/appwidget/IAppWidgetService.aidl
index 85bdbb9..e55cdef 100644
--- a/core/java/com/android/internal/appwidget/IAppWidgetService.aidl
+++ b/core/java/com/android/internal/appwidget/IAppWidgetService.aidl
@@ -80,11 +80,10 @@
             in Bundle extras, in IntentSender resultIntent);
     boolean isRequestPinAppWidgetSupported();
     oneway void noteAppWidgetTapped(in String callingPackage, in int appWidgetId);
-    void setWidgetPreview(in ComponentName providerComponent, in int widgetCategories,
+    boolean setWidgetPreview(in ComponentName providerComponent, in int widgetCategories,
             in RemoteViews preview);
     @nullable RemoteViews getWidgetPreview(in String callingPackage,
             in ComponentName providerComponent, in int profileId, in int widgetCategory);
     void removeWidgetPreview(in ComponentName providerComponent, in int widgetCategories);
-
 }
 
diff --git a/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java b/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
index bd806bf..91678c7 100644
--- a/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
+++ b/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
@@ -560,6 +560,19 @@
      */
     public static final String CURSOR_HOVER_STATES_ENABLED = "cursor_hover_states_enabled";
 
+
+    /*
+     * (long) The reset interval for generated preview API calls.
+     */
+    public static final String GENERATED_PREVIEW_API_RESET_INTERVAL_MS =
+            "generated_preview_api_reset_interval_ms";
+
+    /*
+     * (int) The max number of generated preview API calls per reset interval.
+     */
+    public static final String GENERATED_PREVIEW_API_MAX_CALLS_PER_INTERVAL =
+            "generated_preview_api_max_calls_per_interval";
+
     private SystemUiDeviceConfigFlags() {
     }
 }
diff --git a/services/appwidget/java/com/android/server/appwidget/AppWidgetServiceImpl.java b/services/appwidget/java/com/android/server/appwidget/AppWidgetServiceImpl.java
index 29b9d44..fbd6709 100644
--- a/services/appwidget/java/com/android/server/appwidget/AppWidgetServiceImpl.java
+++ b/services/appwidget/java/com/android/server/appwidget/AppWidgetServiceImpl.java
@@ -84,6 +84,7 @@
 import android.os.Bundle;
 import android.os.Environment;
 import android.os.Handler;
+import android.os.HandlerExecutor;
 import android.os.IBinder;
 import android.os.Looper;
 import android.os.Message;
@@ -98,6 +99,7 @@
 import android.service.appwidget.AppWidgetServiceDumpProto;
 import android.service.appwidget.WidgetProto;
 import android.text.TextUtils;
+import android.util.ArrayMap;
 import android.util.ArraySet;
 import android.util.AtomicFile;
 import android.util.AttributeSet;
@@ -148,6 +150,7 @@
 import java.io.OutputStream;
 import java.io.PrintWriter;
 import java.nio.charset.StandardCharsets;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -159,6 +162,7 @@
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.LongSupplier;
 
 class AppWidgetServiceImpl extends IAppWidgetService.Stub implements WidgetBackupProvider,
         OnCrossProfileWidgetProvidersChangeListener {
@@ -187,6 +191,13 @@
     // used to verify which request has successfully been received by the host.
     private static final AtomicLong UPDATE_COUNTER = new AtomicLong();
 
+    // Default reset interval for generated preview API rate limiting.
+    private static final long DEFAULT_GENERATED_PREVIEW_RESET_INTERVAL_MS =
+            Duration.ofHours(1).toMillis();
+    // Default max API calls per reset interval for generated preview API rate limiting.
+    private static final int DEFAULT_GENERATED_PREVIEW_MAX_CALLS_PER_INTERVAL = 2;
+
+
     private final BroadcastReceiver mBroadcastReceiver = new BroadcastReceiver() {
         @Override
         public void onReceive(Context context, Intent intent) {
@@ -266,6 +277,8 @@
     // Mark widget lifecycle broadcasts as 'interactive'
     private Bundle mInteractiveBroadcast;
 
+    private ApiCounter mGeneratedPreviewsApiCounter;
+
     AppWidgetServiceImpl(Context context) {
         mContext = context;
     }
@@ -294,6 +307,17 @@
         mIsCombinedBroadcastEnabled = DeviceConfig.getBoolean(NAMESPACE_SYSTEMUI,
             SystemUiDeviceConfigFlags.COMBINED_BROADCAST_ENABLED, true);
 
+        final long generatedPreviewResetInterval = DeviceConfig.getLong(NAMESPACE_SYSTEMUI,
+                SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_RESET_INTERVAL_MS,
+                DEFAULT_GENERATED_PREVIEW_RESET_INTERVAL_MS);
+        final int generatedPreviewMaxCallsPerInterval = DeviceConfig.getInt(NAMESPACE_SYSTEMUI,
+                SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_RESET_INTERVAL_MS,
+                DEFAULT_GENERATED_PREVIEW_MAX_CALLS_PER_INTERVAL);
+        mGeneratedPreviewsApiCounter = new ApiCounter(generatedPreviewResetInterval,
+                generatedPreviewMaxCallsPerInterval);
+        DeviceConfig.addOnPropertiesChangedListener(NAMESPACE_SYSTEMUI,
+                new HandlerExecutor(mCallbackHandler), this::handleSystemUiDeviceConfigChange);
+
         BroadcastOptions opts = BroadcastOptions.makeBasic();
         opts.setBackgroundActivityStartsAllowed(false);
         opts.setInteractive(true);
@@ -2480,6 +2504,7 @@
     private void deleteProviderLocked(Provider provider) {
         deleteWidgetsLocked(provider, UserHandle.USER_ALL);
         mProviders.remove(provider);
+        mGeneratedPreviewsApiCounter.remove(provider.id);
 
         // no need to send the DISABLE broadcast, since the receiver is gone anyway
         cancelBroadcastsLocked(provider);
@@ -4004,7 +4029,7 @@
     }
 
     @Override
-    public void setWidgetPreview(@NonNull ComponentName providerComponent,
+    public boolean setWidgetPreview(@NonNull ComponentName providerComponent,
             @AppWidgetProviderInfo.CategoryFlags int widgetCategories,
             @NonNull RemoteViews preview) {
         final int userId = UserHandle.getCallingUserId();
@@ -4026,8 +4051,12 @@
                 throw new IllegalArgumentException(
                         providerComponent + " is not a valid AppWidget provider");
             }
-            provider.setGeneratedPreviewLocked(widgetCategories, preview);
-            scheduleNotifyGroupHostsForProvidersChangedLocked(userId);
+            if (mGeneratedPreviewsApiCounter.tryApiCall(providerId)) {
+                provider.setGeneratedPreviewLocked(widgetCategories, preview);
+                scheduleNotifyGroupHostsForProvidersChangedLocked(userId);
+                return true;
+            }
+            return false;
         }
     }
 
@@ -4068,6 +4097,26 @@
         }
     }
 
+    private void handleSystemUiDeviceConfigChange(DeviceConfig.Properties properties) {
+        Set<String> changed = properties.getKeyset();
+        synchronized (mLock) {
+            if (changed.contains(
+                    SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_RESET_INTERVAL_MS)) {
+                long resetIntervalMs = properties.getLong(
+                        SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_RESET_INTERVAL_MS,
+                        /* defaultValue= */ mGeneratedPreviewsApiCounter.getResetIntervalMs());
+                mGeneratedPreviewsApiCounter.setResetIntervalMs(resetIntervalMs);
+            }
+            if (changed.contains(
+                    SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_MAX_CALLS_PER_INTERVAL)) {
+                int maxCallsPerInterval = properties.getInt(
+                        SystemUiDeviceConfigFlags.GENERATED_PREVIEW_API_MAX_CALLS_PER_INTERVAL,
+                        /* defaultValue= */ mGeneratedPreviewsApiCounter.getMaxCallsPerInterval());
+                mGeneratedPreviewsApiCounter.setMaxCallsPerInterval(maxCallsPerInterval);
+            }
+        }
+    }
+
     private final class CallbackHandler extends Handler {
         public static final int MSG_NOTIFY_UPDATE_APP_WIDGET = 1;
         public static final int MSG_NOTIFY_PROVIDER_CHANGED = 2;
@@ -4541,11 +4590,11 @@
         }
     }
 
-    private static final class ProviderId {
+    static final class ProviderId {
         final int uid;
         final ComponentName componentName;
 
-        private ProviderId(int uid, ComponentName componentName) {
+        ProviderId(int uid, ComponentName componentName) {
             this.uid = uid;
             this.componentName = componentName;
         }
@@ -4788,6 +4837,96 @@
         }
     }
 
+    /**
+     * This class keeps track of API calls and implements rate limiting. One instance of this class
+     * tracks calls from all providers for one API, or a group of APIs that should share the same
+     * rate limit.
+     */
+    static final class ApiCounter {
+
+        private static final class ApiCallRecord {
+            // Number of times the API has been called for this provider.
+            public int apiCallCount = 0;
+            // The last time (from SystemClock.elapsedRealtime) the api call count was reset.
+            public long lastResetTimeMs = 0;
+
+            void reset(long nowMs) {
+                apiCallCount = 0;
+                lastResetTimeMs = nowMs;
+            }
+        }
+
+        private final Map<ProviderId, ApiCallRecord> mCallCount = new ArrayMap<>();
+        // The interval at which the call count is reset.
+        private long mResetIntervalMs;
+        // The max number of API calls per interval.
+        private int mMaxCallsPerInterval;
+        // Returns the current time (monotonic). By default this is SystemClock.elapsedRealtime.
+        private LongSupplier mMonotonicClock;
+
+        ApiCounter(long resetIntervalMs, int maxCallsPerInterval) {
+            this(resetIntervalMs, maxCallsPerInterval, SystemClock::elapsedRealtime);
+        }
+
+        ApiCounter(long resetIntervalMs, int maxCallsPerInterval,
+                LongSupplier monotonicClock) {
+            mResetIntervalMs = resetIntervalMs;
+            mMaxCallsPerInterval = maxCallsPerInterval;
+            mMonotonicClock = monotonicClock;
+        }
+
+        public void setResetIntervalMs(long resetIntervalMs) {
+            mResetIntervalMs = resetIntervalMs;
+        }
+
+        public long getResetIntervalMs() {
+            return mResetIntervalMs;
+        }
+
+        public void setMaxCallsPerInterval(int maxCallsPerInterval) {
+            mMaxCallsPerInterval = maxCallsPerInterval;
+        }
+
+        public int getMaxCallsPerInterval() {
+            return mMaxCallsPerInterval;
+        }
+
+        /**
+         * Returns true if the API call for the provider should be allowed, false if it should be
+         * rate-limited.
+         */
+        public boolean tryApiCall(@NonNull ProviderId provider) {
+            final ApiCallRecord record = getOrCreateRecord(provider);
+            final long now = mMonotonicClock.getAsLong();
+            final long timeSinceLastResetMs = now - record.lastResetTimeMs;
+            // If the last reset was beyond the reset interval, reset now.
+            if (timeSinceLastResetMs > mResetIntervalMs) {
+                record.reset(now);
+            }
+            if (record.apiCallCount < mMaxCallsPerInterval) {
+                record.apiCallCount++;
+                return true;
+            }
+            return false;
+        }
+
+        /**
+         * Remove the provider's call record from this counter, when the provider is no longer
+         * tracked.
+         */
+        public void remove(@NonNull ProviderId id) {
+            mCallCount.remove(id);
+        }
+
+        @NonNull
+        private ApiCallRecord getOrCreateRecord(@NonNull ProviderId provider) {
+            if (!mCallCount.containsKey(provider)) {
+                mCallCount.put(provider, new ApiCallRecord());
+            }
+            return mCallCount.get(provider);
+        }
+    }
+
     private class LoadedWidgetState {
         final Widget widget;
         final int hostTag;
diff --git a/services/tests/servicestests/src/com/android/server/appwidget/ApiCounterTest.kt b/services/tests/servicestests/src/com/android/server/appwidget/ApiCounterTest.kt
new file mode 100644
index 0000000..79766f8
--- /dev/null
+++ b/services/tests/servicestests/src/com/android/server/appwidget/ApiCounterTest.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.server.appwidget
+
+import android.content.ComponentName
+import com.android.server.appwidget.AppWidgetServiceImpl.ApiCounter
+import com.google.common.truth.Truth.assertThat
+import org.junit.Test
+
+class ApiCounterTest {
+    private companion object {
+        const val RESET_INTERVAL_MS = 10L
+        const val MAX_CALLS_PER_INTERVAL = 2
+    }
+
+    private var currentTime = 0L
+
+    private val id =
+        AppWidgetServiceImpl.ProviderId(
+            /* uid= */ 123,
+            ComponentName("com.android.server.appwidget", "FakeProviderClass")
+        )
+    private val counter = ApiCounter(RESET_INTERVAL_MS, MAX_CALLS_PER_INTERVAL) { currentTime }
+
+    @Test
+    fun tryApiCall() {
+        for (i in 0 until MAX_CALLS_PER_INTERVAL) {
+            assertThat(counter.tryApiCall(id)).isTrue()
+        }
+        assertThat(counter.tryApiCall(id)).isFalse()
+        currentTime = 5L
+        assertThat(counter.tryApiCall(id)).isFalse()
+        currentTime = 11L
+        assertThat(counter.tryApiCall(id)).isTrue()
+    }
+
+    @Test
+    fun remove() {
+        for (i in 0 until MAX_CALLS_PER_INTERVAL) {
+            assertThat(counter.tryApiCall(id)).isTrue()
+        }
+        assertThat(counter.tryApiCall(id)).isFalse()
+        // remove should cause the call count to be 0 on the next tryApiCall
+        counter.remove(id)
+        assertThat(counter.tryApiCall(id)).isTrue()
+    }
+}