import React, { FC, useCallback, useEffect, useRef, useState } from 'react';
import Decimal from 'decimal.js';
import { ReactZoomPanPinchState, TransformComponent, TransformWrapper } from 'react-zoom-pan-pinch';
import { ReactComponent as ZoomIn } from '../../../assets/images/zoomIn.svg';
import { ReactComponent as ZoomOut } from '../../../assets/images/zoomOut.svg';
import FileHeatmapGradient from '../FileHeatmapGradient';
import './Heatmap.css';
import { Grid, IconButton, Typography } from '@mui/material';
import { Box } from '@mui/system';
import { HmsBoundaryConditions } from '../../../models/inputTypes/HmsFields';

const styles = {
  canvas: {
    width: '100%',
    height: '100%',
  },
  zoomBox: { background: '#DFE1E5', borderRadius: '4px', padding: '2px' },
  verticalTickNumbers: {
    textAlign: 'right',
    fontSize: '0.65rem',
  },
  horizontalTickNumbers: {
    textAlign: 'center',
    fontSize: '0.65rem',
  },
} as const;

const HeatmapHms: FC<{
  drawStates: boolean[];
  setDrawStates: (drawStates: boolean[]) => void;
  inputState: HmsBoundaryConditions;
  setInputState: (value: HmsBoundaryConditions) => void;
  index: number;
  setIndex: (index: number) => void;
  imageBase64: string | undefined;
  dimensions: number[] | undefined;
  values:
    | {
        max: number;
        min: number;
        originalMax: number;
        originalMin: number;
      }
    | undefined;
  frameCoords: { x: number; y: number };
  setFrameCoords: (frameCoords: { x: number; y: number }) => void;
  gradientColors: any[] | undefined;
}> = ({
  drawStates,
  setDrawStates,
  inputState,
  setInputState,
  index,
  setIndex,
  imageBase64,
  dimensions,
  values,
  frameCoords,
  setFrameCoords,
  gradientColors,
  ...rest
}) => {
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const [moving, setMoving] = useState<boolean>(false);
  const [mouseDown, setMouseDown] = useState<boolean>(false);
  const [point, setPoint] = useState({ x: 0, y: 0 });
  const [points, setPoints] = useState<{ x: number; y: number }[]>([]);
  const [drawing, setDrawing] = useState<boolean>(false);

  const calculateIsDraw = () => {
    return drawStates.some((state) => state);
  };

  const [isDraw, setDraw] = useState<boolean>(calculateIsDraw());

  useEffect(() => {
    setDraw(calculateIsDraw());
  }, [drawStates]);

  useEffect(() => {
    drawStates.forEach((item, i) => {
      if (item) {
        setIndex(i);
      }
    });
  }, [drawStates]);

  const onMouseDown = useCallback((event) => {
    setMoving(false);
    setMouseDown(true);
  }, []);

  const onMouseMove = useCallback(
    (event) => {
      if (moving) return;

      if (mouseDown && (event.movementX > 0 || event.movementY > 0)) {
        setMoving(true);
      }
    },
    [mouseDown, moving],
  );

  const onMouseUp = useCallback(
    (event) => {
      const canvas: HTMLCanvasElement | null = canvasRef.current;
      if (canvas) {
        const rect = canvas.getBoundingClientRect(),
          scaleX = new Decimal(canvas.width).div(rect.width).toNumber(),
          scaleY = new Decimal(canvas.height).div(rect.height).toNumber();
        let x = new Decimal(event.clientX).minus(rect.left).mul(scaleX).floor().minus(10).toNumber();
        let y = new Decimal(event.clientY).minus(rect.top).mul(scaleY).plus(1).ceil().plus(10).toNumber();

        if (dimensions && x + 20 > dimensions?.[1]) {
          x = dimensions?.[1] - 20;
        }
        if (x < 0) {
          x = 0;
        }

        if (y - 20 < 0) {
          y = 20;
        }
        if (dimensions && y > dimensions?.[0]) {
          y = dimensions?.[0];
        }

        if (!moving) {
          setFrameCoords({ x: x, y: y });
        }
      }

      setMoving(false);
      setMouseDown(false);
    },
    [canvasRef.current, moving, setFrameCoords],
  );

  const onMouseUpDraw = useCallback(
    (event) => {
      setDrawing(true);
      const canvas: HTMLCanvasElement | null = canvasRef.current;
      if (canvas) {
        const rect = canvas.getBoundingClientRect(),
          scaleX = new Decimal(canvas.width).div(rect.width).toNumber(),
          scaleY = new Decimal(canvas.height).div(rect.height).toNumber();
        const x = new Decimal(event.clientX).minus(rect.left).mul(scaleX).floor().minus(1).toNumber();
        const y = new Decimal(event.clientY).minus(rect.top).mul(scaleY).plus(1).ceil().plus(1).toNumber();

        if (dimensions) {
          setPoint({ x: x, y: dimensions[0] - y });
          setPoints([...points, { x: x, y: dimensions[0] - y }]);
        }
      }
      setMouseDown(false);
    },
    [canvasRef.current, points],
  );

  const onMouseDownDraw = useCallback((event) => {
    setDraw(true);
    setMouseDown(true);
  }, []);

  const onDbClick = useCallback(
    (event) => {
      setDrawing(false);
      const generatorsArray = inputState.wave_generators_line_coordinates.slice();
      setPoints([]);
      generatorsArray[index] = { wg_fpx: points[0].x, wg_fpy: points[0].y, wg_lpx: points[1].x, wg_lpy: points[1].y };
      setInputState({ ...inputState, wave_generators_line_coordinates: generatorsArray });
      const newDrawStates = new Array(drawStates.length).fill(false);
      setDrawStates(newDrawStates);
    },
    [points],
  );

  const onMouseMoveDraw = useCallback(
    (event) => {
      if (drawing) {
        const canvas: HTMLCanvasElement | null = canvasRef.current;
        const context = canvas?.getContext('2d');
        if (canvas) {
          const image = new Image();
          image.onload = () => {
            context?.clearRect(0, 0, canvas?.width ?? 0, canvas?.height ?? 0);
            const rect = canvas.getBoundingClientRect(),
              scaleX = new Decimal(canvas.width).div(rect.width).toNumber(),
              scaleY = new Decimal(canvas.height).div(rect.height).toNumber();
            const x = new Decimal(event.clientX).minus(rect.left).mul(scaleX).floor().minus(1).toNumber();
            const y = new Decimal(event.clientY).minus(rect.top).mul(scaleY).plus(1).ceil().plus(1).toNumber();

            context?.drawImage(image, 0, 0, canvas.width, canvas.height);

            if (inputState.wave_generators_line_coordinates && dimensions) {
              context?.beginPath();
              if (context) context.lineWidth = 5;
              if (context) context.strokeStyle = 'red';
              inputState.wave_generators_line_coordinates.forEach((item) => {
                context?.moveTo(item.wg_fpx, dimensions[0] - item.wg_fpy);
                context?.lineTo(item.wg_lpx, dimensions[0] - item.wg_lpy);
                context?.stroke();
              });
              context?.closePath();
            }
            context?.beginPath();
            if (context) context.lineWidth = 2;
            if (context) context.strokeStyle = 'black';
            if (dimensions) context?.moveTo(point.x, dimensions[0] - point.y);
            context?.lineTo(x, y);
            context?.stroke();
            context?.closePath();
          };
          if (imageBase64) image.src = 'data:image/png;base64,' + imageBase64;
        }
      }
    },
    [canvasRef.current, imageBase64, drawing, point],
  );

  useEffect(() => {
    const canvas: HTMLCanvasElement | null = canvasRef.current;
    const context = canvas?.getContext('2d');

    const image = new Image();
    image.onload = () => {
      context?.clearRect(0, 0, canvas?.width ?? 0, canvas?.height ?? 0);
      if (canvas && dimensions) canvas.width = dimensions?.[1];
      if (canvas && dimensions) canvas.height = dimensions?.[0];
      context?.drawImage(image, 0, 0, dimensions?.[1] ?? 0, dimensions?.[0] ?? 0);

      if (inputState.wave_generators_line_coordinates && dimensions) {
        context?.beginPath();
        if (context) context.lineWidth = 5;
        if (context) context.strokeStyle = 'red';
        inputState.wave_generators_line_coordinates.forEach((item) => {
          context?.moveTo(item.wg_fpx, dimensions[0] - item.wg_fpy);
          context?.lineTo(item.wg_lpx, dimensions[0] - item.wg_lpy);
          context?.stroke();
        });
        context?.closePath();
      }
      context?.beginPath();
      context?.rect(frameCoords.x, frameCoords.y - 20, 20, 20);
      if (context) context.lineWidth = 2;
      if (context) context.strokeStyle = 'black';
      context?.stroke();
      context?.closePath();
    };
    if (imageBase64) image.src = 'data:image/png;base64,' + imageBase64;
  }, [imageBase64, frameCoords, inputState, isDraw]);

  useEffect(() => {
    const canvas: HTMLCanvasElement | null = canvasRef.current;

    if (isDraw) {
      canvas?.addEventListener('mousedown', onMouseDownDraw);
      canvas?.addEventListener('mousemove', onMouseMoveDraw);
      canvas?.addEventListener('mouseup', onMouseUpDraw);
      canvas?.addEventListener('dblclick', onDbClick);

      return () => {
        canvas?.removeEventListener('mousedown', onMouseDownDraw);
        canvas?.removeEventListener('mousemove', onMouseMoveDraw);
        canvas?.removeEventListener('dblclick', onDbClick);
        canvas?.removeEventListener('mouseup', onMouseUpDraw);
      };
    } else {
      canvas?.addEventListener('mousedown', onMouseDown);
      canvas?.addEventListener('mousemove', onMouseMove);
      canvas?.addEventListener('mouseup', onMouseUp);

      return () => {
        canvas?.removeEventListener('mousedown', onMouseDown);
        canvas?.removeEventListener('mousemove', onMouseMove);
        canvas?.removeEventListener('mouseup', onMouseUp);
      };
    }
  }, [canvasRef.current, moving, mouseDown, isDraw, drawing, points]);

  const getVerticalAxisNumbers = useCallback(
    (state: ReactZoomPanPinchState) => {
      const canvas: HTMLCanvasElement | null = canvasRef.current;

      const verticalAxis: number[] = [];
      const height = new Decimal(dimensions?.[0] ?? 0).div(state.scale);
      const ratio = new Decimal(dimensions?.[0] ?? 0).div(canvas?.clientHeight ?? 1).div(state.scale);
      const offset = new Decimal(state?.positionY ?? 0).mul(ratio).mul(-1);
      const min = new Decimal(dimensions?.[0] ?? 0).minus(offset).minus(height);
      const max = new Decimal(dimensions?.[0] ?? 0).minus(offset);
      const step = new Decimal(max).minus(min).div(6);

      for (let st = 0; st <= 6; st++) {
        const sub = new Decimal(step).mul(st);
        const val = new Decimal(max).sub(sub).trunc();
        verticalAxis.push(val.toNumber());
      }

      return verticalAxis;
    },
    [dimensions],
  );

  const getHorizontalAxisNumbers = useCallback(
    (state: ReactZoomPanPinchState) => {
      const canvas: HTMLCanvasElement | null = canvasRef.current;

      const horizontalAxis: number[] = [];
      const width = new Decimal(dimensions?.[1] ?? 0).div(state.scale);
      const ratio = new Decimal(dimensions?.[1] ?? 0).div(canvas?.clientWidth ?? 1).div(state.scale);
      const offset = new Decimal(state?.positionX ?? 0).mul(ratio).mul(-1);
      const min = new Decimal(0).plus(offset);
      const max = new Decimal(0).plus(offset).plus(width);
      const step = new Decimal(max).minus(min).div(6);

      for (let st = 0; st <= 6; st++) {
        const sub = new Decimal(step).mul(st);
        const val = new Decimal(max).sub(sub).trunc();
        horizontalAxis.push(val.toNumber());
      }

      return horizontalAxis.reverse();
    },
    [dimensions],
  );

  return (
    <Box py={1} {...rest}>
      {imageBase64 && dimensions ? (
        <TransformWrapper panning={{ disabled: isDraw }} doubleClick={{ disabled: isDraw }}>
          {({ zoomIn, zoomOut, resetTransform, state, ...transformRest }) => {
            return (
              <Grid container spacing={1}>
                <Grid
                  item
                  xs={3}
                  container
                  justifyContent={'space-between'}
                  direction={'column'}
                  style={{ borderRight: '2px solid' }}
                >
                  {getVerticalAxisNumbers(state).map((num, ind) => (
                    <Typography sx={styles.verticalTickNumbers} key={ind} variant={'caption'}>
                      {ind === 0 || ind === 6 ? '' : num}
                    </Typography>
                  ))}
                </Grid>
                <Grid item xs={6}>
                  <TransformComponent
                    wrapperClass={'wrapper'}
                    wrapperStyle={{
                      paddingBottom: `${
                        dimensions ? new Decimal(dimensions?.[0]).div(dimensions?.[1]).mul(100).toNumber() : 0
                      }%`,
                    }}
                    contentClass={'content'}
                    {...transformRest}
                  >
                    <canvas style={styles.canvas} ref={canvasRef} {...rest} />
                  </TransformComponent>
                </Grid>
                <Grid item xs={3} container justifyContent={'space-between'}>
                  {gradientColors ? (
                    <FileHeatmapGradient levels={11} gradientColors={gradientColors} values={values} />
                  ) : (
                    <Box />
                  )}
                  <Box display={'flex'} alignItems={'center'} flexDirection={'column'}>
                    <Box sx={styles.zoomBox}>
                      <IconButton color={'primary'} size={'small'} onClick={() => zoomOut()}>
                        <ZoomOut />
                      </IconButton>
                      <IconButton color={'primary'} size={'small'} onClick={() => zoomIn()}>
                        <ZoomIn />
                      </IconButton>
                    </Box>
                  </Box>
                </Grid>
                <Grid item xs={3} />
                <Grid item xs={6} container justifyContent={'space-between'} style={{ borderTop: '2px solid' }}>
                  {getHorizontalAxisNumbers(state).map((num, ind) => (
                    <Typography sx={styles.horizontalTickNumbers} key={ind} variant={'caption'}>
                      {ind === 0 || ind === 6 ? '' : num}
                    </Typography>
                  ))}
                </Grid>
                <Grid item xs={3} />
              </Grid>
            );
          }}
        </TransformWrapper>
      ) : (
        <Grid container spacing={1}>
          <Grid item xs={3} />
          <Grid item xs={6} />
          <Grid item xs={3} container justifyContent={'space-between'}>
            <Box minHeight={360} />
            <Box display={'flex'} alignItems={'center'} flexDirection={'column'}>
              <Box>
                <IconButton color={'primary'} size={'small'}>
                  <ZoomOut />
                </IconButton>
                <IconButton color={'primary'} size={'small'}>
                  <ZoomIn />
                </IconButton>
              </Box>
              <Typography variant={'body2'} color={'primary'}>
                {`Scale: x1.00`}
              </Typography>
            </Box>
          </Grid>
        </Grid>
      )}
    </Box>
  );
};
export default HeatmapHms;
