xds:Allow big cluster total weight (#9864)

* xds:  allow sum of cluster weights above MAX_INT up to max of unsigned int.

* Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.

diff --git a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
index 1e844ce..533ccee 100644
--- a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
+++ b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
@@ -25,6 +25,8 @@
 
   long nextLong();
 
+  long nextLong(long bound);
+
   final class ThreadSafeRandomImpl implements ThreadSafeRandom {
 
     static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
@@ -40,5 +42,10 @@
     public long nextLong() {
       return ThreadLocalRandom.current().nextLong();
     }
+
+    @Override
+    public long nextLong(long bound) {
+      return ThreadLocalRandom.current().nextLong(bound);
+    }
   }
 }
diff --git a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
index 1f5fc6d..904f387 100644
--- a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
+++ b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
@@ -21,6 +21,7 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.MoreObjects;
+import com.google.common.primitives.UnsignedInteger;
 import io.grpc.LoadBalancer.PickResult;
 import io.grpc.LoadBalancer.PickSubchannelArgs;
 import io.grpc.LoadBalancer.SubchannelPicker;
@@ -34,21 +35,22 @@
   final List<WeightedChildPicker> weightedChildPickers;
 
   private final ThreadSafeRandom random;
-  private final int totalWeight;
+  private final long totalWeight;
 
   static final class WeightedChildPicker {
-    private final int weight;
+    private final long weight;
     private final SubchannelPicker childPicker;
 
-    WeightedChildPicker(int weight, SubchannelPicker childPicker) {
+    WeightedChildPicker(long weight, SubchannelPicker childPicker) {
       checkArgument(weight >= 0, "weight is negative");
+      checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
       checkNotNull(childPicker, "childPicker is null");
 
       this.weight = weight;
       this.childPicker = childPicker;
     }
 
-    int getWeight() {
+    long getWeight() {
       return weight;
     }
 
@@ -93,12 +95,16 @@
 
     this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
 
-    int totalWeight = 0;
+    long totalWeight = 0;
     for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
-      int weight = weightedChildPicker.getWeight();
+      long weight = weightedChildPicker.getWeight();
+      checkArgument(weight >= 0, "weight is negative");
+      checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
       totalWeight += weight;
     }
     this.totalWeight = totalWeight;
+    checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
+        "total weight greater than unsigned int can hold");
 
     this.random = random;
   }
