Compute gradient w.r.t. left operand for matmul in reverse mode.
For C = A * B (where A is the left operand): dL/dA = dL/dC * B^T (2D case, equivalent to B * dL/dC in flat storage) dL/dA = upstream_grad (x) right_operand (outer product case)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
subroutine get_partial_matmul_left_val(this, upstream_grad, output) #else pure subroutine get_partial_matmul_left_val(this, upstream_grad, output) #endif !! Compute gradient w.r.t. left operand for matmul in reverse mode. !! !! For C = A * B (where A is the left operand): !! dL/dA = dL/dC * B^T (2D case, equivalent to B * dL/dC in flat storage) !! dL/dA = upstream_grad (x) right_operand (outer product case) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_rhs num_batches = size(upstream_grad, 2) if(size(this%right_operand%shape).eq.2)then m = this%right_operand%shape(1) n = this%right_operand%shape(2) if(this%right_operand%is_sample_dependent)then ! Per-sample: weight matrix varies per sample. ! output(:,s) = W_s * upstream_grad(:,s) where W_s = reshape(right%val(:,s), [m,n]) ! Equivalent to dL/dA_flat = dL/dC * B_s^T per sample. ! Uses intrinsic matmul — BLAS overhead not worthwhile for per-sample vectors. block real(real32), dimension(m, n) :: temp do s = 1, num_batches temp = reshape(this%right_operand%val(:,s), [m, n]) output(:,s) = matmul(upstream_grad(:,s), transpose(temp)) end do end block else ! Non-sample-dependent: single batch operation across all samples. ! output(m, S) = W(m, n) * upstream_grad(n, S) ! Mathematically: dL/dA = dL/dC * B^T for each sample, batched. #ifdef USE_BLAS block real(real32), pointer :: W(:,:) ! Pointer view avoids reshape: right_operand%val(:,1) is column-major ! contiguous (m, n) data. No copy needed. W(1:m, 1:n) => this%right_operand%val(:,1) ! sgemm: C = alpha * A * B + beta * C ! A = W(m, n), B = upstream_grad(n, S), C = output(m, S) ! m_arg = m (rows of W and output) ! n_arg = num_batches (columns of upstream_grad and output) ! k = n (columns of W / rows of upstream_grad) ! lda = m (leading dim of W), ldb = n, ldc = m call sgemm('N', 'N', m, num_batches, n, & 1.0_real32, W, m, upstream_grad, n, & 0.0_real32, output, m) end block #else block real(real32), dimension(m, n) :: temp temp = reshape(this%right_operand%val(:,1), [m, n]) output = matmul(temp, upstream_grad) end block #endif end if else ! Outer product case: output(i + (j-1)*num_el, s) = grad(i,s) * right(j,s) num_elements = size(upstream_grad,1) num_rhs = size(this%right_operand%val,1) if(this%right_operand%is_sample_dependent)then do concurrent(s = 1:num_batches, j = 1:num_rhs) output((j-1)*num_elements+1:j*num_elements, s) = & this%right_operand%val(j,s) * upstream_grad(:,s) end do else do j = 1, num_rhs output((j-1)*num_elements+1:j*num_elements, :) = & this%right_operand%val(j, 1) * upstream_grad end do end if end if end subroutine get_partial_matmul_left_val