xds: Correctly start LRS clients in federation situations (#10000)

xds: Correctly start LRS clients in federation situations

The old code used a single member variable to indicate if load reporting
had already been started by XdsClientImpl. This boolean was used to
avoid starting a LoadReportClient more than twice. This works fine with
a single control plane server.

The problem occurs in federation situations where there is more than one
control plane and thus more than one LoadReportClient. Once the first
LoadReportClient is started, the member variable boolean is flipped to
true and no other LoadReportClients would be started.

This change removes the boolean member variable and relies on the fact
that starting an already started LoadReportClient is a no-op.
diff --git a/xds/src/main/java/io/grpc/xds/LoadReportClient.java b/xds/src/main/java/io/grpc/xds/LoadReportClient.java
index d6a3679..b28357c 100644
--- a/xds/src/main/java/io/grpc/xds/LoadReportClient.java
+++ b/xds/src/main/java/io/grpc/xds/LoadReportClient.java
@@ -69,7 +69,8 @@
   @Nullable
   private ScheduledHandle lrsRpcRetryTimer;
   @Nullable
-  private LrsStream lrsStream;
+  @VisibleForTesting
+  LrsStream lrsStream;
 
   LoadReportClient(
       LoadStatsManager2 loadStatsManager,
diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java
index 591c4d7..a66671b 100644
--- a/xds/src/main/java/io/grpc/xds/XdsClient.java
+++ b/xds/src/main/java/io/grpc/xds/XdsClient.java
@@ -19,6 +19,7 @@
 import static com.google.common.base.Preconditions.checkNotNull;
 import static io.grpc.xds.Bootstrapper.XDSTP_SCHEME;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Joiner;
 import com.google.common.base.Splitter;
 import com.google.common.net.UrlEscapers;
@@ -343,6 +344,15 @@
     throw new UnsupportedOperationException();
   }
 
+  /**
+   * Returns a map of control plane server info objects to the LoadReportClients that are
+   * responsible for sending load reports to the control plane servers.
+   */
+  @VisibleForTesting
+  Map<ServerInfo, LoadReportClient> getServerLrsClientMap() {
+    throw new UnsupportedOperationException();
+  }
+
   interface XdsResponseHandler {
     /** Called when a xds response is received. */
     void handleResourceResponse(
diff --git a/xds/src/main/java/io/grpc/xds/XdsClientImpl.java b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java
index 0216296..a9eb348 100644
--- a/xds/src/main/java/io/grpc/xds/XdsClientImpl.java
+++ b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java
@@ -107,7 +107,6 @@
   private final BackoffPolicy.Provider backoffPolicyProvider;
   private final Supplier<Stopwatch> stopwatchSupplier;
   private final TimeProvider timeProvider;
-  private boolean reportingLoad;
   private final TlsContextManager tlsContextManager;
   private final InternalLogId logId;
   private final XdsLogger logger;
@@ -221,10 +220,8 @@
             for (ControlPlaneClient xdsChannel : serverChannelMap.values()) {
               xdsChannel.shutdown();
             }
-            if (reportingLoad) {
-              for (final LoadReportClient lrsClient : serverLrsClientMap.values()) {
-                lrsClient.stopLoadReporting();
-              }
+            for (final LoadReportClient lrsClient : serverLrsClientMap.values()) {
+              lrsClient.stopLoadReporting();
             }
             cleanUpResourceTimers();
           }
@@ -350,10 +347,7 @@
     syncContext.execute(new Runnable() {
       @Override
       public void run() {
-        if (!reportingLoad) {
-          serverLrsClientMap.get(serverInfo).startLoadReporting();
-          reportingLoad = true;
-        }
+        serverLrsClientMap.get(serverInfo).startLoadReporting();
       }
     });
     return dropCounter;
@@ -368,10 +362,7 @@
     syncContext.execute(new Runnable() {
       @Override
       public void run() {
-        if (!reportingLoad) {
-          serverLrsClientMap.get(serverInfo).startLoadReporting();
-          reportingLoad = true;
-        }
+        serverLrsClientMap.get(serverInfo).startLoadReporting();
       }
     });
     return loadCounter;
@@ -382,6 +373,12 @@
     return bootstrapInfo;
   }
 
+  @VisibleForTesting
+  @Override
+  Map<ServerInfo, LoadReportClient> getServerLrsClientMap() {
+    return ImmutableMap.copyOf(serverLrsClientMap);
+  }
+
   @Override
   public String toString() {
     return logId.toString();
diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java
index 54d428b..aea892c 100644
--- a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java
@@ -16,6 +16,7 @@
 
 package io.grpc.xds;
 
+import static com.google.common.truth.Truth.assertThat;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -23,11 +24,13 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import io.grpc.internal.ObjectPool;
+import io.grpc.xds.Bootstrapper.ServerInfo;
 import io.grpc.xds.Filter.NamedFilterConfig;
 import io.grpc.xds.XdsClient.ResourceWatcher;
 import io.grpc.xds.XdsListenerResource.LdsUpdate;
 import java.util.Collections;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.UUID;
 import org.junit.After;
 import org.junit.Before;
@@ -116,6 +119,61 @@
             "xdstp://server-one/envoy.config.listener.v3.Listener/test-server");
   }
 
+  /**
+   * Assures that when an {@link XdsClient} is asked to add cluster locality stats it appropriately
+   * starts {@link LoadReportClient}s to do that.
+   */
+  @Test
+  public void lrsClientsStartedForLocalityStats() throws InterruptedException {
+    trafficdirector.setLdsConfig(ControlPlaneRule.buildServerListener(),
+        ControlPlaneRule.buildClientListener("test-server"));
+    directpathPa.setLdsConfig(ControlPlaneRule.buildServerListener(),
+        ControlPlaneRule.buildClientListener(
+            "xdstp://server-one/envoy.config.listener.v3.Listener/test-server"));
+
+    xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "test-server", mockWatcher);
+    xdsClient.watchXdsResource(XdsListenerResource.getInstance(),
+        "xdstp://server-one/envoy.config.listener.v3.Listener/test-server", mockDirectPathWatcher);
+
+    // With two control planes and a watcher for each, there should be two LRS clients.
+    assertThat(xdsClient.getServerLrsClientMap().size()).isEqualTo(2);
+
+    // When the XdsClient is asked to report locality stats for a control plane server, the
+    // corresponding LRS client should be started
+    for (Entry<ServerInfo, LoadReportClient> entry : xdsClient.getServerLrsClientMap().entrySet()) {
+      xdsClient.addClusterLocalityStats(entry.getKey(), "clusterName", "edsServiceName",
+          Locality.create("", "", ""));
+      assertThat(entry.getValue().lrsStream).isNotNull();
+    }
+  }
+
+  /**
+   * Assures that when an {@link XdsClient} is asked to add cluster locality stats it appropriately
+   * starts {@link LoadReportClient}s to do that.
+   */
+  @Test
+  public void lrsClientsStartedForDropStats() throws InterruptedException {
+    trafficdirector.setLdsConfig(ControlPlaneRule.buildServerListener(),
+        ControlPlaneRule.buildClientListener("test-server"));
+    directpathPa.setLdsConfig(ControlPlaneRule.buildServerListener(),
+        ControlPlaneRule.buildClientListener(
+            "xdstp://server-one/envoy.config.listener.v3.Listener/test-server"));
+
+    xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "test-server", mockWatcher);
+    xdsClient.watchXdsResource(XdsListenerResource.getInstance(),
+        "xdstp://server-one/envoy.config.listener.v3.Listener/test-server", mockDirectPathWatcher);
+
+    // With two control planes and a watcher for each, there should be two LRS clients.
+    assertThat(xdsClient.getServerLrsClientMap().size()).isEqualTo(2);
+
+    // When the XdsClient is asked to report drop stats for a control plane server, the
+    // corresponding LRS client should be started
+    for (Entry<ServerInfo, LoadReportClient> entry : xdsClient.getServerLrsClientMap().entrySet()) {
+      xdsClient.addClusterDropStats(entry.getKey(), "clusterName", "edsServiceName");
+      assertThat(entry.getValue().lrsStream).isNotNull();
+    }
+  }
+
   private Map<String, ?> defaultBootstrapOverride() {
     return ImmutableMap.of(
         "node", ImmutableMap.of(