grpclb: support "pick_first" child policy (#5438)

The PICK_FIRST mode puts all backend addresses in a single Subchannel. There are a few points where it's different from the default ROUND_ROBIN mode:

1. PICK_FIRST doesn't eagerly connect to backends like ROUND_ROBIN does. Instead, it requests for connections when the Subchannel is picked.

2. PICK_FIRST adds tokens to the headers via a different code path (`TokenAttachingTracerFactory`) than ROUND_ROBIN

3. For simple implementation, when the mode is changed by service config when the LoadBalancer is working, we will shut down `GrpclbState` and starts a new one with the new mode. All connections will be closed during the transition. We don't expect this to happen in practice given the specific use case of PICK_FIRST.
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
index 87e2b6b..65f4832 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java
@@ -16,6 +16,8 @@
 
 package io.grpc.grpclb;
 
+import io.grpc.Attributes;
+import io.grpc.EquivalentAddressGroup;
 import io.grpc.ExperimentalApi;
 import io.grpc.Metadata;
 
@@ -32,5 +34,12 @@
   public static final Metadata.Key<String> TOKEN_METADATA_KEY =
       Metadata.Key.of("lb-token", Metadata.ASCII_STRING_MARSHALLER);
 
+  /**
+   * For passing LB tokens via the EAG attributes.
+   */
+  @EquivalentAddressGroup.Attr
+  static final Attributes.Key<String> TOKEN_ATTRIBUTE_KEY =
+      Attributes.Key.create("lb-token");
+
   private GrpclbConstants() { }
 }
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
index 09ed32e..b0719f0 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java
@@ -17,6 +17,7 @@
 package io.grpc.grpclb;
 
 import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
 import io.grpc.Attributes;
@@ -51,7 +52,11 @@
   private static final Logger logger = Logger.getLogger(GrpclbLoadBalancer.class.getName());
 
   private final Helper helper;
+  private final TimeProvider time;
   private final SubchannelPool subchannelPool;
+  private final BackoffPolicy.Provider backoffPolicyProvider;
+
+  private Mode mode = Mode.ROUND_ROBIN;
 
   // All mutable states in this class are mutated ONLY from Channel Executor
   @Nullable
@@ -63,12 +68,12 @@
       TimeProvider time,
       BackoffPolicy.Provider backoffPolicyProvider) {
     this.helper = checkNotNull(helper, "helper");
-    checkNotNull(time, "time provider");
-    checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
+    this.time = checkNotNull(time, "time provider");
+    this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
     this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
     this.subchannelPool.init(helper);
-    grpclbState =
-        new GrpclbState(helper, subchannelPool, time, backoffPolicyProvider);
+    recreateStates();
+    checkNotNull(grpclbState, "grpclbState");
   }
 
   @Override
@@ -97,7 +102,12 @@
     newBackendServers = Collections.unmodifiableList(newBackendServers);
     Map<String, Object> rawLbConfigValue = attributes.get(ATTR_LOAD_BALANCING_CONFIG);
     Mode newMode = retrieveModeFromLbConfig(rawLbConfigValue, helper.getChannelLogger());
-    grpclbState.handleAddresses(newLbAddressGroups, newBackendServers, newMode);
+    if (!mode.equals(newMode)) {
+      mode = newMode;
+      helper.getChannelLogger().log(ChannelLogLevel.INFO, "Mode: " + newMode);
+      recreateStates();
+    }
+    grpclbState.handleAddresses(newLbAddressGroups, newBackendServers);
   }
 
   @VisibleForTesting
@@ -141,6 +151,12 @@
     }
   }
 
