import { createAction, createSlice } from "@reduxjs/toolkit";
import type {
  CreateSliceOptions,
  Middleware,
  Slice,
  SliceCaseReducers
} from "@reduxjs/toolkit";
import get from "lodash/get";
import merge from "lodash/merge";
import set from "lodash/set";
import { isEqual } from "../../util/isEqual";

function createPersistedSlice<
  State,
  CaseReducers extends SliceCaseReducers<State>,
  Name extends string = string
>(
  options: CreateSliceOptions<State, CaseReducers, Name> & {
    whiteList?: string[];
  }
): [Slice<State, CaseReducers, Name>, Middleware] {
  const rehydrateAction = createAction<Partial<State>>(
    `persist/${options.name}/rehydrate`
  );
  const { whiteList, ...sliceOptions } = options;
  const storageKey = `persist/${sliceOptions.name}`;
  const persistedState = JSON.parse(
    localStorage.getItem(storageKey) ?? "{}"
  ) as Partial<State>;
  const persistedSliceOptions: typeof sliceOptions = {
    ...sliceOptions,
    initialState: () => {
      const originalInitialState =
        sliceOptions.initialState instanceof Function
          ? sliceOptions.initialState()
          : sliceOptions.initialState;
      return merge({}, originalInitialState, persistedState) as State;
    },
    extraReducers: (builder) => {
      builder.addCase(rehydrateAction, (state, action) => {
        merge(state, action.payload);
      });

      if (
        sliceOptions.extraReducers &&
        typeof sliceOptions.extraReducers === "function"
      ) {
        sliceOptions.extraReducers(builder);
      }
    }
  };

  const persistedSlice = createSlice(persistedSliceOptions);

  const persistMiddleware: Middleware = (store) => {
    window.addEventListener("storage", (event) => {
      if (event.key === storageKey) {
        const currentState = store.getState()?.[sliceOptions.name] as State;
        const newState = JSON.parse(event.newValue ?? "{}") as Partial<State>;
        let changes: Partial<State> = {};

        if (whiteList == null) {
          changes = newState;
        } else {
          whiteList.forEach((key) => {
            const currentValue = get(currentState, key);
            const newValue = get(newState, key);
            if (!isEqual(currentValue, newValue)) {
              set(changes, key, newValue);
            }
          });
        }

        store.dispatch(rehydrateAction(changes));
      }
    });

    return (next) => (action) => {
      const previousState = get(store.getState(), sliceOptions.name) as State;
      next(action);
      const currentState = get(store.getState(), sliceOptions.name) as State;

      const willBePersisted =
        whiteList == null
          ? !isEqual(currentState, previousState)
          : whiteList.some(
              (key) => !isEqual(get(currentState, key), get(previousState, key))
            );

      if (willBePersisted) {
        if (whiteList == null) {
          localStorage.setItem(storageKey, JSON.stringify(currentState));
        } else {
          const persistedState: Partial<State> = {};
          for (const key of whiteList) {
            set(persistedState, key, get(currentState, key));
          }
          localStorage.setItem(storageKey, JSON.stringify(persistedState));
        }
      }
    };
  };

  return [persistedSlice, persistMiddleware];
}

export default createPersistedSlice;
