Sum-reduced gradient w.r.t. left operand for matmul. For the outer product case (rank-1 right operand), this computes: output = sum_s(upstream(:,s) (x) right(:,s)) = upstream * right^T using a single SGEMM call instead of computing the full (n_elem, S) array.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:) | :: | output |
subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output) #else pure subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output) #endif !! Sum-reduced gradient w.r.t. left operand for matmul. !! For the outer product case (rank-1 right operand), this computes: !! output = sum_s(upstream(:,s) (x) right(:,s)) = upstream * right^T !! using a single SGEMM call instead of computing the full (n_elem, S) array. implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_rhs num_batches = size(upstream_grad, 2) if(size(this%right_operand%shape).eq.2)then ! 2D case: output = sum_s(W * upstream(:,s)) = W * sum(upstream, dim=2) m = this%right_operand%shape(1) n = this%right_operand%shape(2) if(this%right_operand%is_sample_dependent)then ! Per-sample weight matrices: fall back to explicit loop block real(real32), dimension(m, n) :: temp real(real32), dimension(m) :: row_result output = 0.0_real32 do s = 1, num_batches temp = reshape(this%right_operand%val(:,s), [m, n]) row_result = matmul(upstream_grad(:,s), transpose(temp)) output = output + row_result end do end block else ! Non-sample-dependent: W * sum(upstream, dim=2) via matvec #ifdef USE_BLAS block real(real32), pointer :: W(:,:) real(real32), dimension(size(upstream_grad,1)) :: upstream_sum W(1:m, 1:n) => this%right_operand%val(:,1) upstream_sum = sum(upstream_grad, dim=2) call sgemv('N', m, n, 1.0_real32, W, m, upstream_sum, 1, & 0.0_real32, output, 1) end block #else block real(real32), dimension(m, n) :: temp real(real32), dimension(size(upstream_grad,1)) :: upstream_sum temp = reshape(this%right_operand%val(:,1), [m, n]) upstream_sum = sum(upstream_grad, dim=2) output = matmul(temp, upstream_sum) end block #endif end if else ! Outer product case: sum_s(upstream(i,s) * right(j,s)) ! = matmul(upstream, right^T) stored column-major num_elements = size(upstream_grad, 1) num_rhs = size(this%right_operand%val, 1) if(this%right_operand%is_sample_dependent)then #ifdef USE_BLAS ! SGEMM: result(num_elements, num_rhs) = upstream * right^T call sgemm('N', 'T', num_elements, num_rhs, num_batches, & 1.0_real32, upstream_grad, num_elements, & this%right_operand%val, num_rhs, & 0.0_real32, output, num_elements) #else ! Manual fused outer product + sum output = 0.0_real32 do s = 1, num_batches do j = 1, num_rhs do i = 1, num_elements output((j-1)*num_elements + i) = & output((j-1)*num_elements + i) + & this%right_operand%val(j,s) * upstream_grad(i,s) end do end do end do #endif else ! Right operand not sample-dependent: outer product with scalar right ! sum_s(right(j,1) * upstream(i,s)) = right(j,1) * sum_s(upstream(i,s)) block real(real32), dimension(num_elements) :: upstream_sum upstream_sum = sum(upstream_grad, dim=2) do j = 1, num_rhs output((j-1)*num_elements+1:j*num_elements) = & this%right_operand%val(j, 1) * upstream_sum end do end block end if end if end subroutine get_partial_matmul_left_val_sum