+  private void recreateStates() {
+    resetStates();
+    checkState(grpclbState == null, "Should've been cleared");
+    grpclbState = new GrpclbState(mode, helper, subchannelPool, time, backoffPolicyProvider);
+  }
+
   @Override
   public void shutdown() {
     resetStates();
diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
index e4203da..4d87f32 100644
--- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
+++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java
@@ -138,8 +138,8 @@
 
   @Nullable
   private LbStream lbStream;
-  private Map<EquivalentAddressGroup, Subchannel> subchannels = Collections.emptyMap();
-  private Mode mode;
+  private Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Collections.emptyMap();
+  private final Mode mode;
 
   // Has the same size as the round-robin list from the balancer.
   // A drop entry from the round-robin list becomes a DropEntry here.
@@ -151,10 +151,12 @@
       new RoundRobinPicker(Collections.<DropEntry>emptyList(), Arrays.asList(BUFFER_ENTRY));
 
   GrpclbState(
+      Mode mode,
       Helper helper,
       SubchannelPool subchannelPool,
       TimeProvider time,
       BackoffPolicy.Provider backoffPolicyProvider) {
+    this.mode = checkNotNull(mode, "mode");
     this.helper = checkNotNull(helper, "helper");
     this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
     this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool");
@@ -169,7 +171,7 @@
     if (newState.getState() == SHUTDOWN || !subchannels.values().contains(subchannel)) {
       return;
     }
-    if (newState.getState() == IDLE) {
+    if (mode == Mode.ROUND_ROBIN && newState.getState() == IDLE) {
       subchannel.requestConnection();
     }
     subchannel.getAttributes().get(STATE_INFO).set(newState);
@@ -182,8 +184,7 @@
    * not yet connected.
    */
   void handleAddresses(
-      List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers,
-      Mode mode) {
+      List<LbAddressGroup> newLbAddressGroups, List<EquivalentAddressGroup> newBackendServers) {
     if (newLbAddressGroups.isEmpty()) {
       propagateError(Status.UNAVAILABLE.withDescription(
               "NameResolver returned no LB address while asking for GRPCLB"));
@@ -305,10 +306,20 @@
 
   void shutdown() {
     shutdownLbComm();
-    // We close the subchannels through subchannelPool instead of helper just for convenience of
-    // testing.
-    for (Subchannel subchannel : subchannels.values()) {
-      subchannelPool.returnSubchannel(subchannel);
+    switch (mode) {
+      case ROUND_ROBIN:
+        // We close the subchannels through subchannelPool instead of helper just for convenience of
+        // testing.
+        for (Subchannel subchannel : subchannels.values()) {
+          subchannelPool.returnSubchannel(subchannel);
+        }
+        break;
+      case PICK_FIRST:
+        checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels);
+        subchannels.values().iterator().next().shutdown();
+        break;
+      default:
+        throw new AssertionError("Missing case for " + mode);
     }
     subchannels = Collections.emptyMap();
     subchannelPool.clear();
@@ -341,45 +352,74 @@
       @Nullable GrpclbClientLoadRecorder loadRecorder) {
     logger.log(
         ChannelLogLevel.INFO, "Using RR list={0}, drop={1}", newBackendAddrList, newDropList);
-    HashMap<EquivalentAddressGroup, Subchannel> newSubchannelMap =
+    HashMap<List<EquivalentAddressGroup>, Subchannel> newSubchannelMap =
         new HashMap<>();
     List<BackendEntry> newBackendList = new ArrayList<>();
 
-    for (BackendAddressGroup backendAddr : newBackendAddrList) {
-      EquivalentAddressGroup eag = backendAddr.getAddresses();
-      Subchannel subchannel = newSubchannelMap.get(eag);
-      if (subchannel == null) {
-        subchannel = subchannels.get(eag);
-        if (subchannel == null) {
-          Attributes subchannelAttrs = Attributes.newBuilder()
-              .set(STATE_INFO,
-                  new AtomicReference<>(
-                      ConnectivityStateInfo.forNonError(IDLE)))
-              .build();
-          subchannel = subchannelPool.takeOrCreateSubchannel(eag, subchannelAttrs);
-          subchannel.requestConnection();
+    switch (mode) {
+      case ROUND_ROBIN:
+        for (BackendAddressGroup backendAddr : newBackendAddrList) {
+          EquivalentAddressGroup eag = backendAddr.getAddresses();
+          List<EquivalentAddressGroup> eagAsList = Collections.singletonList(eag);
+          Subchannel subchannel = newSubchannelMap.get(eagAsList);
+          if (subchannel == null) {
+            subchannel = subchannels.get(eagAsList);
+            if (subchannel == null) {
+              subchannel = subchannelPool.takeOrCreateSubchannel(eag, createSubchannelAttrs());
+              subchannel.requestConnection();
+            }
+            newSubchannelMap.put(eagAsList, subchannel);
+          }
+          BackendEntry entry;
+          // Only picks with tokens are reported to LoadRecorder
+          if (backendAddr.getToken() == null) {
+            entry = new BackendEntry(subchannel);
+          } else {
+            entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
+          }
+          newBackendList.add(entry);
         }
-        newSubchannelMap.put(eag, subchannel);
-      }
-      BackendEntry entry;
-      // Only picks with tokens are reported to LoadRecorder
-      if (backendAddr.getToken() == null) {
-        entry = new BackendEntry(subchannel);
-      } else {
-        entry = new BackendEntry(subchannel, loadRecorder, backendAddr.getToken());
-      }
-      newBackendList.add(entry);
+        // Close Subchannels whose addresses have been delisted
+        for (Entry<List<EquivalentAddressGroup>, Subchannel> entry : subchannels.entrySet()) {
+          List<EquivalentAddressGroup> eagList = entry.getKey();
+          if (!newSubchannelMap.containsKey(eagList)) {
+            subchannelPool.returnSubchannel(entry.getValue());
+          }
+        }
+        subchannels = Collections.unmodifiableMap(newSubchannelMap);
+        break;
+      case PICK_FIRST:
+        List<EquivalentAddressGroup> eagList = new ArrayList<>();
+        // Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to
+        // attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on
+        // headers.
+        //
+        // The PICK_FIRST code path doesn't cache Subchannels.
+        for (BackendAddressGroup bag : newBackendAddrList) {
+          EquivalentAddressGroup origEag = bag.getAddresses();
+          Attributes eagAttrs = origEag.getAttributes();
+          if (bag.getToken() != null) {
+            eagAttrs = eagAttrs.toBuilder()
+                .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, bag.getToken()).build();
+          }
+          eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs));
+        }
+        Subchannel subchannel;
+        if (subchannels.isEmpty()) {
+          subchannel = helper.createSubchannel(eagList, createSubchannelAttrs());
+        } else {
+          checkState(subchannels.size() == 1, "Unexpected Subchannel count: %s", subchannels);
+          subchannel = subchannels.values().iterator().next();
+          helper.updateSubchannelAddresses(subchannel, eagList);
+        }
+        subchannels = Collections.singletonMap(eagList, subchannel);
+        newBackendList.add(
+            new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder)));
+        break;
+      default:
+        throw new AssertionError("Missing case for " + mode);
     }
 
-    // Close Subchannels whose addresses have been delisted
-    for (Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) {
-      EquivalentAddressGroup eag = entry.getKey();
-      if (!newSubchannelMap.containsKey(eag)) {
-        subchannelPool.returnSubchannel(entry.getValue());
-      }
-    }
-
-    subchannels = Collections.unmodifiableMap(newSubchannelMap);
     dropList = Collections.unmodifiableList(newDropList);
     backendList = Collections.unmodifiableList(newBackendList);
   }
@@ -619,32 +659,67 @@
    * changed since the last picker created.
    */
   private void maybeUpdatePicker() {
-    List<RoundRobinEntry> pickList = new ArrayList<>(backendList.size());
-    Status error = null;
-    boolean hasIdle = false;
-    for (BackendEntry entry : backendList) {
-      Subchannel subchannel = entry.result.getSubchannel();
-      Attributes attrs = subchannel.getAttributes();
-      ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
-      if (stateInfo.getState() == READY) {
-        pickList.add(entry);
-      } else if (stateInfo.getState() == TRANSIENT_FAILURE) {
-        error = stateInfo.getStatus();
-      } else if (stateInfo.getState() == IDLE) {
-        hasIdle = true;
-      }
-    }
+    List<RoundRobinEntry> pickList;
     ConnectivityState state;
-    if (pickList.isEmpty()) {
-      if (error != null && !hasIdle) {
-        pickList.add(new ErrorEntry(error));
-        state = TRANSIENT_FAILURE;
-      } else {
-        pickList.add(BUFFER_ENTRY);
-        state = CONNECTING;
-      }
-    } else {
-      state = READY;
+    switch (mode) {
+      case ROUND_ROBIN:
+        pickList = new ArrayList<>(backendList.size());
+        Status error = null;
+        boolean hasIdle = false;
+        for (BackendEntry entry : backendList) {
+          Subchannel subchannel = entry.subchannel;
+          Attributes attrs = subchannel.getAttributes();
+          ConnectivityStateInfo stateInfo = attrs.get(STATE_INFO).get();
+          if (stateInfo.getState() == READY) {
+            pickList.add(entry);
+          } else if (stateInfo.getState() == TRANSIENT_FAILURE) {
+            error = stateInfo.getStatus();
+          } else if (stateInfo.getState() == IDLE) {
+            hasIdle = true;
+          }
+        }
+        if (pickList.isEmpty()) {
+          if (error != null && !hasIdle) {
+            pickList.add(new ErrorEntry(error));
+            state = TRANSIENT_FAILURE;
+          } else {
+            pickList.add(BUFFER_ENTRY);
+            state = CONNECTING;
+          }
+        } else {
+          state = READY;
+        }
+        break;
+      case PICK_FIRST:
+        if (backendList.isEmpty()) {
+          pickList = Collections.singletonList(BUFFER_ENTRY);
+          // Have not received server addresses
+          state = CONNECTING;
+        } else {
+          checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList);
+          BackendEntry onlyEntry = backendList.get(0);
+          ConnectivityStateInfo stateInfo =
+              onlyEntry.subchannel.getAttributes().get(STATE_INFO).get();
+          state = stateInfo.getState();
+          switch (state) {
+            case READY:
+              pickList = Collections.<RoundRobinEntry>singletonList(onlyEntry);
+              break;
+            case TRANSIENT_FAILURE:
+              pickList =
+                  Collections.<RoundRobinEntry>singletonList(new ErrorEntry(stateInfo.getStatus()));
+              break;
+            case CONNECTING:
+              pickList = Collections.singletonList(BUFFER_ENTRY);
+              break;
+            default:
+              pickList = Collections.<RoundRobinEntry>singletonList(
+                  new IdleSubchannelEntry(onlyEntry.subchannel));
+          }
+        }
+        break;
+      default:
+        throw new AssertionError("Missing case for " + mode);
     }
     maybeUpdatePicker(state, new RoundRobinPicker(dropList, pickList));
   }
@@ -704,6 +779,14 @@
     return new EquivalentAddressGroup(addrs, attrs);
   }
 
+  private static Attributes createSubchannelAttrs() {
+    return Attributes.newBuilder()
+        .set(STATE_INFO,
+            new AtomicReference<>(
+                ConnectivityStateInfo.forNonError(IDLE)))
+        .build();
+  }
+
   @VisibleForTesting
   static final class DropEntry {
     private final GrpclbClientLoadRecorder loadRecorder;
@@ -740,34 +823,45 @@
     }
   }
 
