get_partial_matmul_left_val_sum Subroutine

subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output)

Sum-reduced gradient w.r.t. left operand for matmul. For the outer product case (rank-1 right operand), this computes: output = sum_s(upstream(:,s) (x) right(:,s)) = upstream * right^T using a single SGEMM call instead of computing the full (n_elem, S) array.

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:) :: output

Source Code

  subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output)
#else
  pure subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output)
#endif
    !! Sum-reduced gradient w.r.t. left operand for matmul.
    !! For the outer product case (rank-1 right operand), this computes:
    !!   output = sum_s(upstream(:,s) (x) right(:,s)) = upstream * right^T
    !! using a single SGEMM call instead of computing the full (n_elem, S) array.
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:), intent(out) :: output

    integer :: i, j, s, m, n, num_elements, num_batches, num_rhs

    num_batches = size(upstream_grad, 2)

    if(size(this%right_operand%shape).eq.2)then
       ! 2D case: output = sum_s(W * upstream(:,s)) = W * sum(upstream, dim=2)
       m = this%right_operand%shape(1)
       n = this%right_operand%shape(2)
       if(this%right_operand%is_sample_dependent)then
          ! Per-sample weight matrices: fall back to explicit loop
          block
            real(real32), dimension(m, n) :: temp
            real(real32), dimension(m) :: row_result
            output = 0.0_real32
            do s = 1, num_batches
               temp = reshape(this%right_operand%val(:,s), [m, n])
               row_result = matmul(upstream_grad(:,s), transpose(temp))
               output = output + row_result
            end do
          end block
       else
          ! Non-sample-dependent: W * sum(upstream, dim=2) via matvec
#ifdef USE_BLAS
          block
            real(real32), pointer :: W(:,:)
            real(real32), dimension(size(upstream_grad,1)) :: upstream_sum
            W(1:m, 1:n) => this%right_operand%val(:,1)
            upstream_sum = sum(upstream_grad, dim=2)
            call sgemv('N', m, n, 1.0_real32, W, m, upstream_sum, 1, &
                 0.0_real32, output, 1)
          end block
#else
          block
            real(real32), dimension(m, n) :: temp
            real(real32), dimension(size(upstream_grad,1)) :: upstream_sum
            temp = reshape(this%right_operand%val(:,1), [m, n])
            upstream_sum = sum(upstream_grad, dim=2)
            output = matmul(temp, upstream_sum)
          end block
#endif
       end if
    else
       ! Outer product case: sum_s(upstream(i,s) * right(j,s))
       ! = matmul(upstream, right^T) stored column-major
       num_elements = size(upstream_grad, 1)
       num_rhs = size(this%right_operand%val, 1)
       if(this%right_operand%is_sample_dependent)then
#ifdef USE_BLAS
          ! SGEMM: result(num_elements, num_rhs) = upstream * right^T
          call sgemm('N', 'T', num_elements, num_rhs, num_batches, &
               1.0_real32, upstream_grad, num_elements, &
               this%right_operand%val, num_rhs, &
               0.0_real32, output, num_elements)
#else
          ! Manual fused outer product + sum
          output = 0.0_real32
          do s = 1, num_batches
             do j = 1, num_rhs
                do i = 1, num_elements
                   output((j-1)*num_elements + i) = &
                        output((j-1)*num_elements + i) + &
                        this%right_operand%val(j,s) * upstream_grad(i,s)
                end do
             end do
          end do
#endif
       else
          ! Right operand not sample-dependent: outer product with scalar right
          ! sum_s(right(j,1) * upstream(i,s)) = right(j,1) * sum_s(upstream(i,s))
          block
            real(real32), dimension(num_elements) :: upstream_sum
            upstream_sum = sum(upstream_grad, dim=2)
            do j = 1, num_rhs
               output((j-1)*num_elements+1:j*num_elements) = &
                    this%right_operand%val(j, 1) * upstream_sum
            end do
          end block
       end if
    end if

  end subroutine get_partial_matmul_left_val_sum