reverse_mode Subroutine

recursive subroutine reverse_mode(array, upstream_grad, depth)

Backward operation for arrays

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(inout) :: array
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
integer, intent(in) :: depth

Source Code

  recursive subroutine reverse_mode(array, upstream_grad, depth)
    !! Backward operation for arrays
    implicit none
    class(array_type), intent(inout) :: array
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    integer, intent(in) :: depth

    integer :: num_samples
    logical :: has_left, has_right

    ! write(*,'("Performing backward operation for: ",A,T60,"id: ",I0)') &
    !      trim(array%operation), array%id
    if(depth.gt.diffstruc__max_recursion_depth)then
       write(0,*) "MAX RECURSION DEPTH REACHED IN REVERSE MODE", depth
       return
    end if

    ! Cache operand checks to avoid repeated pointer checks
    has_left = associated(array%left_operand)
    has_right = associated(array%right_operand)

    ! Early exit if no operands require gradients
    if(has_left)then
       if(.not.array%left_operand%requires_grad) has_left = .false.
    end if
    if(has_right)then
       if(.not.array%right_operand%requires_grad) has_right = .false.
    end if
    if(.not.has_left .and. .not.has_right) return

    array%is_forward = .false.
    ! Process left operand (already verified it requires grad)
    if(has_left)then
       num_samples = max(size(array%left_operand%val, 2), size(upstream_grad, 2))
       if(array%left_operand%is_sample_dependent .or. num_samples.eq.1)then
          call accumulate_gradient_samples( &
               array%left_operand, array, upstream_grad, num_samples, .true., depth &
          )
       else
          call accumulate_gradient_single( &
               array%left_operand, array, upstream_grad, num_samples, .true., depth &
          )
       end if
    end if

    ! Process right operand (already verified it requires grad)
    if(has_right)then
       num_samples = max(size(array%right_operand%val, 2), size(upstream_grad, 2))
       if(array%right_operand%is_sample_dependent .or. num_samples.eq.1)then
          call accumulate_gradient_samples( &
               array%right_operand, array, upstream_grad, num_samples, .false., depth &
          )
       else
          call accumulate_gradient_single( &
               array%right_operand, array, upstream_grad, num_samples, .false., depth &
          )
       end if
    end if
    ! write(*,*) "done operation: ", trim(array%operation)
  end subroutine reverse_mode