-  private interface RoundRobinEntry {
+  @VisibleForTesting
+  interface RoundRobinEntry {
     PickResult picked(Metadata headers);
   }
 
   @VisibleForTesting
   static final class BackendEntry implements RoundRobinEntry {
+    final Subchannel subchannel;
     @VisibleForTesting
     final PickResult result;
     @Nullable
-    private final GrpclbClientLoadRecorder loadRecorder;
-    @Nullable
     private final String token;
 
     /**
-     * Creates a BackendEntry whose usage will be reported to load recorder.
+     * For ROUND_ROBIN: creates a BackendEntry whose usage will be reported to load recorder.
      */
     BackendEntry(Subchannel subchannel, GrpclbClientLoadRecorder loadRecorder, String token) {
-      this.result = PickResult.withSubchannel(subchannel, loadRecorder);
-      this.loadRecorder = checkNotNull(loadRecorder, "loadRecorder");
+      this.subchannel = checkNotNull(subchannel, "subchannel");
+      this.result =
+          PickResult.withSubchannel(subchannel, checkNotNull(loadRecorder, "loadRecorder"));
       this.token = checkNotNull(token, "token");
     }
 
     /**
-     * Creates a BackendEntry whose usage will not be reported.
+     * For ROUND_ROBIN/PICK_FIRST: creates a BackendEntry whose usage will not be reported.
      */
     BackendEntry(Subchannel subchannel) {
+      this.subchannel = checkNotNull(subchannel, "subchannel");
       this.result = PickResult.withSubchannel(subchannel);
-      this.loadRecorder = null;
+      this.token = null;
+    }
+
+    /**
+     * For PICK_FIRST: creates a BackendEntry that includes all addresses.
+     */
+    BackendEntry(Subchannel subchannel, TokenAttachingTracerFactory tracerFactory) {
+      this.subchannel = checkNotNull(subchannel, "subchannel");
+      this.result =
+          PickResult.withSubchannel(subchannel, checkNotNull(tracerFactory, "tracerFactory"));
       this.token = null;
     }
 
@@ -783,12 +877,12 @@
     @Override
     public String toString() {
       // This is printed in logs.  Only give out useful information.
-      return "[" + result.getSubchannel().getAllAddresses().toString() + "(" + token + ")]";
+      return "[" + subchannel.getAllAddresses().toString() + "(" + token + ")]";
     }
 
     @Override
     public int hashCode() {
-      return Objects.hashCode(loadRecorder, result, token);
+      return Objects.hashCode(result, token);
     }
 
     @Override
@@ -797,8 +891,42 @@
         return false;
       }
       BackendEntry that = (BackendEntry) other;
-      return Objects.equal(result, that.result) && Objects.equal(token, that.token)
-          && Objects.equal(loadRecorder, that.loadRecorder);
+      return Objects.equal(result, that.result) && Objects.equal(token, that.token);
+    }
+  }
+
+  @VisibleForTesting
+  static final class IdleSubchannelEntry implements RoundRobinEntry {
+    private final Subchannel subchannel;
+
+    IdleSubchannelEntry(Subchannel subchannel) {
+      this.subchannel = checkNotNull(subchannel, "subchannel");
+    }
+
+    @Override
+    public PickResult picked(Metadata headers) {
+      subchannel.requestConnection();
+      return PickResult.withNoResult();
+    }
+
+    @Override
+    public String toString() {
+      // This is printed in logs.  Only give out useful information.
+      return "(idle)[" + subchannel.getAllAddresses().toString() + "]";
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(subchannel);
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      if (!(other instanceof IdleSubchannelEntry)) {
+        return false;
+      }
+      IdleSubchannelEntry that = (IdleSubchannelEntry) other;
+      return Objects.equal(subchannel, that.subchannel);
     }
   }
 
