xds: prevent concurrent priority LB picker updates (#9363)

If a child policy triggers an update to the parent priority policy
it will be ignored if an update is already in process.
diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java
index 5cae913..7a6bff4 100644
--- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java
@@ -58,7 +58,7 @@
   private final XdsLogger logger;
 
   // Includes all active and deactivated children. Mutable. New entries are only added from priority
-  // 0 up to the selected priority. An entry is only deleted 15 minutes after the its deactivation.
+  // 0 up to the selected priority. An entry is only deleted 15 minutes after its deactivation.
   private final Map<String, ChildLbState> children = new HashMap<>();
 
   // Following fields are only null initially.
@@ -70,6 +70,8 @@
   @Nullable private String currentPriority;
   private ConnectivityState currentConnectivityState;
   private SubchannelPicker currentPicker;
+  // Set to true if currently in the process of handling resolved addresses.
+  private boolean handlingResolvedAddresses;
 
   PriorityLoadBalancer(Helper helper) {
     this.helper = checkNotNull(helper, "helper");
@@ -82,6 +84,15 @@
 
   @Override
   public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
+    try {
+      handlingResolvedAddresses = true;
+      handleResolvedAddressesInternal(resolvedAddresses);
+    } finally {
+      handlingResolvedAddresses = false;
+    }
+  }
+
+  public void handleResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
     logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
     this.resolvedAddresses = resolvedAddresses;
     PriorityLbConfig config = (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
@@ -297,32 +308,34 @@
       @Override
       public void updateBalancingState(final ConnectivityState newState,
           final SubchannelPicker newPicker) {
-        syncContext.execute(new Runnable() {
-          @Override
-          public void run() {
-            if (!children.containsKey(priority)) {
-              return;
-            }
-            connectivityState = newState;
-            picker = newPicker;
-            if (deletionTimer != null && deletionTimer.isPending()) {
-              return;
-            }
-            if (newState.equals(CONNECTING) ) {
-              if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) {
-                failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS,
-                    executor);
-              }
-            } else if (newState.equals(READY) || newState.equals(IDLE)) {
-              seenReadyOrIdleSinceTransientFailure = true;
-              failOverTimer.cancel();
-            } else if (newState.equals(TRANSIENT_FAILURE)) {
-              seenReadyOrIdleSinceTransientFailure = false;
-              failOverTimer.cancel();
-            }
-            tryNextPriority();
+        if (!children.containsKey(priority)) {
+          return;
+        }
+        connectivityState = newState;
+        picker = newPicker;
+
+        if (deletionTimer != null && deletionTimer.isPending()) {
+          return;
+        }
+        if (newState.equals(CONNECTING)) {
+          if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) {
+            failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS,
+                executor);
           }
-        });
+        } else if (newState.equals(READY) || newState.equals(IDLE)) {
+          seenReadyOrIdleSinceTransientFailure = true;
+          failOverTimer.cancel();
+        } else if (newState.equals(TRANSIENT_FAILURE)) {
+          seenReadyOrIdleSinceTransientFailure = false;
+          failOverTimer.cancel();
+        }
+
+        // If we are currently handling newly resolved addresses, let's not try to reconfigure as
+        // the address handling process will take care of that to provide an atomic config update.
+        if (handlingResolvedAddresses) {
+          return;
+        }
+        tryNextPriority();
       }
 
       @Override
diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java
index 6b1fae4..7227dcb 100644
--- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java
@@ -22,8 +22,8 @@
 import static io.grpc.ConnectivityState.READY;
 import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
 import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.ArgumentMatchers.isA;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doReturn;
@@ -676,7 +676,7 @@
             .setAddresses(ImmutableList.<EquivalentAddressGroup>of())
             .setLoadBalancingPolicyConfig(priorityLbConfig)
             .build());
-    verify(helper, times(2)).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class));
+    verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER));
 
     // LB shutdown and subchannel state change can happen simultaneously. If shutdown runs first,
     // any further balancing state update should be ignored.
@@ -686,6 +686,26 @@
     verifyNoMoreInteractions(helper);
   }
 
+  @Test
+  public void noDuplicateOverallBalancingStateUpdate() {
+    FakeLoadBalancerProvider fakeLbProvider = new FakeLoadBalancerProvider();
+
+    PriorityChildConfig priorityChildConfig0 =
+        new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), true);
+    PriorityChildConfig priorityChildConfig1 =
+        new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), false);
+    PriorityLbConfig priorityLbConfig =
+        new PriorityLbConfig(
+            ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1),
+            ImmutableList.of("p0", "p1"));
+    priorityLb.handleResolvedAddresses(
+        ResolvedAddresses.newBuilder()
+            .setAddresses(ImmutableList.<EquivalentAddressGroup>of())
+            .setLoadBalancingPolicyConfig(priorityLbConfig)
+            .build());
+    verify(helper, times(1)).updateBalancingState(any(), any());
+  }
+
   private void assertLatestConnectivityState(ConnectivityState expectedState) {
     verify(helper, atLeastOnce())
         .updateBalancingState(connectivityStateCaptor.capture(), pickerCaptor.capture());
@@ -714,4 +734,49 @@
     PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class));
     assertThat(pickResult).isEqualTo(PickResult.withNoResult());
   }
+
+  private static class FakeLoadBalancerProvider extends LoadBalancerProvider {
+
+    @Override
+    public boolean isAvailable() {
+      return true;
+    }
+
+    @Override
+    public int getPriority() {
+      return 5;
+    }
+
+    @Override
+    public String getPolicyName() {
+      return "foo";
+    }
+
+    @Override
+    public LoadBalancer newLoadBalancer(Helper helper) {
+      return new FakeLoadBalancer(helper);
+    }
+  }
+
+  static class FakeLoadBalancer extends LoadBalancer {
+
+    private Helper helper;
+
+    FakeLoadBalancer(Helper helper) {
+      this.helper = helper;
+    }
+
+    @Override
+    public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
+      helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL));
+    }
+
+    @Override
+    public void handleNameResolutionError(Status error) {
+    }
+
+    @Override
+    public void shutdown() {
+    }
+  }
 }