import React, { useMemo } from 'react';
import cx from 'classnames';
import { Box, makeStyles } from '@material-ui/core';
import { Typography } from '@clef/client-library';
import { AggregatedConfusionMatrix, RegisteredModel } from '@clef/shared/types';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import { Theme } from '@material-ui/core';
import LoadingProgress from '../LoadingProgress';

const useStyles = makeStyles<Theme, { transparency?: number }>(theme => ({
  root: {
    padding: theme.spacing(4),
    maxWidth: '1000px',
    display: 'flex',
    flexDirection: 'column',
    alignItems: 'center',
    marginBottom: theme.spacing(20),
  },
  emptyCell: {
    width: theme.spacing(20),
    height: theme.spacing(4),
  },
  tableCell: {
    display: 'flex',
    justifyContent: 'center',
    alignItems: 'center',
    width: theme.spacing(20),
    height: theme.spacing(12),
    background: '#FAFCFE',
  },
  tableCellBackground: {
    background: props =>
      props.transparency ? `rgba(1, 103, 220, ${props.transparency})` : '#0167DC',
  },
  zeroCell: {
    color: theme.palette.grey[400],
  },
  nonZeroCell: {
    background: '#0167DC',
  },
  fixedHeight: {
    height: theme.spacing(7),
  },
  tableContainer: {
    marginTop: theme.spacing(2),
    fontSize: theme.spacing(3),
    fontFamily: 'Commissioner',
    display: 'flex',
    alignItems: 'flex-start',
  },
  groundTruthTitle: {
    fontWeight: 700,
    fontSize: theme.spacing(3),
    lineHeight: '16px',
    color: theme.palette.grey[400],
    paddingRight: theme.spacing(3),
  },
  firstColumn: {
    display: 'flex',
    flexDirection: 'column',
    textAlign: 'right',
    padding: theme.spacing(1, 0),
    color: theme.palette.grey[900],
  },
  firstColumnCaption: {
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'flex-end',
    height: theme.spacing(12),
    paddingRight: theme.spacing(3),
  },
  lastColumn: {
    display: 'flex',
    flexDirection: 'column',
    justifyContent: 'center',
    fontWeight: 500,
    color: theme.palette.grey[900],
    background: theme.palette.grey[100],
    padding: theme.spacing(1, 1, 0, 1),
    marginLeft: theme.spacing(3),
    borderRadius: '2px',
  },
  lastColumnCaption: {
    display: 'flex',
    alignItems: 'center',
    height: theme.spacing(12),
    fontWeight: 500,
  },
  recallTitle: {
    color: theme.palette.grey[500],
    fontSize: theme.spacing(3),
    lineHeight: '16px',
    fontWeight: 600,
  },
  precisionTitle: {
    background: theme.palette.grey[100],
    fontSize: theme.spacing(3),
    color: theme.palette.grey[500],
    marginTop: theme.spacing(2),
    padding: theme.spacing(1, 3, 1, 0),
    fontWeight: 600,
    lineHeight: '16px',
  },
  precisionRow: {
    display: 'flex',
    alignItems: 'center',
  },
  precisionCell: {
    display: 'flex',
    justifyContent: 'center',
    marginTop: theme.spacing(2),
    width: theme.spacing(20),
    height: theme.spacing(6),
    background: theme.palette.grey[100],
    fontWeight: 500,
    color: theme.palette.grey[900],
  },
  lastLabelRow: {
    display: 'flex',
    alignItems: 'center',
  },
  lastLabelRowCaption: {
    display: 'flex',
    justifyContent: 'flex-end',
    alignItems: 'center',
    width: theme.spacing(20),
    transform: 'rotate(-45deg)',
    paddingRight: theme.spacing(1),
    position: 'relative',
    right: theme.spacing(5),
    transformOrigin: 'right',
    color: theme.palette.grey[900],
  },
  predictionTitle: {
    color: theme.palette.grey[400],
    fontWeight: 700,
  },
  cursorPointer: {
    cursor: 'pointer',
  },
}));

const TableCell: React.FC<{
  count?: number;
  transparency?: number;
  isNoLabelCell?: boolean;
  onClick?: () => void;
}> = props => {
  const classes = useStyles({ transparency: props.transparency });
  if (props.isNoLabelCell) {
    return (
      <Box className={cx(classes.tableCell, classes.zeroCell)}>
        <Typography>--</Typography>
      </Box>
    );
  }
  if (typeof props.count !== 'undefined') {
    return (
      <Box
        className={cx(classes.tableCell, {
          [classes.zeroCell]: props.count === 0,
          [classes.tableCellBackground]: props.count > 0,
          [classes.cursorPointer]: props.count > 0 && !!props.onClick,
        })}
        onClick={() => (!!props.count && props.onClick ? props.onClick() : null)}
      >
        <Typography>{props.count}</Typography>
      </Box>
    );
  }
  return null;
};

interface IProps {
  model?: RegisteredModel;
  evaluationSetId?: number;
  threshold?: number;
  onClick?: (gtClassId: number, predClassId: number) => void;
}

