import { useEffect, useMemo, useRef } from "react";
import * as d3 from "d3";
import { Chart, format_value, useDims } from "../ChartProvider";

export function BubblePlot({
	title,
	data,
	colors
}: {
	title?: string;
	data: { category: string; value: number; date: Date; secondary_value?: number }[];
	colors?: { [key: string]: string };
}) {
	const c = useMemo(
		() =>
			d3
				.scaleOrdinal()
				.domain([
					...Object.keys(colors || {}),
					...data.map((d) => d.category).filter((c) => !Object.keys(colors || {}).includes(c))
				])
				.range([...Object.keys(colors || {}).map((c) => colors[c]), ...d3.schemeTableau10]),
		[data]
	);

	const legend = useMemo(
		() => [...new Set(data.map((d) => d.category))].sort().map((category) => ({ category, color: c(category), outline: true })),
		[data, c]
	);

	return (
		<Chart
			legend={legend}
			title={title}
		>
			<Series
				data={data}
				c={c}
			/>
		</Chart>
	);
}

function Series({ data, c }) {
	const dims = useDims();
	const containerRef = useRef(null);
	const seriesRef = useRef(null);
	const labelsRef = useRef(null);
	const xAxisRef = useRef(null);
	const yAxisRef = useRef(null);

	const x = useMemo(
		() =>
			d3
				.scaleUtc()
				.domain(d3.extent(data, (d) => d.date))
				.range([dims.marginLeft, dims.width - dims.marginRight]),
		[data, dims]
	);

	const y = useMemo(
		() =>
			d3
				.scaleLinear()
				.domain([0, d3.max(data, (d) => d.value)])
				.range([dims.height - dims.marginBottom, dims.marginTop]),
		[data, dims]
	);

	const size = useMemo(
		() =>
			d3
				.scaleSqrt()
				.domain([d3.min(data, (d) => d.secondary_value), d3.max(data, (d) => d.secondary_value)])
				.range([2, 20]),
		[data, dims]
	);

	useEffect(() => {
		if ((seriesRef.current, containerRef.current && x && y && c && data)) {
			d3.select(seriesRef.current)
				.selectAll("circle")
				.data(data.sort((a, b) => a.secondary_value - b.secondary_value))
				.join(
					(enter) =>
						enter
							.append("circle")
							.attr("fill", (d) => c(d.category))
							.attr("fill-opacity", 0.3)
							.attr("stroke", (d) => c(d.category))
							.attr("stroke-width", 2)
							.attr("stroke-opacity", 0.5)
							.attr("r", (d) => size(d.secondary_value))
							.attr("cx", (d) => x(d.date))
							.attr("cy", (d) => y(d.value)),
					(update) =>
						update
							.transition()
							.attr("fill", (d) => c(d.category))
							.attr("fill-opacity", 0.3)
							.attr("stroke", (d) => c(d.category))
							.attr("stroke-width", 2)
							.attr("stroke-opacity", 0.5)
							.attr("r", (d) => size(d.secondary_value))
							.attr("cx", (d) => x(d.date))
							.attr("cy", (d) => y(d.value)),
					(exit) => exit.remove()
				);

			d3.select(labelsRef.current)
				.selectAll("text")
				.data(data)
				.join("text")
				.attr("text-anchor", "start")
				.attr("font-size", "0.65em")
				.attr("font-weight", "600")
				.attr("x", (d) => {
					const xPos = x(d.date) + 6;
					if (xPos <= dims.marginLeft) {
						return xPos + 2;
					}
					if (xPos >= dims.width - dims.marginRight) {
						return xPos - dims.marginRight - 24;
					}
					return xPos;
				})
				.attr("y", (d) => y(d.value) - 3)
				.attr("fill", (d) => c(d.category))
				.text((d) => format_value(d.value));

			d3.select(xAxisRef.current)
				.attr("transform", `translate(0, ${dims.height - dims.marginBottom})`)
				.call(
					d3
						.axisBottom(x)
						.ticks(7)
						.tickFormat((d) => d.toLocaleDateString("en-US", { timeZone: "UTC" }))
				);

			d3.select(yAxisRef.current)
				.attr("transform", `translate(${dims.marginLeft}, 0)`)
				.call(d3.axisLeft(y).ticks(5).tickFormat(format_value));
		}
	}, [seriesRef.current, containerRef.current, x, y, c, dims, data]);

	return (
		<svg
			ref={containerRef}
			width={"100%"}
			height={dims.height}
		>
			<g ref={seriesRef} />
			<g ref={xAxisRef} />
			<g ref={yAxisRef} />
			<g ref={labelsRef} />
		</svg>
	);
}
