import { gql } from '@apollo/client';
import { useCallback, useMemo } from 'react';

import { useCustomQuery } from 'shared/view/hooks/apollo/useCustomQuery';
import { convertTimeRangeToDateRange } from 'shared/utils/TimeRange';
import { toGraphQLDate } from 'shared/utils/graphql/toGraphQLDate';
import { MonitoringMetricType } from 'generated/types';
import { useMemoizedResultToCommunicationWithData } from 'shared/utils/graphql/queryResultToCommunicationWithData';
import { parseGraphQLNumber } from 'shared/utils/graphql/parseGraphQLNumber';
import {
  mapDataOrError,
  RESTRICTED_GRAPHQL_ERROR_FRAGMENT,
} from 'shared/graphql/ErrorFragment';
import { MonitoringWidgetExternalDeps } from 'shared/models/Monitoring/MonitoringModel/MonitoringPanel/MonitoringWidget/MonitoringWidgetExternalDeps';
import { MonitoringIODescription } from 'shared/models/Monitoring/MonitoringModel/MonitoringIODescription';
import { convertMonitoringFilterToGraphQL } from 'shared/models/Monitoring/MonitoringFilters/MonitoringFilter';
import isNotNil from 'shared/utils/isNotNill';
import { ConfusionMatrix } from 'shared/models/Monitoring/ConfusionMatrix';

import {
  ConfusionMatrixQuery,
  ConfusionMatrixQueryVariables,
} from './graphql-types/useConfusionMatrix.generated';

const CONFUSION_MATRIX_QUERY = gql`
  query ConfusionMatrixQuery(
    $monitoredEntityId: ID!
    $confusionMatrixQuery: MonitoringMetricQuery!
  ) {
    monitoredEntity(id: $monitoredEntityId) {
      ... on Error {
        ...ErrorData
      }
      ... on MonitoredEntity {
        id
        metrics {
          metric(query: $confusionMatrixQuery) {
            modelVersionId
            modelVersion {
              ... on Error {
                ...ErrorData
              }
              ... on RegisteredModelVersion {
                id
                version
              }
            }
            value
            type
          }
        }
      }
    }
  }
  ${RESTRICTED_GRAPHQL_ERROR_FRAGMENT}
`;

interface Props {
  widgetExternalDeps: MonitoringWidgetExternalDeps;
  output: MonitoringIODescription;
}

const metricTypes = [
  MonitoringMetricType.TN,
  MonitoringMetricType.TP,
  MonitoringMetricType.FN,
  MonitoringMetricType.FP,
];

export const useConfusionMatrixQuery = (props: Props) => {
  const variables = useMemo((): ConfusionMatrixQueryVariables => {
    const dateRange = convertTimeRangeToDateRange(
      props.widgetExternalDeps.timeRange
    );
    return {
      monitoredEntityId: props.widgetExternalDeps.monitoredEntityId,
      confusionMatrixQuery: {
        startDate: toGraphQLDate(dateRange.from),
        endDate: toGraphQLDate(dateRange.to),
        types: metricTypes,
        output: props.output,
        filters: props.widgetExternalDeps.filters.map(
          convertMonitoringFilterToGraphQL
        ),
      },
    };
  }, [
    props.output,
    props.widgetExternalDeps.monitoredEntityId,
    props.widgetExternalDeps.timeRange,
    props.widgetExternalDeps.filters,
  ]);

  const query = useCustomQuery<
    ConfusionMatrixQuery,
    ConfusionMatrixQueryVariables
  >(CONFUSION_MATRIX_QUERY, {
    variables,
  });

  const convert = useCallback(
    (res: ConfusionMatrixQuery) => {
      return mapDataOrError(
        res.monitoredEntity,
        (monitoredEntity): ConfusionMatrix[] => {
          const modelVersionIds =
            props.widgetExternalDeps.registeredModelVersionIds;

          return modelVersionIds
            .map((modelVersionId) => {
              const filteredMetrics = monitoredEntity.metrics.metric.filter(
                (metric) => metric.modelVersionId === modelVersionId
              );
              const trueNegatives = filteredMetrics.find(
                (m) => m.type === MonitoringMetricType.TN
              );
              const truePositives = filteredMetrics.find(
                (m) => m.type === MonitoringMetricType.TP
              );
              const falseNegatives = filteredMetrics.find(
                (m) => m.type === MonitoringMetricType.FN
              );
              const falsePositives = filteredMetrics.find(
                (m) => m.type === MonitoringMetricType.FP
              );

              if (
                trueNegatives &&
                truePositives &&
                falseNegatives &&
                falsePositives
              ) {
                return {
                  trueNegatives: parseGraphQLNumber(trueNegatives.value),
                  truePositives: parseGraphQLNumber(truePositives.value),
                  falseNegatives: parseGraphQLNumber(falseNegatives.value),
                  falsePositives: parseGraphQLNumber(falsePositives.value),
                  modelVersionId,
                  modelVersion: trueNegatives.modelVersion,
                };
              }

              return null;
            })
            .filter(isNotNil);
        }
      );
    },
    [props.widgetExternalDeps.registeredModelVersionIds]
  );

  return useMemoizedResultToCommunicationWithData({
    memoizedConvert: convert,
    queryResult: query,
  });
};
