import { range } from 'lodash';
import { useCurrentProjectModelInfoQuery } from '@/serverStore/projectModels';
import {
  ClassificationConfusionMatrix,
  ConfusionMatrixPerThreshold,
  DatasetGroupOptions,
  PerformanceMetrics,
  SelectMediaOption,
  ModelStatus,
} from '@clef/shared/types';
import { useGetDatasetStatsQuery } from '@/serverStore/dataset';
import { useGetDatasetFilterOptionsQuery } from '@/serverStore/dataset';
import { useJobDetailForCurrentProject } from '@/serverStore/jobs';
import { CONFIDENCE_THRESHOLD_OPTIONS } from '@clef/shared/constants';

export type ConfusionMatrix = {
  truePositiveCount: number;
  trueNegativeCount: number;
  falsePositiveCount: number;
  falseNegativeCount: number;
  misclassifiedCount?: number;
};
export type ProjectModelPerformanceMetrics = {
  // performance aggregated over media
  mediaLevelPerformance: number;
  mediaLevelPerformancePerSplit: { [split: string]: number };
  // performance aggregated over bboxes / pixels / classifications
  annotationLevelPerformance: number;
  annotationLevelPerformancePerSplit: { [split: string]: number };
  // confusion matrix
  annotationLevelConfusionMatrix: ConfusionMatrix;
  misclassifiedCount: number;
};

export const isModelTrainingFailed = (status?: ModelStatus | null) => {
  return status === ModelStatus.Failed;
};

export const isModelTrainingInProgress = (status?: ModelStatus | null, metricsReady?: boolean) => {
  if (status) {
    if (
      [
        ModelStatus.Created,
        ModelStatus.Starting,
        ModelStatus.Evaluating,
        ModelStatus.Training,
        ModelStatus.LegacyRunning,
        ModelStatus.Publishing,
        ModelStatus.FirstBatch,
        ModelStatus.ALLBatches,
      ].includes(status)
    ) {
      return true;
    }
    // if stopped/failed/terminated, we will not compute the metrics so training should be stopped
    else if ([ModelStatus.Stopped, ModelStatus.Failed, ModelStatus.Terminated].includes(status)) {
      return false;
    }
    // if reaching to other final states (statuses), we will wait for the metrics computation is ready
    else {
      return !metricsReady;
    }
  }

  // if invalid status, training should be stopped
  return false;
};

export const isModelTrainingHasLearningCurve = (status?: ModelStatus | null) =>
  // having learning curve or not only depends on the model status
  !!status &&
  [
    ModelStatus.Training,
    ModelStatus.Evaluating,
    ModelStatus.LegacyRunning,
    ModelStatus.Publishing,
    ModelStatus.Succeed,
  ].includes(status);

// TODO: ideally to speed up time user see model finish training, we should include ModelStatus.Publishing
// as "Successful" but only allow model to be deployed with ModelStatus.Succeed.
export const isModelTrainingSuccessful = (status?: ModelStatus | null, metricsReady?: boolean) =>
  !!(status && [ModelStatus.Succeed, ModelStatus.Saved].includes(status) && metricsReady === true);

export const isDeploymentPending = (status?: ModelStatus | null) =>
  isModelTrainingInProgress(status, true);

export const isDeploymentReady = (status?: ModelStatus | null) =>
  isModelTrainingSuccessful(status, true);

export const getConfusionMatrixPerThreshold = (performanceMetrics: PerformanceMetrics) => {
  if (!performanceMetrics) {
    return undefined;
  }
  const { version, confusionMatrix } = performanceMetrics;

  const { truePositives, falsePositives, trueNegatives, falseNegatives, misclassified } =
    version === '1.0' ? (confusionMatrix! as ConfusionMatrixPerThreshold) : performanceMetrics;
  return {
    truePositives: truePositives || [],
    falsePositives: falsePositives || [],
    trueNegatives: trueNegatives || [],
    falseNegatives: falseNegatives || [],
    misclassified: misclassified || [],
  } as ConfusionMatrixPerThreshold;
};

export const getConfusionMatrix = (
  confusionMatrixPerThreshold:
    | ConfusionMatrixPerThreshold
    | ClassificationConfusionMatrix
    | undefined,
  thresholds: number[],
  threshold?: number,
): ConfusionMatrix | ClassificationConfusionMatrix | undefined => {
  if (!confusionMatrixPerThreshold || threshold === undefined) {
    return undefined;
  }
  if ('truePositives' in confusionMatrixPerThreshold) {
    const {
      truePositives = [],
      falsePositives = [],
      trueNegatives = [],
      falseNegatives = [],
      misclassified = [],
    } = confusionMatrixPerThreshold ?? ({} as ConfusionMatrixPerThreshold);
    const thresholdsStep = thresholds[1] - thresholds[0];
    const thresholdIndex = thresholds.findIndex(n => n >= threshold - thresholdsStep * 0.1);

    const truePositiveCount = truePositives[thresholdIndex] ?? 0;
    const falsePositiveCount = falsePositives[thresholdIndex] ?? 0;
    const trueNegativeCount = trueNegatives[thresholdIndex] ?? 0;
    const falseNegativeCount = falseNegatives[thresholdIndex] ?? 0;
    const misclassifiedCount = misclassified?.[thresholdIndex];
    return {
      truePositiveCount,
      falsePositiveCount,
      trueNegativeCount,
      falseNegativeCount,
      misclassifiedCount,
    } as ConfusionMatrix;
  } else {
    return confusionMatrixPerThreshold;
  }
};

/**
 * Get performance metrics for current confidence threshold.
 */
