import { useRef } from 'react'
import { useUnmount } from 'react-use'

import { useAuth0 } from '@auth0/auth0-react'
import _ from 'lodash'

import { RequestMessageType } from 'openapi/models/RequestMessageType'
import { ResponseMessageType } from 'openapi/models/ResponseMessageType'
import { SocketMessageResponse } from 'openapi/models/SocketMessageResponse'
import { WorkflowEventStatus } from 'openapi/models/WorkflowEventStatus'
import { Maybe } from 'types'

import { backendWebsocketUrl } from 'utils/server-data'
import { LONG_TOAST_DURATION, displayErrorMessage } from 'utils/toast'

import { KnowledgeSourceItem } from 'components/assistant/utils/assistant-knowledge-sources'
import { useAnalytics } from 'components/common/analytics/analytics-context'

import { Source } from './task'
import {
  HarveySocketCompletionStatus,
  getErrorMessage,
} from './use-harvey-socket-utils'
import { ToFrontendKeys } from './utils'

export type HarveySocketTask = SocketMessageResponse & {
  query: string
  queryId: string
  isLoading: boolean
  progress: number
  sources: Source[]
  completionStatus?: HarveySocketCompletionStatus
  dateRange: Maybe<{ from: Date; to: Date }>
  messageId: string
  knowledgeSources: KnowledgeSourceItem[]
  workflow: WorkflowEventStatus
}

export type HarveySocketSetter = (task: Partial<HarveySocketTask>) => void
export type InitSocketAndSendQuery = (
  params: InitSocketAndSendQueryParams
) => void
export type SendCancelRequest = () => void

interface useHarveySocketProps {
  path: string
  setter: HarveySocketSetter
  endCallback?: (
    queryId: string,
    completionStatus: HarveySocketCompletionStatus
  ) => void
  closeOnUnmount?: boolean
  webSocketFactory?: (
    url: string,
    handlers: WebSocketEventHandlers
  ) => WebSocket
}

export interface InitSocketAndSendQueryParams {
  query: string
  onAuthCallback?: (queryId: string) => void
  additionalAuthParams?: {
    [key: string]: any
  }
  additionalRequestParams?: {
    task_type?: string
    source_event_id?: Maybe<number> | string
    documents?: any[]
    [key: string]: any
  }
  maxRetryCount?: number
  recordFields?: Record<string, string | number | string[]>
  recordQuerySubmitted?: (
    fields?: Record<string, string | number | string[]>
  ) => void
  recordQueryCompletion?: (
    fields?: Record<string, string | number | string[]>
  ) => void
  recordQueryCancel?: (
    fields?: Record<string, string | number | string[]>
  ) => void
  recordQueryError?: (
    fields?: Record<string, string | number | string[]>
  ) => void
}

const DEFAULT_RETRY_COUNT = 0