@@ -111,15 +117,15 @@
       childPicker =
           weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
     } else {
-      int rand = random.nextInt(totalWeight);
+      long rand = random.nextLong(totalWeight);
 
       // Find the first idx such that rand < accumulatedWeights[idx]
       // Not using Arrays.binarySearch for better readability.
-      int accumulatedWeight = 0;
-      for (int idx = 0; idx < weightedChildPickers.size(); idx++) {
-        accumulatedWeight += weightedChildPickers.get(idx).getWeight();
+      long accumulatedWeight = 0;
+      for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
+        accumulatedWeight += weightedChildPicker.getWeight();
         if (rand < accumulatedWeight) {
-          childPicker = weightedChildPickers.get(idx).getPicker();
+          childPicker = weightedChildPicker.getPicker();
           break;
         }
       }
diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
index 094bb94..8a5992a 100644
--- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
+++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
@@ -437,12 +437,12 @@
         if (action.cluster() != null) {
           cluster = prefixedClusterName(action.cluster());
         } else if (action.weightedClusters() != null) {
-          int totalWeight = 0;
+          long totalWeight = 0;
           for (ClusterWeight weightedCluster : action.weightedClusters()) {
             totalWeight += weightedCluster.weight();
           }
-          int select = random.nextInt(totalWeight);
-          int accumulator = 0;
+          long select = random.nextLong(totalWeight);
+          long accumulator = 0;
           for (ClusterWeight weightedCluster : action.weightedClusters()) {
             accumulator += weightedCluster.weight();
             if (select < accumulator) {
diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
index ed109fd..6ae2340 100644
--- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
+++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
@@ -24,6 +24,7 @@
 import com.google.common.base.Splitter;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.primitives.UnsignedInteger;
 import com.google.protobuf.Any;
 import com.google.protobuf.Duration;
 import com.google.protobuf.InvalidProtocolBufferException;
@@ -477,7 +478,7 @@
           return StructOrError.fromError("No cluster found in weighted cluster list");
         }
         List<ClusterWeight> weightedClusters = new ArrayList<>();
-        int clusterWeightSum = 0;
+        long clusterWeightSum = 0;
         for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
             : clusterWeights) {
           StructOrError<ClusterWeight> clusterWeightOrError =
@@ -492,6 +493,12 @@
         if (clusterWeightSum <= 0) {
           return StructOrError.fromError("Sum of cluster weights should be above 0.");
         }
+        if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
+          return StructOrError.fromError(String.format(
+              "Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
+                  + " was %d. ",
+              UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
+        }
         return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
             weightedClusters, hashPolicies, timeoutNano, retryPolicy));
       case CLUSTER_SPECIFIER_PLUGIN:
@@ -499,7 +506,7 @@
           String pluginName = proto.getClusterSpecifierPlugin();
           PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
           if (pluginConfig == null) {
-            // Skip route if the plugin is not registered, but it's optional.
+            // Skip route if the plugin is not registered, but it is optional.
             if (optionalPlugins.contains(pluginName)) {
               return null;
             }
diff --git a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
index ecdd96a..d6240fb 100644
--- a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
+++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
@@ -87,7 +87,8 @@
 
   private static final class FakeRandom implements ThreadSafeRandom {
     int nextInt;
-    int bound;
+    long bound;
+    Long nextLong;
 
     @Override
     public int nextInt(int bound) {
@@ -102,6 +103,23 @@
     public long nextLong() {
       throw new UnsupportedOperationException("Should not be called");
     }
+
+    @Override
+    public long nextLong(long bound) {
+      this.bound = bound;
+
+      if (nextLong == null) {
+        assertThat(nextInt).isAtLeast(0);
+        if (bound <= Integer.MAX_VALUE) {
+          assertThat(nextInt).isLessThan((int)bound);
+        }
+        return nextInt;
+      }
+
+      assertThat(nextLong).isAtLeast(0);
+      assertThat(nextLong).isLessThan(bound);
+      return nextLong;
+    }
   }
 
   private final FakeRandom fakeRandom = new FakeRandom();
@@ -121,6 +139,24 @@
   }
 
   @Test
+  public void overWeightSingle() {
+    thrown.expect(IllegalArgumentException.class);
+    new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
+  }
+
+  @Test
+  public void overWeightAggregate() {
+
+    List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
+        new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
+        new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
+        new WeightedChildPicker(10, childPicker2));
+
+    thrown.expect(IllegalArgumentException.class);
+    new WeightedRandomPicker(weightedChildPickers, fakeRandom);
+  }
+
+  @Test
   public void pickWithFakeRandom() {
     WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
     WeightedChildPicker weightedChildPicker1 = new WeightedChildPicker(15, childPicker1);
@@ -157,6 +193,36 @@
   }
 
   @Test
+  public void pickFromLargeTotal() {
+
+    List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
+        new WeightedChildPicker(10, childPicker0),
+        new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
+        new WeightedChildPicker(10, childPicker2));
+    WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
+
+    long totalWeight = weightedChildPickers.stream()
+        .mapToLong(WeightedChildPicker::getWeight)
+        .reduce(0, Long::sum);
+
+    fakeRandom.nextLong = 5L;
+    assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
+    assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+    fakeRandom.nextLong = 16L;
+    assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
+    assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+    fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
+    assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
+    assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+    fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
+    assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
+    assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+  }
+
+  @Test
   public void allZeroWeights() {
     WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
     WeightedChildPicker weightedChildPicker1 = new WeightedChildPicker(0, childPicker1);
diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
index b6f8b3c..3d934e1 100644
--- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
@@ -24,6 +24,7 @@
 import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -994,6 +995,7 @@
   @Test
   public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
     when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
+    when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
     resolver.start(mockListener);
     FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
     xdsClient.deliverLdsUpdate(