import {
  ColumnDef,
  OnChangeFn,
  Row,
  RowSelectionState,
  SortingState,
  flexRender,
  getCoreRowModel,
  getFilteredRowModel,
  useReactTable,
} from '@tanstack/react-table';
import { IconArrowUp } from '@unique/icons';
import cn from 'classnames';
import { MouseEvent, useState } from 'react';

interface TableSortableProps<T extends object> {
  data: T[];
  columns: ColumnDef<T>[];
  className?: string;
  wrapperClasses?: string;
  id?: string;
  sorting?: SortingState;
  rowSelection?: RowSelectionState | undefined;
  setSorting?: OnChangeFn<SortingState>;
  setRowSelection?: OnChangeFn<RowSelectionState>;
  handleClickRowCustom?: (rowObject: T) => void;
  headRowClasses?: string;
  headColClasses?: string;
  rowClasses?: string;
  highlightClickedRow?: boolean;
  isRowSelectable?: (row: Row<T>) => boolean;
  paginationAtEnd?: React.ReactNode;
}

const DEFAULT_TABLE_ROW_HEAD_CLASSES =
  'vertical text-xs font-semibold uppercase tracking-widest text-on-background-dimmed';

const DEFAULT_TABLE_HEAD_CLASSES =
  'gap-2 whitespace-nowrap py-3.5 overline-text first-of-type:pl-4 last-of-type:pr-4 sm:first-of-type:pl-6 sm:last-of-type:pr-6 sticky top-0 sticky left-0 z-1 bg-surface table-head-border';

const DEFAULT_TABLE_ROW_CLASSES =
  'z-0 border-b border-background-variant transition-colors duration-[200ms] last-of-type:border-none focus-within:!z-20';

const DEFAULT_TABLE_COLUMN_CLASSES =
  'h-full items-center py-3 pr-6 body-2 leading-normal first-of-type:pl-4 last-of-type:pr-4 text-on-control-main sm:first-of-type:pl-6 sm:last-of-type:pr-6';

export const TableSortable = <T extends object>({
  data,
  columns,
  className = '',
  wrapperClasses = '',
  id = 'table',
  sorting,
  rowSelection,
  setSorting,
  setRowSelection,
  headRowClasses = '',
  headColClasses = '',
  rowClasses = '',
  handleClickRowCustom,
  highlightClickedRow = false,
  isRowSelectable,
  paginationAtEnd,
}: TableSortableProps<T>) => {
  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    state: {
      sorting,
      rowSelection,
    },
    onSortingChange: setSorting,
    getFilteredRowModel: getFilteredRowModel(),
    onRowSelectionChange: setRowSelection,
    manualSorting: true,
    enableRowSelection: (row) => {
      if (typeof isRowSelectable === 'function') {
        return isRowSelectable(row);
      }
      return true;
    },
  });
  const tableRows = table.getRowModel().rows;
  const headerGroups = table.getHeaderGroups();
  const [clickedRowId, setClickedRowId] = useState<string>('');

  const handleClickRow = (
    event: MouseEvent<HTMLTableRowElement>,
    row: ReturnType<typeof table.getRowModel>['rows'][0],
    rowSelectionHandler: (event: unknown) => void,
  ) => {
    if (highlightClickedRow) {
      setClickedRowId(row.id);
    }
    const target = event.target as HTMLElement;
    const isCheckbox = target.dataset.isCheckbox;

    if (typeof handleClickRowCustom === 'function' && !isCheckbox) {
      handleClickRowCustom(row.original);
    }
    const isColumnClickable = target.dataset?.clickable === 'true';

    // only mark the row if the clicked element does not have a clickable attribute
    if ((!isColumnClickable && !handleClickRowCustom) || isCheckbox) {
      rowSelectionHandler(event);
    }
  };

  return (
    <>
      <div className={`h-full w-full table-auto overflow-auto text-left ${wrapperClasses}`}>
        <table className={`w-full ${className}`} id={id}>
          <thead>
            {headerGroups.map((headerGroup) => (
              <tr
                key={headerGroup.id}
                className={cn(DEFAULT_TABLE_ROW_HEAD_CLASSES, headRowClasses)}
              >
                {headerGroup.headers.map((header) => {
                  return (
                    <th
                      key={header.id}
                      colSpan={header.colSpan}
                      className={cn(DEFAULT_TABLE_HEAD_CLASSES, headColClasses)}
                      style={{
                        width:
                          header.column.getSize() !== 150 ? header.column.getSize() : undefined,
                      }}
                    >
                      {header.isPlaceholder ? null : (
                        <div
                          {...{
                            className: header.column.getCanSort()
                              ? 'cursor-pointer select-none flex items-center gap-x-2'
                              : '',
                            onClick: header.column.getToggleSortingHandler(),
                          }}
                        >
                          {flexRender(header.column.columnDef.header, header.getContext())}

                          {{
                            asc: <IconArrowUp />,
                            desc: (
                              <span className="rotate-180">
                                <IconArrowUp />
                              </span>
                            ),
                          }[header.column.getIsSorted() as string] ?? (
                            <>
                              {header.column.getCanSort() ? (
                                <span className="rotate-180 opacity-30">
                                  <IconArrowUp />
                                </span>
                              ) : null}
                            </>
                          )}
                        </div>
                      )}
                    </th>
                  );
                })}
              </tr>
            ))}
          </thead>
          <tbody>
            {tableRows.map((row) => {
              return (
                <tr
                  key={row.id}
                  className={cn({
                    [DEFAULT_TABLE_ROW_CLASSES]: true,
                    'bg-background-variant':
                      row.getIsSelected() || (highlightClickedRow && row.id === clickedRowId),
                    'hover:bg-background-variant cursor-pointer': !!rowSelection,
                    [rowClasses]: true,
                  })}
                  onClick={
                    rowSelection
                      ? (event) => handleClickRow(event, row, row.getToggleSelectedHandler())
                      : undefined
                  }
                >
                  {row.getVisibleCells().map((cell) => {
                    return (
                      <td
                        key={cell.id}
                        className={DEFAULT_TABLE_COLUMN_CLASSES}
                        style={{
                          width: cell.column.getSize() !== 150 ? cell.column.getSize() : undefined,
                        }}
                      >
                        {flexRender(cell.column.columnDef.cell, cell.getContext())}
                      </td>
                    );
                  })}
                </tr>
              );
            })}

            {paginationAtEnd ? (
              <tr>
                <td colSpan={columns.length} className="p-0">
                  {paginationAtEnd}
                </td>
              </tr>
            ) : null}
          </tbody>
        </table>
      </div>
    </>
  );
};

export default TableSortable;