@@ -860,8 +988,7 @@
         // First round-robin on dropList. If a drop entry is selected, request will be dropped.  If
         // a non-drop entry is selected, then round-robin on pickList.  This makes sure requests are
         // dropped at the same proportion as the drop entries appear on the round-robin list from
-        // the balancer, while only READY backends (that make up pickList) are selected for the
-        // non-drop cases.
+        // the balancer, while only backends from pickList are selected for the non-drop cases.
         if (!dropList.isEmpty()) {
           DropEntry drop = dropList.get(dropIndex);
           dropIndex++;
@@ -881,5 +1008,14 @@
         return pick.picked(args.getHeaders());
       }
     }
+
+    @Override
+    public void requestConnection() {
+      for (RoundRobinEntry entry : pickList) {
+        if (entry instanceof IdleSubchannelEntry) {
+          ((IdleSubchannelEntry) entry).subchannel.requestConnection();
+        }
+      }
+    }
   }
 }
diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java
new file mode 100644
index 0000000..03b9bdf
--- /dev/null
+++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.grpclb;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.common.base.Objects;
+import io.grpc.Attributes;
+import io.grpc.ClientStreamTracer;
+import io.grpc.Metadata;
+import io.grpc.internal.GrpcAttributes;
+import javax.annotation.Nullable;
+
+/**
+ * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and
+ * attaches them to headers.  This is only used in the PICK_FIRST mode.
+ */
+final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory {
+  private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {};
+
+  @Nullable
+  private final ClientStreamTracer.Factory delegate;
+
+  TokenAttachingTracerFactory(@Nullable ClientStreamTracer.Factory delegate) {
+    this.delegate = delegate;
+  }
+
+  @Override
+  public ClientStreamTracer newClientStreamTracer(
+      ClientStreamTracer.StreamInfo info, Metadata headers) {
+    Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs");
+    Attributes eagAttrs =
+        checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs");
+    String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY);
+    headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY);
+    if (token != null) {
+      headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token);
+    }
+    if (delegate != null) {
+      return delegate.newClientStreamTracer(info, headers);
+    } else {
+      return NOOP_TRACER;
+    }
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(delegate);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (!(other instanceof TokenAttachingTracerFactory)) {
+      return false;
+    }
+    return Objects.equal(delegate, ((TokenAttachingTracerFactory) other).delegate);
+  }
+}
diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
index e4bb5f3..ee6e185 100644
--- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
+++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java
@@ -56,6 +56,7 @@
 import io.grpc.ConnectivityState;
 import io.grpc.ConnectivityStateInfo;
 import io.grpc.EquivalentAddressGroup;
