import { useQuery, useQueryClient } from '@tanstack/react-query';
import { OrgId, ProjectId, RegisteredModelId, TrainHealthData, UserId } from '@clef/shared/types';
import { ApiErrorType } from '@/api/base_api';
import { useTypedSelector } from '@/hooks/useTypedSelector';

import PictorAPI, {
  GetSnowflakeDatabaseResponse,
  GetSnowflakeSchemaResponse,
  GetSnowflakeStageResponse,
  SnowflakeFolderDirectoryResponse,
  SnowflakeSyncTaskMonitoringResponse,
  SnowflakeSyncTaskStatus,
} from '@/api/pictor_api';
import { useSnackbar } from 'notistack';
import experiment_report_api from '@/api/experiment_report_api';
import { projectQueryKeys, useGetSelectedProjectQuery } from '../projects';
import { datasetQueryKeys } from '../dataset';
import { throttle } from 'lodash';

export const snowflakeQueryKeys = {
  all: ['snowflake'] as const,
  databaseList: (userId: UserId, orgId: OrgId) =>
    [userId, orgId, ...snowflakeQueryKeys.all, 'databaseList'] as const,
  schemaList: (userId: UserId, orgId: OrgId, database: string | undefined) =>
    [userId, orgId, ...snowflakeQueryKeys.all, 'schemaList', database] as const,
  stageList: (
    userId: UserId,
    orgId: OrgId,
    database: string | undefined,
    schema: string | undefined,
  ) => [userId, orgId, ...snowflakeQueryKeys.all, 'schemaList', database, schema] as const,
  folderPathList: (
    userId: UserId,
    orgId: OrgId,
    database: string | undefined,
    schema: string | undefined,
    stage: string | undefined,
  ) =>
    [userId, orgId, ...snowflakeQueryKeys.all, 'folderPathList', database, schema, stage] as const,
  syncTaskList: (userId: UserId, orgId: OrgId, projectId: ProjectId | undefined) =>
    [userId, orgId, ...snowflakeQueryKeys.all, 'syncTaskList', projectId] as const,
  trainHealth: (userId: UserId, orgId: OrgId, modelId?: RegisteredModelId) =>
    [userId, orgId, ...snowflakeQueryKeys.all, 'trainHealth', modelId] as const,
};

