module diffstruc__operations_hyp
  !! This module contains hyperbolic operations for the diffstruc library.
  use coreutils, only: real32
  use diffstruc__types, only: array_type, get_partial, &
       operator(+), operator(-), operator(*), operator(**)
  implicit none


  private

  public :: tanh


  ! Operation interfaces
  !-----------------------------------------------------------------------------
  interface tanh
     module procedure tanh_array
  end interface


contains

!###############################################################################
  function tanh_array(a) result(c)
    !! Hyperbolic tangent function for autodiff arrays
    implicit none
    class(array_type), intent(in), target :: a
    type(array_type), pointer :: c

    integer :: i, s

    c => a%create_result()
    do concurrent(s = 1:size(a%val, 2), i = 1:size(a%val,1))
       c%val(i,s) = tanh(a%val(i,s))
    end do
    !c%val = tanh(a%val)

    c%get_partial_left => get_partial_tanh
    c%get_partial_left_val => get_partial_tanh_val
    c%get_partial_left_val_sum => get_partial_tanh_val_sum
    if(a%requires_grad) then
       c%requires_grad = .true.
       c%is_forward = a%is_forward
       c%operation = 'tanh'
       c%left_operand => a
       c%owns_left_operand = a%is_temporary
    end if
  end function tanh_array
!-------------------------------------------------------------------------------
  function get_partial_tanh(this, upstream_grad) result(output)
    implicit none
    class(array_type), intent(inout) :: this
    type(array_type), intent(in) :: upstream_grad
    type(array_type) :: output

    logical :: this_is_temporary_local
    type(array_type), pointer :: ptr

    this_is_temporary_local = this%is_temporary
    this%is_temporary = .false.
    ! derivative of tanh(x) is (1 - tanh(x)^2)
    ptr => upstream_grad * tanh_reverse_array(this)
    ! ptr => upstream_grad * (1._real32 - this ** 2._real32)
    this%is_temporary = this_is_temporary_local
    call output%assign_and_deallocate_source(ptr)
  end function get_partial_tanh
!-------------------------------------------------------------------------------
  pure subroutine get_partial_tanh_val(this, upstream_grad, output)
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    output = upstream_grad * (1._real32 - this%val * this%val)
  end subroutine get_partial_tanh_val
!-------------------------------------------------------------------------------
  pure subroutine get_partial_tanh_val_sum(this, upstream_grad, output)
    !! Fused partial+sum for tanh: output(:,1) = sum_s(upstream * (1 - val^2))
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:), intent(out) :: output

    integer :: s, n_samples

    n_samples = size(upstream_grad, 2)
    output(:) = upstream_grad(:,1) * &
         (1._real32 - this%val(:,1) * this%val(:,1))
    do s = 2, n_samples
       output(:) = output(:) + upstream_grad(:,s) * &
            (1._real32 - this%val(:,s) * this%val(:,s))
    end do
  end subroutine get_partial_tanh_val_sum
!###############################################################################


!###############################################################################
  function tanh_reverse_array(a) result(c)
    !! Reverse mode for tanh function
    implicit none
    class(array_type), intent(in), target :: a
    type(array_type), pointer :: c

    integer :: i, s

    c => a%create_result()
    do concurrent(s = 1:size(a%val, 2), i = 1:size(a%val,1))
       c%val(i,s) = 1._real32 - a%val(i,s) * a%val(i,s)
    end do

    c%get_partial_left => get_partial_tanh_reverse
    c%get_partial_left_val => get_partial_tanh_reverse_val
    if(a%requires_grad) then
       c%requires_grad = .true.
       c%is_forward = a%is_forward
       c%operation = 'tanh_reverse'
       c%left_operand => a
       c%owns_left_operand = a%is_temporary
    end if
  end function tanh_reverse_array
!-------------------------------------------------------------------------------
  function get_partial_tanh_reverse(this, upstream_grad) result(output)
    implicit none
    class(array_type), intent(inout) :: this
    type(array_type), intent(in) :: upstream_grad
    type(array_type) :: output
    logical :: left_is_temporary_local
    type(array_type), pointer :: ptr

    left_is_temporary_local = this%left_operand%is_temporary
    this%left_operand%is_temporary = .false.
    ptr => (-2._real32) * upstream_grad * this%left_operand
    this%left_operand%is_temporary = left_is_temporary_local
    call output%assign_and_deallocate_source(ptr)
  end function get_partial_tanh_reverse
!-------------------------------------------------------------------------------
  pure subroutine get_partial_tanh_reverse_val(this, upstream_grad, output)
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    output = (-2._real32) * upstream_grad * this%left_operand%val
  end subroutine get_partial_tanh_reverse_val
!###############################################################################

end module diffstruc__operations_hyp