+import io.grpc.LoadBalancer;
 import io.grpc.LoadBalancer.Helper;
 import io.grpc.LoadBalancer.PickResult;
 import io.grpc.LoadBalancer.PickSubchannelArgs;
@@ -68,7 +69,9 @@
 import io.grpc.grpclb.GrpclbState.BackendEntry;
 import io.grpc.grpclb.GrpclbState.DropEntry;
 import io.grpc.grpclb.GrpclbState.ErrorEntry;
+import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry;
 import io.grpc.grpclb.GrpclbState.Mode;
+import io.grpc.grpclb.GrpclbState.RoundRobinEntry;
 import io.grpc.grpclb.GrpclbState.RoundRobinPicker;
 import io.grpc.inprocess.InProcessChannelBuilder;
 import io.grpc.inprocess.InProcessServerBuilder;
@@ -166,7 +169,8 @@
       new LinkedList<>();
   private final LinkedList<Subchannel> mockSubchannels = new LinkedList<>();
   private final LinkedList<ManagedChannel> fakeOobChannels = new LinkedList<>();
-  private final ArrayList<Subchannel> subchannelTracker = new ArrayList<>();
+  private final ArrayList<Subchannel> pooledSubchannelTracker = new ArrayList<>();
+  private final ArrayList<Subchannel> unpooledSubchannelTracker = new ArrayList<>();
   private final ArrayList<ManagedChannel> oobChannelTracker = new ArrayList<>();
   private final ArrayList<String> failingLbAuthorities = new ArrayList<>();
   private final SynchronizationContext syncContext = new SynchronizationContext(
@@ -251,11 +255,25 @@
           when(subchannel.getAllAddresses()).thenReturn(Arrays.asList(eag));
           when(subchannel.getAttributes()).thenReturn(attrs);
           mockSubchannels.add(subchannel);
-          subchannelTracker.add(subchannel);
+          pooledSubchannelTracker.add(subchannel);
           return subchannel;
         }
       }).when(subchannelPool).takeOrCreateSubchannel(
           any(EquivalentAddressGroup.class), any(Attributes.class));
+    doAnswer(new Answer<Subchannel>() {
+        @Override
+        public Subchannel answer(InvocationOnMock invocation) throws Throwable {
+          Subchannel subchannel = mock(Subchannel.class);
+          List<EquivalentAddressGroup> eagList =
+              (List<EquivalentAddressGroup>) invocation.getArguments()[0];
+          Attributes attrs = (Attributes) invocation.getArguments()[1];
+          when(subchannel.getAllAddresses()).thenReturn(eagList);
+          when(subchannel.getAttributes()).thenReturn(attrs);
+          mockSubchannels.add(subchannel);
+          unpooledSubchannelTracker.add(subchannel);
+          return subchannel;
+        }
+      }).when(helper).createSubchannel(any(List.class), any(Attributes.class));
     when(helper.getSynchronizationContext()).thenReturn(syncContext);
     when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
     when(helper.getChannelLogger()).thenReturn(channelLogger);
@@ -294,14 +312,15 @@
         assertTrue(channel + " is terminated", channel.isTerminated());
       }
       // GRPCLB manages subchannels only through subchannelPool
-      for (Subchannel subchannel: subchannelTracker) {
+      for (Subchannel subchannel : pooledSubchannelTracker) {
         verify(subchannelPool).returnSubchannel(same(subchannel));
         // Our mock subchannelPool never calls Subchannel.shutdown(), thus we can tell if
         // LoadBalancer has called it expectedly.
         verify(subchannel, never()).shutdown();
       }
-      verify(helper, never())
-          .createSubchannel(any(List.class), any(Attributes.class));
+      for (Subchannel subchannel : unpooledSubchannelTracker) {
+        verify(subchannel).shutdown();
+      }
       // No timer should linger after shutdown
       assertThat(fakeClock.getPendingTasks()).isEmpty();
     } finally {
@@ -407,6 +426,65 @@
   }
 
   @Test
