accumulate_gradient_single Subroutine

recursive subroutine accumulate_gradient_single(array, parent, upstream_grad, num_samples, is_left_operand, depth)

Accumulate gradient for array - optimized with sum-reduced path When a sum-reduced partial derivative is available, computes the sum(partial, dim=2) directly, avoiding the large (n_elem, S) allocation. Falls back to full computation when sum-reduced variant is not set.

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(inout) :: array
class(array_type), intent(inout) :: parent
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
integer, intent(in) :: num_samples
logical, intent(in) :: is_left_operand
integer, intent(in) :: depth

Source Code

  recursive subroutine accumulate_gradient_single( &
       array, parent, upstream_grad, num_samples, is_left_operand, depth &
  )
    !! Accumulate gradient for array - optimized with sum-reduced path
    !! When a sum-reduced partial derivative is available, computes the
    !! sum(partial, dim=2) directly, avoiding the large (n_elem, S) allocation.
    !! Falls back to full computation when sum-reduced variant is not set.
    implicit none
    class(array_type), intent(inout) :: array
    class(array_type), intent(inout) :: parent
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    integer, intent(in) :: num_samples
    logical, intent(in) :: is_left_operand
    integer, intent(in) :: depth

    integer :: s, i, n_elem
#ifdef __flang__
    real(real32), dimension(:, :), allocatable :: grad
    real(real32), dimension(:, :), allocatable :: out_grad
#else
    real(real32), dimension(size(array%val, 1), num_samples) :: grad
    real(real32), dimension(size(array%val, 1), 1) :: out_grad
#endif
    logical :: has_direction, needs_recurse, use_sum_reduced

    ! Cache array dimension to avoid repeated calls
    n_elem = size(array%val, 1)

    ! Check direction once
    has_direction = allocated(array%direction)
    if(has_direction) has_direction = (size(array%direction).gt.0)

    ! Check if recursion needed
    needs_recurse = associated(array%left_operand) .or. &
         associated(array%right_operand)

    ! Determine if we can use the sum-reduced path (avoids large allocation)
    use_sum_reduced = .false.
    if(num_samples .gt. 1)then
       if(is_left_operand)then
          use_sum_reduced = associated(parent%get_partial_left_val_sum)
       else
          use_sum_reduced = associated(parent%get_partial_right_val_sum)
       end if
    end if

    if(use_sum_reduced)then
       ! For leaf nodes without directional derivatives, write the reduced
       ! gradient directly into the destination buffer and avoid an extra copy.
       if(.not. needs_recurse .and. .not. has_direction .and. &
            .not. associated(array%grad))then
          allocate(array%grad)
          call array%grad%allocate(array_shape=[array%shape, 1])
          if(is_left_operand)then
             call parent%get_partial_left_val_sum(upstream_grad, array%grad%val(:,1))
          else
             call parent%get_partial_right_val_sum(upstream_grad, array%grad%val(:,1))
          end if
          array%grad%is_scalar = array%is_scalar
          array%grad%is_sample_dependent = array%is_sample_dependent
          array%grad%requires_grad = .not. array%is_scalar
          array%grad%grad => null()
          array%grad%owns_gradient = .false.
          array%owns_gradient = .true.
          array%grad%is_temporary = array%is_temporary
          return
       end if

       !----------------------------------------------------------------------
       ! Sum-reduced fast path: compute sum(partial, dim=2) directly
       ! Avoids allocating the large grad(n_elem, num_samples) array entirely
       !----------------------------------------------------------------------
#ifdef __flang__
       allocate(out_grad(n_elem, 1))
#endif
       if(is_left_operand)then
          call parent%get_partial_left_val_sum(upstream_grad, out_grad(:,1))
       else
          call parent%get_partial_right_val_sum(upstream_grad, out_grad(:,1))
       end if

       ! Apply direction if needed
       if(has_direction)then
          do concurrent( i = 1 : n_elem )
             out_grad(i,1) = out_grad(i,1) * array%direction(i)
          end do
       end if

       ! Fast path: leaf node with existing gradient
       if(.not. needs_recurse .and. associated(array%grad))then
          array%grad%is_temporary = .true.
          do concurrent( i = 1 : n_elem )
             array%grad%val(i,1) = array%grad%val(i,1) + out_grad(i,1)
          end do
