grpclb: support "pick_first" child policy (#5438)
The PICK_FIRST mode puts all backend addresses in a single Subchannel. There are a few points where it's different from the default ROUND_ROBIN mode:
1. PICK_FIRST doesn't eagerly connect to backends like ROUND_ROBIN does. Instead, it requests for connections when the Subchannel is picked.
2. PICK_FIRST adds tokens to the headers via a different code path (`TokenAttachingTracerFactory`) than ROUND_ROBIN
3. For simple implementation, when the mode is changed by service config when the LoadBalancer is working, we will shut down `GrpclbState` and starts a new one with the new mode. All connections will be closed during the transition. We don't expect this to happen in practice given the specific use case of PICK_FIRST.
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
index 87e2b6b..65f4832 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
@@ -16,6 +16,8 @@
package io.grpc.grpclb;
+import io.grpc.Attributes;
+import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.Metadata;
@@ -32,5 +34,12 @@
public static final Metadata.Key<String> TOKEN_METADATA_KEY =
Metadata.Key.of("lb-token", Metadata.ASCII_STRING_MARSHALLER);
+ /**
+ * For passing LB tokens via the EAG attributes.
+ */
+ @EquivalentAddressGroup.Attr
+ static final Attributes.Key<String> TOKEN_ATTRIBUTE_KEY =
+ Attributes.Key.create("lb-token");
+
private GrpclbConstants() { }
}
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
index 09ed32e..b0719f0 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
@@ -17,6 +17,7 @@
package io.grpc.grpclb;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Attributes;
@@ -51,7 +52,11 @@
private static final Logger logger = Logger.getLogger(GrpclbLoadBalancer.class.getName());
private final Helper helper;
+ private final TimeProvider time;
private final SubchannelPool subchannelPool;
+ private final BackoffPolicy.Provider backoffPolicyProvider;
+
+ private Mode mode = Mode.ROUND_ROBIN;
// All mutable states in this class are mutated ONLY from Channel Executor
@Nullable
@@ -63,12 +68,12 @@
TimeProvider time,
BackoffPolicy.Provider backoffPolicyProvider) {
this.helper = checkNotNull(helper, "helper");
- checkNotNull(time, "time provider");
- checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
+ this.time = checkNotNull(time, "time provider");
+ this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
this.subchannelPool.init(helper);
- grpclbState =
- new GrpclbState(helper, subchannelPool, time, backoffPolicyProvider);
+ recreateStates();
+ checkNotNull(grpclbState, "grpclbState");
}
@Override
@@ -97,7 +102,12 @@
newBackendServers = Collections.unmodifiableList(newBackendServers);
Map<String, Object> rawLbConfigValue = attributes.get(ATTR_LOAD_BALANCING_CONFIG);
Mode newMode = retrieveModeFromLbConfig(rawLbConfigValue, helper.getChannelLogger());
- grpclbState.handleAddresses(newLbAddressGroups, newBackendServers, newMode);
+ if (!mode.equals(newMode)) {
+ mode = newMode;
+ helper.getChannelLogger().log(ChannelLogLevel.INFO, "Mode: " + newMode);
+ recreateStates();
+ }
+ grpclbState.handleAddresses(newLbAddressGroups, newBackendServers);
}
@VisibleForTesting
@@ -141,6 +151,12 @@
}
}
+ private void recreateStates() {
+ resetStates();
+ checkState(grpclbState == null, "Should've been cleared");
+ grpclbState = new GrpclbState(mode, helper, subchannelPool, time, backoffPolicyProvider);
+ }
+
@Override
public void shutdown() {
resetStates();
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
index e4203da..4d87f32 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
@@ -138,8 +138,8 @@
@Nullable
private LbStream lbStream;
- private Map<EquivalentAddressGroup, Subchannel> subchannels = Collections.emptyMap();
- private Mode mode;
+ private Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Collections.emptyMap();
+ private final Mode mode;
// Has the same size as the round-robin list from the balancer.
// A drop entry from the round-robin list becomes a DropEntry here.
@@ -151,10 +151,12 @@
new RoundRobinPicker(Collections.<DropEntry>emptyList(), Arrays.asList(BUFFER_ENTRY));
GrpclbState(
+ Mode mode,
Helper helper,
SubchannelPool subchannelPool,
TimeProvider time,
BackoffPolicy.Provider backoffPolicyProvider) {
+ this.mode = checkNotNull(mode, "mode");
this.helper = checkNotNull(helper, "helper");
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
@@ -169,7 +171,7 @@
if (newState.getState() == SHUTDOWN || !subchannels.values().contains(subchannel)) {
return;
}
- if (newState.getState() == IDLE) {
+ if (mode == Mode.ROUND_ROBIN && newState.getState() == IDLE) {
subchannel.requestConnection();
}
subchannel.getAttributes().get(STATE_INFO).set(newState);
@@ -182,8 +184,7 @@
* not yet connected.
*/
void handleAddresses(
- List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers,
- Mode mode) {
+ List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers) {
if (newLbAddressGroups.isEmpty()) {
propagateError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no LB address while asking for GRPCLB"));
@@ -305,10 +306,20 @@
void shutdown() {
shutdownLbComm();
- // We close the subchannels through subchannelPool instead of helper just for convenience of
- // testing.
- for (Subchannel subchannel : subchannels.values()) {
- subchannelPool.returnSubchannel(subchannel);
+ switch (mode) {
+ case ROUND_ROBIN:
+ // We close the subchannels through subchannelPool instead of helper just for convenience of
+ // testing.
+ for (Subchannel subchannel : subchannels.values()) {
+ subchannelPool.returnSubchannel(subchannel);
+ }
+ break;
+ case PICK_FIRST:
+ checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels);
+ subchannels.values().iterator().next().shutdown();
+ break;
+ default:
+ throw new AssertionError("Missing case for " + mode);
}
subchannels = Collections.emptyMap();
subchannelPool.clear();
@@ -341,45 +352,74 @@
@Nullable GrpclbClientLoadRecorder loadRecorder) {
logger.log(
ChannelLogLevel.INFO, "Using RR list={0}, drop={1}", newBackendAddrList, newDropList);
- HashMap<EquivalentAddressGroup, Subchannel> newSubchannelMap =
+ HashMap<List<EquivalentAddressGroup>, Subchannel> newSubchannelMap =
new HashMap<>();
List<BackendEntry> newBackendList = new ArrayList<>();
- for (BackendAddressGroup backendAddr : newBackendAddrList) {
- EquivalentAddressGroup eag = backendAddr.getAddresses();
- Subchannel subchannel = newSubchannelMap.get(eag);
- if (subchannel == null) {
- subchannel = subchannels.get(eag);
- if (subchannel == null) {
- Attributes subchannelAttrs = Attributes.newBuilder()
- .set(STATE_INFO,
- new AtomicReference<>(
- ConnectivityStateInfo.forNonError(IDLE)))
- .build();
- subchannel = subchannelPool.takeOrCreateSubchannel(eag, subchannelAttrs);
- subchannel.requestConnection();
+ switch (mode) {
+ case ROUND_ROBIN:
+ for (BackendAddressGroup backendAddr : newBackendAddrList) {
+ EquivalentAddressGroup eag = backendAddr.getAddresses();
+ List<EquivalentAddressGroup> eagAsList = Collections.singletonList(eag);
+ Subchannel subchannel = newSubchannelMap.get(eagAsList);
+ if (subchannel == null) {
+ subchannel = subchannels.get(eagAsList);
+ if (subchannel == null) {
+ subchannel = subchannelPool.takeOrCreateSubchannel(eag, createSubchannelAttrs());
+ subchannel.requestConnection();
+ }
+ newSubchannelMap.put(eagAsList, subchannel);
+ }
+ BackendEntry entry;
+ // Only picks with tokens are reported to LoadRecorder
+ if (backendAddr.getToken() == null) {
+ entry = new BackendEntry(subchannel);
+ } else {
+ entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
+ }
+ newBackendList.add(entry);
}
- newSubchannelMap.put(eag, subchannel);
- }
- BackendEntry entry;
- // Only picks with tokens are reported to LoadRecorder
- if (backendAddr.getToken() == null) {
- entry = new BackendEntry(subchannel);
- } else {
- entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
- }
- newBackendList.add(entry);
+ // Close Subchannels whose addresses have been delisted
+ for (Entry<List<EquivalentAddressGroup>, Subchannel> entry : subchannels.entrySet()) {
+ List<EquivalentAddressGroup> eagList = entry.getKey();
+ if (!newSubchannelMap.containsKey(eagList)) {
+ subchannelPool.returnSubchannel(entry.getValue());
+ }
+ }
+ subchannels = Collections.unmodifiableMap(newSubchannelMap);
+ break;
+ case PICK_FIRST:
+ List<EquivalentAddressGroup> eagList = new ArrayList<>();
+ // Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to
+ // attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on
+ // headers.
+ //
+ // The PICK_FIRST code path doesn't cache Subchannels.
+ for (BackendAddressGroup bag : newBackendAddrList) {
+ EquivalentAddressGroup origEag = bag.getAddresses();
+ Attributes eagAttrs = origEag.getAttributes();
+ if (bag.getToken() != null) {
+ eagAttrs = eagAttrs.toBuilder()
+ .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, bag.getToken()).build();
+ }
+ eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs));
+ }
+ Subchannel subchannel;
+ if (subchannels.isEmpty()) {
+ subchannel = helper.createSubchannel(eagList, createSubchannelAttrs());
+ } else {
+ checkState(subchannels.size() == 1, "Unexpected Subchannel count: %s", subchannels);
+ subchannel = subchannels.values().iterator().next();
+ helper.updateSubchannelAddresses(subchannel, eagList);
+ }
+ subchannels = Collections.singletonMap(eagList, subchannel);
+ newBackendList.add(
+ new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder)));
+ break;
+ default:
+ throw new AssertionError("Missing case for " + mode);
}
- // Close Subchannels whose addresses have been delisted
- for (Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) {
- EquivalentAddressGroup eag = entry.getKey();
- if (!newSubchannelMap.containsKey(eag)) {
- subchannelPool.returnSubchannel(entry.getValue());
- }
- }
-
- subchannels = Collections.unmodifiableMap(newSubchannelMap);
dropList = Collections.unmodifiableList(newDropList);
backendList = Collections.unmodifiableList(newBackendList);
}
@@ -619,32 +659,67 @@
* changed since the last picker created.
*/
private void maybeUpdatePicker() {
- List<RoundRobinEntry> pickList = new ArrayList<>(backendList.size());
- Status error = null;
- boolean hasIdle = false;
- for (BackendEntry entry : backendList) {
- Subchannel subchannel = entry.result.getSubchannel();
- Attributes attrs = subchannel.getAttributes();
- ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
- if (stateInfo.getState() == READY) {
- pickList.add(entry);
- } else if (stateInfo.getState() == TRANSIENT_FAILURE) {
- error = stateInfo.getStatus();
- } else if (stateInfo.getState() == IDLE) {
- hasIdle = true;
- }
- }
+ List<RoundRobinEntry> pickList;
ConnectivityState state;
- if (pickList.isEmpty()) {
- if (error != null && !hasIdle) {
- pickList.add(new ErrorEntry(error));
- state = TRANSIENT_FAILURE;
- } else {
- pickList.add(BUFFER_ENTRY);
- state = CONNECTING;
- }
- } else {
- state = READY;
+ switch (mode) {
+ case ROUND_ROBIN:
+ pickList = new ArrayList<>(backendList.size());
+ Status error = null;
+ boolean hasIdle = false;
+ for (BackendEntry entry : backendList) {
+ Subchannel subchannel = entry.subchannel;
+ Attributes attrs = subchannel.getAttributes();
+ ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
+ if (stateInfo.getState() == READY) {
+ pickList.add(entry);
+ } else if (stateInfo.getState() == TRANSIENT_FAILURE) {
+ error = stateInfo.getStatus();
+ } else if (stateInfo.getState() == IDLE) {
+ hasIdle = true;
+ }
+ }
+ if (pickList.isEmpty()) {
+ if (error != null && !hasIdle) {
+ pickList.add(new ErrorEntry(error));
+ state = TRANSIENT_FAILURE;
+ } else {
+ pickList.add(BUFFER_ENTRY);
+ state = CONNECTING;
+ }
+ } else {
+ state = READY;
+ }
+ break;
+ case PICK_FIRST:
+ if (backendList.isEmpty()) {
+ pickList = Collections.singletonList(BUFFER_ENTRY);
+ // Have not received server addresses
+ state = CONNECTING;
+ } else {
+ checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList);
+ BackendEntry onlyEntry = backendList.get(0);
+ ConnectivityStateInfo stateInfo =
+ onlyEntry.subchannel.getAttributes().get(STATE_INFO).get();
+ state = stateInfo.getState();
+ switch (state) {
+ case READY:
+ pickList = Collections.<RoundRobinEntry>singletonList(onlyEntry);
+ break;
+ case TRANSIENT_FAILURE:
+ pickList =
+ Collections.<RoundRobinEntry>singletonList(new ErrorEntry(stateInfo.getStatus()));
+ break;
+ case CONNECTING:
+ pickList = Collections.singletonList(BUFFER_ENTRY);
+ break;
+ default:
+ pickList = Collections.<RoundRobinEntry>singletonList(
+ new IdleSubchannelEntry(onlyEntry.subchannel));
+ }
+ }
+ break;
+ default:
+ throw new AssertionError("Missing case for " + mode);
}
maybeUpdatePicker(state, new RoundRobinPicker(dropList, pickList));
}
@@ -704,6 +779,14 @@
return new EquivalentAddressGroup(addrs, attrs);
}
+ private static Attributes createSubchannelAttrs() {
+ return Attributes.newBuilder()
+ .set(STATE_INFO,
+ new AtomicReference<>(
+ ConnectivityStateInfo.forNonError(IDLE)))
+ .build();
+ }
+
@VisibleForTesting
static final class DropEntry {
private final GrpclbClientLoadRecorder loadRecorder;
@@ -740,34 +823,45 @@
}
}
- private interface RoundRobinEntry {
+ @VisibleForTesting
+ interface RoundRobinEntry {
PickResult picked(Metadata headers);
}
@VisibleForTesting
static final class BackendEntry implements RoundRobinEntry {
+ final Subchannel subchannel;
@VisibleForTesting
final PickResult result;
@Nullable
- private final GrpclbClientLoadRecorder loadRecorder;
- @Nullable
private final String token;
/**
- * Creates a BackendEntry whose usage will be reported to load recorder.
+ * For ROUND_ROBIN: creates a BackendEntry whose usage will be reported to load recorder.
*/
BackendEntry(Subchannel subchannel, GrpclbClientLoadRecorder loadRecorder, String token) {
- this.result = PickResult.withSubchannel(subchannel, loadRecorder);
- this.loadRecorder = checkNotNull(loadRecorder, "loadRecorder");
+ this.subchannel = checkNotNull(subchannel, "subchannel");
+ this.result =
+ PickResult.withSubchannel(subchannel, checkNotNull(loadRecorder, "loadRecorder"));
this.token = checkNotNull(token, "token");
}
/**
- * Creates a BackendEntry whose usage will not be reported.
+ * For ROUND_ROBIN/PICK_FIRST: creates a BackendEntry whose usage will not be reported.
*/
BackendEntry(Subchannel subchannel) {
+ this.subchannel = checkNotNull(subchannel, "subchannel");
this.result = PickResult.withSubchannel(subchannel);
- this.loadRecorder = null;
+ this.token = null;
+ }
+
+ /**
+ * For PICK_FIRST: creates a BackendEntry that includes all addresses.
+ */
+ BackendEntry(Subchannel subchannel, TokenAttachingTracerFactory tracerFactory) {
+ this.subchannel = checkNotNull(subchannel, "subchannel");
+ this.result =
+ PickResult.withSubchannel(subchannel, checkNotNull(tracerFactory, "tracerFactory"));
this.token = null;
}
@@ -783,12 +877,12 @@
@Override
public String toString() {
// This is printed in logs. Only give out useful information.
- return "[" + result.getSubchannel().getAllAddresses().toString() + "(" + token + ")]";
+ return "[" + subchannel.getAllAddresses().toString() + "(" + token + ")]";
}
@Override
public int hashCode() {
- return Objects.hashCode(loadRecorder, result, token);
+ return Objects.hashCode(result, token);
}
@Override
@@ -797,8 +891,42 @@
return false;
}
BackendEntry that = (BackendEntry) other;
- return Objects.equal(result, that.result) && Objects.equal(token, that.token)
- && Objects.equal(loadRecorder, that.loadRecorder);
+ return Objects.equal(result, that.result) && Objects.equal(token, that.token);
+ }
+ }
+
+ @VisibleForTesting
+ static final class IdleSubchannelEntry implements RoundRobinEntry {
+ private final Subchannel subchannel;
+
+ IdleSubchannelEntry(Subchannel subchannel) {
+ this.subchannel = checkNotNull(subchannel, "subchannel");
+ }
+
+ @Override
+ public PickResult picked(Metadata headers) {
+ subchannel.requestConnection();
+ return PickResult.withNoResult();
+ }
+
+ @Override
+ public String toString() {
+ // This is printed in logs. Only give out useful information.
+ return "(idle)[" + subchannel.getAllAddresses().toString() + "]";
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(subchannel);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof IdleSubchannelEntry)) {
+ return false;
+ }
+ IdleSubchannelEntry that = (IdleSubchannelEntry) other;
+ return Objects.equal(subchannel, that.subchannel);
}
}
@@ -860,8 +988,7 @@
// First round-robin on dropList. If a drop entry is selected, request will be dropped. If
// a non-drop entry is selected, then round-robin on pickList. This makes sure requests are
// dropped at the same proportion as the drop entries appear on the round-robin list from
- // the balancer, while only READY backends (that make up pickList) are selected for the
- // non-drop cases.
+ // the balancer, while only backends from pickList are selected for the non-drop cases.
if (!dropList.isEmpty()) {
DropEntry drop = dropList.get(dropIndex);
dropIndex++;
@@ -881,5 +1008,14 @@
return pick.picked(args.getHeaders());
}
}
+
+ @Override
+ public void requestConnection() {
+ for (RoundRobinEntry entry : pickList) {
+ if (entry instanceof IdleSubchannelEntry) {
+ ((IdleSubchannelEntry) entry).subchannel.requestConnection();
+ }
+ }
+ }
}
}
diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java
new file mode 100644
index 0000000..03b9bdf
--- /dev/null
+++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.grpclb;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.common.base.Objects;
+import io.grpc.Attributes;
+import io.grpc.ClientStreamTracer;
+import io.grpc.Metadata;
+import io.grpc.internal.GrpcAttributes;
+import javax.annotation.Nullable;
+
+/**
+ * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and
+ * attaches them to headers. This is only used in the PICK_FIRST mode.
+ */
+final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory {
+ private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {};
+
+ @Nullable
+ private final ClientStreamTracer.Factory delegate;
+
+ TokenAttachingTracerFactory(@Nullable ClientStreamTracer.Factory delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public ClientStreamTracer newClientStreamTracer(
+ ClientStreamTracer.StreamInfo info, Metadata headers) {
+ Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs");
+ Attributes eagAttrs =
+ checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs");
+ String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY);
+ headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY);
+ if (token != null) {
+ headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token);
+ }
+ if (delegate != null) {
+ return delegate.newClientStreamTracer(info, headers);
+ } else {
+ return NOOP_TRACER;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(delegate);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof TokenAttachingTracerFactory)) {
+ return false;
+ }
+ return Objects.equal(delegate, ((TokenAttachingTracerFactory) other).delegate);
+ }
+}
diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
index e4bb5f3..ee6e185 100644
--- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
+++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
@@ -56,6 +56,7 @@
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
+import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
@@ -68,7 +69,9 @@
import io.grpc.grpclb.GrpclbState.BackendEntry;
import io.grpc.grpclb.GrpclbState.DropEntry;
import io.grpc.grpclb.GrpclbState.ErrorEntry;
+import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry;
import io.grpc.grpclb.GrpclbState.Mode;
+import io.grpc.grpclb.GrpclbState.RoundRobinEntry;
import io.grpc.grpclb.GrpclbState.RoundRobinPicker;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
@@ -166,7 +169,8 @@
new LinkedList<>();
private final LinkedList<Subchannel> mockSubchannels = new LinkedList<>();
private final LinkedList<ManagedChannel> fakeOobChannels = new LinkedList<>();
- private final ArrayList<Subchannel> subchannelTracker = new ArrayList<>();
+ private final ArrayList<Subchannel> pooledSubchannelTracker = new ArrayList<>();
+ private final ArrayList<Subchannel> unpooledSubchannelTracker = new ArrayList<>();
private final ArrayList<ManagedChannel> oobChannelTracker = new ArrayList<>();
private final ArrayList<String> failingLbAuthorities = new ArrayList<>();
private final SynchronizationContext syncContext = new SynchronizationContext(
@@ -251,11 +255,25 @@
when(subchannel.getAllAddresses()).thenReturn(Arrays.asList(eag));
when(subchannel.getAttributes()).thenReturn(attrs);
mockSubchannels.add(subchannel);
- subchannelTracker.add(subchannel);
+ pooledSubchannelTracker.add(subchannel);
return subchannel;
}
}).when(subchannelPool).takeOrCreateSubchannel(
any(EquivalentAddressGroup.class), any(Attributes.class));
+ doAnswer(new Answer<Subchannel>() {
+ @Override
+ public Subchannel answer(InvocationOnMock invocation) throws Throwable {
+ Subchannel subchannel = mock(Subchannel.class);
+ List<EquivalentAddressGroup> eagList =
+ (List<EquivalentAddressGroup>) invocation.getArguments()[0];
+ Attributes attrs = (Attributes) invocation.getArguments()[1];
+ when(subchannel.getAllAddresses()).thenReturn(eagList);
+ when(subchannel.getAttributes()).thenReturn(attrs);
+ mockSubchannels.add(subchannel);
+ unpooledSubchannelTracker.add(subchannel);
+ return subchannel;
+ }
+ }).when(helper).createSubchannel(any(List.class), any(Attributes.class));
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
when(helper.getChannelLogger()).thenReturn(channelLogger);
@@ -294,14 +312,15 @@
assertTrue(channel + " is terminated", channel.isTerminated());
}
// GRPCLB manages subchannels only through subchannelPool
- for (Subchannel subchannel: subchannelTracker) {
+ for (Subchannel subchannel : pooledSubchannelTracker) {
verify(subchannelPool).returnSubchannel(same(subchannel));
// Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if
// LoadBalancer has called it expectedly.
verify(subchannel, never()).shutdown();
}
- verify(helper, never())
- .createSubchannel(any(List.class), any(Attributes.class));
+ for (Subchannel subchannel : unpooledSubchannelTracker) {
+ verify(subchannel).shutdown();
+ }
// No timer should linger after shutdown
assertThat(fakeClock.getPendingTasks()).isEmpty();
} finally {
@@ -407,6 +426,65 @@
}
@Test
+ public void roundRobinPickerWithIdleEntry_noDrop() {
+ Subchannel subchannel = mock(Subchannel.class);
+ IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
+
+ RoundRobinPicker picker =
+ new RoundRobinPicker(Collections.<DropEntry>emptyList(), Collections.singletonList(entry));
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+ verify(subchannel, never()).requestConnection();
+ assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+ verify(subchannel).requestConnection();
+ }
+
+ @Test
+ public void roundRobinPickerWithIdleEntry_andDrop() {
+ GrpclbClientLoadRecorder loadRecorder =
+ new GrpclbClientLoadRecorder(fakeClock.getTimeProvider());
+ // 1 out of 2 requests are to be dropped
+ DropEntry d = new DropEntry(loadRecorder, "LBTOKEN0003");
+ List<DropEntry> dropList = Arrays.asList(null, d);
+
+ Subchannel subchannel = mock(Subchannel.class);
+ IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
+
+ RoundRobinPicker picker = new RoundRobinPicker(dropList, Collections.singletonList(entry));
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+ verify(subchannel, never()).requestConnection();
+ assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+ verify(subchannel).requestConnection();
+
+ assertThat(picker.pickSubchannel(args)).isSameAs(DROP_PICK_RESULT);
+
+ verify(subchannel).requestConnection();
+ assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+ verify(subchannel, times(2)).requestConnection();
+ }
+
+ @Test
+ public void roundRobinPicker_requestConnection() {
+ // requestConnection() on RoundRobinPicker is only passed to IdleSubchannelEntry
+
+ Subchannel subchannel1 = mock(Subchannel.class);
+ Subchannel subchannel2 = mock(Subchannel.class);
+
+ RoundRobinPicker picker = new RoundRobinPicker(
+ Collections.<DropEntry>emptyList(),
+ Arrays.<RoundRobinEntry>asList(
+ new BackendEntry(subchannel1), new IdleSubchannelEntry(subchannel2),
+ new ErrorEntry(Status.UNAVAILABLE)));
+
+ verify(subchannel2, never()).requestConnection();
+
+ picker.requestConnection();
+ verify(subchannel2).requestConnection();
+ verify(subchannel1, never()).requestConnection();
+ }
+
+ @Test
public void loadReporting() {
Metadata headers = new Metadata();
PickSubchannelArgs args = mock(PickSubchannelArgs.class);
@@ -1591,6 +1669,297 @@
verify(helper, times(4)).refreshNameResolution();
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void grpclbWorking_pickFirstMode() throws Exception {
+ InOrder inOrder = inOrder(helper);
+
+ String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+ List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
+ Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+ LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+
+ deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+ assertEquals(1, fakeOobChannels.size());
+ ManagedChannel oobChannel = fakeOobChannels.poll();
+ verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+ StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+ assertEquals(1, lbRequestObservers.size());
+ StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
+ verify(lbRequestObserver).onNext(
+ eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+ InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+ .build()));
+
+ // Simulate receiving LB response
+ List<ServerEntry> backends1 = Arrays.asList(
+ new ServerEntry("127.0.0.1", 2000, "token0001"),
+ new ServerEntry("127.0.0.1", 2010, "token0002"));
+ inOrder.verify(helper, never())
+ .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+ lbResponseObserver.onNext(buildInitialResponse());
+ lbResponseObserver.onNext(buildLbResponse(backends1));
+
+ inOrder.verify(helper).createSubchannel(
+ eq(Arrays.asList(
+ new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+ new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
+ any(Attributes.class));
+
+ // Initially IDLE
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
+
+ // Only one subchannel is created
+ assertThat(mockSubchannels).hasSize(1);
+ Subchannel subchannel = mockSubchannels.poll();
+ assertThat(picker0.dropList).containsExactly(null, null);
+ assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
+
+ // PICK_FIRST doesn't eagerly connect
+ verify(subchannel, never()).requestConnection();
+
+ // CONNECTING
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
+
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+ RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker1.dropList).containsExactly(null, null);
+ assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY);
+
+ // TRANSIENT_FAILURE
+ Status error = Status.UNAVAILABLE.withDescription("Simulated connection error");
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error));
+ inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+ RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker2.dropList).containsExactly(null, null);
+ assertThat(picker2.pickList).containsExactly(new ErrorEntry(error));
+
+ // READY
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker3.dropList).containsExactly(null, null);
+ assertThat(picker3.pickList).containsExactly(
+ new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+
+ // New server list with drops
+ List<ServerEntry> backends2 = Arrays.asList(
+ new ServerEntry("127.0.0.1", 2000, "token0001"),
+ new ServerEntry("token0003"), // drop
+ new ServerEntry("127.0.0.1", 2020, "token0004"));
+ inOrder.verify(helper, never())
+ .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+ lbResponseObserver.onNext(buildLbResponse(backends2));
+
+ // new addresses will be updated to the existing subchannel
+ // createSubchannel() has ever been called only once
+ verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
+ assertThat(mockSubchannels).isEmpty();
+ inOrder.verify(helper).updateSubchannelAddresses(
+ same(subchannel),
+ eq(Arrays.asList(
+ new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")),
+ new EquivalentAddressGroup(backends2.get(2).addr,
+ eagAttrsWithToken("token0004")))));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker4.dropList).containsExactly(
+ null, new DropEntry(getLoadRecorder(), "token0003"), null);
+ assertThat(picker4.pickList).containsExactly(
+ new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+ // Subchannel goes IDLE, but PICK_FIRST will not try to reconnect
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue();
+ verify(subchannel, never()).requestConnection();
+
+ // ... until it's selected
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+ PickResult pick = picker5.pickSubchannel(args);
+ assertThat(pick).isSameAs(PickResult.withNoResult());
+ verify(subchannel).requestConnection();
+
+ // ... or requested by application
+ picker5.requestConnection();
+ verify(subchannel, times(2)).requestConnection();
+
+ // PICK_FIRST doesn't use subchannelPool
+ verify(subchannelPool, never())
+ .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
+ verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void pickFirstMode_fallback() throws Exception {
+ InOrder inOrder = inOrder(helper);
+
+ String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+
+ // Name resolver returns a mix of balancer and backend addresses
+ List<EquivalentAddressGroup> grpclbResolutionList =
+ createResolvedServerAddresses(false, true, false);
+ Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+ LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+ deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+ // Attempted to connect to balancer
+ assertEquals(1, fakeOobChannels.size());
+ ManagedChannel oobChannel = fakeOobChannels.poll();
+ verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+ StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+ assertEquals(1, lbRequestObservers.size());
+
+ // Fallback timer expires with no response
+ fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
+
+ // Entering fallback mode
+ inOrder.verify(helper).createSubchannel(
+ eq(Arrays.asList(grpclbResolutionList.get(0), grpclbResolutionList.get(2))),
+ any(Attributes.class));
+
+ assertThat(mockSubchannels).hasSize(1);
+ Subchannel subchannel = mockSubchannels.poll();
+
+ // Initially IDLE
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
+
+ // READY
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker1.dropList).containsExactly(null, null);
+ assertThat(picker1.pickList).containsExactly(
+ new BackendEntry(subchannel, new TokenAttachingTracerFactory(null)));
+
+ assertThat(picker0.dropList).containsExactly(null, null);
+ assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
+
+
+ // Finally, an LB response, which brings us out of fallback
+ List<ServerEntry> backends1 = Arrays.asList(
+ new ServerEntry("127.0.0.1", 2000, "token0001"),
+ new ServerEntry("127.0.0.1", 2010, "token0002"));
+ inOrder.verify(helper, never())
+ .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+ lbResponseObserver.onNext(buildInitialResponse());
+ lbResponseObserver.onNext(buildLbResponse(backends1));
+
+ // new addresses will be updated to the existing subchannel
+ // createSubchannel() has ever been called only once
+ verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
+ assertThat(mockSubchannels).isEmpty();
+ inOrder.verify(helper).updateSubchannelAddresses(
+ same(subchannel),
+ eq(Arrays.asList(
+ new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+ new EquivalentAddressGroup(backends1.get(1).addr,
+ eagAttrsWithToken("token0002")))));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
+ assertThat(picker2.dropList).containsExactly(null, null);
+ assertThat(picker2.pickList).containsExactly(
+ new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+ // PICK_FIRST doesn't use subchannelPool
+ verify(subchannelPool, never())
+ .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
+ verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+ }
+
+ @Test
+ public void switchMode() throws Exception {
+ InOrder inOrder = inOrder(helper);
+
+ String lbConfig = "{\"childPolicy\" : [ {\"round_robin\" : {}} ]}";
+ List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
+ Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+ LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+
+ deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+ assertEquals(1, fakeOobChannels.size());
+ ManagedChannel oobChannel = fakeOobChannels.poll();
+ verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+ StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+ assertEquals(1, lbRequestObservers.size());
+ StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
+ verify(lbRequestObserver).onNext(
+ eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+ InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+ .build()));
+
+ // Simulate receiving LB response
+ List<ServerEntry> backends1 = Arrays.asList(
+ new ServerEntry("127.0.0.1", 2000, "token0001"),
+ new ServerEntry("127.0.0.1", 2010, "token0002"));
+ inOrder.verify(helper, never())
+ .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+ lbResponseObserver.onNext(buildInitialResponse());
+ lbResponseObserver.onNext(buildLbResponse(backends1));
+
+ // ROUND_ROBIN: create one subchannel per server
+ verify(subchannelPool).takeOrCreateSubchannel(
+ eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)),
+ any(Attributes.class));
+ verify(subchannelPool).takeOrCreateSubchannel(
+ eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)),
+ any(Attributes.class));
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+ assertEquals(2, mockSubchannels.size());
+ Subchannel subchannel1 = mockSubchannels.poll();
+ Subchannel subchannel2 = mockSubchannels.poll();
+ verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+
+ // Switch to PICK_FIRST
+ lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+ grpclbResolutionAttrs = Attributes.newBuilder().set(
+ LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+ deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+
+ // GrpclbState will be shutdown, and a new one will be created
+ assertThat(oobChannel.isShutdown()).isTrue();
+ verify(subchannelPool).returnSubchannel(same(subchannel1));
+ verify(subchannelPool).returnSubchannel(same(subchannel2));
+
+ // A new LB stream is created
+ assertEquals(1, fakeOobChannels.size());
+ oobChannel = fakeOobChannels.poll();
+ verify(mockLbService, times(2)).balanceLoad(lbResponseObserverCaptor.capture());
+ lbResponseObserver = lbResponseObserverCaptor.getValue();
+ assertEquals(1, lbRequestObservers.size());
+ lbRequestObserver = lbRequestObservers.poll();
+ verify(lbRequestObserver).onNext(
+ eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+ InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+ .build()));
+
+ // Simulate receiving LB response
+ inOrder.verify(helper, never())
+ .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+ lbResponseObserver.onNext(buildInitialResponse());
+ lbResponseObserver.onNext(buildLbResponse(backends1));
+
+ // PICK_FIRST Subchannel
+ inOrder.verify(helper).createSubchannel(
+ eq(Arrays.asList(
+ new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+ new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
+ any(Attributes.class));
+
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ }
+
+ private static Attributes eagAttrsWithToken(String token) {
+ return LB_BACKEND_ATTRS.toBuilder().set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, token).build();
+ }
+
@Test
public void retrieveModeFromLbConfig_pickFirst() throws Exception {
String lbConfig = "{\"childPolicy\" : [{\"pick_first\" : {}}, {\"round_robin\" : {}}]}";
diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java
new file mode 100644
index 0000000..469372b
--- /dev/null
+++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java
@@ -0,0 +1,124 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.grpclb;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+import io.grpc.Attributes;
+import io.grpc.CallOptions;
+import io.grpc.ClientStreamTracer;
+import io.grpc.Metadata;
+import io.grpc.internal.GrpcAttributes;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link TokenAttachingTracerFactory}. */
+@RunWith(JUnit4.class)
+public class TokenAttachingTracerFactoryTest {
+ private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {};
+
+ private final ClientStreamTracer.Factory delegate = mock(
+ ClientStreamTracer.Factory.class,
+ delegatesTo(
+ new ClientStreamTracer.Factory() {
+ @Override
+ public ClientStreamTracer newClientStreamTracer(
+ ClientStreamTracer.StreamInfo info, Metadata headers) {
+ return fakeTracer;
+ }
+ }));
+
+ @Test
+ public void hasToken() {
+ TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
+ ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+ @Override
+ public Attributes getTransportAttrs() {
+ Attributes eagAttrs = Attributes.newBuilder()
+ .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build();
+ return Attributes.newBuilder()
+ .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build();
+ }
+
+ @Override
+ public CallOptions getCallOptions() {
+ return CallOptions.DEFAULT;
+ }
+ };
+ Metadata headers = new Metadata();
+ // Preexisting token should be replaced
+ headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
+
+ ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+ verify(delegate).newClientStreamTracer(same(info), same(headers));
+ assertThat(tracer).isSameAs(fakeTracer);
+ assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001");
+ }
+
+ @Test
+ public void noToken() {
+ TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
+ ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+ @Override
+ public Attributes getTransportAttrs() {
+ return Attributes.newBuilder()
+ .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
+ }
+
+ @Override
+ public CallOptions getCallOptions() {
+ return CallOptions.DEFAULT;
+ }
+ };
+
+ Metadata headers = new Metadata();
+ // Preexisting token should be removed
+ headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
+
+ ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+ verify(delegate).newClientStreamTracer(same(info), same(headers));
+ assertThat(tracer).isSameAs(fakeTracer);
+ assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
+ }
+
+ @Test
+ public void nullDelegate() {
+ TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null);
+ ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+ @Override
+ public Attributes getTransportAttrs() {
+ return Attributes.newBuilder()
+ .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
+ }
+
+ @Override
+ public CallOptions getCallOptions() {
+ return CallOptions.DEFAULT;
+ }
+ };
+ Metadata headers = new Metadata();
+
+ ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+ assertThat(tracer).isNotNull();
+ assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
+ }
+}