xds:Allow big cluster total weight (#9864)
* xds: allow sum of cluster weights above MAX_INT up to max of unsigned int.
* Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.
diff --git a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
index 1e844ce..533ccee 100644
--- a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
+++ b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java
@@ -25,6 +25,8 @@
long nextLong();
+ long nextLong(long bound);
+
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
@@ -40,5 +42,10 @@
public long nextLong() {
return ThreadLocalRandom.current().nextLong();
}
+
+ @Override
+ public long nextLong(long bound) {
+ return ThreadLocalRandom.current().nextLong(bound);
+ }
}
}
diff --git a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
index 1f5fc6d..904f387 100644
--- a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
+++ b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java
@@ -21,6 +21,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
+import com.google.common.primitives.UnsignedInteger;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.SubchannelPicker;
@@ -34,21 +35,22 @@
final List<WeightedChildPicker> weightedChildPickers;
private final ThreadSafeRandom random;
- private final int totalWeight;
+ private final long totalWeight;
static final class WeightedChildPicker {
- private final int weight;
+ private final long weight;
private final SubchannelPicker childPicker;
- WeightedChildPicker(int weight, SubchannelPicker childPicker) {
+ WeightedChildPicker(long weight, SubchannelPicker childPicker) {
checkArgument(weight >= 0, "weight is negative");
+ checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
checkNotNull(childPicker, "childPicker is null");
this.weight = weight;
this.childPicker = childPicker;
}
- int getWeight() {
+ long getWeight() {
return weight;
}
@@ -93,12 +95,16 @@
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
- int totalWeight = 0;
+ long totalWeight = 0;
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
- int weight = weightedChildPicker.getWeight();
+ long weight = weightedChildPicker.getWeight();
+ checkArgument(weight >= 0, "weight is negative");
+ checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
totalWeight += weight;
}
this.totalWeight = totalWeight;
+ checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
+ "total weight greater than unsigned int can hold");
this.random = random;
}
@@ -111,15 +117,15 @@
childPicker =
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
} else {
- int rand = random.nextInt(totalWeight);
+ long rand = random.nextLong(totalWeight);
// Find the first idx such that rand < accumulatedWeights[idx]
// Not using Arrays.binarySearch for better readability.
- int accumulatedWeight = 0;
- for (int idx = 0; idx < weightedChildPickers.size(); idx++) {
- accumulatedWeight += weightedChildPickers.get(idx).getWeight();
+ long accumulatedWeight = 0;
+ for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
+ accumulatedWeight += weightedChildPicker.getWeight();
if (rand < accumulatedWeight) {
- childPicker = weightedChildPickers.get(idx).getPicker();
+ childPicker = weightedChildPicker.getPicker();
break;
}
}
diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
index 094bb94..8a5992a 100644
--- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
+++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java
@@ -437,12 +437,12 @@
if (action.cluster() != null) {
cluster = prefixedClusterName(action.cluster());
} else if (action.weightedClusters() != null) {
- int totalWeight = 0;
+ long totalWeight = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
totalWeight += weightedCluster.weight();
}
- int select = random.nextInt(totalWeight);
- int accumulator = 0;
+ long select = random.nextLong(totalWeight);
+ long accumulator = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
accumulator += weightedCluster.weight();
if (select < accumulator) {
diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
index ed109fd..6ae2340 100644
--- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
+++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java
@@ -24,6 +24,7 @@
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import com.google.common.primitives.UnsignedInteger;
import com.google.protobuf.Any;
import com.google.protobuf.Duration;
import com.google.protobuf.InvalidProtocolBufferException;
@@ -477,7 +478,7 @@
return StructOrError.fromError("No cluster found in weighted cluster list");
}
List<ClusterWeight> weightedClusters = new ArrayList<>();
- int clusterWeightSum = 0;
+ long clusterWeightSum = 0;
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
: clusterWeights) {
StructOrError<ClusterWeight> clusterWeightOrError =
@@ -492,6 +493,12 @@
if (clusterWeightSum <= 0) {
return StructOrError.fromError("Sum of cluster weights should be above 0.");
}
+ if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
+ return StructOrError.fromError(String.format(
+ "Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
+ + " was %d. ",
+ UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
+ }
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
weightedClusters, hashPolicies, timeoutNano, retryPolicy));
case CLUSTER_SPECIFIER_PLUGIN:
@@ -499,7 +506,7 @@
String pluginName = proto.getClusterSpecifierPlugin();
PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
if (pluginConfig == null) {
- // Skip route if the plugin is not registered, but it's optional.
+ // Skip route if the plugin is not registered, but it is optional.
if (optionalPlugins.contains(pluginName)) {
return null;
}
diff --git a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
index ecdd96a..d6240fb 100644
--- a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
+++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java
@@ -87,7 +87,8 @@
private static final class FakeRandom implements ThreadSafeRandom {
int nextInt;
- int bound;
+ long bound;
+ Long nextLong;
@Override
public int nextInt(int bound) {
@@ -102,6 +103,23 @@
public long nextLong() {
throw new UnsupportedOperationException("Should not be called");
}
+
+ @Override
+ public long nextLong(long bound) {
+ this.bound = bound;
+
+ if (nextLong == null) {
+ assertThat(nextInt).isAtLeast(0);
+ if (bound <= Integer.MAX_VALUE) {
+ assertThat(nextInt).isLessThan((int)bound);
+ }
+ return nextInt;
+ }
+
+ assertThat(nextLong).isAtLeast(0);
+ assertThat(nextLong).isLessThan(bound);
+ return nextLong;
+ }
}
private final FakeRandom fakeRandom = new FakeRandom();
@@ -121,6 +139,24 @@
}
@Test
+ public void overWeightSingle() {
+ thrown.expect(IllegalArgumentException.class);
+ new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
+ }
+
+ @Test
+ public void overWeightAggregate() {
+
+ List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
+ new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
+ new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
+ new WeightedChildPicker(10, childPicker2));
+
+ thrown.expect(IllegalArgumentException.class);
+ new WeightedRandomPicker(weightedChildPickers, fakeRandom);
+ }
+
+ @Test
public void pickWithFakeRandom() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
WeightedChildPicker weightedChildPicker1 = new WeightedChildPicker(15, childPicker1);
@@ -157,6 +193,36 @@
}
@Test
+ public void pickFromLargeTotal() {
+
+ List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
+ new WeightedChildPicker(10, childPicker0),
+ new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
+ new WeightedChildPicker(10, childPicker2));
+ WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
+
+ long totalWeight = weightedChildPickers.stream()
+ .mapToLong(WeightedChildPicker::getWeight)
+ .reduce(0, Long::sum);
+
+ fakeRandom.nextLong = 5L;
+ assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
+ assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+ fakeRandom.nextLong = 16L;
+ assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
+ assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+ fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
+ assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
+ assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+
+ fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
+ assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
+ assertThat(fakeRandom.bound).isEqualTo(totalWeight);
+ }
+
+ @Test
public void allZeroWeights() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
WeightedChildPicker weightedChildPicker1 = new WeightedChildPicker(0, childPicker1);
diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
index b6f8b3c..3d934e1 100644
--- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java
@@ -24,6 +24,7 @@
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -994,6 +995,7 @@
@Test
public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
+ when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
resolver.start(mockListener);
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
xdsClient.deliverLdsUpdate(