import { sum } from 'ramda';

import { isNotRestrictedGraphqlError } from 'shared/graphql/ErrorFragment';
import { formatWithDefaultPrecision } from 'shared/utils/formatters/formatWithDefaultPrecision';
import { IntervalValue } from 'shared/view/charts/shared/intervals';

import { Distribution } from './Distribution';
import { discreteDistributionKeys } from './DistributionDescription';

export interface NormalizedDistribution {
  type: 'float' | 'discrete';
  name: string;
  buckets: IntervalValue[];
  normalizedValues: number[];
  counts: number[];
  totalCount: number;
  modelVersion: string;
  modelVersionId: string;
}

export const normalizeDistribution = (
  distribution: Distribution
): NormalizedDistribution => {
  const totalCount = sum(distribution.bucketCounts);
  const divider = (totalCount || 1) / 100;

  const modelVersion = isNotRestrictedGraphqlError(distribution.modelVersion)
    ? distribution.modelVersion.version
    : distribution.modelVersionId;

  if (distribution.bucketCounts.length === 4) {
    const counts = [distribution.bucketCounts[0], distribution.bucketCounts[2]];

    const normalizedValues = counts.map((value) =>
      Number(formatWithDefaultPrecision(value / divider))
    );

    return {
      type: 'discrete',
      buckets: discreteDistributionKeys,
      counts,
      normalizedValues,
      name: distribution.name,
      totalCount,
      modelVersion,
      modelVersionId: distribution.modelVersionId,
    };
  }

  const normalizedValues = distribution.bucketCounts.map((value) =>
    Number(formatWithDefaultPrecision(value / divider))
  );
  const buckets = [
    '-infinity',
    ...distribution.bucketLimits.map(formatWithDefaultPrecision).map(Number),
    'infinity+',
  ];

  return {
    type: 'float',
    name: distribution.name,
    buckets,
    normalizedValues,
    counts: distribution.bucketCounts,
    totalCount,
    modelVersion,
    modelVersionId: distribution.modelVersionId,
  };
};
