import {useMemo, useState} from 'react';
import {useCallbackRef} from '@unthinkable/react-utils';

export const useSelection = ({
  rowKey = '_id',
  data = [],
  onSelectionChange,
  onSelect,
  singleSelection,
  defaultSelectedIds = [],
  recursiveKey = 'children',
  selectionOnIndex,
} = {}) => {
  const [selectedIds, setSelectedIds] = useState(defaultSelectedIds);

  const _setSelectedIds = useCallbackRef(_selectedIds => {
    setSelectedIds(_selectedIds);
    onSelectionChange && onSelectionChange(_selectedIds);
  });

  const resetSelection = useCallbackRef(() => {
    _setSelectedIds([]);
  });

  const getSelectionKeyValue = (row, index) => {
    if (selectionOnIndex) {
      return index;
    }
    return row[rowKey];
  };

  // Helper function to get all descendant keys
  const getDescendantKeys = (node, index, keys = []) => {
    const keyValue = getSelectionKeyValue(node, index);
    if (keyValue !== undefined && keyValue !== null) {
      keys.push(keyValue);
    }
    const children = node?.[recursiveKey];
    if (children) {
      children.forEach((child, _index) =>
        getDescendantKeys(child, _index, keys),
      );
    }
    return keys;
  };

  // Helper function to get the parent node
  const getParentNode = (node, nodes) => {
    for (let parentNode of nodes) {
      const children = parentNode?.[recursiveKey];
      let nodeKeyValue = getSelectionKeyValue(node);
      if (nodeKeyValue && children) {
        if (
          children.some(child => getSelectionKeyValue(child) === nodeKeyValue)
        ) {
          return parentNode;
        } else {
          const parent = getParentNode(node, children);
          if (parent) return parent;
        }
      }
    }
    return null;
  };

  // Helper function to update the selection state for a node's ancestors
  const updateAncestorSelection = (node, nodes, selected) => {
    const parentNode = getParentNode(node, nodes);
    if (parentNode) {
      const children = parentNode?.[recursiveKey];
      const allChildrenSelected = children.every(child =>
        selected.includes(getSelectionKeyValue(child)),
      );
      let parentNodeKeyValue = getSelectionKeyValue(parentNode);
      if (allChildrenSelected) {
        // If all children are selected, add the parent to the selected list
        if (!selected.includes(parentNodeKeyValue)) {
          selected.push(parentNodeKeyValue);
        }
      } else {
        // If not all children are selected, remove the parent from the selected list
        const parentIndex = selected.indexOf(parentNodeKeyValue);
        if (parentIndex > -1) {
          selected.splice(parentIndex, 1);
        }
      }

      // Continue updating up the tree
      updateAncestorSelection(parentNode, nodes, selected);
    }
  };

  const toggleSelection = useCallbackRef((row, index, e) => {
    const rowId = getSelectionKeyValue(row, index);
    if (rowId === undefined || rowId === null) {
      return;
    }

    let newSelectedIds = [...selectedIds];
    const selectedIndex = newSelectedIds.indexOf(rowId);
    const isSelected = selectedIndex !== -1;

    if (isSelected) {
      // Deselect the item and its descendants
      const descendantKeys = getDescendantKeys(row, index);
      newSelectedIds = newSelectedIds.filter(
        id => !descendantKeys.includes(id),
      );
    } else {
      // Select the item and its descendants
      const descendantKeys = getDescendantKeys(row, index);
      newSelectedIds = singleSelection
        ? [rowId]
        : [...new Set([...newSelectedIds, ...descendantKeys])];
    }

    // Update selection state for ancestors if necessary
    if (!singleSelection) {
      updateAncestorSelection(row, data, newSelectedIds);
    }

    onSelect && onSelect(row, !isSelected, e);
    _setSelectedIds(newSelectedIds);
  });

  // Function to check if all nodes are selected
  const isAllSelected = useCallbackRef(() => {
    const allKeys = getDescendantKeys({[recursiveKey]: data});
    return (
      allKeys.length > 0 && allKeys.every(key => selectedIds.includes(key))
    );
  });

  const isIndeterminate = useCallbackRef(() => {
    if (!data || data?.length === 0 || isAllSelected()) return false;
    const allKeys = data.flatMap((node, index) =>
      getDescendantKeys(node, index),
    );
    const selectedCount = allKeys.filter(key =>
      selectedIds.includes(key),
    ).length;
    return selectedCount > 0 && selectedCount < allKeys.length;
  });

  const isRowIndeterminate = useCallbackRef((row, index) => {
    const rowId = getSelectionKeyValue(row, index);
    if (rowId === undefined || rowId === null) return false;
    const descendantKeys = getDescendantKeys(row, index);
    const selectedCount = descendantKeys.filter(key =>
      selectedIds.includes(key),
    ).length;
    return selectedCount > 0 && selectedCount < descendantKeys.length;
  });

  // Function to toggle selection of all nodes
  const toggleAllSelection = useCallbackRef(() => {
    if (isAllSelected()) {
      // Deselect all
      _setSelectedIds([]);
    } else {
      // Select all
      const allKeys = getDescendantKeys({[recursiveKey]: data});
      _setSelectedIds(allKeys);
    }
  });

  const isSelected = useCallbackRef((row, index) => {
    const rowId = getSelectionKeyValue(row, index);
    return rowId === undefined || rowId === null
      ? false
      : selectedIds.indexOf(rowId) > -1;
  });

  const selectedData = useMemo(() => {
    if (!selectedIds?.length) {
      return [];
    }
    return data?.filter((row, index) =>
      selectedIds?.includes(getSelectionKeyValue(row, index)),
    );
  }, [data, selectedIds, rowKey]);

  return {
    selectedIds,
    selectedData,
    toggleSelection,
    toggleAllSelection,
    resetSelection,
    isSelected,
    isAllSelected,
    isIndeterminate,
    isRowIndeterminate,
  };
};
