Accumulate gradient for array - optimized with sum-reduced path When a sum-reduced partial derivative is available, computes the sum(partial, dim=2) directly, avoiding the large (n_elem, S) allocation. Falls back to full computation when sum-reduced variant is not set.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(inout) | :: | array | |||
| class(array_type), | intent(inout) | :: | parent | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| integer, | intent(in) | :: | num_samples | |||
| logical, | intent(in) | :: | is_left_operand | |||
| integer, | intent(in) | :: | depth |
recursive subroutine accumulate_gradient_single( & array, parent, upstream_grad, num_samples, is_left_operand, depth & ) !! Accumulate gradient for array - optimized with sum-reduced path !! When a sum-reduced partial derivative is available, computes the !! sum(partial, dim=2) directly, avoiding the large (n_elem, S) allocation. !! Falls back to full computation when sum-reduced variant is not set. implicit none class(array_type), intent(inout) :: array class(array_type), intent(inout) :: parent real(real32), dimension(:,:), intent(in) :: upstream_grad integer, intent(in) :: num_samples logical, intent(in) :: is_left_operand integer, intent(in) :: depth integer :: s, i, n_elem #ifdef __flang__ real(real32), dimension(:, :), allocatable :: grad real(real32), dimension(:, :), allocatable :: out_grad #else real(real32), dimension(size(array%val, 1), num_samples) :: grad real(real32), dimension(size(array%val, 1), 1) :: out_grad #endif logical :: has_direction, needs_recurse, use_sum_reduced ! Cache array dimension to avoid repeated calls n_elem = size(array%val, 1) ! Check direction once has_direction = allocated(array%direction) if(has_direction) has_direction = (size(array%direction).gt.0) ! Check if recursion needed needs_recurse = associated(array%left_operand) .or. & associated(array%right_operand) ! Determine if we can use the sum-reduced path (avoids large allocation) use_sum_reduced = .false. if(num_samples .gt. 1)then if(is_left_operand)then use_sum_reduced = associated(parent%get_partial_left_val_sum) else use_sum_reduced = associated(parent%get_partial_right_val_sum) end if end if if(use_sum_reduced)then ! For leaf nodes without directional derivatives, write the reduced ! gradient directly into the destination buffer and avoid an extra copy. if(.not. needs_recurse .and. .not. has_direction .and. & .not. associated(array%grad))then allocate(array%grad) call array%grad%allocate(array_shape=[array%shape, 1]) if(is_left_operand)then call parent%get_partial_left_val_sum(upstream_grad, array%grad%val(:,1)) else call parent%get_partial_right_val_sum(upstream_grad, array%grad%val(:,1)) end if array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary return end if !---------------------------------------------------------------------- ! Sum-reduced fast path: compute sum(partial, dim=2) directly ! Avoids allocating the large grad(n_elem, num_samples) array entirely !---------------------------------------------------------------------- #ifdef __flang__ allocate(out_grad(n_elem, 1)) #endif if(is_left_operand)then call parent%get_partial_left_val_sum(upstream_grad, out_grad(:,1)) else call parent%get_partial_right_val_sum(upstream_grad, out_grad(:,1)) end if ! Apply direction if needed if(has_direction)then do concurrent( i = 1 : n_elem ) out_grad(i,1) = out_grad(i,1) * array%direction(i) end do end if ! Fast path: leaf node with existing gradient if(.not. needs_recurse .and. associated(array%grad))then array%grad%is_temporary = .true. do concurrent( i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + out_grad(i,1) end do #ifdef __flang__ deallocate(out_grad) #endif return end if else !---------------------------------------------------------------------- ! Standard path: compute full gradient then reduce !---------------------------------------------------------------------- #ifdef __flang__ allocate(grad(n_elem, num_samples)) #endif if(is_left_operand)then call parent%get_partial_left_val(upstream_grad, grad) else call parent%get_partial_right_val(upstream_grad, grad) end if ! Fast path: leaf node with existing gradient — direct accumulation if(.not. needs_recurse .and. associated(array%grad))then array%grad%is_temporary = .true. if(num_samples .eq. 1)then if(has_direction)then do concurrent( i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + & grad(i,1) * array%direction(i) end do else do concurrent( i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + grad(i,1) end do end if else if(.not. has_direction)then do concurrent( s = 1 : num_samples, i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + grad(i,s) end do else #ifdef __flang__ allocate(out_grad(n_elem, 1)) #endif out_grad(:,1) = grad(:,1) do concurrent( s = 2 : num_samples, i = 1 : n_elem ) out_grad(i,1) = out_grad(i,1) + grad(i,s) end do do concurrent( i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + & out_grad(i,1) * array%direction(i) end do #ifdef __flang__ deallocate(out_grad) #endif end if #ifdef __flang__ deallocate(grad) #endif return end if ! General path: compute out_grad with reduction #ifdef __flang__ allocate(out_grad(n_elem, 1)) #endif if(num_samples .eq. 1)then if(has_direction)then do concurrent( i = 1 : n_elem ) out_grad(i,1) = grad(i,1) * array%direction(i) end do else out_grad(:,1) = grad(:,1) end if else out_grad(:,1) = grad(:,1) do concurrent( s = 2 : num_samples, i = 1 : n_elem ) out_grad(i,1) = out_grad(i,1) + grad(i,s) end do if(has_direction)then do concurrent( i = 1 : n_elem ) out_grad(i,1) = out_grad(i,1) * array%direction(i) end do end if end if ! Free large workspace BEFORE recursion #ifdef __flang__ deallocate(grad) #endif end if ! Accumulate gradient if(.not. associated(array%grad))then allocate(array%grad) call array%grad%allocate(array_shape=[array%shape, 1]) array%grad%val = out_grad array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary else array%grad%is_temporary = .true. do concurrent( i = 1 : n_elem ) array%grad%val(i,1) = array%grad%val(i,1) + out_grad(i,1) end do end if ! Recurse if needed if(needs_recurse)then call reverse_mode(array, out_grad, depth+1) end if end subroutine accumulate_gradient_single