export const useGetSnowflakeDatabaseQuery = () => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { enqueueSnackbar } = useSnackbar();
  return useQuery<GetSnowflakeDatabaseResponse, ApiErrorType>({
    queryKey: snowflakeQueryKeys.databaseList(userId, orgId) as unknown as [unknown],
    queryFn: async () => {
      const res = await PictorAPI.getSnowflakeDatabase();
      return res ?? [];
    },
    onError: async () => {
      enqueueSnackbar(
        t(`Failed to list snowflake database. Please contact LandingAI support team.`),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId,
  });
};

export const useGetSnowflakeSchemaQuery = (database: string | undefined) => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { enqueueSnackbar } = useSnackbar();
  return useQuery<GetSnowflakeSchemaResponse, ApiErrorType>({
    queryKey: snowflakeQueryKeys.schemaList(userId, orgId, database),
    queryFn: async () => {
      if (database === undefined) {
        return [];
      }
      const res = await PictorAPI.getSnowflakeSchema({ database });
      return res ?? [];
    },
    onError: async e => {
      enqueueSnackbar(
        t(`Failed to load snowflake schema list, {{errorMessage}}`, { errorMessage: e.message }),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId && !!database,
  });
};

export const useGetSnowflakeStageQuery = (
  database: string | undefined,
  schema: string | undefined,
) => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { enqueueSnackbar } = useSnackbar();
  return useQuery<GetSnowflakeStageResponse, ApiErrorType>({
    queryKey: snowflakeQueryKeys.stageList(userId, orgId, database, schema),
    queryFn: async () => {
      if (database === undefined || schema === undefined) {
        return [];
      }
      const res = await PictorAPI.getSnowflakeStage({ database, schema });
      return res ?? [];
    },
    onError: async e => {
      enqueueSnackbar(
        t(`Failed to load snowflake stage list, {{errorMessage}}`, { errorMessage: e.message }),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId && !!database && !!schema,
  });
};

export const useGetSnowflakeFolderPathListQuery = (
  database: string | undefined,
  schema: string | undefined,
  stage: string | undefined,
) => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { enqueueSnackbar } = useSnackbar();
  const isEmptyStr = (strToTest: string | undefined) => {
    return !strToTest || strToTest.trim().length === 0;
  };
  return useQuery<SnowflakeFolderDirectoryResponse | null, ApiErrorType>({
    queryKey: snowflakeQueryKeys.folderPathList(userId, orgId, database, schema, stage),
    queryFn: async () => {
      if (isEmptyStr(database) || isEmptyStr(schema) || isEmptyStr(stage)) {
        return null;
      }
      const fullStage = `${database}.${schema}.${stage}`;
      const res = await PictorAPI.getSnowflakeFolderPathList({ stage: fullStage });
      return res ?? {};
    },
    onError: async e => {
      enqueueSnackbar(
        t(`Failed to load folder list, {{errorMessage}}`, { errorMessage: e.message }),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId && !!database && !!schema && !!stage,
  });
};

export const useGetSnowflakeSyncTaskList = () => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { data: selectedProject } = useGetSelectedProjectQuery();
  const { enqueueSnackbar } = useSnackbar();
  const queryClient = useQueryClient();
  const selectedProjectId = selectedProject?.id;

  const refreshMediasThrottled = throttle(
    () => {
      selectedProjectId &&
        queryClient.invalidateQueries(datasetQueryKeys.medias(selectedProjectId), undefined, {
          cancelRefetch: false,
        });
      selectedProjectId &&
        queryClient.invalidateQueries(datasetQueryKeys.mediaCount(selectedProjectId));
      selectedProjectId &&
        queryClient.invalidateQueries(projectQueryKeys.defects(selectedProjectId));
      selectedProjectId &&
        queryClient.invalidateQueries(datasetQueryKeys.filterOptions(selectedProjectId));
    },
    3000,
    { leading: false, trailing: true },
  );

  return useQuery<SnowflakeSyncTaskMonitoringResponse, ApiErrorType>({
    queryKey: snowflakeQueryKeys.syncTaskList(userId, orgId, selectedProject?.id),
    queryFn: async () => {
      const res = await PictorAPI.getSnowflakeSyncTaskList(selectedProject?.id);
      return res ?? [];
    },
    onSuccess: async () => {
      refreshMediasThrottled();
    },
    onError: async e => {
      enqueueSnackbar(
        t(`Failed to load snowflake sync task list, {{errorMessage}}`, { errorMessage: e.message }),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId,
    refetchInterval: (data, query) => {
      // stop polling on any error
      if (query.state.error) {
        return false;
      }
      const hasSyncInProgress = data?.some(task => task.status === SnowflakeSyncTaskStatus.RUNNING);
      if (hasSyncInProgress) {
        return 10000;
      }
      return false;
    },
  });
};

export const useGetTrainHealthQuery = (modelId?: RegisteredModelId) => {
  const orgId = useTypedSelector(state => state.login.user?.orgId)!;
  const userId = useTypedSelector(state => state.login.user?.id)!;
  const { enqueueSnackbar } = useSnackbar();
  return useQuery<TrainHealthData, ApiErrorType>({
    queryKey: snowflakeQueryKeys.trainHealth(userId, orgId, modelId),
    queryFn: async () => {
      const res = await experiment_report_api.getTrainHealth();
      return res ?? [];
    },
    onError: async e => {
      enqueueSnackbar(
        t(`Failed to get training health, {{errorMessage}}`, { errorMessage: e.message }),
        { variant: 'error', preventDuplicate: true, autoHideDuration: 12000 },
      );
    },
    enabled: !!orgId && !!userId && !!modelId,
  });
};