export const useProjectModelPerformanceMetrics = (threshold?: number) => {
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const { data: jobDetails } = useJobDetailForCurrentProject(currentModelId);
  const datasetLevelMetrics = jobDetails?.datasetLevelMetrics;

  if (!datasetLevelMetrics || threshold === undefined) {
    return undefined;
  }

  const {
    performance: annotationLevelPerformanceList,
    performancePerSplit: annotationLevelPerformanceListPerSplit,
    binarizedPerformance: mediaLevelPerformanceList = [],
    binarizedPerformancePerSplit: mediaLevelPerformanceListPerSplit = {},
  } = datasetLevelMetrics;

  // find the current threshold index
  const { thresholdsMin, thresholdsMax, thresholdsStep } = CONFIDENCE_THRESHOLD_OPTIONS;
  const thresholds = range(thresholdsMin, thresholdsMax, thresholdsStep);
  const thresholdIndex = thresholds.findIndex(n => n >= threshold - thresholdsStep * 0.1);

  // for compatibility
  const confusionMatrix = getConfusionMatrixPerThreshold(datasetLevelMetrics);

  if (thresholdIndex < 0 || !confusionMatrix) {
    return undefined;
  }

  const { truePositives, falsePositives, trueNegatives, falseNegatives, misclassified } =
    confusionMatrix;

  // get performance metrics for the current threshold index
  const annotationLevelPerformance = annotationLevelPerformanceList[thresholdIndex];
  const annotationLevelPerformancePerSplit = Object.entries(
    annotationLevelPerformanceListPerSplit || {},
  ).reduce(
    (acc, [split, list]) => ({
      ...acc,
      [split]: list[thresholdIndex],
    }),
    {} as { [split: string]: number },
  );
  const mediaLevelPerformance = mediaLevelPerformanceList?.[thresholdIndex];
  const mediaLevelPerformancePerSplit = Object.entries(
    mediaLevelPerformanceListPerSplit || {},
  ).reduce(
    (acc, [split, list]) => ({
      ...acc,
      [split]: list[thresholdIndex],
    }),
    {} as { [split: string]: number },
  );
  const truePositiveCount = truePositives?.[thresholdIndex] ?? 0;
  const falsePositiveCount = falsePositives?.[thresholdIndex] ?? 0;
  const trueNegativeCount = trueNegatives?.[thresholdIndex] ?? 0;
  const falseNegativeCount = falseNegatives?.[thresholdIndex] ?? 0;
  const misclassifiedCount = misclassified?.[thresholdIndex] ?? 0;

  return {
    annotationLevelPerformance,
    annotationLevelPerformancePerSplit,
    mediaLevelPerformance,
    mediaLevelPerformancePerSplit,
    annotationLevelConfusionMatrix: {
      truePositiveCount,
      falsePositiveCount,
      trueNegativeCount,
      falseNegativeCount,
    },
    misclassifiedCount,
  } as ProjectModelPerformanceMetrics;
};

export const useMediaLevelConfusionMatrix = (
  split?: string,
): [ConfusionMatrix | undefined, boolean] => {
  const { data: allFilters } = useGetDatasetFilterOptionsQuery();
  // 'Split' is in columnFilterMap, the project has split in datasetContent's splitSet column
  // 'split' is in fieldFilterMap, the project has split metadata
  // Always prefer to use 'Split' column in case user manually creates 'split' or 'Split' metadata
  const splitFilterOption = allFilters
    ? allFilters.find(value => value.filterName === 'Split' && value.filterType === 'column') ||
      allFilters.find(value => value.filterName === 'split')
    : undefined;
  const splitId = splitFilterOption?.value?.[split ?? ''] as number | undefined;
  let selectOptions: SelectMediaOption | undefined = undefined;
  if (split) {
    selectOptions = {
      selectedMedia: [],
      unselectedMedia: [],
      isUnselectMode: true,
      fieldFilterMap:
        splitFilterOption?.filterType === 'field'
          ? { [splitFilterOption.fieldId!]: { CONTAINS_ANY: [split] } }
          : {},
      columnFilterMap:
        splitFilterOption?.filterType === 'column' && splitId
          ? { datasetContent: { splitSet: { CONTAINS_ANY: [splitId] } } }
          : {},
    };
  }

  const { data: stats, isLoading: statsLoading } = useGetDatasetStatsQuery(
    {
      selectOptions,
      groupOptions: [DatasetGroupOptions.GROUND_TRUTH_MEDIA_LEVEL_LABEL],
    },
    !!selectOptions,
  );

  if (!stats) {
    return [undefined, statsLoading];
  }

  // true positive: predict NG as NG
  const truePositiveCount =
    stats.find(
      ({ ground_truth_media_level_label: gt, prediction_media_level_label: predict }) =>
        gt === 'NG' && predict === 'NG',
    )?.count ?? 0;
  // false positive: predict OK as NG
  const falsePositiveCount =
    stats.find(
      ({ ground_truth_media_level_label: gt, prediction_media_level_label: predict }) =>
        gt === 'OK' && predict === 'NG',
    )?.count ?? 0;
  // false negative: predict NG as OK
  const falseNegativeCount =
    stats.find(
      ({ ground_truth_media_level_label: gt, prediction_media_level_label: predict }) =>
        gt === 'NG' && predict === 'OK',
    )?.count ?? 0;
  // true negative: predict OK as OK
  const trueNegativeCount =
    stats.find(
      ({ ground_truth_media_level_label: gt, prediction_media_level_label: predict }) =>
        gt === 'OK' && predict === 'OK',
    )?.count ?? 0;

  return [
    {
      truePositiveCount,
      falsePositiveCount,
      falseNegativeCount,
      trueNegativeCount,
    },
    statsLoading,
  ];
};