const useHarveySocket = ({
  path,
  setter,
  endCallback = () => {},
  closeOnUnmount = true,
  webSocketFactory = window.WebSocketFactory,
}: useHarveySocketProps) => {
  const { getAccessTokenSilently } = useAuth0()
  const { trackEvent } = useAnalytics()

  const socket = useRef<WebSocket | null>(null)
  const endpoint = path
  const errorCount = useRef(0)
  const queryId = useRef('')
  const isComplete = useRef(false)
  const isCancelled = useRef(false)
  const caughtError = useRef(false)
  const socketAnalytics = useRef<Record<string, string | number>>({})

  const authenticateSocket = async (authParams?: object) => {
    const accessToken = await getAccessTokenSilently()

    if (socket.current === null || socket.current.readyState !== WebSocket.OPEN)
      return

    socket.current.send(
      JSON.stringify({
        type: RequestMessageType.AUTHENTICATE,
        data: accessToken,
        ...authParams,
      })
    )
  }

  // TODO: Need to decide how to handle additional params
  const initSocketAndSendQuery = (params: InitSocketAndSendQueryParams) => {
    const {
      query,
      onAuthCallback,
      additionalAuthParams,
      additionalRequestParams,
      maxRetryCount,
      recordFields,
      recordQuerySubmitted,
      recordQueryCompletion,
      recordQueryCancel,
      recordQueryError,
    } = params
    errorCount.current = 0
    queryId.current = ''
    isComplete.current = false
    isCancelled.current = false
    caughtError.current = false
    if (
      additionalRequestParams?.request_type === 'retry' &&
      additionalAuthParams?.event_id
    ) {
      const eventId = additionalAuthParams.event_id as string
      queryId.current = eventId
      setter({
        queryId: eventId,
        isLoading: true,
        completionStatus: HarveySocketCompletionStatus.Loading,
      })
    } else {
      setter({
        query,
        response: '',
        sources: [],
        isLoading: true,
        completionStatus: HarveySocketCompletionStatus.Loading,
        progress: 0,
      })
    }

    if (recordQuerySubmitted) {
      recordQuerySubmitted({
        ...recordFields,
        ...socketAnalytics.current,
        event_id: queryId.current,
        query_length: query.length,
      })
    }

    const eventHandlers: WebSocketEventHandlers = {
      onopen: () => {
        void authenticateSocket(additionalAuthParams)
      },
      onmessage: (event) => {
        const data = JSON.parse(event.data)

        if (data.type === ResponseMessageType.AUTHENTICATED) {
          setter({
            queryId: String(data.data),
          })
          queryId.current = String(data.data)
          if (onAuthCallback) {
            onAuthCallback(String(data.data))
          }

          socket.current?.send(
            JSON.stringify({
              type: RequestMessageType.REQUEST,
              data: query,
              ...additionalRequestParams,
            })
          )
        }

        if (data.type === ResponseMessageType.RESPONSE) {
          // keeping the same for now, since certain workflows don't have all required keys
          setter({
            queryId: queryId.current,
            response: data.response ?? data.data,
            headerText: data.header_text ?? '',
            sources: data.sources?.length ? ToFrontendKeys(data.sources) : [],
            relatedQuestions: data.related_questions,
            annotations: data.annotations
              ? ToFrontendKeys(data.annotations)
              : {},
            metadata: ToFrontendKeys(data.metadata),
            caption: data.caption,
            progress: data.progress,
            messageId: data.message_id,
            knowledgeSources: data.knowledge_sources
              ? ToFrontendKeys(data.knowledge_sources)
              : [],
            workflow: data.data?.workflow,
          })
        }

        if (data.type === ResponseMessageType.ERROR) {
          setter({
            queryId: queryId.current,
            isLoading: false,
            completionStatus: HarveySocketCompletionStatus.Error,
          })
          endCallback(queryId.current, HarveySocketCompletionStatus.Error)
          const tooManyRequests = 'Too many requests'
          const indexOfTooManyRequests = data.data.indexOf(tooManyRequests)
          if (indexOfTooManyRequests !== -1) {
            displayErrorMessage(
              'Too many requests sent to Harvey. Please try again in a minute.',
              LONG_TOAST_DURATION
            )
            trackEvent('Too Many Requests Message Displayed', {
              status_code: data.status_code,
            })
          } else if (data.status_code === 413) {
            // need to pass in better errors from backend, this isn't the format for all errors passed.
            //cleanup as part of https://www.notion.so/harveyai/Clean-up-ErrorKind-and-clarify-the-backend-frontend-error-philosophy-05fa4da32485462ead6eb87de512812e?pvs=4
            displayErrorMessage(
              'Harvey was unable to process this request due to the combined length of the input and output, please reduce the length of your input.',
              LONG_TOAST_DURATION
            )
          } else {
            const errorMessage = getErrorMessage(data.data, data.status_code)
            displayErrorMessage(errorMessage, LONG_TOAST_DURATION)
          }

          caughtError.current = true
        }

        if (data.type === ResponseMessageType.ANALYTICS) {
          socketAnalytics.current = data.data || {}
        }

        if (data.type === ResponseMessageType.COMPLETE) {
          setter({
            queryId: queryId.current,
            completionStatus: HarveySocketCompletionStatus.Completed,
          })
          isComplete.current = true
        }
      },
      onclose: () => {
        const maxRetries = _.isNil(maxRetryCount)
          ? DEFAULT_RETRY_COUNT
          : maxRetryCount
        socket.current = null

        if (isComplete.current) {
          if (!isCancelled.current) {
            setter({
              queryId: queryId.current,
              isLoading: false,
            })
            endCallback(queryId.current, HarveySocketCompletionStatus.Completed)

            if (recordQueryCompletion) {
              recordQueryCompletion({
                ...recordFields,
                ...socketAnalytics.current,
                event_id: queryId.current,
                query_length: query.length,
              })
            }
          } else {
            if (recordQueryCancel) {
              recordQueryCancel({
                ...recordFields,
                ...socketAnalytics.current,
                event_id: queryId.current,
                query_length: query.length,
              })
            }
          }
        } else if (errorCount.current < maxRetries) {
          errorCount.current += 1
          initSocketAndSendQuery({
            query,
            additionalAuthParams,
            additionalRequestParams,
            recordFields,
            recordQuerySubmitted,
            recordQueryCompletion,
            recordQueryCancel,
            recordQueryError,
          })
        } else if (errorCount.current >= maxRetries) {
          errorCount.current = 0

          if (!caughtError.current) {
            displayErrorMessage(
              'Sorry, something went wrong. Please refresh the page.',
              LONG_TOAST_DURATION
            )
            if (recordQueryError) {
              recordQueryError({
                ...recordFields,
                ...socketAnalytics.current,
                event_id: queryId.current,
                query_length: query.length,
              })
            }
          } else {
            caughtError.current = false
          }
          setter({
            queryId: queryId.current,
            isLoading: false,
            headerText: '',
            progress: 0,
            completionStatus: HarveySocketCompletionStatus.Error,
          })
          endCallback(queryId.current, HarveySocketCompletionStatus.Error)
        }

        // Reset socketAnalytics after use
        socketAnalytics.current = {}
      },
    }

    const url = `${backendWebsocketUrl}/${endpoint}`
    socket.current = webSocketFactory(url, eventHandlers)
  }

  const sendCancelRequest = () => {
    if (!socket.current) return
    isCancelled.current = true
    isComplete.current = true

    if (socket.current.readyState === WebSocket.OPEN) {
      socket.current.send(
        JSON.stringify({
          type: RequestMessageType.CANCEL,
        })
      )
    }
    setter({
      queryId: queryId.current,
      isLoading: false,
      headerText: '',
      caption: '',
      progress: 0,
      completionStatus: HarveySocketCompletionStatus.Cancelled,
    })
    socket.current.close()
    endCallback(queryId.current, HarveySocketCompletionStatus.Cancelled)
  }

  useUnmount(() => {
    if (closeOnUnmount && socket.current) {
      socket.current.close()
    }
  })

  return { initSocketAndSendQuery, sendCancelRequest }
}

export default useHarveySocket
