import React, { useMemo } from 'react';
import {
  TableCell,
  Box,
  makeStyles,
  Tooltip,
  TextField,
  FormHelperText,
  FormControl,
  MenuItem,
  Slider,
} from '@material-ui/core';
import { useAtom } from 'jotai';
import { useSnackbar } from 'notistack';

import { TRANSFORM_TEXTS } from '@/constants/model_train';
import {
  getHyperParamDetails,
  getErrorsByHyperParams,
  getModelArchsByLabelType,
} from '@/utils/job_train_utils';
import { ClientFeatures, useFeatureGateEnabled } from '@/hooks/useFeatureGate';
import { modelsConfigListAtom } from '@/uiStates/customTraining/pageUIStates';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { LabelType } from '@clef/shared/types';
import { useGetModelArchSchemas } from '@/serverStore/train';

import { transferDatasetLimitsFormat } from './index';

const useStyles = makeStyles(theme => ({
  configCell: {
    verticalAlign: 'top',
  },
  parameterRoot: {
    display: 'flex',
    margin: theme.spacing(0, 4, 3, 0),
  },
  parameterItem: {
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'space-between',
  },
  parameterLabel: {
    color: theme.palette.greyModern[900],
  },
  parameterField: {
    border: `1px solid transparent`,
    borderRadius: 5,
    flexShrink: 0,
    '& .MuiInputBase-root': {
      backgroundColor: 'transparent',
      '&::before': {
        display: 'none',
      },
    },
    '& .MuiInputBase-input': {
      padding: theme.spacing(2),
      width: '136px',
      color: theme.palette.greyModern[500],
    },
    '& .MuiSvgIcon-root': {
      display: 'none',
    },
    '&:hover': {
      borderColor: theme.palette.greyModern[300],
      '& .MuiSvgIcon-root': {
        display: 'inline-block',
      },
    },
  },
  parameterFieldSlider: {
    width: '54%',
  },
}));

export const getHyperParamsTooltip = (key: string, labelType?: LabelType) => {
  if (key === 'archName') {
    if (!labelType) return '';
    const archNames = getModelArchsByLabelType(labelType);
    return t(`{{level3Text}}{{break}}{{level2Text}}{{level1Text}}`, {
      level3Text: t(
        `{{level3}}: Capture more complex patterns in the training data. Training and running inferences will take longer.`,
        {
          level3: <strong>{archNames.level3}</strong>,
        },
      ),
      level2Text: t(`{{level2}}: Train and run inferences faster.`, {
        level2: <strong>{archNames.level2}</strong>,
      }),
      break: <br />,
      level1Text: archNames.level1
        ? t(`{{level1}}: Fastest training and inference times.`, {
            level1: (
              <>
                <br />
                <strong>{archNames.level1}</strong>
              </>
            ),
          })
        : '',
    });
  }
  return TRANSFORM_TEXTS[key]?.tooltip;
};

interface HyperParametersProps {
  rowIndex: number;
}

