recursive function forward_over_reverse(this, variable, depth) result(output)
implicit none
type(array_type), intent(inout) :: this
type(array_type), intent(in) :: variable
integer, intent(inout) :: depth
type(array_type), pointer :: output
integer :: s, i, n_elem, n_samples
logical :: is_right_a_variable, is_left_a_variable
type(array_type), pointer :: left_deriv_tmp, right_deriv_tmp
type(array_type), pointer :: left_deriv, right_deriv
logical :: is_forward_local
logical :: is_add_op, has_left, has_right, has_left_partial, has_right_partial
depth = depth + 1
if(depth.gt.diffstruc__max_recursion_depth)then
write(0,*) "MAX RECURSION DEPTH REACHED", depth
return
end if
is_forward_local = this%is_forward
this%is_forward = .true.
has_left = associated(this%left_operand)
has_right = associated(this%right_operand)
! write(*,*) "Performing forward-over-reverse operation for: ", trim(this%operation)
if(loc(this).eq.loc(variable))then
allocate(output)
output = this
if(allocated(this%direction))then
n_elem = size(output%val, 1)
n_samples = size(output%val, 2)
do concurrent(s = 1:n_samples, i = 1:n_elem)
output%val(i,s) = this%direction(i)
end do
else
output%val = 1._real32
end if
if(allocated(output%direction)) deallocate(output%direction)
elseif(has_left .or. has_right)then
! Cache association checks and operation type
has_left_partial = associated(this%get_partial_left)
has_right_partial = associated(this%get_partial_right)
is_add_op = (trim(this%operation).eq.'add')
is_left_a_variable = .false.
if(has_left .and. has_left_partial)then
is_left_a_variable = .true.
left_deriv_tmp => &
forward_over_reverse(this%left_operand, variable, depth)
! call left_deriv_tmp%set_requires_grad(.false.)
if(is_add_op)then
left_deriv => left_deriv_tmp
else
allocate(left_deriv)
if(has_right_partial .and. .not.has_right)then
left_deriv = this%get_partial_right(left_deriv_tmp)
else
left_deriv = this%get_partial_left(left_deriv_tmp)
end if
end if
! left_deriv%owns_left_operand = .true.
! left_deriv%owns_right_operand = .true.
end if
is_right_a_variable = .false.
if(has_right .and. has_right_partial)then
is_right_a_variable = .true.
right_deriv_tmp => &
forward_over_reverse(this%right_operand, variable, depth)
! call right_deriv_tmp%set_requires_grad(.false.)
if(is_add_op)then
right_deriv => right_deriv_tmp
else
allocate(right_deriv)
right_deriv = this%get_partial_right(right_deriv_tmp)
end if
! right_deriv%owns_left_operand = .true.
! right_deriv%owns_right_operand = .true.
end if
if(is_left_a_variable.and.is_right_a_variable)then
output => left_deriv + right_deriv
elseif(is_left_a_variable)then
output => left_deriv
elseif(is_right_a_variable)then
output => right_deriv
else
call stop_program("Neither operand is a variable in forward-over-reverse")
end if
else
allocate(output)
output = this
if(allocated(output%direction)) deallocate(output%direction)
output%val = 0._real32
end if
this%is_forward = is_forward_local
output%is_forward = .true.
output%is_temporary = .true.
! write(*,*) "done operation: ", trim(this%operation)
end function forward_over_reverse