+  public void roundRobinPickerWithIdleEntry_noDrop() {
+    Subchannel subchannel = mock(Subchannel.class);
+    IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
+
+    RoundRobinPicker picker =
+        new RoundRobinPicker(Collections.<DropEntry>emptyList(), Collections.singletonList(entry));
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+    verify(subchannel, never()).requestConnection();
+    assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+    verify(subchannel).requestConnection();
+  }
+
+  @Test
+  public void roundRobinPickerWithIdleEntry_andDrop() {
+    GrpclbClientLoadRecorder loadRecorder =
+        new GrpclbClientLoadRecorder(fakeClock.getTimeProvider());
+    // 1 out of 2 requests are to be dropped
+    DropEntry d = new DropEntry(loadRecorder, "LBTOKEN0003");
+    List<DropEntry> dropList = Arrays.asList(null, d);
+
+    Subchannel subchannel = mock(Subchannel.class);
+    IdleSubchannelEntry entry = new IdleSubchannelEntry(subchannel);
+
+    RoundRobinPicker picker = new RoundRobinPicker(dropList, Collections.singletonList(entry));
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+    verify(subchannel, never()).requestConnection();
+    assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+    verify(subchannel).requestConnection();
+
+    assertThat(picker.pickSubchannel(args)).isSameAs(DROP_PICK_RESULT);
+
+    verify(subchannel).requestConnection();
+    assertThat(picker.pickSubchannel(args)).isSameAs(PickResult.withNoResult());
+    verify(subchannel, times(2)).requestConnection();
+  }
+
+  @Test
+  public void roundRobinPicker_requestConnection() {
+    // requestConnection() on RoundRobinPicker is only passed to IdleSubchannelEntry
+
+    Subchannel subchannel1 = mock(Subchannel.class);
+    Subchannel subchannel2 = mock(Subchannel.class);
+
+    RoundRobinPicker picker = new RoundRobinPicker(
+        Collections.<DropEntry>emptyList(),
+        Arrays.<RoundRobinEntry>asList(
+            new BackendEntry(subchannel1), new IdleSubchannelEntry(subchannel2),
+            new ErrorEntry(Status.UNAVAILABLE)));
+
+    verify(subchannel2, never()).requestConnection();
+
+    picker.requestConnection();
+    verify(subchannel2).requestConnection();
+    verify(subchannel1, never()).requestConnection();
+  }
+
+  @Test
   public void loadReporting() {
     Metadata headers = new Metadata();
     PickSubchannelArgs args = mock(PickSubchannelArgs.class);
@@ -1591,6 +1669,297 @@
     verify(helper, times(4)).refreshNameResolution();
   }
 
