matmul_arrays Module Function

module function matmul_arrays(a, b) result(c)

Matrix multiplication of two autodiff arrays

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: a
class(array_type), intent(in), target :: b

Return Value type(array_type), pointer


Source Code

  module function matmul_arrays(a, b) result(c)
    !! Matrix multiplication of two autodiff arrays
    implicit none
    class(array_type), intent(in), target :: a, b
    type(array_type), pointer :: c

    integer :: s, m, k, n, num_samples
    character(len=128) :: err_msg
    real(real32), pointer :: temp(:,:)

    if(.not.a%is_sample_dependent)then
       if(size(b%shape).ne.1)then
          write(err_msg,'("Matrix multiplication not implemented for array ''b'' &
               &rank: ",I0)') size(b%shape)
          call stop_program(err_msg)
          return
       end if
       ! C(m, S) = A(m, k) * B(k, S) where A is the weight matrix
       m = a%shape(1)
       k = a%shape(2)
       num_samples = size(b%val, 2)
       c => a%create_result(array_shape=[m, num_samples])
       temp(1:m, 1:k) => a%val
#ifdef USE_BLAS
       ! sgemm: C = alpha * A * B + beta * C
       ! m = rows of A and C, n = columns of B and C, k = cols of A / rows of B
       ! lda = m (leading dim of A), ldb = k (leading dim of B), ldc = m
       call sgemm('N', 'N', m, num_samples, k, &
            1.0_real32, temp, m, b%val, k, 0.0_real32, c%val, m)
#else
       c%val = matmul(temp, b%val)
#endif
    elseif(.not.b%is_sample_dependent)then
       if(size(a%shape).ne.1)then
          write(err_msg,'("Matrix multiplication not implemented for array ''a'' &
               &rank: ",I0)') size(a%shape)
          call stop_program(err_msg)
          return
       end if
       ! C(n, S) = B^T(n, k) * A(k, S) where B is the weight matrix
       k = b%shape(1)
       n = b%shape(2)
       num_samples = size(a%val, 2)
       c => b%create_result(array_shape=[n, num_samples])
       temp(1:k, 1:n) => b%val
#ifdef USE_BLAS
       ! sgemm: C = alpha * op(A) * B + beta * C with transa='T'
       ! Computes C(n, S) = temp^T(n, k) * a%val(k, S)
       ! m = n (rows of result), n_arg = S, k = k (shared dim)
       ! lda = k (leading dim of temp before transpose), ldb = k, ldc = n
       call sgemm('T', 'N', n, num_samples, k, &
            1.0_real32, temp, k, a%val, k, 0.0_real32, c%val, n)
#else
       c%val = matmul(transpose(temp), a%val)
#endif
    else
       write(0,*) "NOT SURE WHAT TO DO YET"
       stop 0
    end if

    c%is_sample_dependent = .true.
    c%get_partial_left => get_partial_matmul_left
    c%get_partial_right => get_partial_matmul_right
    c%get_partial_left_val => get_partial_matmul_left_val
    c%get_partial_right_val => get_partial_matmul_right_val
    c%get_partial_left_val_sum => get_partial_matmul_left_val_sum
    c%get_partial_right_val_sum => get_partial_matmul_right_val_sum
    if(a%requires_grad .or. b%requires_grad) then
       c%requires_grad = .true.
       c%is_forward = a%is_forward .or. b%is_forward
       c%operation = 'matmul'
       c%left_operand => a
       c%right_operand => b
       c%owns_left_operand = a%is_temporary
       c%owns_right_operand = b%is_temporary
    end if
  end function matmul_arrays