blob: 68e4ec377d622707ea59d193b967524f730d40bc [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Based on https://github.com/daijifeng001/caffe-rfcn/blob/r-fcn/src/caffe/layers/psroi_pooling_layer.cu
//
// ------------------------------------------------------------------
// R-FCN
// Copyright (c) 2016 Microsoft
// Licensed under The MIT License [see r-fcn/LICENSE for details]
// Written by Yi Li
// ------------------------------------------------------------------
//
// COPYRIGHT
//
// All contributions by the University of California:
// Copyright (c) 2014, 2015, The Regents of the University of California
// (Regents)
// All rights reserved.
//
// All other contributions:
// Copyright (c) 2014, 2015, the respective contributors
// All rights reserved.
//
// Caffe uses a shared copyright model: each contributor holds copyright over
// their contributions to Caffe. The project versioning records all such
// contribution and copyright details. If a contributor wants to further mark
// their specific copyright on a particular contribution, they should indicate
// their copyright solely in the commit message of the change when it is
// committed.
//
// LICENSE
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// CONTRIBUTION AGREEMENT
//
// By contributing to the BVLC/caffe repository through pull-request, comment,
// or otherwise, the contributor releases their content to the
// license and copyright terms herein.
#include <cfloat>
#include "caffe2/core/context_gpu.h"
#include "modules/detectron/ps_roi_pool_op.h"
namespace caffe2 {
namespace {
template <typename T>
inline __device__ T gpu_atomic_add(const T val, T* address);
template <>
inline __device__
float gpu_atomic_add(const float val, float* address) {
return atomicAdd(address, val);
}
template <typename T>
__global__ void PSRoIPoolForward(
const int nthreads,
const T* bottom_data,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* bottom_rois,
const int output_dim,
const int group_size,
T* top_data,
int* mapping_channel) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast<T>(
roundf(offset_bottom_rois[1])) * spatial_scale;
T roi_start_h = static_cast<T>(
roundf(offset_bottom_rois[2])) * spatial_scale;
T roi_end_w = static_cast<T>(
roundf(offset_bottom_rois[3]) + 1.) * spatial_scale;
T roi_end_h = static_cast<T>(
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); // avoid 0
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
// Add roi offsets and clip to input boundaries
int hstart = floor(
static_cast<T>(ph) * bin_size_h + roi_start_h);
int wstart = floor(
static_cast<T>(pw)* bin_size_w + roi_start_w);
int hend = ceil(
static_cast<T>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(
static_cast<T>(pw + 1) * bin_size_w + roi_start_w);
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0),width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int gw = pw;
int gh = ph;
int c = (ctop * group_size + gh) * group_size + gw;
const T* offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
T out_sum = 0;
for (int h = hstart; h < hend; ++h){
for (int w = wstart; w < wend; ++w){
int bottom_index = h*width + w;
out_sum += offset_bottom_data[bottom_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
top_data[index] = is_empty ? 0. : out_sum / bin_area;
mapping_channel[index] = c;
}
}
template <typename T>
__global__ void PSRoIPoolBackward(
const int nthreads,
const T* top_diff,
const int* mapping_channel,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int output_dim,
T* bottom_diff,
const T* bottom_rois) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast<T>(
roundf(offset_bottom_rois[1])) * spatial_scale;
T roi_start_h = static_cast<T>(
roundf(offset_bottom_rois[2])) * spatial_scale;
T roi_end_w = static_cast<T>(
roundf(offset_bottom_rois[3]) + 1.) * spatial_scale;
T roi_end_h = static_cast<T>(
roundf(offset_bottom_rois[4]) + 1.) * spatial_scale;
// Force too small ROIs to be 1x1
T roi_width = c10::cuda::compat::max(roi_end_w - roi_start_w, static_cast<T>(0.1)); //avoid 0
T roi_height = c10::cuda::compat::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int hstart = floor(
static_cast<T>(ph)* bin_size_h + roi_start_h);
int wstart = floor(
static_cast<T>(pw)* bin_size_w + roi_start_w);
int hend = ceil(
static_cast<T>(ph + 1) * bin_size_h + roi_start_h);
int wend = ceil(
static_cast<T>(pw + 1) * bin_size_w + roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Compute c at bottom
int c = mapping_channel[index];
T* offset_bottom_diff =
bottom_diff + (roi_batch_ind * channels + c) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? 0. : top_diff[index] / bin_area;
for (int h = hstart; h < hend; ++h){
for (int w = wstart; w < wend; ++w){
int bottom_index = h * width + w;
gpu_atomic_add(diff_val, offset_bottom_diff + bottom_index);
}
}
}
}
} // namespace
template<>
bool PSRoIPoolOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto* Y = Output(0, {R.dim32(0), output_dim_, pooled_height_, pooled_width_}, at::dtype<float>()); // PSRoI pooled data
auto* A = Output(1, Y->sizes(), at::dtype<int>()); // mapping_channel
int output_size = Y->numel();
PSRoIPoolForward<float><<<CAFFE_GET_BLOCKS(output_size),
CAFFE_CUDA_NUM_THREADS,
0, context_.cuda_stream()>>>(
output_size, X.data<float>(), spatial_scale_, X.dim32(1), X.dim32(2),
X.dim32(3), pooled_height_, pooled_width_, R.data<float>(), output_dim_,
group_size_, Y->mutable_data<float>(), A->mutable_data<int>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template<>
bool PSRoIPoolGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto& A = Input(2); // mapping channels
auto& dY = Input(3); // Gradient of net w.r.t. output of "forward" op
// (aka "gradOutput")
auto* dX = Output(0, X.sizes(), at::dtype<float>()); // Gradient of net w.r.t. input to "forward" op
// (aka "gradInput")
// Must zero-out dX before accumulating gradients
math::Set<float, CUDAContext>(
dX->size(), 0.f, dX->mutable_data<float>(), &context_);
PSRoIPoolBackward<float><<<CAFFE_GET_BLOCKS(dY.size()),
CAFFE_CUDA_NUM_THREADS,
0, context_.cuda_stream()>>>(
dY.size(), dY.data<float>(), A.data<int>(), R.dim32(0), spatial_scale_,
X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_,
output_dim_, dX->mutable_data<float>(), R.data<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
REGISTER_CUDA_OPERATOR(PSRoIPool,
PSRoIPoolOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(PSRoIPoolGradient,
PSRoIPoolGradientOp<float, CUDAContext>);
} // namespace caffe2