diff --git a/aten/src/ATen/native/cuda/NLLLoss2d.cu b/aten/src/ATen/native/cuda/NLLLoss2d.cu index ba98f18427d65..8e7199815517e 100644 --- a/aten/src/ATen/native/cuda/NLLLoss2d.cu +++ b/aten/src/ATen/native/cuda/NLLLoss2d.cu @@ -36,8 +36,8 @@ inline Tensor optional_contiguous(const Tensor& source) { // Returns the address of the first element of a tensor // or nullptr if the tensor is undefined. template -inline scalar_t* optional_data(const Tensor& source) { - return source.defined() ? source.data_ptr() : nullptr; +inline const scalar_t* optional_data(const Tensor& source) { + return source.defined() ? source.const_data_ptr() : nullptr; } using at::cuda::detail::CUDA_NUM_THREADS; @@ -51,7 +51,7 @@ __global__ void nll_loss2d_forward_no_reduce_kernel( PackedTensorAccessor64 input, PackedTensorAccessor64 target, PackedTensorAccessor64 output, - scalar_t* weight, + const scalar_t* weight, int64_t ignore_index ) { int64_t batch_size = input.size(0); @@ -79,9 +79,9 @@ C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) __global__ void nll_loss2d_forward_kernel( scalar_t* output, scalar_t* total_weight, - scalar_t* input, - int64_t* target, - scalar_t* weight, + const scalar_t* input, + const int64_t* target, + const scalar_t* weight, int n_classes, int map_nelem, int blocks_per_sample, @@ -125,7 +125,7 @@ template C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) __global__ void nll_loss2d_forward_size_average_kernel( scalar_t* output, - scalar_t* total_weight + const scalar_t* total_weight ) { *output /= *total_weight; } @@ -137,7 +137,7 @@ __global__ void nll_loss2d_backward_no_reduce_kernel( PackedTensorAccessor64 target, PackedTensorAccessor64 grad_output, PackedTensorAccessor64 grad_input, - scalar_t* weight, + const scalar_t* weight, int64_t ignore_index ) { int64_t batch_size = target.size(0); @@ -162,10 +162,10 @@ template C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) __global__ void nll_loss2d_backward_kernel( scalar_t* grad_input, - scalar_t* grad_output, - int64_t* target, - scalar_t* weights, - scalar_t* total_weight, + const scalar_t* grad_output, + const int64_t* target, + const scalar_t* weights, + const scalar_t* total_weight, bool size_average, int n_classes, int map_nelem, @@ -323,10 +323,10 @@ void nll_loss2d_forward_out_cuda_template( CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( - output.data_ptr(), - total_weight.data_ptr(), - input_.data_ptr(), - target_.data_ptr(), + output.mutable_data_ptr(), + total_weight.mutable_data_ptr(), + input_.const_data_ptr(), + target_.const_data_ptr(), optional_data(weight_), input_.size(1), input_.size(2) * input_.size(3), @@ -337,8 +337,8 @@ void nll_loss2d_forward_out_cuda_template( if (reduction == at::Reduction::Mean) { nll_loss2d_forward_size_average_kernel <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - output.data_ptr(), - total_weight.data_ptr()); + output.mutable_data_ptr(), + total_weight.const_data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); @@ -441,11 +441,11 @@ void nll_loss2d_backward_out_cuda_template( CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( - grad_input.data_ptr(), - grad_output.data_ptr(), - target_.data_ptr(), + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + target_.const_data_ptr(), optional_data(weight_), - total_weight.data_ptr(), + total_weight.const_data_ptr(), reduction == at::Reduction::Mean, input.size(1), map_nelem,