import React, { useMemo, useState } from 'react';
import cx from 'classnames';
import { useAtom } from 'jotai';
import { Box, makeStyles, MenuItem, Select, Tooltip } from '@material-ui/core';
import { Typography } from '@clef/client-library';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { modelListFilterOptionsAtom } from '../atoms';
import ModelImageList from '../ModelImageList/ModelImageList';
import { RegisteredModelWithThreshold } from '@/api/model_api';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import LoadingProgress from '../LoadingProgress';
import InfoOutlined from '@material-ui/icons/InfoOutlined';
import { AggregatedConfusionMatrix } from '@clef/shared/types';
import { ConfusionMatrixDiffTable } from '../ModelDetailsPanel/ConfusionMatrixDiffTable';
import { ConfusionMatrixTable } from '../ModelDetailsPanel/ConfusionMatrixTable';

const ALL_CLASSES_MENU_VALUE = -1;

interface ComparisonMatrix {
  gtDefectId: number | null;
  predDefectId: number | null;
  baseline: number;
  candidate: number;
}

const useMatrixStyles = makeStyles(theme => ({
  title: {
    fontWeight: 700,
    color: theme.palette.grey[900],
  },
  textAlignRight: {
    textAlign: 'right',
  },
  paddingLeft16px: {
    paddingLeft: theme.spacing(4),
  },
  paddingRight16px: {
    paddingRight: theme.spacing(4),
  },
  matrixTableContainer: {
    gap: theme.spacing(7),
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'center',
  },
  matrixTableTitle: {
    fontWeight: 700,
    color: theme.palette.greyModern[900],
  },
}));

export type ModelComparisonConfusionMatrixProps = {
  baseline: RegisteredModelWithThreshold;
  candidate: RegisteredModelWithThreshold;
  evaluationSet: EvaluationSetItem;
  baselineThreshold: number;
  candidateThreshold: number;
};

