import { useTheme } from '@mui/material';
import { GridColumns, GridRows } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLinear } from '@visx/scale';
import { ScaleBand } from 'd3-scale';
import { uniq } from 'ramda';
import { useMemo } from 'react';

import { getChartSizes, Size } from 'shared/utils/charts/chartSizes';
import { getLinearScaleDomain } from 'shared/utils/charts/getLinearScaleDomain';

import AxisBottomWithRotatedTicks from '../Axis/AxisBottomWithRotatedTicks/AxisBottomWithRotatedTicks';
import CustomAxisLeft from '../Axis/CustomAxisLeft/CustomAxisLeft';
import ChartSVG from '../shared/ChartSVG/ChartSVG';
import { defaultChartMargin } from '../shared/margin';

interface ChartData {
  levelKey: string;
  groups: Array<{
    groupKey: string;
    bars: Array<{
      barKey: string;
      value: number;
    }>;
  }>;
}

export interface BaseMultiLevelBarChartProps {
  data: ChartData[];
  getColorByBarKey: (key: string) => string;
  size: Size;
  groupKeys: string[];
  getGroupLeftOffset: (props: { groupScale: ScaleBand<string> }) => number;
}

const getValueDomain = (data: ChartData[]) => {
  const [min, max] = getLinearScaleDomain(
    data.flatMap((d) => d.groups.flatMap((g) => g.bars.map((b) => b.value)))
  );

  return [min > 0 ? 0 : min, max];
};

const BaseMultiLevelBarChart = (props: BaseMultiLevelBarChartProps) => {
  const { width, height, innerHeight, innerWidth, margin } = getChartSizes({
    margin: defaultChartMargin,
    width: props.size.width,
    height: props.size.height,
  });

  const barKeys = useMemo(
    () =>
      uniq(
        props.data.flatMap((d) =>
          d.groups.flatMap((g) => g.bars.map((b) => b.barKey))
        )
      ),
    [props.data]
  );

  const levelKeys = useMemo(
    () => uniq(props.data.map((d) => d.levelKey)),
    [props.data]
  );

  const levelScale = useMemo(
    () =>
      scaleBand({
        domain: levelKeys,
        range: [innerHeight, 0],
        padding: 0,
      }),
    [innerHeight, levelKeys]
  );

  const groupScale = useMemo(
    () =>
      scaleBand({
        domain: props.groupKeys,
        range: [0, innerWidth],
        padding: 0.1,
      }),
    [innerWidth, props.groupKeys]
  );

  const barScale = useMemo(
    () =>
      scaleBand({
        domain: barKeys,
        range: [0, groupScale.bandwidth()],
        align: 1,
        padding: 0.1,
        paddingInner: 0.2,
      }),
    [groupScale, barKeys]
  );

  const valueScale = useMemo(
    () =>
      scaleLinear({
        domain: getValueDomain(props.data),
        range: [0, levelScale.bandwidth() * 0.9],
      }),
    [levelScale, props.data]
  );

  const theme = useTheme();

  return (
    <ChartSVG
      width={width}
      height={height}
      margin={margin}
      isNilData={props.data.every((d) =>
        d.groups.every((g) => g.bars.length === 0)
      )}
    >
      <GridRows
        scale={levelScale}
        width={innerWidth}
        stroke={theme.palette.charts.accentColor}
        strokeWidth={0.5}
        offset={levelScale.bandwidth() / 2}
      />

      <GridColumns
        scale={groupScale}
        height={innerHeight}
        stroke={theme.palette.charts.accentColor}
        strokeWidth={1}
        offset={
          -groupScale.bandwidth() / 2 + props.getGroupLeftOffset({ groupScale })
        }
      />

      <CustomAxisLeft scale={levelScale} />

      <AxisBottomWithRotatedTicks
        scale={groupScale}
        top={innerHeight}
        tickStroke={theme.palette.charts.accentColor}
      />

      {props.data.map((d) => {
        return (
          <Group key={d.levelKey} top={levelScale(d.levelKey) ?? 0}>
            {d.groups.map((g) => (
              <Group
                key={g.groupKey + d.levelKey}
                left={
                  (groupScale(g.groupKey) ?? 0) +
                  props.getGroupLeftOffset({ groupScale })
                }
              >
                {g.bars.map((b) => (
                  <rect
                    key={b.barKey + d.levelKey + g.groupKey}
                    fill={props.getColorByBarKey(b.barKey)}
                    fillOpacity={0.5}
                    width={barScale.bandwidth()}
                    height={valueScale(b.value)}
                    y={levelScale.bandwidth() - valueScale(b.value)}
                    x={barScale(b.barKey)}
                    strokeWidth={1}
                    stroke={props.getColorByBarKey(b.barKey)}
                  />
                ))}
              </Group>
            ))}
          </Group>
        );
      })}
    </ChartSVG>
  );
};

export default BaseMultiLevelBarChart;
