import { useMemo } from 'react';

import { useConfusionMatrixQuery } from 'features/monitoring/widgets/store/confusionMatrix/useConfusionMatrix';
import { ConfusionMatrix } from 'shared/models/Monitoring/ConfusionMatrix';
import { ConfusionMatrixWidget } from 'shared/models/Monitoring/MonitoringModel/MonitoringPanel/MonitoringWidget/Widgets/ConfusionMatrixWidget';
import { isNotEmptyArray } from 'shared/utils/collection';
import { head } from 'shared/utils/opaqueTypes/NonEmptyArray';
import Heatmap from 'shared/view/charts/Heatmap/Heatmap';
import CellRendererNumeric from 'shared/view/elements/DataGrid/columns/CellRendererNumeric';
import { typeSafeConfiguration } from 'shared/view/elements/DataGrid/configuration/helpers/typeSafeConfiguration';
import { DataGridColumn } from 'shared/view/elements/DataGrid/DataGridColumn';
import { DataGridWithTypes } from 'shared/view/elements/DataGrid/DataGridWithTypes';
import { DefaultMatchRemoteDataOrError } from 'shared/view/elements/MatchRemoteDataComponents/DefaultMatchRemoteData';

import { MonitoringWidgetProps } from '../shared/types';
import { modelVersionColumn } from '../../shared/modelVersionColumn';

const ConfusionMatrixWidgetView = (
  props: MonitoringWidgetProps<ConfusionMatrixWidget>
) => {
  const { data, communication } = useConfusionMatrixQuery({
    widgetExternalDeps: props.widgetExternalDeps,
    output: props.widget.output,
  });

  return (
    <DefaultMatchRemoteDataOrError
      data={data}
      communication={communication}
      context="loading confusion matrix"
    >
      {(loadedData) => {
        if (loadedData.length === 1 && isNotEmptyArray(loadedData)) {
          const confusionMatrix = head(loadedData);
          return (
            <Heatmap
              size={props.size}
              matrix={[
                [confusionMatrix.truePositives, confusionMatrix.falsePositives],
                [confusionMatrix.falseNegatives, confusionMatrix.trueNegatives],
              ]}
              labels={['Positive', 'Negative']}
            />
          );
        }
        return (
          <div style={props.size}>
            <ConfusionMatrixTable data={loadedData} />
          </div>
        );
      }}
    </DefaultMatchRemoteDataOrError>
  );
};

const makeConfusionMatrixValueColumn = (props: {
  field: string;
  key: Exclude<keyof ConfusionMatrix, 'modelVersionId' | 'modelVersion'>;
}): DataGridColumn<ConfusionMatrix> => ({
  field: props.field,
  flex: 1,
  additionalConfiguration: typeSafeConfiguration(
    ['sort', 'filter'],
    'number',
    (params) => params.row[props.key]
  ),
  renderCell: (params) => <CellRendererNumeric value={params.row[props.key]} />,
});

const columns: Array<DataGridColumn<ConfusionMatrix>> = [
  modelVersionColumn,
  makeConfusionMatrixValueColumn({
    key: 'truePositives',
    field: 'True positives',
  }),
  makeConfusionMatrixValueColumn({
    key: 'trueNegatives',
    field: 'True negatives',
  }),
  makeConfusionMatrixValueColumn({
    key: 'falsePositives',
    field: 'False positives',
  }),
  makeConfusionMatrixValueColumn({
    key: 'falseNegatives',
    field: 'False negatives',
  }),
];

const ConfusionMatrixTable = (props: { data: ConfusionMatrix[] }) => {
  const rows = useMemo(
    () => props.data.map((d) => ({ ...d, id: d.modelVersionId })),
    [props.data]
  );

  return (
    <DataGridWithTypes
      rows={rows}
      columns={columns}
      heightType="parentHeight"
    />
  );
};

export default ConfusionMatrixWidgetView;