const ModelComparisonConfusionMatrix: React.FC<ModelComparisonConfusionMatrixProps> = props => {
  const { baseline, candidate, evaluationSet, baselineThreshold, candidateThreshold } = props;
  const { data: baselineConfusionMatrixData, isLoading: isLoadingBaseline } =
    useGetConfusionMatrixQuery(baseline.id, evaluationSet.id, baselineThreshold);
  const { data: candidateConfusionMatrixData, isLoading: isLoadingCandidate } =
    useGetConfusionMatrixQuery(candidate.id, evaluationSet.id, candidateThreshold);
  const [filterOptions, setFilterOptions] = useAtom(modelListFilterOptionsAtom);
  const [filteredDefect, setFilteredDefect] = useState<number>(ALL_CLASSES_MENU_VALUE);
  const getDefectNameById = useGetDefectNameById();
  const styles = useMatrixStyles();

  const { splitConfusionMatrices: baselineConfusionMatrices } = baselineConfusionMatrixData ?? {};
  const { splitConfusionMatrices: candidateConfusionMatrices } = candidateConfusionMatrixData ?? {};

  // No Label should not be included on correct mapping
  const getComparisonMatrix = (
    baselineConfusionMatrix: AggregatedConfusionMatrix[],
    candidateConfusionMatrix: AggregatedConfusionMatrix[],
    defectSets: Set<number>,
    filteredDefect: number,
    shouldNotIncludeNoLabel: boolean = false,
  ): ComparisonMatrix[] => {
    const baselineConfusionMatrixMap = baselineConfusionMatrix
      .filter(m => m.count > 0)
      .reduce((acc, cur) => {
        return { ...acc, [`${cur.gtClassId}-${cur.predClassId}`]: cur };
      }, {} as Record<string, AggregatedConfusionMatrix>);
    const candidateConfusionMatrixMap = candidateConfusionMatrix
      .filter(m => m.count > 0)
      .reduce((acc, cur) => {
        return { ...acc, [`${cur.gtClassId}-${cur.predClassId}`]: cur };
      }, {} as Record<string, AggregatedConfusionMatrix>);
    const compareConfusionMatrix: ComparisonMatrix[] = [];

    Object.keys({
      ...baselineConfusionMatrixMap,
      ...candidateConfusionMatrixMap,
    }).forEach(key => {
      const baselineCount = baselineConfusionMatrixMap.hasOwnProperty(key)
        ? baselineConfusionMatrixMap[key].count
        : 0;
      const candidateCount = candidateConfusionMatrixMap.hasOwnProperty(key)
        ? candidateConfusionMatrixMap[key].count
        : 0;
      const item = baselineConfusionMatrixMap.hasOwnProperty(key)
        ? baselineConfusionMatrixMap[key]
        : candidateConfusionMatrixMap[key];
      const res = {
        gtDefectId: item.gtClassId,
        predDefectId: item.predClassId,
        baseline: baselineCount,
        candidate: candidateCount,
      };
      defectSets.add(res.gtDefectId ?? 0);
      defectSets.add(res.predDefectId ?? 0);
      if (
        filteredDefect !== ALL_CLASSES_MENU_VALUE &&
        filteredDefect !== res.gtDefectId &&
        filteredDefect !== res.predDefectId
      ) {
        return;
      }
      if (shouldNotIncludeNoLabel) {
        !!res.gtDefectId && compareConfusionMatrix.push(res);
      } else {
        compareConfusionMatrix.push(res);
      }
    });
    return compareConfusionMatrix;
  };

  const {
    defectSets,
    correctConfusionMatrix,
    misClassificationConfusionMatrix,
    falseNegativeConfusionMatrix,
    falsePositiveConfusionMatrix,
  } = useMemo(() => {
    const defectSets: Set<number> = new Set();
    const {
      correct: baselineCorrect,
      falsePositive: baselineFP,
      falseNegative: baselineFN,
      misClassification: baselineMC,
    } = baselineConfusionMatrices ?? {};
    const {
      correct: candidateCorrect,
      falsePositive: candidateFP,
      falseNegative: candidateFN,
      misClassification: candidateMC,
    } = candidateConfusionMatrices ?? {};

    const correctConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineCorrect?.data ?? [],
      candidateCorrect?.data ?? [],
      defectSets,
      filteredDefect,
      true, // No Label should not be included on correct mapping
    );
    const falsePositiveConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineFP?.data ?? [],
      candidateFP?.data ?? [],
      defectSets,
      filteredDefect,
    );
    const falseNegativeConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineFN?.data ?? [],
      candidateFN?.data ?? [],
      defectSets,
      filteredDefect,
    );
    const misClassificationConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineMC?.data ?? [],
      candidateMC?.data ?? [],
      defectSets,
      filteredDefect,
    );
    return {
      correctConfusionMatrix,
      falsePositiveConfusionMatrix,
      falseNegativeConfusionMatrix,
      misClassificationConfusionMatrix,
      defectSets,
    };
  }, [filteredDefect, baselineConfusionMatrixData, candidateConfusionMatrixData]);

  if (isLoadingBaseline || isLoadingCandidate) {
    return <LoadingProgress size={24} />;
  }

  return (
    <>
      <Box className={styles.matrixTableContainer}>
        <Box display="flex" flexDirection="column" alignItems="center">
          <Typography className={styles.matrixTableTitle}>{t('Baseline model')}</Typography>
          <ConfusionMatrixTable
            model={baseline}
            evaluationSetId={evaluationSet.id}
            threshold={baseline.threshold}
            onClick={(gtClassId, predClassId) => setFilterOptions({ gtClassId, predClassId })}
          />
        </Box>
        <Box display="flex" flexDirection="column" alignItems="center">
          <Typography className={styles.matrixTableTitle}>{t('Candidate model')}</Typography>
          <ConfusionMatrixTable
            model={candidate}
            evaluationSetId={evaluationSet.id}
            threshold={candidate.threshold}
            onClick={(gtClassId, predClassId) => setFilterOptions({ gtClassId, predClassId })}
          />
        </Box>
      </Box>
      {!!filterOptions && (
        <Box display="flex">
          <Box
            width={filterOptions ? 600 : 850}
            flexShrink={0}
            flexGrow={0}
            style={{
              transition: 'width 0.3s ease-in-out',
            }}
          >
            <Box display="flex" marginBottom={3} alignItems="center">
              <Box className={cx(styles.title, styles.paddingLeft16px)} width={350}>
                <Select
                  value={filteredDefect}
                  onChange={(e: React.ChangeEvent<{ value: unknown }>) =>
                    setFilteredDefect(e.target.value as number)
                  }
                >
                  <MenuItem value={ALL_CLASSES_MENU_VALUE}>{t('All Classes')}</MenuItem>
                  {Array.from(defectSets).map(
                    id =>
                      id && (
                        <MenuItem key={id} value={id}>
                          {getDefectNameById(id)}
                        </MenuItem>
                      ),
                  )}
                </Select>
              </Box>
              <Typography className={cx(styles.title, styles.textAlignRight)} width={145}>
                {filterOptions ? t('Baseline') : t('Baseline model')}
              </Typography>
              <Typography
                className={cx(styles.title, styles.textAlignRight, {
                  [styles.paddingRight16px]: !filterOptions,
                })}
                width={145}
              >
                {filterOptions ? t('Candidate') : t('Candidate model')}
              </Typography>
              <Box display="flex" alignItems="center">
                <Typography className={cx(styles.title, styles.textAlignRight)} width={90}>
                  {t('Differences')}
                </Typography>
                <Box marginLeft={1} />
                <Tooltip
                  placement="top"
                  arrow={true}
                  title={
                    <>
                      <Typography>
                        {t(
                          `Shows if the candidate model performed better (Fixed) or worse (New Error) than the baseline candidate for each error type. For differences shown as a percentage, the percentage is calculated as:`,
                        )}
                      </Typography>
                      <Typography>{t('((candidate - baseline) / baseline) * 100')}</Typography>
                    </>
                  }
                >
                  <InfoOutlined fontSize="small" />
                </Tooltip>
              </Box>
            </Box>
            {[
              falsePositiveConfusionMatrix,
              falseNegativeConfusionMatrix,
              misClassificationConfusionMatrix,
              correctConfusionMatrix,
            ].map(
              (matrix, index) =>
                matrix.length > 0 && (
                  <ConfusionMatrixDiffTable
                    key={['false-positive', 'false-negative', 'mis-classified', 'correct'][index]}
                    title={
                      [t('False Positive'), t('False Negative'), t('Misclassified'), t('Correct')][
                        index
                      ]
                    }
                    titleTooltip={
                      [
                        t(
                          'The model predicted that an object of interest was present, but the model was incorrect.',
                        ),
                        t(
                          'The model predicted that an object of interest was not present, but the model was incorrect.',
                        ),
                        t(
                          'The model correctly predicted that an object of interest was present, but it predicted the wrong class.',
                        ),
                        t('The model’s prediction was correct.'),
                      ][index]
                    }
                    isCorrectMapping={index === 3}
                    comparisonSum={{
                      base: matrix.reduce((accum, val) => accum + val.baseline, 0),
                      candidate: matrix.reduce((accum, val) => accum + val.candidate, 0),
                    }}
                    comparsionMatrices={matrix}
                  />
                ),
            )}
          </Box>
          {!!filterOptions && (
            <ModelImageList
              model={baseline}
              candidate={candidate}
              threshold={baseline.threshold}
              candidateThreshold={candidate.threshold}
              evaluationSet={evaluationSet}
            />
          )}
        </Box>
      )}
    </>
  );
};

export default ModelComparisonConfusionMatrix;
