import { useRequiredContext } from "@redotech/react-util/context";
import { env, FeatureExtractionPipeline, pipeline } from "@xenova/transformers";
import { createContext, useEffect, useReducer, useRef } from "react";

const MODEL_ID = "Xenova/all-mpnet-base-v2";
const MAX_WAIT_TIME = 90000;
const CHECK_INTERVAL = 3000;

type ModelStatus = "uninitialized" | "loading" | "ready" | "error";

type SearchModelState = {
  model: FeatureExtractionPipeline | null;
  status: ModelStatus;
  error: Error | null;
};

type SearchModelAction =
  | { type: "START_LOADING" }
  | { type: "MODEL_LOADED"; payload: FeatureExtractionPipeline }
  | { type: "ERROR"; payload: Error }
  | { type: "RESET" };

const SearchModelContext = createContext<
  | {
      state: SearchModelState;
      getEmbedding: (text: string) => Promise<number[]>;
    }
  | undefined
>(undefined);

function searchModelReducer(
  state: SearchModelState,
  action: SearchModelAction,
): SearchModelState {
  switch (action.type) {
    case "START_LOADING":
      return { ...state, status: "loading", error: null };
    case "MODEL_LOADED":
      return { model: action.payload, status: "ready", error: null };
    case "ERROR":
      return { ...state, status: "error", error: action.payload };
    case "RESET":
      return { model: null, status: "uninitialized", error: null };
    default:
      return state;
  }
}

export const SearchModelProvider: React.FC<{ children: React.ReactNode }> = ({
  children,
}) => {
  const [state, dispatch] = useReducer(searchModelReducer, {
    model: null,
    status: "uninitialized",
    error: null,
  });

  const initializationRef = useRef({ attempt: 0, lastAttempt: 0 });

  useEffect(() => {
    const abortController = new AbortController();

    async function loadModel() {
      if (state.status === "loading") return;

      dispatch({ type: "START_LOADING" });
      try {
        env.allowLocalModels = false;
        const extractor = await pipeline("feature-extraction", MODEL_ID, {
          quantized: true,
        });

        if (abortController.signal.aborted) return;

        await extractor("test string", { pooling: "mean", normalize: true });

        if (abortController.signal.aborted) return;

        dispatch({ type: "MODEL_LOADED", payload: extractor });
        initializationRef.current.attempt = 0;
      } catch (error) {
        if (!abortController.signal.aborted) {
          console.warn("Model initialization error:", error);
          dispatch({ type: "ERROR", payload: error as Error });
          if (window.caches) {
            await caches
              .keys()
              .then((keys) =>
                Promise.all(keys.map((key) => caches.delete(key))),
              );
          }
          if (
            Date.now() - initializationRef.current.lastAttempt >
            MAX_WAIT_TIME
          ) {
            initializationRef.current.attempt = 0;
          }

          if (initializationRef.current.attempt < 3) {
            initializationRef.current.attempt++;
            initializationRef.current.lastAttempt = Date.now();
            setTimeout(loadModel, CHECK_INTERVAL);
          }
        }
      }
    }

    void loadModel();

    return () => {
      abortController.abort();
      dispatch({ type: "RESET" });
    };
  }, []);

  const getEmbedding = async (text: string): Promise<number[]> => {
    const startTime = Date.now();

    while (state.status !== "ready") {
      if (Date.now() - startTime >= MAX_WAIT_TIME) {
        throw new Error("Model failed to load within the expected time");
      }
      if (state.status === "error") {
        throw state.error || new Error("Model initialization failed");
      }
      await new Promise((resolve) => setTimeout(resolve, CHECK_INTERVAL));
    }

    if (!state.model) {
      throw new Error("Model unexpectedly null despite ready status");
    }

    try {
      const output = await state.model(text, {
        pooling: "mean",
        normalize: true,
      });

      return Array.from(output.data);
    } catch (error) {
      console.error("Embedding generation error:", error);
      throw error;
    }
  };

  return (
    <SearchModelContext.Provider value={{ state, getEmbedding }}>
      {children}
    </SearchModelContext.Provider>
  );
};

export const useSearchModel = () => useRequiredContext(SearchModelContext);
