get_partial_matmul_left_val Subroutine

subroutine get_partial_matmul_left_val(this, upstream_grad, output)

Compute gradient w.r.t. left operand for matmul in reverse mode.

For C = A * B (where A is the left operand): dL/dA = dL/dC * B^T (2D case, equivalent to B * dL/dC in flat storage) dL/dA = upstream_grad (x) right_operand (outer product case)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  subroutine get_partial_matmul_left_val(this, upstream_grad, output)
#else
  pure subroutine get_partial_matmul_left_val(this, upstream_grad, output)
#endif
    !! Compute gradient w.r.t. left operand for matmul in reverse mode.
    !!
    !! For C = A * B (where A is the left operand):
    !!   dL/dA = dL/dC * B^T  (2D case, equivalent to B * dL/dC in flat storage)
    !!   dL/dA = upstream_grad (x) right_operand  (outer product case)
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    integer :: i, j, s, m, n, num_elements, num_batches, num_rhs

    num_batches = size(upstream_grad, 2)

    if(size(this%right_operand%shape).eq.2)then
       m = this%right_operand%shape(1)
       n = this%right_operand%shape(2)
       if(this%right_operand%is_sample_dependent)then
          ! Per-sample: weight matrix varies per sample.
          ! output(:,s) = W_s * upstream_grad(:,s) where W_s = reshape(right%val(:,s), [m,n])
          ! Equivalent to dL/dA_flat = dL/dC * B_s^T per sample.
          ! Uses intrinsic matmul — BLAS overhead not worthwhile for per-sample vectors.
          block
            real(real32), dimension(m, n) :: temp
            do s = 1, num_batches
               temp = reshape(this%right_operand%val(:,s), [m, n])
               output(:,s) = matmul(upstream_grad(:,s), transpose(temp))
            end do
          end block
       else
          ! Non-sample-dependent: single batch operation across all samples.
          ! output(m, S) = W(m, n) * upstream_grad(n, S)
          ! Mathematically: dL/dA = dL/dC * B^T for each sample, batched.
#ifdef USE_BLAS
          block
            real(real32), pointer :: W(:,:)
            ! Pointer view avoids reshape: right_operand%val(:,1) is column-major
            ! contiguous (m, n) data. No copy needed.
            W(1:m, 1:n) => this%right_operand%val(:,1)
            ! sgemm: C = alpha * A * B + beta * C
            ! A = W(m, n), B = upstream_grad(n, S), C = output(m, S)
            ! m_arg = m (rows of W and output)
            ! n_arg = num_batches (columns of upstream_grad and output)
            ! k = n (columns of W / rows of upstream_grad)
            ! lda = m (leading dim of W), ldb = n, ldc = m
            call sgemm('N', 'N', m, num_batches, n, &
                 1.0_real32, W, m, upstream_grad, n, &
                 0.0_real32, output, m)
          end block
#else
          block
            real(real32), dimension(m, n) :: temp
            temp = reshape(this%right_operand%val(:,1), [m, n])
            output = matmul(temp, upstream_grad)
          end block
#endif
       end if
    else
       ! Outer product case: output(i + (j-1)*num_el, s) = grad(i,s) * right(j,s)
       num_elements = size(upstream_grad,1)
       num_rhs = size(this%right_operand%val,1)
       if(this%right_operand%is_sample_dependent)then
          do concurrent(s = 1:num_batches, j = 1:num_rhs)
             output((j-1)*num_elements+1:j*num_elements, s) = &
                  this%right_operand%val(j,s) * upstream_grad(:,s)
          end do
       else
          do j = 1, num_rhs
             output((j-1)*num_elements+1:j*num_elements, :) = &
                  this%right_operand%val(j, 1) * upstream_grad
          end do
       end if
    end if

  end subroutine get_partial_matmul_left_val