api, core, services: make ProtoReflectionService interceptor compatible (#6967)

Eliminate the hack of InternalNotifyOnBuild mechanism for letting ProtoReflectionService get access to the Sever instance, which makes ProtoReflectionService incompatible with server interceptors. This change put the Server instance into the Context and let the ProtoReflectionService RPC obtain it in its RPC Context. Also enhanced ProtoReflectionService so that one service instance can be used across multiple servers.
diff --git a/api/src/main/java/io/grpc/InternalNotifyOnServerBuild.java b/api/src/main/java/io/grpc/InternalNotifyOnServerBuild.java
deleted file mode 100644
index b52acfa..0000000
--- a/api/src/main/java/io/grpc/InternalNotifyOnServerBuild.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Copyright 2016 The gRPC Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.grpc;
-
-/**
- * Provides a callback method for a service to receive a reference to its server. The contract with
- * {@link ServerBuilder} is that this method will be called on all registered services implementing
- * the interface after build() has been called and before the {@link Server} instance is returned.
- */
-@Internal
-public interface InternalNotifyOnServerBuild {
-  /** Notifies the service that the server has been built. */
-  void notifyOnBuild(Server server);
-}
diff --git a/api/src/main/java/io/grpc/InternalServer.java b/api/src/main/java/io/grpc/InternalServer.java
new file mode 100644
index 0000000..8a28c91
--- /dev/null
+++ b/api/src/main/java/io/grpc/InternalServer.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+/**
+ * Internal accessor for getting the {@link Server} instance inside server RPC {@link Context}.
+ * This is intended for usage internal to the gRPC team. If you think you need to use
+ * this, contact the gRPC team first.
+ */
+@Internal
+public class InternalServer {
+  public static final Context.Key<Server> SERVER_CONTEXT_KEY = Server.SERVER_CONTEXT_KEY;
+
+  // Prevent instantiation.
+  private InternalServer() {
+  }
+}
diff --git a/api/src/main/java/io/grpc/Server.java b/api/src/main/java/io/grpc/Server.java
index fc98fe2..781455b 100644
--- a/api/src/main/java/io/grpc/Server.java
+++ b/api/src/main/java/io/grpc/Server.java
@@ -31,6 +31,14 @@
 public abstract class Server {
 
   /**
+   * Key for accessing the {@link Server} instance inside server RPC {@link Context}. It's
+   * unclear to us what users would need. If you think you need to use this, please file an
+   * issue for us to discuss a public API.
+   */
+  static final Context.Key<Server> SERVER_CONTEXT_KEY =
+      Context.key("io.grpc.Server");
+
+  /**
    * Bind and start the server.  After this call returns, clients may begin connecting to the
    * listening socket(s).
    *
diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
index 6928434..21a1bec 100644
--- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
@@ -29,7 +29,6 @@
 import io.grpc.DecompressorRegistry;
 import io.grpc.HandlerRegistry;
 import io.grpc.InternalChannelz;
-import io.grpc.InternalNotifyOnServerBuild;
 import io.grpc.Server;
 import io.grpc.ServerBuilder;
 import io.grpc.ServerInterceptor;
@@ -77,7 +76,6 @@
       new InternalHandlerRegistry.Builder();
   final List<ServerTransportFilter> transportFilters = new ArrayList<>();
   final List<ServerInterceptor> interceptors = new ArrayList<>();
-  private final List<InternalNotifyOnServerBuild> notifyOnBuildList = new ArrayList<>();
   private final List<ServerStreamTracer.Factory> streamTracerFactories = new ArrayList<>();
   HandlerRegistry fallbackRegistry = DEFAULT_FALLBACK_REGISTRY;
   ObjectPool<? extends Executor> executorPool = DEFAULT_EXECUTOR_POOL;
@@ -114,9 +112,6 @@
 
   @Override
   public final T addService(BindableService bindableService) {
-    if (bindableService instanceof InternalNotifyOnServerBuild) {
-      notifyOnBuildList.add((InternalNotifyOnServerBuild) bindableService);
-    }
     return addService(checkNotNull(bindableService, "bindableService").bindService());
   }
 
@@ -222,14 +217,7 @@
 
   @Override
   public final Server build() {
-    ServerImpl server = new ServerImpl(
-        this,
-        buildTransportServers(getTracerFactories()),
-        Context.ROOT);
-    for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
-      notifyTarget.notifyOnBuild(server);
-    }
-    return server;
+    return new ServerImpl(this, buildTransportServers(getTracerFactories()), Context.ROOT);
   }
 
   @VisibleForTesting
diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java
index 6e9cb9b..5db33a2 100644
--- a/core/src/main/java/io/grpc/internal/ServerImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerImpl.java
@@ -593,7 +593,10 @@
         Metadata headers, StatsTraceContext statsTraceCtx) {
       Long timeoutNanos = headers.get(TIMEOUT_KEY);
 
-      Context baseContext = statsTraceCtx.serverFilterContext(rootContext);
+      Context baseContext =
+          statsTraceCtx
+              .serverFilterContext(rootContext)
+              .withValue(io.grpc.InternalServer.SERVER_CONTEXT_KEY, ServerImpl.this);
 
       if (timeoutNanos == null) {
         return baseContext.withCancellation();
diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java
index 2fe2692..d76e714 100644
--- a/core/src/test/java/io/grpc/internal/ServerImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java
@@ -561,6 +561,7 @@
     Context callContext = callContextReference.get();
     assertNotNull(callContext);
     assertEquals("context added by tracer", SERVER_TRACER_ADDED_KEY.get(callContext));
+    assertEquals(server, io.grpc.InternalServer.SERVER_CONTEXT_KEY.get(callContext));
 
     streamListener.messagesAvailable(new SingleMessageProducer(STRING_MARSHALLER.stream(request)));
     assertEquals(1, executor.runDueTasks());
diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java
new file mode 100644
index 0000000..cae08a3
--- /dev/null
+++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java
@@ -0,0 +1,288 @@
+package io.grpc.reflection.testing;
+
+import static io.grpc.MethodDescriptor.generateFullMethodName;
+import static io.grpc.stub.ClientCalls.asyncBidiStreamingCall;
+import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
+import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
+import static io.grpc.stub.ClientCalls.asyncUnaryCall;
+import static io.grpc.stub.ClientCalls.blockingServerStreamingCall;
+import static io.grpc.stub.ClientCalls.blockingUnaryCall;
+import static io.grpc.stub.ClientCalls.futureUnaryCall;
+import static io.grpc.stub.ServerCalls.asyncBidiStreamingCall;
+import static io.grpc.stub.ServerCalls.asyncClientStreamingCall;
+import static io.grpc.stub.ServerCalls.asyncServerStreamingCall;
+import static io.grpc.stub.ServerCalls.asyncUnaryCall;
+import static io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall;
+import static io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall;
+
+/**
+ */
[email protected](
+    value = "by gRPC proto compiler",
+    comments = "Source: io/grpc/reflection/testing/reflection_test.proto")
+public final class AnotherReflectableServiceGrpc {
+
+  private AnotherReflectableServiceGrpc() {}
+
+  public static final String SERVICE_NAME = "grpc.reflection.testing.AnotherReflectableService";
+
+  // Static method descriptors that strictly reflect the proto.
+  private static volatile io.grpc.MethodDescriptor<io.grpc.reflection.testing.Request,
+      io.grpc.reflection.testing.Reply> getMethodMethod;
+
+  @io.grpc.stub.annotations.RpcMethod(
+      fullMethodName = SERVICE_NAME + '/' + "Method",
+      requestType = io.grpc.reflection.testing.Request.class,
+      responseType = io.grpc.reflection.testing.Reply.class,
+      methodType = io.grpc.MethodDescriptor.MethodType.UNARY)
+  public static io.grpc.MethodDescriptor<io.grpc.reflection.testing.Request,
+      io.grpc.reflection.testing.Reply> getMethodMethod() {
+    io.grpc.MethodDescriptor<io.grpc.reflection.testing.Request, io.grpc.reflection.testing.Reply> getMethodMethod;
+    if ((getMethodMethod = AnotherReflectableServiceGrpc.getMethodMethod) == null) {
+      synchronized (AnotherReflectableServiceGrpc.class) {
+        if ((getMethodMethod = AnotherReflectableServiceGrpc.getMethodMethod) == null) {
+          AnotherReflectableServiceGrpc.getMethodMethod = getMethodMethod =
+              io.grpc.MethodDescriptor.<io.grpc.reflection.testing.Request, io.grpc.reflection.testing.Reply>newBuilder()
+              .setType(io.grpc.MethodDescriptor.MethodType.UNARY)
+              .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Method"))
+              .setSampledToLocalTracing(true)
+              .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
+                  io.grpc.reflection.testing.Request.getDefaultInstance()))
+              .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
+                  io.grpc.reflection.testing.Reply.getDefaultInstance()))
+              .setSchemaDescriptor(new AnotherReflectableServiceMethodDescriptorSupplier("Method"))
+              .build();
+        }
+      }
+    }
+    return getMethodMethod;
+  }
+
+  /**
+   * Creates a new async stub that supports all call types for the service
+   */
+  public static AnotherReflectableServiceStub newStub(io.grpc.Channel channel) {
+    io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceStub> factory =
+      new io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceStub>() {
+        @java.lang.Override
+        public AnotherReflectableServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+          return new AnotherReflectableServiceStub(channel, callOptions);
+        }
+      };
+    return AnotherReflectableServiceStub.newStub(factory, channel);
+  }
+
+  /**
+   * Creates a new blocking-style stub that supports unary and streaming output calls on the service
+   */
+  public static AnotherReflectableServiceBlockingStub newBlockingStub(
+      io.grpc.Channel channel) {
+    io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceBlockingStub> factory =
+      new io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceBlockingStub>() {
+        @java.lang.Override
+        public AnotherReflectableServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+          return new AnotherReflectableServiceBlockingStub(channel, callOptions);
+        }
+      };
+    return AnotherReflectableServiceBlockingStub.newStub(factory, channel);
+  }
+
+  /**
+   * Creates a new ListenableFuture-style stub that supports unary calls on the service
+   */
+  public static AnotherReflectableServiceFutureStub newFutureStub(
+      io.grpc.Channel channel) {
+    io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceFutureStub> factory =
+      new io.grpc.stub.AbstractStub.StubFactory<AnotherReflectableServiceFutureStub>() {
+        @java.lang.Override
+        public AnotherReflectableServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+          return new AnotherReflectableServiceFutureStub(channel, callOptions);
+        }
+      };
+    return AnotherReflectableServiceFutureStub.newStub(factory, channel);
+  }
+
+  /**
+   */
+  public static abstract class AnotherReflectableServiceImplBase implements io.grpc.BindableService {
+
+    /**
+     */
+    public void method(io.grpc.reflection.testing.Request request,
+        io.grpc.stub.StreamObserver<io.grpc.reflection.testing.Reply> responseObserver) {
+      asyncUnimplementedUnaryCall(getMethodMethod(), responseObserver);
+    }
+
+    @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() {
+      return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor())
+          .addMethod(
+            getMethodMethod(),
+            asyncUnaryCall(
+              new MethodHandlers<
+                io.grpc.reflection.testing.Request,
+                io.grpc.reflection.testing.Reply>(
+                  this, METHODID_METHOD)))
+          .build();
+    }
+  }
+
+  /**
+   */
+  public static final class AnotherReflectableServiceStub extends io.grpc.stub.AbstractAsyncStub<AnotherReflectableServiceStub> {
+    private AnotherReflectableServiceStub(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      super(channel, callOptions);
+    }
+
+    @java.lang.Override
+    protected AnotherReflectableServiceStub build(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      return new AnotherReflectableServiceStub(channel, callOptions);
+    }
+
+    /**
+     */
+    public void method(io.grpc.reflection.testing.Request request,
+        io.grpc.stub.StreamObserver<io.grpc.reflection.testing.Reply> responseObserver) {
+      asyncUnaryCall(
+          getChannel().newCall(getMethodMethod(), getCallOptions()), request, responseObserver);
+    }
+  }
+
+  /**
+   */
+  public static final class AnotherReflectableServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub<AnotherReflectableServiceBlockingStub> {
+    private AnotherReflectableServiceBlockingStub(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      super(channel, callOptions);
+    }
+
+    @java.lang.Override
+    protected AnotherReflectableServiceBlockingStub build(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      return new AnotherReflectableServiceBlockingStub(channel, callOptions);
+    }
+
+    /**
+     */
+    public io.grpc.reflection.testing.Reply method(io.grpc.reflection.testing.Request request) {
+      return blockingUnaryCall(
+          getChannel(), getMethodMethod(), getCallOptions(), request);
+    }
+  }
+
+  /**
+   */
+  public static final class AnotherReflectableServiceFutureStub extends io.grpc.stub.AbstractFutureStub<AnotherReflectableServiceFutureStub> {
+    private AnotherReflectableServiceFutureStub(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      super(channel, callOptions);
+    }
+
+    @java.lang.Override
+    protected AnotherReflectableServiceFutureStub build(
+        io.grpc.Channel channel, io.grpc.CallOptions callOptions) {
+      return new AnotherReflectableServiceFutureStub(channel, callOptions);
+    }
+
+    /**
+     */
+    public com.google.common.util.concurrent.ListenableFuture<io.grpc.reflection.testing.Reply> method(
+        io.grpc.reflection.testing.Request request) {
+      return futureUnaryCall(
+          getChannel().newCall(getMethodMethod(), getCallOptions()), request);
+    }
+  }
+
+  private static final int METHODID_METHOD = 0;
+
+  private static final class MethodHandlers<Req, Resp> implements
+      io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
+      io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,
+      io.grpc.stub.ServerCalls.ClientStreamingMethod<Req, Resp>,
+      io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {
+    private final AnotherReflectableServiceImplBase serviceImpl;
+    private final int methodId;
+
+    MethodHandlers(AnotherReflectableServiceImplBase serviceImpl, int methodId) {
+      this.serviceImpl = serviceImpl;
+      this.methodId = methodId;
+    }
+
+    @java.lang.Override
+    @java.lang.SuppressWarnings("unchecked")
+    public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
+      switch (methodId) {
+        case METHODID_METHOD:
+          serviceImpl.method((io.grpc.reflection.testing.Request) request,
+              (io.grpc.stub.StreamObserver<io.grpc.reflection.testing.Reply>) responseObserver);
+          break;
+        default:
+          throw new AssertionError();
+      }
+    }
+
+    @java.lang.Override
+    @java.lang.SuppressWarnings("unchecked")
+    public io.grpc.stub.StreamObserver<Req> invoke(
+        io.grpc.stub.StreamObserver<Resp> responseObserver) {
+      switch (methodId) {
+        default:
+          throw new AssertionError();
+      }
+    }
+  }
+
+  private static abstract class AnotherReflectableServiceBaseDescriptorSupplier
+      implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier {
+    AnotherReflectableServiceBaseDescriptorSupplier() {}
+
+    @java.lang.Override
+    public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() {
+      return io.grpc.reflection.testing.ReflectionTestProto.getDescriptor();
+    }
+
+    @java.lang.Override
+    public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() {
+      return getFileDescriptor().findServiceByName("AnotherReflectableService");
+    }
+  }
+
+  private static final class AnotherReflectableServiceFileDescriptorSupplier
+      extends AnotherReflectableServiceBaseDescriptorSupplier {
+    AnotherReflectableServiceFileDescriptorSupplier() {}
+  }
+
+  private static final class AnotherReflectableServiceMethodDescriptorSupplier
+      extends AnotherReflectableServiceBaseDescriptorSupplier
+      implements io.grpc.protobuf.ProtoMethodDescriptorSupplier {
+    private final String methodName;
+
+    AnotherReflectableServiceMethodDescriptorSupplier(String methodName) {
+      this.methodName = methodName;
+    }
+
+    @java.lang.Override
+    public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() {
+      return getServiceDescriptor().findMethodByName(methodName);
+    }
+  }
+
+  private static volatile io.grpc.ServiceDescriptor serviceDescriptor;
+
+  public static io.grpc.ServiceDescriptor getServiceDescriptor() {
+    io.grpc.ServiceDescriptor result = serviceDescriptor;
+    if (result == null) {
+      synchronized (AnotherReflectableServiceGrpc.class) {
+        result = serviceDescriptor;
+        if (result == null) {
+          serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME)
+              .setSchemaDescriptor(new AnotherReflectableServiceFileDescriptorSupplier())
+              .addMethod(getMethodMethod())
+              .build();
+        }
+      }
+    }
+    return result;
+  }
+}
diff --git a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java
index beadb0f..fbf6bf7 100644
--- a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java
+++ b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java
@@ -26,7 +26,7 @@
 import com.google.protobuf.Descriptors.ServiceDescriptor;
 import io.grpc.BindableService;
 import io.grpc.ExperimentalApi;
