forward_over_reverse Function

recursive function forward_over_reverse(this, variable, depth) result(output)

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(inout) :: this
type(array_type), intent(in) :: variable
integer, intent(inout) :: depth

Return Value type(array_type), pointer


Source Code

  recursive function forward_over_reverse(this, variable, depth) result(output)
    implicit none
    type(array_type), intent(inout) :: this
    type(array_type), intent(in) :: variable
    integer, intent(inout) :: depth
    type(array_type), pointer :: output

    integer :: s, i, n_elem, n_samples
    logical :: is_right_a_variable, is_left_a_variable
    type(array_type), pointer :: left_deriv_tmp, right_deriv_tmp
    type(array_type), pointer :: left_deriv, right_deriv
    logical :: is_forward_local
    logical :: is_add_op, has_left, has_right, has_left_partial, has_right_partial

    depth = depth + 1
    if(depth.gt.diffstruc__max_recursion_depth)then
       write(0,*) "MAX RECURSION DEPTH REACHED", depth
       return
    end if
    is_forward_local = this%is_forward
    this%is_forward = .true.
    has_left = associated(this%left_operand)
    has_right = associated(this%right_operand)
    ! write(*,*) "Performing forward-over-reverse operation for: ", trim(this%operation)
    if(loc(this).eq.loc(variable))then
       allocate(output)
       output = this
       if(allocated(this%direction))then
          n_elem = size(output%val, 1)
          n_samples = size(output%val, 2)
          do concurrent(s = 1:n_samples, i = 1:n_elem)
             output%val(i,s) = this%direction(i)
          end do
       else
          output%val = 1._real32
       end if
       if(allocated(output%direction)) deallocate(output%direction)

    elseif(has_left .or. has_right)then
       ! Cache association checks and operation type
       has_left_partial = associated(this%get_partial_left)
       has_right_partial = associated(this%get_partial_right)
       is_add_op = (trim(this%operation).eq.'add')

       is_left_a_variable = .false.
       if(has_left .and. has_left_partial)then
          is_left_a_variable = .true.
          left_deriv_tmp => &
               forward_over_reverse(this%left_operand, variable, depth)
          ! call left_deriv_tmp%set_requires_grad(.false.)
          if(is_add_op)then
             left_deriv => left_deriv_tmp
          else
             allocate(left_deriv)
             if(has_right_partial .and. .not.has_right)then
                left_deriv = this%get_partial_right(left_deriv_tmp)
             else
                left_deriv = this%get_partial_left(left_deriv_tmp)
             end if
          end if
          !  left_deriv%owns_left_operand = .true.
          !  left_deriv%owns_right_operand = .true.
       end if

       is_right_a_variable = .false.
       if(has_right .and. has_right_partial)then
          is_right_a_variable = .true.
          right_deriv_tmp => &
               forward_over_reverse(this%right_operand, variable, depth)
          ! call right_deriv_tmp%set_requires_grad(.false.)
          if(is_add_op)then
             right_deriv => right_deriv_tmp
          else
             allocate(right_deriv)
             right_deriv = this%get_partial_right(right_deriv_tmp)
          end if
          !  right_deriv%owns_left_operand = .true.
          !  right_deriv%owns_right_operand = .true.
       end if

       if(is_left_a_variable.and.is_right_a_variable)then
          output => left_deriv + right_deriv
       elseif(is_left_a_variable)then
          output => left_deriv
       elseif(is_right_a_variable)then
          output => right_deriv
       else
          call stop_program("Neither operand is a variable in forward-over-reverse")
       end if

    else
       allocate(output)
       output = this
       if(allocated(output%direction)) deallocate(output%direction)
       output%val = 0._real32
    end if
    this%is_forward = is_forward_local
    output%is_forward = .true.
    output%is_temporary = .true.
    ! write(*,*) "done operation: ", trim(this%operation)

  end function forward_over_reverse