export const ConfusionMatrixTable: React.FC<IProps> = ({
  model,
  evaluationSetId,
  threshold,
  onClick,
}) => {
  const classes = useStyles({});
  const { data: confusionMatrixData, isLoading: isConfusionMatrixDataLoading } =
    useGetConfusionMatrixQuery(model?.id, evaluationSetId, threshold);

  const getDefectNameById = useGetDefectNameById();
  const { countSet, defectMap, confusionMatrixMap } = useMemo(() => {
    const defectMap = new Map<number, string>();
    const countSet = new Set<number>();
    const confusionMatrixMap = new Map<string, number>();
    const setConfusionMatrixMap = (confusionMatrix: AggregatedConfusionMatrix[]) => {
      confusionMatrix
        .filter(m => m.count > 0)
        .forEach(item => {
          const gtCaption = item.gtClassId ? getDefectNameById(item.gtClassId) : 'No label';
          const predictionCaption = item.predClassId
            ? getDefectNameById(item.predClassId)
            : 'No label';
          defectMap.set(item.gtClassId, gtCaption);
          defectMap.set(item.predClassId, predictionCaption);
          confusionMatrixMap.set(`${item.gtClassId}-${item.predClassId}`, item.count);
          countSet.add(item.count);
        });
    };

    const { correct, misClassification, falseNegative, falsePositive } =
      confusionMatrixData?.splitConfusionMatrices ?? {};
    setConfusionMatrixMap(correct?.data ?? []);
    setConfusionMatrixMap(misClassification?.data ?? []);
    setConfusionMatrixMap(falseNegative?.data ?? []);
    setConfusionMatrixMap(falsePositive?.data ?? []);
    return {
      defectMap,
      countSet,
      confusionMatrixMap,
    };
  }, [confusionMatrixData, getDefectNameById]);

  const defectIds = Array.from(defectMap.keys()).sort((a, b) => b - a);
  const orderedCountSet = Array.from(countSet).sort((a, b) => a - b);

  if (isConfusionMatrixDataLoading) {
    return (
      <Box className={classes.root}>
        <LoadingProgress />
      </Box>
    );
  }

  return (
    <Box className={classes.root}>
      <Box className={classes.tableContainer}>
        <Box className={classes.firstColumn}>
          <Typography className={cx(classes.groundTruthTitle)}>{t('Ground truth')}</Typography>
          {defectIds.map(defectId => (
            <Typography key={`ground-truth-${defectId}`} className={classes.firstColumnCaption}>
              {defectMap.get(defectId)}
            </Typography>
          ))}
          <Typography className={classes.precisionTitle}>{t('Precision')}</Typography>
        </Box>
        <Box>
          <Box height={20} />
          {defectIds.map(defectIdA => {
            return (
              <Box key={`row-${defectIdA}`} display="flex">
                {defectIds.map(defectIdB => {
                  if (defectIdA === defectIdB && defectIdA === 0) {
                    return <TableCell key={`cell-${defectIdA}-${defectIdB}`} isNoLabelCell />;
                  }
                  const count = confusionMatrixMap.get(`${defectIdA}-${defectIdB}`);
                  return (
                    <TableCell
                      key={`cell-${defectIdA}-${defectIdB}`}
                      count={confusionMatrixMap.get(`${defectIdA}-${defectIdB}`) ?? 0}
                      transparency={
                        count
                          ? (orderedCountSet.indexOf(count) + 1) / orderedCountSet.length
                          : undefined
                      }
                      onClick={() => onClick?.(defectIdA, defectIdB)}
                    />
                  );
                })}
              </Box>
            );
          })}
          <Box className={classes.precisionRow}>
            {defectIds.map(defectId => {
              if (defectId === 0) {
                return null;
              }
              const correctCount = confusionMatrixMap.get(`${defectId}-${defectId}`) ?? 0;
              const totalCount = defectIds.reduce(
                (total, _defectId) =>
                  total + (confusionMatrixMap.get(`${_defectId}-${defectId}`) ?? 0),
                0,
              );
              return (
                <Typography className={classes.precisionCell} key={`precision-${defectId}`}>
                  {((correctCount / totalCount) * 100).toFixed(1)}%
                </Typography>
              );
            })}
          </Box>
          <Box className={classes.lastLabelRow}>
            {defectIds.map(defectId => (
              <Typography key={`last-label-${defectId}`} className={classes.lastLabelRowCaption}>
                {defectMap.get(defectId)}
              </Typography>
            ))}
          </Box>
        </Box>
        <Box display="flex" flexDirection="column">
          <Box className={classes.lastColumn}>
            <Typography className={classes.recallTitle}>{t('Recall')}</Typography>
            {defectIds.map(defectId => {
              if (defectId === 0) {
                return null;
              }
              const correctCount = confusionMatrixMap.get(`${defectId}-${defectId}`) ?? 0;
              const totalCount = defectIds.reduce(
                (total, _defectId) =>
                  total + (confusionMatrixMap.get(`${defectId}-${_defectId}`) ?? 0),
                0,
              );
              return (
                <Typography className={cx(classes.lastColumnCaption)} key={`precision-${defectId}`}>
                  {((correctCount / totalCount) * 100).toFixed(1)}%
                </Typography>
              );
            })}
          </Box>
          <Box display="flex" flexDirection="column">
            <Box height={80}></Box>
            <Typography className={cx(classes.lastLabelRowCaption, classes.predictionTitle)}>
              {t('Prediction')}
            </Typography>
          </Box>
        </Box>
      </Box>
    </Box>
  );
};