-import io.grpc.InternalNotifyOnServerBuild;
+import io.grpc.InternalServer;
 import io.grpc.Server;
 import io.grpc.ServerServiceDefinition;
 import io.grpc.Status;
@@ -50,6 +50,7 @@
 import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
+import java.util.WeakHashMap;
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
 
@@ -61,41 +62,37 @@
  * extension.
  */
 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222")
-public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase
-    implements InternalNotifyOnServerBuild {
+public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase {
 
   private final Object lock = new Object();
 
   @GuardedBy("lock")
-  private ServerReflectionIndex serverReflectionIndex;
-
-  private Server server;
+  private final Map<Server, ServerReflectionIndex> serverReflectionIndexes = new WeakHashMap<>();
 
   private ProtoReflectionService() {}
 
+  /**
+   * Creates a instance of {@link ProtoReflectionService}.
+   */
   public static BindableService newInstance() {
     return new ProtoReflectionService();
   }
 
-  /** Receives a reference to the server at build time. */
-  @Override
-  public void notifyOnBuild(Server server) {
-    this.server = checkNotNull(server);
-  }
-
   /**
-   * Checks for updates to the server's mutable services and updates the index if any changes are
+   * Retrieves the index for services of the server that dispatches the current call. Computes
+   * one if not exist. The index is updated if any changes to the server's mutable services are
    * detected. A change is any addition or removal in the set of file descriptors attached to the
    * mutable services or a change in the service names.
-   *
-   * @return The (potentially updated) index.
    */
-  private ServerReflectionIndex updateIndexIfNecessary() {
+  private ServerReflectionIndex getRefreshedIndex() {
     synchronized (lock) {
-      if (serverReflectionIndex == null) {
-        serverReflectionIndex =
+      Server server = InternalServer.SERVER_CONTEXT_KEY.get();
+      ServerReflectionIndex index = serverReflectionIndexes.get(server);
+      if (index == null) {
+        index =
             new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices());
-        return serverReflectionIndex;
+        serverReflectionIndexes.put(server, index);
+        return index;
       }
 
       Set<FileDescriptor> serverFileDescriptors = new HashSet<>();
@@ -116,14 +113,15 @@
       // Replace the index if the underlying mutable services have changed. Check both the file
       // descriptors and the service names, because one file descriptor can define multiple
       // services.
-      FileDescriptorIndex mutableServicesIndex = serverReflectionIndex.getMutableServicesIndex();
+      FileDescriptorIndex mutableServicesIndex = index.getMutableServicesIndex();
       if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors)
           || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) {
-        serverReflectionIndex =
+        index =
             new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices);
+        serverReflectionIndexes.put(server, index);
       }
 
-      return serverReflectionIndex;
+      return index;
     }
   }
 
