accumulate_gradient_ptr Subroutine

recursive subroutine accumulate_gradient_ptr(array, grad, depth)

Accumulate gradient for array with safe memory management

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(inout), target :: array
type(array_type), intent(in), pointer :: grad
integer, intent(in) :: depth

Source Code

  recursive subroutine accumulate_gradient_ptr(array, grad, depth)
    !! Accumulate gradient for array with safe memory management
    implicit none
    type(array_type), intent(inout), target :: array
    type(array_type), intent(in), pointer :: grad
    integer, intent(in) :: depth

    integer :: s
    logical :: is_directional
    type(array_type), pointer :: directional_grad, tmp_ptr

    is_directional = .false.
    if(allocated(array%direction))then
       if(size(array%direction).gt.0) is_directional = .true.
    end if

    if(is_directional)then
       allocate(directional_grad)
       directional_grad = grad
       do s = 1, size(grad%val, 2)
          directional_grad%val(:, s) = grad%val(:, s) * array%direction
       end do
    else
       directional_grad => grad
    end if

    if(.not. associated(array%grad))then
       if(array%is_sample_dependent)then
          array%grad => directional_grad
       else
          ! ! mean reduction
          ! array%grad => array%grad + mean( directional_grad, dim = 2 )
          ! sum reduction
          array%grad => sum( directional_grad, dim = 2 )
       end if
       array%grad%is_scalar = array%is_scalar
       array%grad%is_sample_dependent = array%is_sample_dependent
       array%grad%requires_grad = .not. array%is_scalar
       array%grad%grad => null()
       array%grad%owns_gradient = .false.
       array%owns_gradient = .true.
       array%grad%is_temporary = array%is_temporary
    else

       array%grad%is_temporary = .true.
       if(array%is_sample_dependent)then
          array%grad => array%grad + directional_grad
          !array%grad%val = array%grad%val + directional_grad%val
       else
          ! ! mean reduction
          ! array%grad => array%grad + mean( directional_grad, dim = 2 )
          ! sum reduction
          array%grad => array%grad + sum( directional_grad, dim = 2 )
          !array%grad%val(:,1) = array%grad%val(:,1) + sum( directional_grad%val, dim = 2 )
       end if
       array%grad%is_temporary = array%is_temporary

    end if

    if(associated(array%left_operand).or.associated(array%right_operand))then
       call reverse_mode_ptr(array, directional_grad, depth+1)
    end if
  end subroutine accumulate_gradient_ptr