import { FC, useMemo } from 'react';
import { scaleLinear } from '@visx/scale';
import { AxisRight } from '@visx/axis';
import { NumberValue } from 'd3-scale';
import { useTheme, Theme } from '@mui/material';

import {
  getChartSizes,
  Size,
  ChartMargin,
} from 'shared/utils/charts/chartSizes';

import ChartSVG from '../shared/ChartSVG/ChartSVG';
import LinearGradient from '../shared/LinearGradient/LinearGradient';

interface Props {
  labels?: string[];
  matrix: number[][];
  size: Size;
  margin?: ChartMargin;
}

const gradientId = 'heatmap-gradient-id';

const colorStart = '#C7DDF0';
const colorEnd = '#004785';

const fontSize = '12px';
const fontWeight = '500';

const gradientWidth = 28;
const xOffset = 15;
const yOffset = 25;
const yLabelRotate = -25;
const gradientAxisOffset = gradientWidth + xOffset;

const gradientAxisLabelProps = (value: NumberValue, theme: Theme) => ({
  fontWeight,
  fontSize,
  children: value.toString(),
  fill: theme.palette.text.primary,
});

const margin: ChartMargin = {
  top: 30,
  right: 100,
  left: 100,
  bottom: 80,
};

const Heatmap: FC<React.PropsWithChildren<Props>> = (props) => {
  const chartSizes = getChartSizes({
    margin: props.margin || margin,
    width: props.size.width,
    height: props.size.height,
  });

  const xScale = useMemo(
    () =>
      scaleLinear({
        domain: [0, props.matrix.length],
        range: [0, chartSizes.innerWidth],
      }),
    [chartSizes.innerWidth, props.matrix.length]
  );

  const maxLength = Math.max(...props.matrix.map((d) => d.length));
  const maxValue = useMemo(
    () => Math.max(...props.matrix.flat()),
    [props.matrix]
  );

  const yScale = useMemo(
    () =>
      scaleLinear({
        domain: [0, maxLength],
        range: [0, chartSizes.innerHeight],
      }),
    [chartSizes.innerHeight, maxLength]
  );

  const colorScale = useMemo(
    () =>
      scaleLinear({
        range: [colorStart, colorEnd],
        domain: [0, maxValue],
      }),
    [maxValue]
  );

  const valueScale = useMemo(
    () =>
      scaleLinear({
        range: [chartSizes.innerHeight, 0],
        domain: [0, maxValue],
      }),
    [chartSizes.innerHeight, maxValue]
  );

  const binWidth = chartSizes.innerWidth / props.matrix.length;
  const binHeight = chartSizes.innerHeight / maxLength;

  const theme = useTheme();

  return (
    <ChartSVG
      isNilData={props.matrix.every((c) => c.length === 0)}
      width={chartSizes.width}
      height={chartSizes.height}
      margin={chartSizes.margin}
    >
      {props.labels?.map((label, index) => (
        <text
          key={`y-label-${index}`}
          y={yScale(index) + binHeight / 2}
          x={-xOffset}
          textAnchor="end"
          dominantBaseline="middle"
          fontSize={fontSize}
          fontWeight={fontWeight}
          fill={theme.palette.text.primary}
        >
          {label}
        </text>
      ))}

      {props.labels?.map((label, index) => (
        <text
          key={`y-label-${index}`}
          transform={`translate(${xScale(index) + binWidth / 2}, ${
            chartSizes.innerHeight + yOffset
          }) rotate(${yLabelRotate})`}
          textAnchor="end"
          fontSize={fontSize}
          fontWeight={fontWeight}
          fill={theme.palette.text.primary}
        >
          {label}
        </text>
      ))}

      {props.matrix.map((row, rowIndex) =>
        row.map((d, index) => (
          <g key={`heatmap-rect-${rowIndex}-${index}`}>
            <rect
              className="visx-heatmap-rect"
              width={binWidth}
              height={binHeight}
              x={xScale(index)}
              y={yScale(rowIndex)}
              fill={colorScale(d)}
            />
            <text
              x={xScale(index) + binWidth / 2}
              y={yScale(rowIndex) + binHeight / 2}
              textAnchor="middle"
              dominantBaseline="middle"
              fill={d > maxValue / 2 ? '#fff' : '#000'}
              fontSize="16px"
              fontWeight={fontWeight}
            >
              {d}
            </text>
          </g>
        ))
      )}

      <defs>
        <LinearGradient id={gradientId} colors={[colorEnd, colorStart]} />
      </defs>
      <rect
        fill={`url(#${gradientId}`}
        height={chartSizes.innerHeight}
        x={chartSizes.innerWidth + xOffset}
        y={0}
        width={gradientWidth}
      />
      <AxisRight
        scale={valueScale}
        left={chartSizes.innerWidth + gradientAxisOffset}
        hideAxisLine={true}
        hideTicks={true}
        numTicks={6}
        hideZero={true}
        tickLabelProps={(v) => gradientAxisLabelProps(v, theme)}
      />
    </ChartSVG>
  );
};

export default Heatmap;
