Accumulate gradient for array - optimized with fused direction+accumulate
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(inout) | :: | array | |||
| class(array_type), | intent(inout) | :: | parent | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| integer, | intent(in) | :: | num_samples | |||
| logical, | intent(in) | :: | is_left_operand | |||
| integer, | intent(in) | :: | depth |
recursive subroutine accumulate_gradient_samples( & array, parent, upstream_grad, num_samples, is_left_operand, depth & ) !! Accumulate gradient for array - optimized with fused direction+accumulate implicit none class(array_type), intent(inout) :: array class(array_type), intent(inout) :: parent real(real32), dimension(:,:), intent(in) :: upstream_grad integer, intent(in) :: num_samples logical, intent(in) :: is_left_operand integer, intent(in) :: depth integer :: s, i, n_elem, n_samples_actual #ifdef __flang__ real(real32), dimension(:, :), allocatable :: grad #else real(real32), dimension(size(array%val, 1), num_samples) :: grad #endif logical :: has_direction, needs_recurse ! Cache array dimensions n_elem = size(array%val, 1) ! Check direction once has_direction = allocated(array%direction) if(has_direction) has_direction = (size(array%direction).gt.0) ! Check recursion need needs_recurse = associated(array%left_operand) .or. & associated(array%right_operand) ! Direct-write fast path: first gradient can be written straight into ! persistent storage, then optionally scaled and recursed from there. if(.not. associated(array%grad))then allocate(array%grad) call array%grad%allocate(array_shape=[array%shape, num_samples]) if(is_left_operand)then call parent%get_partial_left_val(upstream_grad, array%grad%val) else call parent%get_partial_right_val(upstream_grad, array%grad%val) end if if(has_direction)then do concurrent( s = 1 : num_samples, i = 1 : n_elem ) array%grad%val(i,s) = array%grad%val(i,s) * array%direction(i) end do end if array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary if(needs_recurse) call reverse_mode(array, array%grad%val, depth+1) return end if ! General path: compute partial into temp grad array #ifdef __flang__ allocate(grad(n_elem, num_samples)) #endif if(is_left_operand)then call parent%get_partial_left_val(upstream_grad, grad) else call parent%get_partial_right_val(upstream_grad, grad) end if ! Leaf node fast path: fuse direction into accumulation, skip grad modification if(.not. needs_recurse .and. has_direction)then if(.not. associated(array%grad))then allocate(array%grad) call array%grad%allocate(array_shape=[array%shape, size(array%val,2)]) do concurrent( s = 1 : num_samples, i = 1 : n_elem ) array%grad%val(i,s) = grad(i,s) * array%direction(i) end do array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary else array%grad%is_temporary = .true. n_samples_actual = size(array%grad%val, 2) do concurrent( s = 1 : n_samples_actual, i = 1 : n_elem ) array%grad%val(i,s) = array%grad%val(i,s) + & grad(i,s) * array%direction(i) end do end if return end if ! Apply directional derivative in-place (needed for recursion) if(has_direction)then do concurrent( s = 1 : num_samples, i = 1 : n_elem ) grad(i, s) = grad(i, s) * array%direction(i) end do end if ! Accumulate gradient if(.not. associated(array%grad))then allocate(array%grad) call array%grad%allocate(array_shape=[array%shape, size(array%val,2)]) array%grad%val = grad array%grad%is_scalar = array%is_scalar array%grad%is_sample_dependent = array%is_sample_dependent array%grad%requires_grad = .not. array%is_scalar array%grad%grad => null() array%grad%owns_gradient = .false. array%owns_gradient = .true. array%grad%is_temporary = array%is_temporary else ! In-place addition array%grad%is_temporary = .true. n_samples_actual = size(array%grad%val, 2) do concurrent( s = 1 : n_samples_actual, i = 1 : n_elem ) array%grad%val(i,s) = array%grad%val(i,s) + grad(i,s) end do end if ! Recurse if needed if(needs_recurse)then call reverse_mode(array, grad, depth+1) end if end subroutine accumulate_gradient_samples