import { ModelStatus, RegisteredModel } from '@clef/shared/types';

import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useSnackbar } from 'notistack';
import ProjectModelAPI from '@/api/project_model_api';
import { projectModelQueryKeys } from './queries';
import { useTypedSelector } from '@/hooks/useTypedSelector';
import { modelAnalysisQueryKeys } from '../modelAnalysis';
import { RegisteredModelWithBundles } from '@/api/model_api';
import { layoutQueryKeys } from '../layout';

export const useSaveModelMutation = () => {
  const { enqueueSnackbar } = useSnackbar();
  const queryClient = useQueryClient();
  const selectedProjectId = useTypedSelector(state => state.project.selectedProjectId) ?? 0;
  return useMutation({
    mutationFn: async (params: { id: string; modelName: string }) => {
      await ProjectModelAPI.saveModel(params.id, params.modelName);
      return params;
    },
    onSuccess: params => {
      queryClient.setQueryData(
        projectModelQueryKeys.list(selectedProjectId),
        (previous: RegisteredModel[] = []) => {
          return previous.map(model => {
            if (model.id === params.id) {
              return {
                ...model,
                modelName: params.modelName,
              };
            }
            return model;
          });
        },
      );
      queryClient.setQueryData(
        modelAnalysisQueryKeys.modelList(selectedProjectId),
        (previous?: RegisteredModelWithBundles[]) => {
          return previous?.map(model => {
            if (params.id === model.id) {
              return { ...model, modelName: params.modelName };
            } else {
              return { ...model };
            }
          });
        },
      );
      enqueueSnackbar(t(`"${params.modelName}" is successfully saved.`), {
        variant: 'success',
      });
    },
    onError: (e: Error) => {
      enqueueSnackbar(t(`Failed to save model, {{errorMessage}}`, { errorMessage: e.message }), {
        variant: 'error',
        autoHideDuration: 12000,
      });
    },
  });
};

export const useStopTrainingMutation = () => {
  const { enqueueSnackbar } = useSnackbar();
  const queryClient = useQueryClient();
  return useMutation({
    mutationFn: async (params: { projectId: number; modelId: string }) => {
      await ProjectModelAPI.stopTraining(params.projectId, params.modelId);
      return params;
    },
    onSuccess: params => {
      queryClient.setQueryData(
        projectModelQueryKeys.list(params.projectId),
        (previous: RegisteredModel[] = []) => {
          return previous.map(model => {
            if (model.id === params.modelId) {
              return {
                ...model,
                isTraining: false,
                status: ModelStatus.Terminated,
              };
            }
            return model;
          });
        },
      );
      queryClient.invalidateQueries(layoutQueryKeys.list(params.projectId));
      enqueueSnackbar(t(`Training stopped.`), { variant: 'success' });
    },
    onError: (e: Error) => {
      enqueueSnackbar(t(`Failed to stop training, {{errorMessage}}`, { errorMessage: e.message }), {
        variant: 'error',
        autoHideDuration: 12000,
      });
    },
  });
};
