import { defaultGraphMargins } from '../CONSTANTS';
import { PointDrawFn, ScatterPoints, SVGGraphProps } from '../types';
import { isPopulatedDomain } from '../types/typeguards';
import { nanoid } from '@reduxjs/toolkit';
import {
	extent as d3Extent,
} from 'd3-array';
import { axisLeft, axisBottom } from 'd3-axis';
import { scaleLinear } from 'd3-scale';
import { selectAll, select } from 'd3-selection';
import { FunctionComponent, useLayoutEffect, useMemo, useState } from 'react';
import styled from 'styled-components';

const StyledPoint = styled.circle`
	fill: ${(p) => p.theme.palette.primary.main};
`;

const drawDefaultPoint: PointDrawFn<number, number> = ({ drawX, drawY }) => (
	<StyledPoint
		cx={drawX}
		cy={drawY}
		r={2}
		opacity={0.5}
		key={nanoid()}
		data-testid="default-point"
	/>
);

interface ScatterplotProps extends SVGGraphProps {
	facts: ScatterPoints;
	drawPoint?: PointDrawFn<number, number, any>;
}

const Scatterplot: FunctionComponent<ScatterplotProps> = ({
	width,
	height,
	facts,
	drawPoint,
	...margins
}) => {
	const draw = drawPoint ?? drawDefaultPoint;

	const { top, bottom, left, right } = { ...defaultGraphMargins, ...margins };

	const yMargin = top + bottom;
	const xMargin = left + right;

	const [svgRef, setSVGRef] = useState<SVGSVGElement | null>(null);

	const { xScale, yScale, drawHeight } = useMemo(() => {
		if (facts.length === 0) {
			return {
				xScale: null,
				yScale: null,
				drawWidth: 0,
				drawHeight: 0,
			};
		}

		const drawWidth = width > 0 ? width - xMargin : 0;

		const drawHeight = height > 0 ? height - yMargin : 0;

		//   specify some fallbacks here--d3.extent returns [undefined, undefined]
		// if it encounters data that cannot be compared.
		const xExtent = d3Extent(facts, (point) => point.x);

		const xDomain = isPopulatedDomain(xExtent) ? xExtent : [0, 0];

		const yExtent = d3Extent(facts, (point) => point.y);

		const yDomain = isPopulatedDomain(yExtent) ? yExtent : [0, 0];

		const xScale = scaleLinear()
			.range([0, drawWidth])
			.domain(xDomain)
			.nice();

		const yScale = scaleLinear()
			.range([drawHeight, 0])
			.domain(yDomain)
			.nice();

		return { xScale, yScale, drawHeight, drawWidth };
	}, [facts, width, height, xMargin, yMargin]);

	useLayoutEffect(() => {
		if (xScale && yScale) {
			const svg = select(svgRef);

			// remove pre-existing axes, if any
			if (!svg.empty()) {
				selectAll('.scatter-axis').remove();
			}

			svg.append('g')
				.attr('transform', `translate(${left}, ${top})`)
				.classed('scatter-axis', true)
				.call(axisLeft(yScale).tickSize(0).tickPadding(6));

			svg.append('g')
				.attr('transform', `translate(${left}, ${drawHeight + top})`)
				.classed('scatter-axis', true)
				.call(axisBottom(xScale).tickSize(0).tickPadding(6));
		}
	}, [xScale, yScale, svgRef, drawHeight, left, top]);

	return (
		<svg width={width} height={height} ref={setSVGRef}>
			{xScale &&
				yScale &&
				facts.map((b) => {
					const { x: xVal, y: yVal } = b;

					// TODO: all the zero-checking in this component should be
					// handled more elegantly.
					const yCalc = yScale(yVal) + top;
					const drawY = yCalc > 0 ? yCalc : 0;

					const xCalc = left + xScale(xVal);
					const drawX = xCalc > 0 ? xCalc : 0;

					return draw({ ...b, drawX, drawY });
				})}
		</svg>
	);
};

export default Scatterplot;