+  @SuppressWarnings("unchecked")
+  @Test
+  public void grpclbWorking_pickFirstMode() throws Exception {
+    InOrder inOrder = inOrder(helper);
+
+    String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+    List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
+    Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+        LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+
+    deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+    assertEquals(1, fakeOobChannels.size());
+    ManagedChannel oobChannel = fakeOobChannels.poll();
+    verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+    StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+    assertEquals(1, lbRequestObservers.size());
+    StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
+    verify(lbRequestObserver).onNext(
+        eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+                InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+            .build()));
+
+    // Simulate receiving LB response
+    List<ServerEntry> backends1 = Arrays.asList(
+        new ServerEntry("127.0.0.1", 2000, "token0001"),
+        new ServerEntry("127.0.0.1", 2010, "token0002"));
+    inOrder.verify(helper, never())
+        .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+    lbResponseObserver.onNext(buildInitialResponse());
+    lbResponseObserver.onNext(buildLbResponse(backends1));
+
+    inOrder.verify(helper).createSubchannel(
+        eq(Arrays.asList(
+                new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+                new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
+        any(Attributes.class));
+
+    // Initially IDLE
+    inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+    RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
+
+    // Only one subchannel is created
+    assertThat(mockSubchannels).hasSize(1);
+    Subchannel subchannel = mockSubchannels.poll();
+    assertThat(picker0.dropList).containsExactly(null, null);
+    assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
+
+    // PICK_FIRST doesn't eagerly connect
+    verify(subchannel, never()).requestConnection();
+
+    // CONNECTING
+    deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
+
+    inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+    RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker1.dropList).containsExactly(null, null);
+    assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY);
+
+    // TRANSIENT_FAILURE
+    Status error = Status.UNAVAILABLE.withDescription("Simulated connection error");
+    deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error));
+    inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+    RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker2.dropList).containsExactly(null, null);
+    assertThat(picker2.pickList).containsExactly(new ErrorEntry(error));
+
+    // READY
+    deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+    inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+    RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker3.dropList).containsExactly(null, null);
+    assertThat(picker3.pickList).containsExactly(
+        new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+
+    // New server list with drops
+    List<ServerEntry> backends2 = Arrays.asList(
+        new ServerEntry("127.0.0.1", 2000, "token0001"),
+        new ServerEntry("token0003"),  // drop
+        new ServerEntry("127.0.0.1", 2020, "token0004"));
+    inOrder.verify(helper, never())
+        .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+    lbResponseObserver.onNext(buildLbResponse(backends2));
+
+    // new addresses will be updated to the existing subchannel
+    // createSubchannel() has ever been called only once
+    verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
+    assertThat(mockSubchannels).isEmpty();
+    inOrder.verify(helper).updateSubchannelAddresses(
+        same(subchannel),
+        eq(Arrays.asList(
+                new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")),
+                new EquivalentAddressGroup(backends2.get(2).addr,
+                    eagAttrsWithToken("token0004")))));
+    inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+    RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker4.dropList).containsExactly(
+        null, new DropEntry(getLoadRecorder(), "token0003"), null);
+    assertThat(picker4.pickList).containsExactly(
+        new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+    // Subchannel goes IDLE, but PICK_FIRST will not try to reconnect
+    deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
+    inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+    RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue();
+    verify(subchannel, never()).requestConnection();
+
+    // ... until it's selected
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+    PickResult pick = picker5.pickSubchannel(args);
+    assertThat(pick).isSameAs(PickResult.withNoResult());
+    verify(subchannel).requestConnection();
+
+    // ... or requested by application
+    picker5.requestConnection();
+    verify(subchannel, times(2)).requestConnection();
+
+    // PICK_FIRST doesn't use subchannelPool
+    verify(subchannelPool, never())
+        .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
+    verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void pickFirstMode_fallback() throws Exception {
+    InOrder inOrder = inOrder(helper);
+
+    String lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+
+    // Name resolver returns a mix of balancer and backend addresses
+    List<EquivalentAddressGroup> grpclbResolutionList =
+        createResolvedServerAddresses(false, true, false);
+    Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+        LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+    deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+    // Attempted to connect to balancer
+    assertEquals(1, fakeOobChannels.size());
+    ManagedChannel oobChannel = fakeOobChannels.poll();
+    verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+    StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+    assertEquals(1, lbRequestObservers.size());
+
+    // Fallback timer expires with no response
+    fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
+
+    // Entering fallback mode
+    inOrder.verify(helper).createSubchannel(
+        eq(Arrays.asList(grpclbResolutionList.get(0), grpclbResolutionList.get(2))),
+        any(Attributes.class));
+
+    assertThat(mockSubchannels).hasSize(1);
+    Subchannel subchannel = mockSubchannels.poll();
+
+    // Initially IDLE
+    inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+    RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue();
+
+    // READY
+    deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+    inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+    RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker1.dropList).containsExactly(null, null);
+    assertThat(picker1.pickList).containsExactly(
+        new BackendEntry(subchannel, new TokenAttachingTracerFactory(null)));
+
+    assertThat(picker0.dropList).containsExactly(null, null);
+    assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel));
+
+
+    // Finally, an LB response, which brings us out of fallback
+    List<ServerEntry> backends1 = Arrays.asList(
+        new ServerEntry("127.0.0.1", 2000, "token0001"),
+        new ServerEntry("127.0.0.1", 2010, "token0002"));
+    inOrder.verify(helper, never())
+        .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+    lbResponseObserver.onNext(buildInitialResponse());
+    lbResponseObserver.onNext(buildLbResponse(backends1));
+
+    // new addresses will be updated to the existing subchannel
+    // createSubchannel() has ever been called only once
+    verify(helper, times(1)).createSubchannel(any(List.class), any(Attributes.class));
+    assertThat(mockSubchannels).isEmpty();
+    inOrder.verify(helper).updateSubchannelAddresses(
+        same(subchannel),
+        eq(Arrays.asList(
+                new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+                new EquivalentAddressGroup(backends1.get(1).addr,
+                    eagAttrsWithToken("token0002")))));
+    inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+    RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue();
+    assertThat(picker2.dropList).containsExactly(null, null);
+    assertThat(picker2.pickList).containsExactly(
+        new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder())));
+
+    // PICK_FIRST doesn't use subchannelPool
+    verify(subchannelPool, never())
+        .takeOrCreateSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
+    verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+  }
+
+  @Test
+  public void switchMode() throws Exception {
+    InOrder inOrder = inOrder(helper);
+
+    String lbConfig = "{\"childPolicy\" : [ {\"round_robin\" : {}} ]}";
+    List<EquivalentAddressGroup> grpclbResolutionList = createResolvedServerAddresses(true);
+    Attributes grpclbResolutionAttrs = Attributes.newBuilder().set(
+        LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+
+    deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+    assertEquals(1, fakeOobChannels.size());
+    ManagedChannel oobChannel = fakeOobChannels.poll();
+    verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture());
+    StreamObserver<LoadBalanceResponse> lbResponseObserver = lbResponseObserverCaptor.getValue();
+    assertEquals(1, lbRequestObservers.size());
+    StreamObserver<LoadBalanceRequest> lbRequestObserver = lbRequestObservers.poll();
+    verify(lbRequestObserver).onNext(
+        eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+                InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+            .build()));
+
+    // Simulate receiving LB response
+    List<ServerEntry> backends1 = Arrays.asList(
+        new ServerEntry("127.0.0.1", 2000, "token0001"),
+        new ServerEntry("127.0.0.1", 2010, "token0002"));
+    inOrder.verify(helper, never())
+        .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+    lbResponseObserver.onNext(buildInitialResponse());
+    lbResponseObserver.onNext(buildLbResponse(backends1));
+
+    // ROUND_ROBIN: create one subchannel per server
+    verify(subchannelPool).takeOrCreateSubchannel(
+        eq(new EquivalentAddressGroup(backends1.get(0).addr, LB_BACKEND_ATTRS)),
+        any(Attributes.class));
+    verify(subchannelPool).takeOrCreateSubchannel(
+        eq(new EquivalentAddressGroup(backends1.get(1).addr, LB_BACKEND_ATTRS)),
+        any(Attributes.class));
+    inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+    assertEquals(2, mockSubchannels.size());
+    Subchannel subchannel1 = mockSubchannels.poll();
+    Subchannel subchannel2 = mockSubchannels.poll();
+    verify(subchannelPool, never()).returnSubchannel(any(Subchannel.class));
+
+    // Switch to PICK_FIRST
+    lbConfig = "{\"childPolicy\" : [ {\"pick_first\" : {}} ]}";
+    grpclbResolutionAttrs = Attributes.newBuilder().set(
+        LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, parseJsonObject(lbConfig)).build();
+    deliverResolvedAddresses(grpclbResolutionList, grpclbResolutionAttrs);
+
+
+    // GrpclbState will be shutdown, and a new one will be created
+    assertThat(oobChannel.isShutdown()).isTrue();
+    verify(subchannelPool).returnSubchannel(same(subchannel1));
+    verify(subchannelPool).returnSubchannel(same(subchannel2));
+
+    // A new LB stream is created
+    assertEquals(1, fakeOobChannels.size());
+    oobChannel = fakeOobChannels.poll();
+    verify(mockLbService, times(2)).balanceLoad(lbResponseObserverCaptor.capture());
+    lbResponseObserver = lbResponseObserverCaptor.getValue();
+    assertEquals(1, lbRequestObservers.size());
+    lbRequestObserver = lbRequestObservers.poll();
+    verify(lbRequestObserver).onNext(
+        eq(LoadBalanceRequest.newBuilder().setInitialRequest(
+                InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build())
+            .build()));
+
+    // Simulate receiving LB response
+    inOrder.verify(helper, never())
+        .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class));
+    lbResponseObserver.onNext(buildInitialResponse());
+    lbResponseObserver.onNext(buildLbResponse(backends1));
+
+    // PICK_FIRST Subchannel
+    inOrder.verify(helper).createSubchannel(
+        eq(Arrays.asList(
+                new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")),
+                new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002")))),
+        any(Attributes.class));
+    
+    inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+  }
+
+  private static Attributes eagAttrsWithToken(String token) {
+    return LB_BACKEND_ATTRS.toBuilder().set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, token).build();
+  }
+
   @Test
   public void retrieveModeFromLbConfig_pickFirst() throws Exception {
     String lbConfig = "{\"childPolicy\" : [{\"pick_first\" : {}}, {\"round_robin\" : {}}]}";
diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java
new file mode 100644
index 0000000..469372b
--- /dev/null
+++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java
@@ -0,0 +1,124 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.grpclb;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+import io.grpc.Attributes;
+import io.grpc.CallOptions;
+import io.grpc.ClientStreamTracer;
+import io.grpc.Metadata;
+import io.grpc.internal.GrpcAttributes;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link TokenAttachingTracerFactory}. */
+@RunWith(JUnit4.class)
+public class TokenAttachingTracerFactoryTest {
+  private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {};
+
+  private final ClientStreamTracer.Factory delegate = mock(
+      ClientStreamTracer.Factory.class,
+      delegatesTo(
+          new ClientStreamTracer.Factory() {
+            @Override
+            public ClientStreamTracer newClientStreamTracer(
+                ClientStreamTracer.StreamInfo info, Metadata headers) {
+              return fakeTracer;
+            }
+          }));
+
+  @Test
+  public void hasToken() {
+    TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
+    ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+        @Override
+        public Attributes getTransportAttrs() {
+          Attributes eagAttrs = Attributes.newBuilder()
+              .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build();
+          return Attributes.newBuilder()
+              .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build();
+        }
+
+        @Override
+        public CallOptions getCallOptions() {
+          return CallOptions.DEFAULT;
+        }
+      };
+    Metadata headers = new Metadata();
+    // Preexisting token should be replaced
+    headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
+
+    ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+    verify(delegate).newClientStreamTracer(same(info), same(headers));
+    assertThat(tracer).isSameAs(fakeTracer);
+    assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001");
+  }
+
+  @Test
+  public void noToken() {
+    TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate);
+    ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+        @Override
+        public Attributes getTransportAttrs() {
+          return Attributes.newBuilder()
+              .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
+        }
+
+        @Override
+        public CallOptions getCallOptions() {
+          return CallOptions.DEFAULT;
+        }
+      };
+
+    Metadata headers = new Metadata();
+    // Preexisting token should be removed
+    headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token");
+
+    ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+    verify(delegate).newClientStreamTracer(same(info), same(headers));
+    assertThat(tracer).isSameAs(fakeTracer);
+    assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
+  }
+
+  @Test
+  public void nullDelegate() {
+    TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null);
+    ClientStreamTracer.StreamInfo info = new ClientStreamTracer.StreamInfo() {
+        @Override
+        public Attributes getTransportAttrs() {
+          return Attributes.newBuilder()
+              .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build();
+        }
+
+        @Override
+        public CallOptions getCallOptions() {
+          return CallOptions.DEFAULT;
+        }
+      };
+    Metadata headers = new Metadata();
+
+    ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers);
+    assertThat(tracer).isNotNull();
+    assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull();
+  }
+}