submodule(diffstruc__operations_linalg) diffstruc__operations_linalg_sub !! Submodule containing implementations of linear algebra operations use coreutils, only: stop_program #ifdef USE_BLAS interface subroutine sgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, & beta, C, ldc) import :: real32 character(len=1), intent(in) :: transa, transb integer, intent(in) :: m, n, k, lda, ldb, ldc real(real32), intent(in) :: alpha, beta real(real32), intent(in) :: A(lda,*), B(ldb,*) real(real32), intent(inout) :: C(ldc,*) end subroutine sgemm subroutine sgemv(trans, m, n, alpha, A, lda, x, incx, beta, y, incy) import :: real32 character(len=1), intent(in) :: trans integer, intent(in) :: m, n, lda, incx, incy real(real32), intent(in) :: alpha, beta real(real32), intent(in) :: A(lda,*), x(*) real(real32), intent(inout) :: y(*) end subroutine sgemv end interface #endif contains !############################################################################### module function matmul_arrays(a, b) result(c) !! Matrix multiplication of two autodiff arrays implicit none class(array_type), intent(in), target :: a, b type(array_type), pointer :: c integer :: s, m, k, n, num_samples character(len=128) :: err_msg real(real32), pointer :: temp(:,:) if(.not.a%is_sample_dependent)then if(size(b%shape).ne.1)then write(err_msg,'("Matrix multiplication not implemented for array ''b'' & &rank: ",I0)') size(b%shape) call stop_program(err_msg) return end if ! C(m, S) = A(m, k) * B(k, S) where A is the weight matrix m = a%shape(1) k = a%shape(2) num_samples = size(b%val, 2) c => a%create_result(array_shape=[m, num_samples]) temp(1:m, 1:k) => a%val #ifdef USE_BLAS ! sgemm: C = alpha * A * B + beta * C ! m = rows of A and C, n = columns of B and C, k = cols of A / rows of B ! lda = m (leading dim of A), ldb = k (leading dim of B), ldc = m call sgemm('N', 'N', m, num_samples, k, & 1.0_real32, temp, m, b%val, k, 0.0_real32, c%val, m) #else c%val = matmul(temp, b%val) #endif elseif(.not.b%is_sample_dependent)then if(size(a%shape).ne.1)then write(err_msg,'("Matrix multiplication not implemented for array ''a'' & &rank: ",I0)') size(a%shape) call stop_program(err_msg) return end if ! C(n, S) = B^T(n, k) * A(k, S) where B is the weight matrix k = b%shape(1) n = b%shape(2) num_samples = size(a%val, 2) c => b%create_result(array_shape=[n, num_samples]) temp(1:k, 1:n) => b%val #ifdef USE_BLAS ! sgemm: C = alpha * op(A) * B + beta * C with transa='T' ! Computes C(n, S) = temp^T(n, k) * a%val(k, S) ! m = n (rows of result), n_arg = S, k = k (shared dim) ! lda = k (leading dim of temp before transpose), ldb = k, ldc = n call sgemm('T', 'N', n, num_samples, k, & 1.0_real32, temp, k, a%val, k, 0.0_real32, c%val, n) #else c%val = matmul(transpose(temp), a%val) #endif else write(0,*) "NOT SURE WHAT TO DO YET" stop 0 end if c%is_sample_dependent = .true. c%get_partial_left => get_partial_matmul_left c%get_partial_right => get_partial_matmul_right c%get_partial_left_val => get_partial_matmul_left_val c%get_partial_right_val => get_partial_matmul_right_val c%get_partial_left_val_sum => get_partial_matmul_left_val_sum c%get_partial_right_val_sum => get_partial_matmul_right_val_sum if(a%requires_grad .or. b%requires_grad) then c%requires_grad = .true. c%is_forward = a%is_forward .or. b%is_forward c%operation = 'matmul' c%left_operand => a c%right_operand => b c%owns_left_operand = a%is_temporary c%owns_right_operand = b%is_temporary end if end function matmul_arrays !------------------------------------------------------------------------------- module function matmul_real2d(a, b) result(c) !! Matrix multiplication of a real array and an autodiff array !! Computes C = a * b where a is autodiff (vector per sample) and b is a !! real 2D matrix. Equivalent to C(:,s) = b^T * a(:,s) for each sample s. implicit none class(array_type), intent(in), target :: a real(real32), dimension(:,:), intent(in) :: b type(array_type), pointer :: c type(array_type), pointer :: b_array integer :: s, i, rows, cols, num_samples rows = size(b, 1) cols = size(b, 2) num_samples = size(a%val, 2) c => a%create_result(array_shape = [cols, num_samples]) #ifdef USE_BLAS ! C(cols, S) = b^T(cols, rows) * a%val(rows, S) ! sgemm: transa='T' transposes b in-place, no temporary needed ! m = cols, n = S, k = rows ! lda = rows (leading dim of b before transpose), ldb = rows, ldc = cols call sgemm('T', 'N', cols, num_samples, rows, & 1.0_real32, b, rows, a%val, rows, 0.0_real32, c%val, cols) #else c%val = matmul(transpose(b), a%val) #endif c%is_sample_dependent = a%is_sample_dependent c%get_partial_left => get_partial_matmul_left c%get_partial_right => get_partial_matmul_right c%get_partial_left_val => get_partial_matmul_left_val c%get_partial_right_val => get_partial_matmul_right_val c%get_partial_left_val_sum => get_partial_matmul_left_val_sum c%get_partial_right_val_sum => get_partial_matmul_right_val_sum if(a%requires_grad) then c%requires_grad = .true. c%is_forward = a%is_forward c%operation = 'matmul_scalar' c%left_operand => a c%owns_left_operand = a%is_temporary end if allocate(b_array) b_array%is_sample_dependent = .false. b_array%shape = shape(b) b_array%requires_grad = .false. call b_array%allocate(array_shape=[size(b,1), size(b,2), 1]) b_array%val(:,1) = reshape(b, [size(b,1)*size(b,2)]) c%right_operand => b_array c%owns_right_operand = .true. end function matmul_real2d !------------------------------------------------------------------------------- module function real2d_matmul(a, b) result(c) !! Matrix multiplication of a real 2D matrix and an autodiff array !! Computes C = a * b where a is a real matrix and b is autodiff implicit none real(real32), dimension(:,:), intent(in) :: a class(array_type), intent(in), target :: b type(array_type), pointer :: c type(array_type), pointer :: a_array integer :: s, i, m, k, num_samples m = size(a, 1) k = size(a, 2) num_samples = size(b%val, 2) c => b%create_result(array_shape = [m, num_samples]) #ifdef USE_BLAS ! C(m, S) = a(m, k) * b%val(k, S) ! sgemm: standard C = alpha * A * B, no transposes ! m = m, n = S, k = k ! lda = m, ldb = k, ldc = m call sgemm('N', 'N', m, num_samples, k, & 1.0_real32, a, m, b%val, k, 0.0_real32, c%val, m) #else c%val = matmul(a, b%val) #endif c%is_sample_dependent = b%is_sample_dependent c%get_partial_left => get_partial_matmul_left c%get_partial_right => get_partial_matmul_right c%get_partial_left_val => get_partial_matmul_left_val c%get_partial_right_val => get_partial_matmul_right_val c%get_partial_left_val_sum => get_partial_matmul_left_val_sum c%get_partial_right_val_sum => get_partial_matmul_right_val_sum if(b%requires_grad) then c%requires_grad = .true. c%is_forward = b%is_forward c%operation = 'matmul_scalar' c%right_operand => b c%owns_right_operand = b%is_temporary end if allocate(a_array) a_array%is_sample_dependent = .false. a_array%shape = shape(a) a_array%requires_grad = .false. call a_array%allocate(array_shape=[size(a,1), size(a,2), 1]) a_array%val(:,1) = reshape(a, [size(a,1)*size(a,2)]) c%left_operand => a_array c%owns_left_operand = .true. end function real2d_matmul !------------------------------------------------------------------------------- function get_partial_matmul_left(this, upstream_grad) result(output) !! Get partial derivative with respect to left operand of matmul implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: right_is_temporary_local type(array_type), pointer :: ptr right_is_temporary_local = this%right_operand%is_temporary this%right_operand%is_temporary = .false. if(size(this%right_operand%shape).eq.2)then if(this%is_forward)then ptr => matmul( upstream_grad, this%right_operand ) else ptr => matmul( upstream_grad, transpose(this%right_operand) ) end if elseif(size(upstream_grad%shape).eq.2)then if(this%is_forward)then ptr => matmul( upstream_grad, this%right_operand ) else ptr => matmul( transpose(upstream_grad), this%right_operand ) end if else ptr => upstream_grad .outer. this%right_operand end if this%right_operand%is_temporary = right_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_matmul_left !------------------------------------------------------------------------------- function get_partial_matmul_right(this, upstream_grad) result(output) !! Get partial derivative with respect to right operand of matmul implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: left_is_temporary_local type(array_type), pointer :: ptr left_is_temporary_local = this%left_operand%is_temporary this%left_operand%is_temporary = .false. if(size(this%left_operand%shape).eq.2)then if(this%is_forward)then ptr => matmul(this%left_operand, upstream_grad) else ptr => matmul(transpose(this%left_operand), upstream_grad) end if elseif(size(upstream_grad%shape).eq.2)then if(this%is_forward)then ptr => matmul(this%left_operand, upstream_grad) else ptr => matmul(this%left_operand, transpose(upstream_grad)) end if else ptr => this%left_operand .outer. upstream_grad end if this%left_operand%is_temporary = left_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_matmul_right !------------------------------------------------------------------------------- #ifdef USE_BLAS subroutine get_partial_matmul_left_val(this, upstream_grad, output) #else pure subroutine get_partial_matmul_left_val(this, upstream_grad, output) #endif !! Compute gradient w.r.t. left operand for matmul in reverse mode. !! !! For C = A * B (where A is the left operand): !! dL/dA = dL/dC * B^T (2D case, equivalent to B * dL/dC in flat storage) !! dL/dA = upstream_grad (x) right_operand (outer product case) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_rhs num_batches = size(upstream_grad, 2) if(size(this%right_operand%shape).eq.2)then m = this%right_operand%shape(1) n = this%right_operand%shape(2) if(this%right_operand%is_sample_dependent)then ! Per-sample: weight matrix varies per sample. ! output(:,s) = W_s * upstream_grad(:,s) where W_s = reshape(right%val(:,s), [m,n]) ! Equivalent to dL/dA_flat = dL/dC * B_s^T per sample. ! Uses intrinsic matmul — BLAS overhead not worthwhile for per-sample vectors. block real(real32), dimension(m, n) :: temp do s = 1, num_batches temp = reshape(this%right_operand%val(:,s), [m, n]) output(:,s) = matmul(upstream_grad(:,s), transpose(temp)) end do end block else ! Non-sample-dependent: single batch operation across all samples. ! output(m, S) = W(m, n) * upstream_grad(n, S) ! Mathematically: dL/dA = dL/dC * B^T for each sample, batched. #ifdef USE_BLAS block real(real32), pointer :: W(:,:) ! Pointer view avoids reshape: right_operand%val(:,1) is column-major ! contiguous (m, n) data. No copy needed. W(1:m, 1:n) => this%right_operand%val(:,1) ! sgemm: C = alpha * A * B + beta * C ! A = W(m, n), B = upstream_grad(n, S), C = output(m, S) ! m_arg = m (rows of W and output) ! n_arg = num_batches (columns of upstream_grad and output) ! k = n (columns of W / rows of upstream_grad) ! lda = m (leading dim of W), ldb = n, ldc = m call sgemm('N', 'N', m, num_batches, n, & 1.0_real32, W, m, upstream_grad, n, & 0.0_real32, output, m) end block #else block real(real32), dimension(m, n) :: temp temp = reshape(this%right_operand%val(:,1), [m, n]) output = matmul(temp, upstream_grad) end block #endif end if else ! Outer product case: output(i + (j-1)*num_el, s) = grad(i,s) * right(j,s) num_elements = size(upstream_grad,1) num_rhs = size(this%right_operand%val,1) if(this%right_operand%is_sample_dependent)then do concurrent(s = 1:num_batches, j = 1:num_rhs) output((j-1)*num_elements+1:j*num_elements, s) = & this%right_operand%val(j,s) * upstream_grad(:,s) end do else do j = 1, num_rhs output((j-1)*num_elements+1:j*num_elements, :) = & this%right_operand%val(j, 1) * upstream_grad end do end if end if end subroutine get_partial_matmul_left_val !------------------------------------------------------------------------------- #ifdef USE_BLAS subroutine get_partial_matmul_right_val(this, upstream_grad, output) #else pure subroutine get_partial_matmul_right_val(this, upstream_grad, output) #endif !! Compute gradient w.r.t. right operand for matmul in reverse mode. !! !! For C = A * B (where B is the right operand): !! dL/dB = A^T * dL/dC (2D case) !! dL/dB = left_operand (x) upstream_grad (outer product case) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_upstream num_batches = size(upstream_grad, 2) if(size(this%left_operand%shape).eq.2)then m = this%left_operand%shape(1) n = this%left_operand%shape(2) if(this%left_operand%is_sample_dependent)then ! Per-sample: weight matrix varies per sample. ! output(:,s) = W_s^T * upstream_grad(:,s) ! where W_s = reshape(left%val(:,s), [m, n]) ! Mathematically: dL/dB_flat = A_s^T * dL/dC per sample. ! Uses intrinsic matmul — BLAS overhead not worthwhile for per-sample vectors. block real(real32), dimension(m, n) :: temp do s = 1, num_batches temp = reshape(this%left_operand%val(:,s), [m, n]) output(:,s) = matmul(transpose(temp), upstream_grad(:,s)) end do end block else ! Non-sample-dependent: single batch operation across all samples. ! output(n, S) = W^T(n, m) * upstream_grad(m, S) ! Mathematically: dL/dB = A^T * dL/dC for each sample, batched. #ifdef USE_BLAS block real(real32), pointer :: W(:,:) ! Pointer view avoids reshape and transpose: left_operand%val(:,1) is ! column-major contiguous (m, n) data. SGEMM transposes in-place. W(1:m, 1:n) => this%left_operand%val(:,1) ! sgemm: C = alpha * A^T * B + beta * C ! A = W(m, n), transposed to W^T(n, m); B = upstream_grad(m, S) ! m_arg = n (rows of W^T and output) ! n_arg = num_batches (columns of upstream_grad and output) ! k = m (columns of W^T / rows of upstream_grad) ! lda = m (leading dim of W before transpose), ldb = m, ldc = n call sgemm('T', 'N', n, num_batches, m, & 1.0_real32, W, m, upstream_grad, m, & 0.0_real32, output, n) end block #else block real(real32), dimension(n, m) :: temp_t temp_t = transpose(reshape(this%left_operand%val(:,1), [m, n])) output = matmul(temp_t, upstream_grad) end block #endif end if else ! Outer product case: output(i + (j-1)*num_el, s) = left(i,s) * grad(j,s) num_elements = size(this%left_operand%val,1) num_upstream = size(upstream_grad, 1) if(this%left_operand%is_sample_dependent)then do concurrent(s = 1:num_batches, j = 1:num_upstream) output((j-1)*num_elements+1:j*num_elements, s) = & upstream_grad(j,s) * this%left_operand%val(:,s) end do else do concurrent(s=1:num_batches, j=1:num_upstream) output((j-1)*num_elements+1:j*num_elements, s) = & this%left_operand%val(:, 1) * upstream_grad(j, s) end do end if end if end subroutine get_partial_matmul_right_val !############################################################################### !############################################################################### ! Sum-reduced gradient functions for matmul ! These compute sum(partial(upstream_grad), dim=2) directly, avoiding ! the large (n_elem, num_samples) intermediate array allocation. !############################################################################### #ifdef USE_BLAS subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output) #else pure subroutine get_partial_matmul_left_val_sum(this, upstream_grad, output) #endif !! Sum-reduced gradient w.r.t. left operand for matmul. !! For the outer product case (rank-1 right operand), this computes: !! output = sum_s(upstream(:,s) (x) right(:,s)) = upstream * right^T !! using a single SGEMM call instead of computing the full (n_elem, S) array. implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_rhs num_batches = size(upstream_grad, 2) if(size(this%right_operand%shape).eq.2)then ! 2D case: output = sum_s(W * upstream(:,s)) = W * sum(upstream, dim=2) m = this%right_operand%shape(1) n = this%right_operand%shape(2) if(this%right_operand%is_sample_dependent)then ! Per-sample weight matrices: fall back to explicit loop block real(real32), dimension(m, n) :: temp real(real32), dimension(m) :: row_result output = 0.0_real32 do s = 1, num_batches temp = reshape(this%right_operand%val(:,s), [m, n]) row_result = matmul(upstream_grad(:,s), transpose(temp)) output = output + row_result end do end block else ! Non-sample-dependent: W * sum(upstream, dim=2) via matvec #ifdef USE_BLAS block real(real32), pointer :: W(:,:) real(real32), dimension(size(upstream_grad,1)) :: upstream_sum W(1:m, 1:n) => this%right_operand%val(:,1) upstream_sum = sum(upstream_grad, dim=2) call sgemv('N', m, n, 1.0_real32, W, m, upstream_sum, 1, & 0.0_real32, output, 1) end block #else block real(real32), dimension(m, n) :: temp real(real32), dimension(size(upstream_grad,1)) :: upstream_sum temp = reshape(this%right_operand%val(:,1), [m, n]) upstream_sum = sum(upstream_grad, dim=2) output = matmul(temp, upstream_sum) end block #endif end if else ! Outer product case: sum_s(upstream(i,s) * right(j,s)) ! = matmul(upstream, right^T) stored column-major num_elements = size(upstream_grad, 1) num_rhs = size(this%right_operand%val, 1) if(this%right_operand%is_sample_dependent)then #ifdef USE_BLAS ! SGEMM: result(num_elements, num_rhs) = upstream * right^T call sgemm('N', 'T', num_elements, num_rhs, num_batches, & 1.0_real32, upstream_grad, num_elements, & this%right_operand%val, num_rhs, & 0.0_real32, output, num_elements) #else ! Manual fused outer product + sum output = 0.0_real32 do s = 1, num_batches do j = 1, num_rhs do i = 1, num_elements output((j-1)*num_elements + i) = & output((j-1)*num_elements + i) + & this%right_operand%val(j,s) * upstream_grad(i,s) end do end do end do #endif else ! Right operand not sample-dependent: outer product with scalar right ! sum_s(right(j,1) * upstream(i,s)) = right(j,1) * sum_s(upstream(i,s)) block real(real32), dimension(num_elements) :: upstream_sum upstream_sum = sum(upstream_grad, dim=2) do j = 1, num_rhs output((j-1)*num_elements+1:j*num_elements) = & this%right_operand%val(j, 1) * upstream_sum end do end block end if end if end subroutine get_partial_matmul_left_val_sum !############################################################################### !############################################################################### #ifdef USE_BLAS subroutine get_partial_matmul_right_val_sum(this, upstream_grad, output) #else pure subroutine get_partial_matmul_right_val_sum(this, upstream_grad, output) #endif !! Sum-reduced gradient w.r.t. right operand for matmul. !! For the outer product case (rank-1 left operand), this computes: !! output = sum_s(left(:,s) (x) upstream(:,s)) = left * upstream^T !! using a single SGEMM call. implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:), intent(out) :: output integer :: i, j, s, m, n, num_elements, num_batches, num_upstream num_batches = size(upstream_grad, 2) if(size(this%left_operand%shape).eq.2)then ! 2D case: output = sum_s(W^T * upstream(:,s)) = W^T * sum(upstream, dim=2) m = this%left_operand%shape(1) n = this%left_operand%shape(2) if(this%left_operand%is_sample_dependent)then block real(real32), dimension(m, n) :: temp real(real32), dimension(n) :: col_result output = 0.0_real32 do s = 1, num_batches temp = reshape(this%left_operand%val(:,s), [m, n]) col_result = matmul(transpose(temp), upstream_grad(:,s)) output = output + col_result end do end block else #ifdef USE_BLAS block real(real32), pointer :: W(:,:) real(real32), dimension(size(upstream_grad,1)) :: upstream_sum W(1:m, 1:n) => this%left_operand%val(:,1) upstream_sum = sum(upstream_grad, dim=2) ! W^T * upstream_sum: sgemv with 'T' call sgemv('T', m, n, 1.0_real32, W, m, upstream_sum, 1, & 0.0_real32, output, 1) end block #else block real(real32), dimension(n, m) :: temp_t real(real32), dimension(size(upstream_grad,1)) :: upstream_sum temp_t = transpose(reshape(this%left_operand%val(:,1), [m, n])) upstream_sum = sum(upstream_grad, dim=2) output = matmul(temp_t, upstream_sum) end block #endif end if else ! Outer product case: sum_s(left(i,s) * upstream(j,s)) ! = matmul(left, upstream^T) stored column-major num_elements = size(this%left_operand%val, 1) num_upstream = size(upstream_grad, 1) if(this%left_operand%is_sample_dependent)then #ifdef USE_BLAS call sgemm('N', 'T', num_elements, num_upstream, num_batches, & 1.0_real32, this%left_operand%val, num_elements, & upstream_grad, num_upstream, & 0.0_real32, output, num_elements) #else output = 0.0_real32 do s = 1, num_batches do j = 1, num_upstream do i = 1, num_elements output((j-1)*num_elements + i) = & output((j-1)*num_elements + i) + & upstream_grad(j,s) * this%left_operand%val(i,s) end do end do end do #endif else block real(real32), dimension(num_upstream) :: upstream_sum upstream_sum = sum(upstream_grad, dim=2) do j = 1, num_upstream output((j-1)*num_elements+1:j*num_elements) = & this%left_operand%val(:, 1) * upstream_sum(j) end do end block end if end if end subroutine get_partial_matmul_right_val_sum !############################################################################### !############################################################################### module function outer_product_arrays(a, b) result(c) !! Outer product of two autodiff arrays implicit none class(array_type), intent(in), target :: a, b type(array_type), pointer :: c integer :: i, j, s ! check shapes if(size(a%shape).ne.1 .or. size(b%shape).ne.1)then call stop_program("dot_product_arrays: only 1D arrays supported") elseif(size(a%val,2).ne.size(b%val,2))then call stop_program("dot_product_arrays: array length mismatch") end if c => a%create_result(array_shape = [size(a%val,1), size(b%val,1), size(a%val,2)]) ! outer product 1D array by using shape to swap dimensions do concurrent(s=1:size(a%val,2)) do concurrent(i=1:size(a%val,1), j=1:size(b%val,1)) c%val(i + (j-1)*size(a%val,1),s) = a%val(i,s) * b%val(j,s) end do end do c%get_partial_left => get_partial_outer_product_left c%get_partial_right => get_partial_outer_product_right c%is_sample_dependent = a%is_sample_dependent if(a%requires_grad .or. b%requires_grad) then c%requires_grad = .true. c%is_forward = a%is_forward .or. b%is_forward c%operation = 'outer_product' c%left_operand => a c%right_operand => b c%owns_left_operand = a%is_temporary c%owns_right_operand = b%is_temporary end if end function outer_product_arrays !------------------------------------------------------------------------------- function get_partial_outer_product_left(this, upstream_grad) result(output) implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: right_is_temporary_local type(array_type), pointer :: ptr right_is_temporary_local = this%right_operand%is_temporary this%right_operand%is_temporary = .false. if(this%is_forward)then ptr => upstream_grad .outer. this%right_operand else ptr => matmul(upstream_grad, this%right_operand) end if this%right_operand%is_temporary = right_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_outer_product_left !------------------------------------------------------------------------------- function get_partial_outer_product_right(this, upstream_grad) result(output) implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: left_is_temporary_local type(array_type), pointer :: ptr left_is_temporary_local = this%left_operand%is_temporary this%left_operand%is_temporary = .false. if(this%is_forward)then ptr => this%left_operand .outer. upstream_grad else ! mathematically should be ptr => transpose(upstream_grad) .mmul. this%left_operand ! but for how we store vectors, this SHOULD BE equivalent ptr => matmul(this%left_operand, upstream_grad) end if this%left_operand%is_temporary = left_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_outer_product_right !############################################################################### !############################################################################### module function dot_product_arrays(a, b) result(c) !! Dot product of two autodiff arrays implicit none class(array_type), intent(in), target :: a, b type(array_type), pointer :: c integer :: s ! check shapes if(size(a%shape).ne.1 .or. size(b%shape).ne.1)then call stop_program("dot_product_arrays: only 1D arrays supported") elseif(any(shape(a%val).ne.shape(b%val)))then call stop_program("dot_product_arrays: array length mismatch") end if c => a%create_result(array_shape = [1, size(a%val,2)]) do concurrent(s=1:size(a%val,2)) c%val(1,s) = dot_product(a%val(:,s), b%val(:,s)) end do c%get_partial_left => get_partial_dot_product_left c%get_partial_right => get_partial_dot_product_right c%get_partial_left_val => get_partial_dot_product_left_val c%get_partial_right_val => get_partial_dot_product_right_val c%is_sample_dependent = a%is_sample_dependent if(a%requires_grad .or. b%requires_grad) then c%requires_grad = .true. c%is_forward = a%is_forward .or. b%is_forward c%operation = 'dot_product' c%left_operand => a c%right_operand => b c%owns_left_operand = a%is_temporary c%owns_right_operand = b%is_temporary end if end function dot_product_arrays !------------------------------------------------------------------------------- function get_partial_dot_product_left(this, upstream_grad) result(output) implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: right_is_temporary_local type(array_type), pointer :: ptr right_is_temporary_local = this%right_operand%is_temporary this%right_operand%is_temporary = .false. if(this%is_forward)then ptr => dot_product(upstream_grad, this%right_operand) else ptr => upstream_grad * this%right_operand end if this%right_operand%is_temporary = right_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_dot_product_left !------------------------------------------------------------------------------- function get_partial_dot_product_right(this, upstream_grad) result(output) implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output logical :: left_is_temporary_local type(array_type), pointer :: ptr left_is_temporary_local = this%left_operand%is_temporary this%left_operand%is_temporary = .false. if(this%is_forward)then ptr => dot_product(this%left_operand, upstream_grad) else ptr => upstream_grad * this%left_operand end if this%left_operand%is_temporary = left_is_temporary_local call output%assign_and_deallocate_source(ptr) end function get_partial_dot_product_right !------------------------------------------------------------------------------- pure subroutine get_partial_dot_product_left_val(this, upstream_grad, output) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, s do concurrent(s=1:size(upstream_grad,2), i=1:size(this%right_operand%val,1)) output(i,s) = upstream_grad(1,s) * this%right_operand%val(i,s) end do end subroutine get_partial_dot_product_left_val !------------------------------------------------------------------------------- pure subroutine get_partial_dot_product_right_val(this, upstream_grad, output) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, s do concurrent(s=1:size(upstream_grad,2), i=1:size(this%left_operand%val,1)) output(i,s) = upstream_grad(1,s) * this%left_operand%val(i,s) end do end subroutine get_partial_dot_product_right_val !############################################################################### !############################################################################### module function transpose_array(a) result(c) !! Transpose an autodiff array implicit none class(array_type), intent(in), target :: a type(array_type), pointer :: c integer :: i, j, s if(size(a%shape) .ne. 2)then write(*,*) "ashape", a%shape call stop_program("transpose_array: only 2D arrays can be transposed") end if c => a%create_result(array_shape=[a%shape(2), a%shape(1), size(a%val,2)]) ! transpose 1D array by using shape to swap dimensions do concurrent(s=1:size(a%val,2)) do concurrent(i=1:a%shape(1), j=1:a%shape(2)) c%val( (i-1)*a%shape(2) + j, s) = a%val( (j-1)*a%shape(1) + i, s) end do end do c%get_partial_left => get_partial_transpose_left ! c%get_partial_right => get_partial_transpose_right c%is_sample_dependent = a%is_sample_dependent if(a%requires_grad) then c%requires_grad = .true. c%is_forward = a%is_forward c%operation = 'transpose' c%left_operand => a c%owns_left_operand = a%is_temporary end if end function transpose_array !------------------------------------------------------------------------------- function get_partial_transpose_left(this, upstream_grad) result(output) implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: ptr ptr => transpose(upstream_grad) call output%assign_and_deallocate_source(ptr) end function get_partial_transpose_left ! !------------------------------------------------------------------------------- ! ! function get_partial_transpose_right(this, upstream_grad) result(output) ! ! class(array_type), intent(inout) :: this ! ! type(array_type), intent(in) :: upstream_grad ! ! type(array_type) :: output ! ! output = transpose(this%left_operand) ! ! end function get_partial_transpose_right !############################################################################### end submodule diffstruc__operations_linalg_sub