#ifdef __flang__
          deallocate(out_grad)
#endif
          return
       end if
    else
       !----------------------------------------------------------------------
       ! Standard path: compute full gradient then reduce
       !----------------------------------------------------------------------
#ifdef __flang__
       allocate(grad(n_elem, num_samples))
#endif
       if(is_left_operand)then
          call parent%get_partial_left_val(upstream_grad, grad)
       else
          call parent%get_partial_right_val(upstream_grad, grad)
       end if

       ! Fast path: leaf node with existing gradient — direct accumulation
       if(.not. needs_recurse .and. associated(array%grad))then
          array%grad%is_temporary = .true.
          if(num_samples .eq. 1)then
             if(has_direction)then
                do concurrent( i = 1 : n_elem )
                   array%grad%val(i,1) = array%grad%val(i,1) + &
                        grad(i,1) * array%direction(i)
                end do
             else
                do concurrent( i = 1 : n_elem )
                   array%grad%val(i,1) = array%grad%val(i,1) + grad(i,1)
                end do
             end if
          else if(.not. has_direction)then
             do concurrent( s = 1 : num_samples, i = 1 : n_elem )
                array%grad%val(i,1) = array%grad%val(i,1) + grad(i,s)
             end do
          else
#ifdef __flang__
             allocate(out_grad(n_elem, 1))
#endif
             out_grad(:,1) = grad(:,1)
             do concurrent( s = 2 : num_samples, i = 1 : n_elem )
                out_grad(i,1) = out_grad(i,1) + grad(i,s)
             end do
             do concurrent( i = 1 : n_elem )
                array%grad%val(i,1) = array%grad%val(i,1) + &
                     out_grad(i,1) * array%direction(i)
             end do
#ifdef __flang__
             deallocate(out_grad)
#endif
          end if
#ifdef __flang__
          deallocate(grad)
#endif
          return
       end if

       ! General path: compute out_grad with reduction
#ifdef __flang__
       allocate(out_grad(n_elem, 1))
#endif
       if(num_samples .eq. 1)then
          if(has_direction)then
             do concurrent( i = 1 : n_elem )
                out_grad(i,1) = grad(i,1) * array%direction(i)
             end do
          else
             out_grad(:,1) = grad(:,1)
          end if
       else
          out_grad(:,1) = grad(:,1)
          do concurrent( s = 2 : num_samples, i = 1 : n_elem )
             out_grad(i,1) = out_grad(i,1) + grad(i,s)
          end do
          if(has_direction)then
             do concurrent( i = 1 : n_elem )
                out_grad(i,1) = out_grad(i,1) * array%direction(i)
             end do
          end if
       end if

       ! Free large workspace BEFORE recursion
#ifdef __flang__
       deallocate(grad)
#endif
    end if

    ! Accumulate gradient
    if(.not. associated(array%grad))then
       allocate(array%grad)
       call array%grad%allocate(array_shape=[array%shape, 1])
       array%grad%val = out_grad
       array%grad%is_scalar = array%is_scalar
       array%grad%is_sample_dependent = array%is_sample_dependent
       array%grad%requires_grad = .not. array%is_scalar
       array%grad%grad => null()
       array%grad%owns_gradient = .false.
       array%owns_gradient = .true.
       array%grad%is_temporary = array%is_temporary
    else
       array%grad%is_temporary = .true.
       do concurrent( i = 1 : n_elem )
          array%grad%val(i,1) = array%grad%val(i,1) + out_grad(i,1)
       end do
    end if

    ! Recurse if needed
    if(needs_recurse)then
       call reverse_mode(array, out_grad, depth+1)
    end if
  end subroutine accumulate_gradient_single