| #include "rebatching_queue.h" |
| #include "caffe2/utils/smart_tensor_printer.h" |
| |
| namespace caffe2 { |
| |
| namespace { |
| |
| // This concat function will always create a new first dimension to concat |
| void concat( |
| CPUContext& context, |
| const std::vector<std::vector<TensorCPU>>& inputs, |
| const std::vector<TensorCPU*>& outputs) { |
| CAFFE_ENFORCE(!inputs.empty()); |
| |
| const auto& inputZero = inputs[0]; |
| const auto numTensors = inputZero.size(); |
| const auto numRows = inputs.size(); |
| |
| // Precompute the output sizes to avoid resizing |
| std::vector<std::vector<int64_t>> outputDims(numTensors); |
| |
| for (size_t i = 0; i < numTensors; ++i) { |
| SmartTensorPrinter::PrintTensor(inputZero.at(i)); |
| outputDims[i] = inputZero.at(i).sizes().vec(); |
| outputDims[i].insert(outputDims[i].begin(), numRows); |
| } |
| |
| // Resize to the final output size |
| std::vector<void*> destinations(numTensors); |
| for (size_t i = 0; i < numTensors; ++i) { |
| outputs[i]->Resize(outputDims[i]); |
| destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta()); |
| } |
| |
| for (size_t i = 0; i < numRows; ++i) { |
| CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors); |
| |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| for (int j = 0; j < numTensors; ++j) { |
| const auto& input = inputs[i][j]; |
| |
| CAFFE_ENFORCE(inputZero[j].meta() == input.dtype()); |
| CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize()); |
| CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.dim()); |
| for (int k = 0; k < input.dim(); ++k) { |
| CAFFE_ENFORCE_EQ(input.sizes()[k], inputZero[j].size(k)); |
| } |
| |
| // Skip empty tensors |
| if (input.numel() == 0) { |
| continue; |
| } |
| |
| context.CopyItemsToCPU( |
| input.dtype(), |
| input.numel(), |
| input.raw_data() /* src */, |
| destinations[j] /* dst */ |
| ); |
| |
| destinations[j] = |
| (char*)destinations[j] + input.numel() * input.itemsize(); |
| } |
| } |
| } |
| |
| std::vector<std::vector<TensorCPU>> split( |
| CPUContext& context, |
| const std::vector<const TensorCPU*>& inputs) { |
| CAFFE_ENFORCE(!inputs.empty()); |
| |
| const auto outputSize = inputs[0]->sizes().at(0); |
| std::vector<std::vector<TensorCPU>> outputs(outputSize); |
| |
| for (const auto* inputPtr : inputs) { |
| CAFFE_ENFORCE(inputPtr); |
| |
| const auto& input = *inputPtr; |
| const auto innerSize = input.size_from_dim(1); |
| const auto itemSize = input.dtype().itemsize(); |
| |
| auto outputDims = input.sizes().vec(); |
| CAFFE_ENFORCE(!outputDims.empty()); |
| outputDims.erase(outputDims.begin()); |
| CAFFE_ENFORCE_EQ(input.sizes().at(0), outputSize); |
| |
| for (int i = 0; i < outputSize; ++i) { |
| outputs[i].push_back(Tensor(outputDims, CPU)); |
| context.CopyItemsToCPU( |
| input.dtype(), |
| innerSize, |
| (char*)input.raw_data() + i * innerSize * itemSize /* src */, |
| outputs[i].back().raw_mutable_data(input.dtype()) /* dst */); |
| } |
| } |
| |
| return outputs; |
| } |
| } // anonymous namespace |
| |
| RebatchingQueue::RebatchingQueue(size_t capacity, size_t numBlobs) |
| : capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {} |
| |
| RebatchingQueue::~RebatchingQueue() { |
| close(); |
| } |
| |
| bool RebatchingQueue::canRead() const { |
| return tail_ < head_; |
| } |
| |
| bool RebatchingQueue::dequeue( |
| CPUContext& context, |
| size_t numElements, |
| const std::vector<TensorCPU*>& outputs) { |
| std::vector<std::vector<TensorCPU>> results; |
| results.reserve(numElements); |
| |
| for (;;) { |
| if (results.size() == numElements) { |
| break; |
| } |
| |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| |
| cvEmpty_.wait(lock, [this] { return canRead() || isClosed_; }); |
| |
| // We only want to stop reading if the queue is empty and closed |
| if (!canRead() && isClosed_) { |
| break; |
| } |
| |
| do { |
| results.push_back(std::move(queue_[tail_++ % capacity()])); |
| } while (canRead() && results.size() < numElements); |
| } |
| |
| if (numElements == 1) { |
| cvOverflow_.notify_one(); |
| } else { |
| cvOverflow_.notify_all(); |
| } |
| } |
| |
| if (results.empty()) { |
| return false; |
| } |
| |
| concat(context, results, outputs); |
| |
| return true; |
| } |
| |
| bool RebatchingQueue::canWrite() const { |
| return tail_ + capacity() > head_; |
| } |
| |
| bool RebatchingQueue::enqueueOne( |
| CPUContext& /*context*/, |
| const std::vector<const TensorCPU*>& inputs) { |
| std::vector<std::vector<TensorCPU>> splittedInputs; |
| splittedInputs.emplace_back(); |
| auto& tensorVector = splittedInputs.back(); |
| tensorVector.reserve(inputs.size()); |
| for (const auto* tensorPtr : inputs) { |
| tensorVector.push_back(tensorPtr->Clone()); |
| } |
| |
| return enqueue(std::move(splittedInputs)); |
| } |
| |
| bool RebatchingQueue::enqueueMany( |
| CPUContext& context, |
| const std::vector<const TensorCPU*>& inputs) { |
| CAFFE_ENFORCE_EQ(numBlobs_, inputs.size()); |
| |
| std::vector<std::vector<TensorCPU>> splittedInputs; |
| splittedInputs = split(context, inputs); |
| return enqueue(std::move(splittedInputs)); |
| } |
| |
| bool RebatchingQueue::enqueue( |
| std::vector<std::vector<TensorCPU>> splittedInputs) { |
| int idx = 0; |
| for (;;) { |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| if (idx >= splittedInputs.size()) { |
| break; |
| } |
| |
| { |
| std::unique_lock<std::mutex> lock(mutex_); |
| |
| cvOverflow_.wait(lock, [this] { return canWrite() || isClosed_; }); |
| |
| if (isClosed_) { |
| // If we are here it means that we didn't apply the entire batch and if |
| // we get closed in the middle of enquing we treat it as a non-success. |
| return false; |
| } |
| |
| do { |
| queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]); |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| } while (canWrite() && idx < splittedInputs.size()); |
| } |
| |
| cvEmpty_.notify_all(); |
| } |
| |
| return true; |
| } |
| |
| size_t RebatchingQueue::capacity() const { |
| return capacity_; |
| } |
| |
| size_t RebatchingQueue::numBlobs() const { |
| return numBlobs_; |
| } |
| |
| bool RebatchingQueue::isClosed() const { |
| std::lock_guard<std::mutex> g(mutex_); |
| return isClosed_; |
| } |
| |
| void RebatchingQueue::close() { |
| { |
| std::lock_guard<std::mutex> g(mutex_); |
| isClosed_ = true; |
| } |
| |
| cvEmpty_.notify_all(); |
| cvOverflow_.notify_all(); |
| } |
| } // namespace caffe2 |