blob: a01b5b7b65d046a64d451c623d13d23927a8985f [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <unordered_map>
#include "caffe2/core/common.h"
#include "caffe2/core/event.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
#include "caffe2/observers/operator_attaching_net_observer.h"
namespace caffe2 {
/**
* This observer displays a description of each operator executed in a network.
* This includes input and tensors (name, size, type), arguments, and execution
* time. This can be used to analyze different performance characteristics.
* NOTE: Currently this observer only supports synchronized computation
**/
class ProfileObserver;
class ProfileCounter {
public:
explicit ProfileCounter() {}
protected:
Timer timer_;
float start_time_ = 0.0f;
float run_time_ = 0.0f;
};
class TORCH_API ProfileOperatorObserver final
: public ProfileCounter,
public ObserverBase<OperatorBase> {
public:
explicit ProfileOperatorObserver(OperatorBase* subject) = delete;
explicit ProfileOperatorObserver(
OperatorBase* subject,
ProfileObserver* netObserver)
: ObserverBase<OperatorBase>(subject), netObserver_(netObserver) {
if (subject) {
net_position_ = subject->net_position();
}
}
explicit ProfileOperatorObserver(
OperatorBase* subject,
ProfileObserver* netObserver,
int net_position,
int rnn_order)
: ProfileOperatorObserver(subject, netObserver) {
net_position_ = net_position;
rnn_order_ = rnn_order;
}
std::unique_ptr<ObserverBase<OperatorBase>> rnnCopy(
OperatorBase* subject,
int rnn_order) const override;
void Dump() const;
virtual std::string getId() const {
std::stringstream ss;
ss << net_position_;
if (rnn_order_ != OperatorBase::kNoNetPositionSet) {
ss << "-" << rnn_order_;
}
return ss.str();
}
protected:
ProfileObserver* netObserver_;
int net_position_; // Needed because this is not visible in RNN Executor
int rnn_order_ = OperatorBase::kNoNetPositionSet;
private:
void Start() override;
void Stop() override;
};
class TORCH_API ProfileObserver final : public OperatorAttachingNetObserver<
ProfileOperatorObserver,
ProfileObserver> {
public:
explicit ProfileObserver(NetBase* subject)
: OperatorAttachingNetObserver<ProfileOperatorObserver, ProfileObserver>(
subject,
this) {}
void Start() override{};
void Stop() override{};
private:
vector<const ProfileOperatorObserver*> operator_observers_;
};
} // namespace caffe2