Skip to content

Commit

Permalink
make ATen/native/cuda/NLLLoss2d.cu data_ptr-correct (pytorch#99179)
Browse files Browse the repository at this point in the history
make ATen/native/cuda/NLLLoss2d.cu data_ptr-correct

Test Plan: Rely on CI.

Pull Request resolved: pytorch#99179
Approved by: https://github.com/ezyang
  • Loading branch information
mikey dagitses authored and pytorchmergebot committed Apr 15, 2023
1 parent e9201ab commit 506bd05
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions aten/src/ATen/native/cuda/NLLLoss2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ inline Tensor optional_contiguous(const Tensor& source) {
// Returns the address of the first element of a tensor
// or nullptr if the tensor is undefined.
template <typename scalar_t>
inline scalar_t* optional_data(const Tensor& source) {
return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
inline const scalar_t* optional_data(const Tensor& source) {
return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
}

using at::cuda::detail::CUDA_NUM_THREADS;
Expand All @@ -51,7 +51,7 @@ __global__ void nll_loss2d_forward_no_reduce_kernel(
PackedTensorAccessor64<scalar_t, 4> input,
PackedTensorAccessor64<int64_t, 3> target,
PackedTensorAccessor64<scalar_t, 3> output,
scalar_t* weight,
const scalar_t* weight,
int64_t ignore_index
) {
int64_t batch_size = input.size(0);
Expand Down Expand Up @@ -79,9 +79,9 @@ C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_kernel(
scalar_t* output,
scalar_t* total_weight,
scalar_t* input,
int64_t* target,
scalar_t* weight,
const scalar_t* input,
const int64_t* target,
const scalar_t* weight,
int n_classes,
int map_nelem,
int blocks_per_sample,
Expand Down Expand Up @@ -125,7 +125,7 @@ template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_size_average_kernel(
scalar_t* output,
scalar_t* total_weight
const scalar_t* total_weight
) {
*output /= *total_weight;
}
Expand All @@ -137,7 +137,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel(
PackedTensorAccessor64<int64_t, 3> target,
PackedTensorAccessor64<scalar_t, 3> grad_output,
PackedTensorAccessor64<scalar_t, 4> grad_input,
scalar_t* weight,
const scalar_t* weight,
int64_t ignore_index
) {
int64_t batch_size = target.size(0);
Expand All @@ -162,10 +162,10 @@ template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_backward_kernel(
scalar_t* grad_input,
scalar_t* grad_output,
int64_t* target,
scalar_t* weights,
scalar_t* total_weight,
const scalar_t* grad_output,
const int64_t* target,
const scalar_t* weights,
const scalar_t* total_weight,
bool size_average,
int n_classes,
int map_nelem,
Expand Down Expand Up @@ -323,10 +323,10 @@ void nll_loss2d_forward_out_cuda_template(
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
total_weight.data_ptr<scalar_t>(),
input_.data_ptr<scalar_t>(),
target_.data_ptr<int64_t>(),
output.mutable_data_ptr<scalar_t>(),
total_weight.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
optional_data<scalar_t>(weight_),
input_.size(1),
input_.size(2) * input_.size(3),
Expand All @@ -337,8 +337,8 @@ void nll_loss2d_forward_out_cuda_template(
if (reduction == at::Reduction::Mean) {
nll_loss2d_forward_size_average_kernel<scalar_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
total_weight.data_ptr<scalar_t>());
output.mutable_data_ptr<scalar_t>(),
total_weight.const_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
Expand Down Expand Up @@ -441,11 +441,11 @@ void nll_loss2d_backward_out_cuda_template(
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
target_.data_ptr<int64_t>(),
grad_input.mutable_data_ptr<scalar_t>(),
grad_output.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
optional_data<scalar_t>(weight_),
total_weight.data_ptr<scalar_t>(),
total_weight.const_data_ptr<scalar_t>(),
reduction == at::Reduction::Mean,
input.size(1),
map_nelem,
Expand Down

0 comments on commit 506bd05

Please sign in to comment.