import { gql } from '@apollo/client';
import { useCallback, useMemo } from 'react';

import { useCustomQuery } from 'shared/view/hooks/apollo/useCustomQuery';
import { MonitoringIODescription } from 'shared/models/Monitoring/MonitoringModel/MonitoringIODescription';
import { useMemoizedResultToCommunicationWithData } from 'shared/utils/graphql/queryResultToCommunicationWithData';
import { toGraphQLDate } from 'shared/utils/graphql/toGraphQLDate';
import { convertTimeRangeToDateRange } from 'shared/utils/TimeRange';
import { DistributionOverTime } from 'shared/models/Monitoring/Distribution/DistributionOverTime';
import parseGraphqlDate from 'shared/utils/graphql/parseGraphqlDate';
import { MonitoringDriftMetricType } from 'generated/types';
import {
  mapDataOrError,
  RESTRICTED_GRAPHQL_ERROR_FRAGMENT,
} from 'shared/graphql/ErrorFragment';
import { MonitoringWidgetExternalDeps } from 'shared/models/Monitoring/MonitoringModel/MonitoringPanel/MonitoringWidget/MonitoringWidgetExternalDeps';
import { convertMonitoringFilterToGraphQL } from 'shared/models/Monitoring/MonitoringFilters/MonitoringFilter';
import { Distribution } from 'shared/models/Monitoring/Distribution/Distribution';
import { ExtractByTypename } from 'shared/utils/types';

import {
  DistributionOverTimeQuery,
  DistributionOverTimeQueryVariables,
} from './graphql-types/useDistributionOverTime.generated';
import { convertIODescriptionToQuery } from '../shared/convertIODescriptionToQuery';

const DISTRIBUTION_OVER_TIME_QUERY = gql`
  query DistributionOverTimeQuery(
    $monitoredEntityId: ID!
    $distributionQuery: MonitoringDistributionQuery!
    $aggregationMilliseconds: Int!
    $driftMetricType: MonitoringDriftMetricType!
  ) {
    monitoredEntity(id: $monitoredEntityId) {
      ... on Error {
        ...ErrorData
      }
      ... on MonitoredEntity {
        id
        metrics {
          driftOverTime(
            query: {
              base: { base: $distributionQuery, types: [$driftMetricType] }
              aggregationMilliseconds: $aggregationMilliseconds
            }
          ) {
            name
            modelVersionId
            time
            values
          }
          referenceDistribution(query: $distributionQuery) {
            name
            modelVersionId
            bucketLimits
            bucketCounts
            modelVersion {
              ... on Error {
                ...ErrorData
              }
              ... on RegisteredModelVersion {
                id
                version
              }
            }
          }
          liveDistributionOverTime(
            query: {
              aggregationMilliseconds: $aggregationMilliseconds
              base: $distributionQuery
            }
          ) {
            time
            name
            modelVersionId
            bucketLimits
            bucketCounts
            modelVersion {
              ... on Error {
                ...ErrorData
              }
              ... on RegisteredModelVersion {
                id
                version
              }
            }
          }
        }
      }
    }
  }
  ${RESTRICTED_GRAPHQL_ERROR_FRAGMENT}
`;

interface Props {
  widgetExternalDeps: MonitoringWidgetExternalDeps;
  description: MonitoringIODescription;
  driftMetricType: MonitoringDriftMetricType;
}

export const useDistributionOverTimeQuery = (props: Props) => {
  const variables = useMemo((): DistributionOverTimeQueryVariables => {
    const dateRange = convertTimeRangeToDateRange(
      props.widgetExternalDeps.timeRange
    );
    return {
      monitoredEntityId: props.widgetExternalDeps.monitoredEntityId,
      distributionQuery: {
        startDate: toGraphQLDate(dateRange.from),
        endDate: toGraphQLDate(dateRange.to),
        ioDescriptions: [convertIODescriptionToQuery(props.description)],
        filters: props.widgetExternalDeps.filters.map(
          convertMonitoringFilterToGraphQL
        ),
      },
      aggregationMilliseconds: props.widgetExternalDeps.aggregation.timeWindow,
      driftMetricType: props.driftMetricType,
    };
  }, [
    props.widgetExternalDeps.timeRange,
    props.widgetExternalDeps.monitoredEntityId,
    props.widgetExternalDeps.aggregation.timeWindow,
    props.widgetExternalDeps.filters,
    props.description,
    props.driftMetricType,
  ]);

  const query = useCustomQuery<
    DistributionOverTimeQuery,
    DistributionOverTimeQueryVariables
  >(DISTRIBUTION_OVER_TIME_QUERY, {
    variables,
  });

  const convert = useCallback(
    (res: DistributionOverTimeQuery) => {
      return mapDataOrError(
        res.monitoredEntity,
        (
          monitoredEntity
        ): {
          liveDistributions: DistributionOverTime[];
          referenceDistributions: Distribution[];
          drift: ExtractByTypename<
            DistributionOverTimeQuery['monitoredEntity'],
            'MonitoredEntity'
          >['metrics']['driftOverTime'];
        } => {
          const referenceDistributions =
            monitoredEntity.metrics.referenceDistribution.filter((r) =>
              props.widgetExternalDeps.registeredModelVersionIds.includes(
                r.modelVersionId
              )
            );
          const liveDistributions =
            monitoredEntity.metrics.liveDistributionOverTime
              .filter((l) =>
                props.widgetExternalDeps.registeredModelVersionIds.includes(
                  l.modelVersionId
                )
              )
              .flatMap(convertDistributionOverTime);
          const drift = monitoredEntity.metrics.driftOverTime.filter((d) =>
            props.widgetExternalDeps.registeredModelVersionIds.includes(
              d.modelVersionId
            )
          );
          return {
            liveDistributions,
            referenceDistributions,
            drift,
          };
        }
      );
    },
    [props.widgetExternalDeps.registeredModelVersionIds]
  );

  return useMemoizedResultToCommunicationWithData({
    memoizedConvert: convert,
    queryResult: query,
  });
};

const convertDistributionOverTime = (
  data: ExtractByTypename<
    DistributionOverTimeQuery['monitoredEntity'],
    'MonitoredEntity'
  >['metrics']['liveDistributionOverTime'][0]
): DistributionOverTime[] => {
  return data.time.map((time, index): DistributionOverTime => {
    return {
      time: parseGraphqlDate(time),
      bucketCounts: data.bucketCounts[index],
      bucketLimits: data.bucketLimits,
      name: data.name,
      modelVersionId: data.modelVersionId,
      modelVersion: data.modelVersion,
    };
  });
};
