import {
  Table as MuiTable,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
} from '@material-ui/core';
import { createStyles, makeStyles } from '@material-ui/core/styles';
import clsx from 'clsx';
import { ChevronDownIcon } from 'icons/ChevronDownIcon';
import { ChevronUpDownIcon } from 'icons/ChevronUpDownIcon';
import { ChevronUpIcon } from 'icons/ChevronUpIcon';
import get from 'lodash.get';
import { ReactElement, useCallback, useMemo, useState } from 'react';
import { Pagination } from './Pagination';
import { ESortDirection, ITableColumn, ITableProps, ITableRow } from './types';

const useStyles = makeStyles(({ palette, spacing }) =>
  createStyles({
    table: {
      border: 0,
    },
    headerCell: {
      cursor: 'pointer',
    },
    tableRow: {
      cursor: 'pointer',
      '&:hover': {
        backgroundColor: palette.grey2.main,
      },
    },
    sortIconWrapper: {
      display: 'flex',
      alignItems: 'center',
      gap: '5px',
    },
    sortIconWrapperCenterAligned: {
      justifyContent: 'center',
    },
    sortIconWrapperRightAligned: {
      justifyContent: 'flex-end',
    },
    paginationContainer: {
      marginTop: spacing(2),
      display: 'flex',
      justifyContent: 'flex-end',
    },
  })
);

const nextSortDirections = {
  [ESortDirection.NONE]: ESortDirection.ASC,
  [ESortDirection.ASC]: ESortDirection.DESC,
  [ESortDirection.DESC]: ESortDirection.NONE,
} as const;

/**
 * The common table function to be used in the app.
 *
 * @param {ITableProps<T>} - the table props.
 * @returns {ReactElement} - the rendered table.
 */
export function Table<T extends ITableRow>({
  className,
  rows,
  columns,
  pageSize,
  page,
  onRowClick,
  onPageChange,
}: ITableProps<T>): ReactElement {
  const classes = useStyles();
  const [sortField, setSortField] = useState<Nullable<ITableColumn<T>>>(null);
  const [sortDirection, setSortDirection] = useState(ESortDirection.ASC);

  const pageCount = Math.ceil(rows.length / pageSize);

  const sortedRows = useMemo(() => {
    if (!sortField || sortDirection === ESortDirection.NONE) return rows;

    const sorted = [...rows].sort((a, b) => {
      if (sortField.sort) return sortField.sort(a, b);
      return String(get(a, sortField.field)).localeCompare(
        String(get(b, sortField.field))
      );
    });

    return sortDirection === ESortDirection.DESC ? sorted.reverse() : sorted;
  }, [rows, sortField, sortDirection]);

  const filteredRows = useMemo(() => {
    return sortedRows.slice((page - 1) * pageSize, page * pageSize);
  }, [sortedRows, page, pageSize]);

  const getSortIcon = (column: ITableColumn<T>) => {
    if (column.field !== sortField?.field) {
      return <ChevronUpDownIcon />;
    }
    if (sortDirection === ESortDirection.ASC) {
      return <ChevronDownIcon />;
    }
    if (sortDirection === ESortDirection.DESC) {
      return <ChevronUpIcon />;
    }
    return <ChevronUpDownIcon />;
  };

  const handleChangeSort = useCallback(
    (field: ITableColumn<T>) => {
      if (field.sortDisabled) {
        return;
      }

      if (sortField === field) {
        setSortDirection(nextSortDirections[sortDirection]);
      } else {
        setSortField(field);
        setSortDirection(ESortDirection.ASC);
      }
    },
    [sortField, sortDirection]
  );

  return (
    <TableContainer>
      <MuiTable className={clsx(classes.table, className)}>
        <TableHead>
          <TableRow>
            {columns.map((column) => (
              <TableCell
                key={column.field}
                align={column.headerAlign}
                className={classes.headerCell}
                onClick={() => handleChangeSort(column)}
              >
                <div
                  className={clsx(
                    classes.sortIconWrapper,
                    column.align === 'center' &&
                      classes.sortIconWrapperCenterAligned,
                    column.align === 'right' &&
                      classes.sortIconWrapperRightAligned
                  )}
                >
                  {column.renderHeader ? column.renderHeader() : column.header}
                  {getSortIcon(column)}
                </div>
              </TableCell>
            ))}
          </TableRow>
        </TableHead>
        <TableBody>
          {filteredRows.map((row) => (
            <TableRow
              key={row.id}
              className={classes.tableRow}
              onClick={() => onRowClick(row.id)}
            >
              {columns.map((column) => (
                <TableCell key={column.field} align={column.align}>
                  {column.renderCell
                    ? column.renderCell(row)
                    : get(row, column.field)}
                </TableCell>
              ))}
            </TableRow>
          ))}
        </TableBody>
      </MuiTable>

      <div className={classes.paginationContainer}>
        <Pagination
          page={page}
          pageCount={pageCount}
          pageSize={pageSize}
          onPageChange={onPageChange}
        />
      </div>
    </TableContainer>
  );
}
