Backward operation for arrays
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(inout) | :: | array | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| integer, | intent(in) | :: | depth |
recursive subroutine reverse_mode(array, upstream_grad, depth) !! Backward operation for arrays implicit none class(array_type), intent(inout) :: array real(real32), dimension(:,:), intent(in) :: upstream_grad integer, intent(in) :: depth integer :: num_samples logical :: has_left, has_right ! write(*,'("Performing backward operation for: ",A,T60,"id: ",I0)') & ! trim(array%operation), array%id if(depth.gt.diffstruc__max_recursion_depth)then write(0,*) "MAX RECURSION DEPTH REACHED IN REVERSE MODE", depth return end if ! Cache operand checks to avoid repeated pointer checks has_left = associated(array%left_operand) has_right = associated(array%right_operand) ! Early exit if no operands require gradients if(has_left)then if(.not.array%left_operand%requires_grad) has_left = .false. end if if(has_right)then if(.not.array%right_operand%requires_grad) has_right = .false. end if if(.not.has_left .and. .not.has_right) return array%is_forward = .false. ! Process left operand (already verified it requires grad) if(has_left)then num_samples = max(size(array%left_operand%val, 2), size(upstream_grad, 2)) if(array%left_operand%is_sample_dependent .or. num_samples.eq.1)then call accumulate_gradient_samples( & array%left_operand, array, upstream_grad, num_samples, .true., depth & ) else call accumulate_gradient_single( & array%left_operand, array, upstream_grad, num_samples, .true., depth & ) end if end if ! Process right operand (already verified it requires grad) if(has_right)then num_samples = max(size(array%right_operand%val, 2), size(upstream_grad, 2)) if(array%right_operand%is_sample_dependent .or. num_samples.eq.1)then call accumulate_gradient_samples( & array%right_operand, array, upstream_grad, num_samples, .false., depth & ) else call accumulate_gradient_single( & array%right_operand, array, upstream_grad, num_samples, .false., depth & ) end if end if ! write(*,*) "done operation: ", trim(array%operation) end subroutine reverse_mode