blob: f972a8f6a34b2720485dc1b407b5d1040994d61d [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch {
namespace distributed {
namespace rpc {
using torch::distributed::autograd::DistAutogradContainer;
using torch::distributed::autograd::DistAutogradContext;
DistAutogradContainer* getDistAutogradContainer();
class TestE2EBase : public ::testing::Test {
protected:
void SetUp() override {
// Setup distributed autograd.
autogradContainer = getDistAutogradContainer();
// Setup server store.
c10d::TCPStoreOptions opts{
/* port */ 0,
/* isServer */ true,
numWorkers,
/* waitWorkers */ true,
/* timeout */ std::chrono::seconds(10)};
store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts);
buildRpcAgent();
rpcAgentPostProcessing();
}
void rpcAgentPostProcessing() {
RpcAgent::setCurrentRpcAgent(rpcAgent);
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
// For Dict that is used for device map.
auto pos = qn.name().find("Dict");
if (pos != std::string::npos) {
return c10::StrongTypePtr(
nullptr,
c10::DictType::create(
c10::StringType::get(), c10::StringType::get()));
}
return c10::StrongTypePtr(
nullptr, c10::TensorType::create(at::Tensor()));
});
rpcAgent->setTypeResolver(typeResolver);
rpcAgent->start();
}
void TearDown() override {
rpcAgent->join();
rpcAgent->shutdown();
RpcAgent::setCurrentRpcAgent(nullptr);
}
c10::intrusive_ptr<OwnerRRef> createRemoteRRef(
at::Tensor t1,
at::Tensor t2,
std::shared_ptr<torch::jit::Operator> op) {
auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1));
// prevent this owner RRef being deleted due to other forks
ctx.addSelfAsFork(ownerRRef);
ScriptRemoteCall scriptRemoteCall(
op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId());
auto jitFuture = autograd::sendMessageWithAutograd(
*rpcAgent,
rpcAgent->getWorkerInfo("worker"),
std::move(scriptRemoteCall).toMessage(),
false);
ownerRRef->registerOwnerCreationFuture(jitFuture);
// Builtin operators does not return py::object, and hence does not require
// GIL for destructing the potentially deleted OwerRRef.
jitFuture->addCallback(
[ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) {
callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId);
});
return ownerRRef;
}
at::Tensor remoteAdd(
at::Tensor t1,
at::Tensor t2,
std::shared_ptr<torch::jit::Operator> op) {
ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1});
// Send the RPC and return result.
auto response = autograd::sendMessageWithAutograd(
*rpcAgent,
rpcAgent->getWorkerInfo("worker"),
std::move(scriptCall).toMessage());
response->waitAndThrow();
MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP;
auto wrappedResponse = deserializeResponse(
std::move(*response->value().toCustomClass<Message>()), messageType);
return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor();
}
virtual void buildRpcAgent() = 0;
class AutogradContextGuard {
public:
explicit AutogradContextGuard()
: context(DistAutogradContainer::getInstance().newContext()) {}
~AutogradContextGuard() {
DistAutogradContainer::getInstance().releaseContext(context->contextId());
}
private:
std::shared_ptr<DistAutogradContext> context;
};
void runTrainingLoop() {
auto options = at::TensorOptions().requires_grad(true);
auto t1 = torch::ones({3, 3}, options);
auto t2 = torch::ones({3, 3}, options);
c10::OperatorName full_name("aten::add", "Tensor");
auto matchedOp = torch::jit::findOperatorFor(full_name);
ASSERT_TRUE(matchedOp);
for (size_t i = 0; i < numIters; i++) {
// Create the autograd context guard.
AutogradContextGuard guard;
// Multiple RPCs within one autograd context for the forward pass.
auto result = remoteAdd(t1, t2, matchedOp);
for (size_t j = 0; j < 5; j++) {
result = remoteAdd(t1, result, matchedOp);
}
auto rref = createRemoteRRef(t1, result, matchedOp);
result = rref->getValue().toTensor();
// Run backward pass now.
autograd::DistEngine::getInstance().execute(
DistAutogradContainer::currentContextId(),
{torch::sum(result)},
/* retainGraph */ false);
}
}
DistAutogradContainer* autogradContainer;
std::shared_ptr<RpcAgent> rpcAgent;
static const size_t numIters;
static const size_t numWorkers;
c10::intrusive_ptr<c10d::Store> store;
static const char* serverAddress;
};
} // namespace rpc
} // namespace distributed
} // namespace torch