import { useMemo } from 'react';

import { MonitoringConfidenceMetricType } from 'generated/types';
import isNotNil from 'shared/utils/isNotNill';
import { parseGraphQLNumber } from 'shared/utils/graphql/parseGraphQLNumber';
import {
  isNotNullableRestrictedGraphqlError,
  isNotRestrictedGraphqlError,
} from 'shared/graphql/ErrorFragment';
import { MonitoringWidgetExternalDeps } from 'shared/models/Monitoring/MonitoringModel/MonitoringPanel/MonitoringWidget/MonitoringWidgetExternalDeps';
import { MonitoringIODescription } from 'shared/models/Monitoring/MonitoringModel/MonitoringIODescription';
import { ExtractByTypename } from 'shared/utils/types';

import { ConfidenceMetrics } from '../confidenceMetrics/graphql-types/useConfidenceMetrics.generated';
import { useConfidenceMetrics } from '../confidenceMetrics/useConfidenceMetrics';

interface Props {
  widgetExternalDeps: MonitoringWidgetExternalDeps;
  output: MonitoringIODescription;
}

type MonitoredEntity = ExtractByTypename<
  ConfidenceMetrics['monitoredEntity'],
  'MonitoredEntity'
>;

const metricTypes = [
  MonitoringConfidenceMetricType.PRECISION,
  MonitoringConfidenceMetricType.RECALL,
  MonitoringConfidenceMetricType.TP,
  MonitoringConfidenceMetricType.FP,
  MonitoringConfidenceMetricType.TN,
  MonitoringConfidenceMetricType.FN,
];

export const usePrecisionRecall = (props: Props) => {
  const { communication, data } = useConfidenceMetrics({
    widgetExternalDeps: props.widgetExternalDeps,
    metricTypes,
    output: props.output,
  });

  const convertedData = useMemo(
    () => (isNotNullableRestrictedGraphqlError(data) ? convert(data) : null),
    [data]
  );

  return {
    communication,
    data: convertedData,
  };
};

const convert = (data: MonitoredEntity['metrics']['confidenceMetric']) => {
  const recallMetrics = data.filter(
    (d) => d.type === MonitoringConfidenceMetricType.RECALL
  );
  const precisionRecall = data
    .filter((d) => d.type === MonitoringConfidenceMetricType.PRECISION)
    .map((precisionMetric) => {
      const recallMetric = recallMetrics.find(
        (d) => d.modelVersionId === precisionMetric.modelVersionId
      );
      return recallMetric
        ? {
            modelVersion: isNotRestrictedGraphqlError(
              precisionMetric.modelVersion
            )
              ? precisionMetric.modelVersion.version
              : precisionMetric.modelVersionId,
            precisionValues: precisionMetric.values.map(parseGraphQLNumber),
            recallValues: recallMetric.values.map(parseGraphQLNumber),
          }
        : null;
    })
    .filter(isNotNil);

  const trueNegativesValues = data.find(
    (d) => d.type === MonitoringConfidenceMetricType.TN
  );
  const falseNegativesValues = data.find(
    (d) => d.type === MonitoringConfidenceMetricType.FN
  );
  const truePositivesValues = data.find(
    (d) => d.type === MonitoringConfidenceMetricType.TP
  );
  const falsePositivesValues = data.find(
    (d) => d.type === MonitoringConfidenceMetricType.FP
  );

  const totalPredictions = [
    trueNegativesValues,
    falseNegativesValues,
    truePositivesValues,
    falsePositivesValues,
  ]
    .filter(isNotNil)
    .flatMap((d) => d.values)
    .map(parseGraphQLNumber)
    .reduce((a, b) => a + b, 0);

  const totalPositives = [truePositivesValues, falsePositivesValues]
    .filter(isNotNil)
    .flatMap((d) => d.values)
    .map(parseGraphQLNumber)
    .reduce((a, b) => a + b, 0);

  return {
    totalPredictions,
    totalPositives,
    precisionRecall,
  };
};