const HyperParameters: React.FC<HyperParametersProps> = ({ rowIndex }) => {
  const styles = useStyles();
  const { labelType } = useGetSelectedProjectQuery().data ?? {};

  const { enqueueSnackbar } = useSnackbar();

  const enableNMSThreshold =
    useFeatureGateEnabled(ClientFeatures.NMSThresholdAdvTrain) &&
    labelType === LabelType.BoundingBox;

  const { data: modelArchSchemas } = useGetModelArchSchemas();

  const [modelsConfigList, setModelsConfigList] = useAtom(modelsConfigListAtom);
  const { trainingParams, availableModelSizes, currentSchema } = modelsConfigList[rowIndex].config;
  const hyperParamsFormatted = trainingParams
    ? getHyperParamDetails(trainingParams.hyperParams!.model, {
        availableModelSizes,
      })
    : [];

  const errorsByHyperParams = useMemo(() => {
    return getErrorsByHyperParams(
      currentSchema?.schema,
      Number(trainingParams?.hyperParams?.model['learningParams.epochs']),
    );
  }, [currentSchema, trainingParams?.hyperParams?.model]);

  return (
    <TableCell className={styles.configCell}>
      {hyperParamsFormatted.map(param => {
        if (param.name === 'nmsParams.iou_threshold' && !enableNMSThreshold) return null;

        return (
          <FormControl
            error={!!errorsByHyperParams?.model?.[param.name]}
            key={param.name}
            className={styles.parameterRoot}
          >
            <Box className={styles.parameterItem}>
              <Tooltip
                title={getHyperParamsTooltip(param.name, labelType!) ?? ''}
                arrow
                placement="top"
              >
                <Box className={styles.parameterLabel}>{param.label}</Box>
              </Tooltip>

              {param.name === 'nmsParams.iou_threshold' &&
              !!currentSchema?.schema.definitions.NonMaxSuppressionParams ? (
                <Slider
                  classes={{ root: styles.parameterFieldSlider }}
                  step={0.05}
                  min={
                    currentSchema?.schema.definitions.NonMaxSuppressionParams.properties
                      .iou_threshold.minimum
                  }
                  max={
                    currentSchema?.schema.definitions.NonMaxSuppressionParams.properties
                      .iou_threshold.maximum
                  }
                  defaultValue={currentSchema?.schema.properties.nmsParams.default.iou_threshold}
                  valueLabelDisplay="auto"
                  marks={[
                    { value: 0, label: '0' },
                    { value: 1, label: '1' },
                  ]}
                  onChange={(_, newValue) => {
                    const newHyperParams_model = {
                      ...trainingParams!.hyperParams!.model,
                      [param.name]: Number(newValue),
                    };
                    setModelsConfigList(prev => {
                      const newModelsConfigList = [...prev];
                      newModelsConfigList[rowIndex].config.trainingParams!.hyperParams!.model =
                        newHyperParams_model;
                      return newModelsConfigList;
                    });
                  }}
                />
              ) : param.type === 'number' ? (
                <TextField
                  classes={{ root: styles.parameterField }}
                  type={param.type}
                  value={param.default.toString()}
                  onChange={event => {
                    const newHyperParams_model = {
                      ...trainingParams!.hyperParams!.model,
                      [param.name]: Number(event.target.value),
                    };
                    setModelsConfigList(prev => {
                      const newModelsConfigList = [...prev];
                      newModelsConfigList[rowIndex].config.trainingParams!.hyperParams!.model =
                        newHyperParams_model;
                      return newModelsConfigList;
                    });
                  }}
                  inputProps={{
                    min: currentSchema?.schema.definitions.LearningParams.properties.epochs.minimum,
                    step: 1,
                  }}
                />
              ) : (
                <TextField
                  classes={{ root: styles.parameterField }}
                  type={param.type}
                  value={param.default}
                  select={!!param.options}
                  onChange={event => {
                    const newSchema = modelArchSchemas?.find(
                      schema => schema.name === event.target.value,
                    );
                    if (!newSchema) {
                      enqueueSnackbar(t('Invalid schema for the model size.'), {
                        variant: 'error',
                        autoHideDuration: 12000,
                      });
                      return;
                    }
                    const newHyperParams_model = {
                      ...trainingParams!.hyperParams!.model,
                      [param.name]: event.target.value,
                      modelSize: newSchema.modelSize,
                    };
                    setModelsConfigList(prev => {
                      const newModelsConfigList = [...prev];
                      newModelsConfigList[rowIndex].config.trainingParams!.hyperParams!.model =
                        newHyperParams_model;
                      newModelsConfigList[rowIndex].config.currentSchema = newSchema;
                      newModelsConfigList[rowIndex].config.limits = transferDatasetLimitsFormat(
                        newSchema.datasetLimits,
                      );
                      return newModelsConfigList;
                    });
                  }}
                >
                  {param.options?.map(modelSizeOption => (
                    <MenuItem key={modelSizeOption} value={modelSizeOption}>
                      {modelSizeOption}
                    </MenuItem>
                  ))}
                </TextField>
              )}
            </Box>
            {!!errorsByHyperParams?.model?.[param.name] && (
              <FormHelperText>
                {errorsByHyperParams?.model?.[param.name] ??
                  TRANSFORM_TEXTS[param.name]?.description ??
                  param.description}
              </FormHelperText>
            )}
          </FormControl>
        );
      })}
    </TableCell>
  );
};

export default HyperParameters;