@@ -133,7 +131,7 @@
     final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver =
         (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver;
     ProtoReflectionStreamObserver requestObserver =
-        new ProtoReflectionStreamObserver(updateIndexIfNecessary(), serverCallStreamObserver);
+        new ProtoReflectionStreamObserver(getRefreshedIndex(), serverCallStreamObserver);
     serverCallStreamObserver.setOnReadyHandler(requestObserver);
     serverCallStreamObserver.disableAutoInboundFlowControl();
     serverCallStreamObserver.request(1);
diff --git a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java
index 00cc42c..1e36f54 100644
--- a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java
+++ b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java
@@ -30,6 +30,7 @@
 import io.grpc.inprocess.InProcessServerBuilder;
 import io.grpc.internal.testing.StreamRecorder;
 import io.grpc.reflection.testing.AnotherDynamicServiceGrpc;
+import io.grpc.reflection.testing.AnotherReflectableServiceGrpc;
 import io.grpc.reflection.testing.DynamicReflectionTestDepthTwoProto;
 import io.grpc.reflection.testing.DynamicServiceGrpc;
 import io.grpc.reflection.testing.ReflectableServiceGrpc;
@@ -47,14 +48,17 @@
 import io.grpc.stub.ClientCallStreamObserver;
 import io.grpc.stub.ClientResponseObserver;
 import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
 import io.grpc.util.MutableHandlerRegistry;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
-import org.junit.After;
+import java.util.concurrent.ExecutionException;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -62,6 +66,9 @@
 /** Tests for {@link ProtoReflectionService}. */
 @RunWith(JUnit4.class)
 public class ProtoReflectionServiceTest {
+  @Rule
+  public GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
+
   private static final String TEST_HOST = "localhost";
   private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry();
   private BindableService reflectionService;
@@ -69,14 +76,12 @@
       new DynamicServiceGrpc.DynamicServiceImplBase() {}.bindService();
   private ServerServiceDefinition anotherDynamicService =
       new AnotherDynamicServiceGrpc.AnotherDynamicServiceImplBase() {}.bindService();
-  private Server server;
-  private ManagedChannel channel;
   private ServerReflectionGrpc.ServerReflectionStub stub;
 
   @Before
   public void setUp() throws Exception {
     reflectionService = ProtoReflectionService.newInstance();
-    server =
+    Server server =
         InProcessServerBuilder.forName("proto-reflection-test")
             .directExecutor()
             .addService(reflectionService)
@@ -84,20 +89,13 @@
             .fallbackHandlerRegistry(handlerRegistry)
             .build()
             .start();
-    channel = InProcessChannelBuilder.forName("proto-reflection-test").directExecutor().build();
+    grpcCleanupRule.register(server);
+    ManagedChannel channel =
+        grpcCleanupRule.register(
+            InProcessChannelBuilder.forName("proto-reflection-test").directExecutor().build());
     stub = ServerReflectionGrpc.newStub(channel);
   }
 
-  @After
-  public void tearDown() {
-    if (server != null) {
-      server.shutdownNow();
-    }
-    if (channel != null) {
-      channel.shutdownNow();
-    }
-  }
-
   @Test
   public void listServices() throws Exception {
     Set<ServiceResponse> originalServices =
@@ -525,6 +523,40 @@
   }
 
   @Test
+  public void sharedServiceBetweenServers()
+      throws IOException, ExecutionException, InterruptedException {
+    Server anotherServer = InProcessServerBuilder.forName("proto-reflection-test-2")
+        .directExecutor()
+        .addService(reflectionService)
+        .addService(new AnotherReflectableServiceGrpc.AnotherReflectableServiceImplBase() {})
+        .build()
+        .start();
+    grpcCleanupRule.register(anotherServer);
+    ManagedChannel anotherChannel = grpcCleanupRule.register(
+        InProcessChannelBuilder.forName("proto-reflection-test-2").directExecutor().build());
+    ServerReflectionGrpc.ServerReflectionStub stub2 = ServerReflectionGrpc.newStub(anotherChannel);
+
+    ServerReflectionRequest request =
+        ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build();
+    StreamRecorder<ServerReflectionResponse> responseObserver = StreamRecorder.create();
+    StreamObserver<ServerReflectionRequest> requestObserver =
+        stub2.serverReflectionInfo(responseObserver);
+    requestObserver.onNext(request);
+    requestObserver.onCompleted();
+    List<ServiceResponse> response =
+        responseObserver.firstValue().get().getListServicesResponse().getServiceList();
+    assertEquals(new HashSet<>(
+        Arrays.asList(
+            ServiceResponse.newBuilder()
+                .setName("grpc.reflection.v1alpha.ServerReflection")
+                .build(),
+            ServiceResponse.newBuilder()
+                .setName("grpc.reflection.testing.AnotherReflectableService")
+                .build())),
+        new HashSet<>(response));
+  }
+
+  @Test
   public void flowControl() throws Exception {
     FlowControlClientResponseObserver clientResponseObserver =
         new FlowControlClientResponseObserver();
diff --git a/services/src/test/proto/io/grpc/reflection/testing/reflection_test.proto b/services/src/test/proto/io/grpc/reflection/testing/reflection_test.proto
index 12f3969..3d0cd02 100644
--- a/services/src/test/proto/io/grpc/reflection/testing/reflection_test.proto
+++ b/services/src/test/proto/io/grpc/reflection/testing/reflection_test.proto
@@ -32,3 +32,7 @@
 service ReflectableService {
   rpc Method (Request) returns (Reply) {}
 }
+
+service AnotherReflectableService {
+  rpc Method (Request) returns (Reply) {}
+}