Accumulate gradient for array with safe memory management
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| type(array_type), | intent(inout), | target | :: | array | ||
| type(array_type), | intent(in), | pointer | :: | grad | ||
| integer, | intent(in) | :: | depth |
recursive subroutine accumulate_gradient_ptr(array, grad, depth) !! Accumulate gradient for array with safe memory management implicit none type(array_type), intent(inout), target :: array type(array_type), intent(in), pointer :: grad integer, intent(in) :: depth integer :: s logical :: is_directional type(array_type), pointer :: directional_grad, tmp_ptr is_directional = .false. if(allocated(array%direction))then if(size(array%direction).gt.0) is_directional = .true. end if if(is_directional)then allocate(directional_grad) directional_grad = grad do s = 1, size(grad%val, 2) directional_grad%val(:, s) = grad%val(:, s) * array%direction end do else directional_grad => grad end if if(.not. associated(array%grad))then if(array%is_sample_dependent)then array%grad => directional_grad else ! ! mean reduction ! array%grad => array%grad + mean( directional_grad, dim = 2 ) ! sum reduction array%grad => sum( directional_grad, dim = 2 ) end if array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary else array%grad%is_temporary = .true. if(array%is_sample_dependent)then array%grad => array%grad + directional_grad !array%grad%val = array%grad%val + directional_grad%val else ! ! mean reduction ! array%grad => array%grad + mean( directional_grad, dim = 2 ) ! sum reduction array%grad => array%grad + sum( directional_grad, dim = 2 ) !array%grad%val(:,1) = array%grad%val(:,1) + sum( directional_grad%val, dim = 2 ) end if array%grad%is_temporary = array%is_temporary end if if(associated(array%left_operand).or.associated(array%right_operand))then call reverse_mode_ptr(array, directional_grad, depth+1) end if end subroutine accumulate_gradient_ptr