accumulate_gradient_samples Subroutine

recursive subroutine accumulate_gradient_samples(array, parent, upstream_grad, num_samples, is_left_operand, depth)

Accumulate gradient for array - optimized with fused direction+accumulate

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(inout) :: array
class(array_type), intent(inout) :: parent
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
integer, intent(in) :: num_samples
logical, intent(in) :: is_left_operand
integer, intent(in) :: depth

Source Code

  recursive subroutine accumulate_gradient_samples( &
       array, parent, upstream_grad, num_samples, is_left_operand, depth &
  )
    !! Accumulate gradient for array - optimized with fused direction+accumulate
    implicit none
    class(array_type), intent(inout) :: array
    class(array_type), intent(inout) :: parent
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    integer, intent(in) :: num_samples
    logical, intent(in) :: is_left_operand
    integer, intent(in) :: depth

    integer :: s, i, n_elem, n_samples_actual
#ifdef __flang__
    real(real32), dimension(:, :), allocatable :: grad
#else
    real(real32), dimension(size(array%val, 1), num_samples) :: grad
#endif
    logical :: has_direction, needs_recurse

    ! Cache array dimensions
    n_elem = size(array%val, 1)

    ! Check direction once
    has_direction = allocated(array%direction)
    if(has_direction) has_direction = (size(array%direction).gt.0)

    ! Check recursion need
    needs_recurse = associated(array%left_operand) .or. &
         associated(array%right_operand)

    ! Direct-write fast path: first gradient can be written straight into
    ! persistent storage, then optionally scaled and recursed from there.
    if(.not. associated(array%grad))then
       allocate(array%grad)
       call array%grad%allocate(array_shape=[array%shape, num_samples])
       if(is_left_operand)then
          call parent%get_partial_left_val(upstream_grad, array%grad%val)
       else
          call parent%get_partial_right_val(upstream_grad, array%grad%val)
       end if
       if(has_direction)then
          do concurrent( s = 1 : num_samples, i = 1 : n_elem )
             array%grad%val(i,s) = array%grad%val(i,s) * array%direction(i)
          end do
       end if
       array%grad%is_scalar = array%is_scalar
       array%grad%is_sample_dependent = array%is_sample_dependent
       array%grad%requires_grad = .not. array%is_scalar
       array%grad%grad => null()
       array%grad%owns_gradient = .false.
       array%owns_gradient = .true.
       array%grad%is_temporary = array%is_temporary
       if(needs_recurse) call reverse_mode(array, array%grad%val, depth+1)
       return
    end if

    ! General path: compute partial into temp grad array
#ifdef __flang__
    allocate(grad(n_elem, num_samples))
#endif
    if(is_left_operand)then
       call parent%get_partial_left_val(upstream_grad, grad)
    else
       call parent%get_partial_right_val(upstream_grad, grad)
    end if

    ! Leaf node fast path: fuse direction into accumulation, skip grad modification
    if(.not. needs_recurse .and. has_direction)then
       if(.not. associated(array%grad))then
          allocate(array%grad)
          call array%grad%allocate(array_shape=[array%shape, size(array%val,2)])
          do concurrent( s = 1 : num_samples, i = 1 : n_elem )
             array%grad%val(i,s) = grad(i,s) * array%direction(i)
          end do
          array%grad%is_scalar = array%is_scalar
          array%grad%is_sample_dependent = array%is_sample_dependent
          array%grad%requires_grad = .not. array%is_scalar
          array%grad%grad => null()
          array%grad%owns_gradient = .false.
          array%owns_gradient = .true.
          array%grad%is_temporary = array%is_temporary
       else
          array%grad%is_temporary = .true.
          n_samples_actual = size(array%grad%val, 2)
          do concurrent( s = 1 : n_samples_actual, i = 1 : n_elem )
             array%grad%val(i,s) = array%grad%val(i,s) + &
                  grad(i,s) * array%direction(i)
          end do
       end if
       return
    end if

    ! Apply directional derivative in-place (needed for recursion)
    if(has_direction)then
       do concurrent( s = 1 : num_samples, i = 1 : n_elem )
          grad(i, s) = grad(i, s) * array%direction(i)
       end do
    end if

    ! Accumulate gradient
    if(.not. associated(array%grad))then
       allocate(array%grad)
       call array%grad%allocate(array_shape=[array%shape, size(array%val,2)])
       array%grad%val = grad
       array%grad%is_scalar = array%is_scalar
       array%grad%is_sample_dependent = array%is_sample_dependent
       array%grad%requires_grad = .not. array%is_scalar
       array%grad%grad => null()
       array%grad%owns_gradient = .false.
       array%owns_gradient = .true.
       array%grad%is_temporary = array%is_temporary
    else
       ! In-place addition
       array%grad%is_temporary = .true.
       n_samples_actual = size(array%grad%val, 2)
       do concurrent( s = 1 : n_samples_actual, i = 1 : n_elem )
          array%grad%val(i,s) = array%grad%val(i,s) + grad(i,s)
       end do
    end if

    ! Recurse if needed
    if(needs_recurse)then
       call reverse_mode(array, grad, depth+1)
    end if
  end subroutine accumulate_gradient_samples