import { useEffect, useMemo, useRef } from "react";
import * as d3 from "d3";
import { useChartColors } from "@sage/state";
import { Chart, format_tick, format_value, useDims } from "../ChartProvider";

export function BarChart({
	title,
	data,
	trendline,
	colors,
	updateColor
}: {
	title?: string;
	data: { category: string; value: number; value_label?: string; date: Date; date_label?: string }[];
	trendline?: { category: string; value: number; value_label?: string; date: Date; date_label?: string }[];
	colors?: { [key: string]: string };
	updateColor?: (category: string, color: string) => void;
}) {
	const { chart_colors } = useChartColors();
	const c = useMemo(
		() =>
			d3
				.scaleOrdinal()
				.domain([
					...Object.keys(colors || {}),
					...data.map((d) => d.category).filter((c) => !Object.keys(colors || {}).includes(c))
				])
				.range([
					// "#003466",
					// "#01509d",
					// "#007acd",
					// "#66a3fe",
					...Object.keys(colors || {}).map((c) => colors[c]),
					...(chart_colors?.length > 0
						? chart_colors
						: ["#0b84a5", "#f6c85f", "#6f4e7c", "#9dd866", "#ca472f", "#ffa056", "#8dddd0", ...d3.schemeTableau10])
				]),
		[data, colors]
	);

	const legend = useMemo(
		() => [
			...[...new Set(data.map((d) => d.category))].sort().map((category) => ({ category, color: c(category) })),
			...(trendline ? [...new Set(trendline.map((d) => d.category))] : []).map((category) => ({ category, color: c(category) }))
		],
		[data, c]
	);

	return (
		<Chart
			legend={legend}
			title={title}
			updateColor={updateColor ? updateColor : undefined}
			data={data}
		>
			<Series
				data={data}
				trendline={trendline}
				c={c}
			/>
		</Chart>
	);
}

function Series({ data, c, trendline }) {
	const dims = useDims();
	const containerRef = useRef(null);
	const seriesRef = useRef(null);
	const labelsRef = useRef(null);
	const xAxisRef = useRef(null);
	const xGridRef = useRef(null);
	const yAxisRef = useRef(null);
	const yGridRef = useRef(null);
	const trendlineSeriesRef = useRef(null);
	const trendlineLabelsRef = useRef(null);
	const trendlineMarksRef = useRef(null);

	const categories = useMemo(() => new Set(data.map((data) => data.category)).size, [data]);

	const offset = useMemo(
		() =>
			d3
				.scaleBand()
				.domain(new Set(data.map((d) => d.date)))
				.rangeRound([dims.marginLeft, dims.width - dims.marginRight])
				.paddingInner(0.1),
		[data, dims]
	);

	const x = useMemo(
		() =>
			d3
				.scaleBand()
				.domain(data.map((d) => d.category))
				.rangeRound([0, offset.bandwidth()])
				.padding(0.2),
		[data, offset]
	);

	const y = useMemo(
		() =>
			d3
				.scaleLinear()
				.domain([0, d3.max(data, (d) => d.value) * 1.1])
				.range([dims.height - dims.marginBottom, dims.marginTop]),
		[data, dims]
	);

	const trendlineY = useMemo(() => {
		if (trendline) {
			return d3
				.scaleLinear()
				.domain([0, d3.max(trendline, (d) => d.value) * 1.05])
				.range([dims.height - dims.marginBottom, dims.marginTop]);
		}
	}, [trendline, dims]);

	const line = useMemo(() => {
		const pad = x.bandwidth() * (categories || 1) * 0.75;
		if (trendline) {
			return d3
				.line()
				.x((d) => offset(d.data[0]) + pad)
				.y((d) => trendlineY(d[1] - d[0]))
				.curve(d3.curveCatmullRom.alpha(0.5));
		}
	}, [trendlineY, x, categories]);

	const groupedTrendlineData = useMemo(() => {
		if (trendline) {
			const stack = d3.stack();
			const keys = stack.keys(d3.union(trendline.map((d) => d.category)));
			const values = keys.value(([, map], key) => map.get(key)?.value || 0)(
				d3.index(
					trendline,
					(d) => d.date,
					(d) => d.category
				)
			);
			return values;
		}
	}, [trendline]);

	const groupedData = useMemo(() => d3.group(data, (d) => d.date), [data]);

	const fontSize = d3.scaleLinear().domain([0, 5]).range([0.85, 0.55]);

	useEffect(() => {
		if ((seriesRef.current, containerRef.current && x && y && c)) {
			d3.select(seriesRef.current)
				.selectAll("g")
				.data(groupedData)
				.join("g")
				.attr("transform", (d) => `translate(${offset(d[0])}, -2)`)
				.selectAll("rect")
				.data(([, d]) => d)
				.join("rect")
				.attr("x", (d) => x(d.category)) //+ (x.bandwidth()))
				.attr("y", (d) => y(d.value))
				.attr("width", x.bandwidth())
				.attr("height", (d) => y(0) - y(d.value))
				.attr("fill", (d) => c(d.category))
				.attr("rx", x.bandwidth() / 10)
				.attr("ry", x.bandwidth() / 10);

			d3.select(labelsRef.current)
				.selectAll("g")
				.data(groupedData)
				.join("g")
				.attr("transform", (d) => `translate(${offset(d[0])}, -3)`)
				.selectAll("text")
				.data(([, d]) => d)
				.join("text")
				.attr("text-anchor", "middle")
				.attr("font-size", `${fontSize(categories)}rem`)
				.attr("x", (d) => x(d.category) + x.bandwidth() / 2)
				.attr("y", (d) => y(d.value))
				.attr("dy", "-0.25rem")
				.text((d) => d.value_label || format_value(d.value))
				.attr("fill", "black")
				.attr("font-weight", "500");

			d3.select(xAxisRef.current)
				.attr("transform", `translate(0, ${dims.height - dims.marginBottom})`)
				.call(
					d3.axisBottom(offset).tickFormat((d) => {
						try {
							const datestring = new Date(d).toLocaleDateString("en-US", {
								month: "long",
								day: "numeric",
								year: "numeric",
								timeZone: "UTC"
							});
							if (datestring.toLowerCase() === "invalid date") {
								return d;
							}
							if (datestring.includes("January 1") || datestring.includes("December 31")) {
								return new Date(d).toLocaleDateString("en-US", { year: "numeric", timeZone: "UTC" });
							}
							return datestring;
						} catch (e) {
							return d.toLocaleString("en-US", { timeZone: "UTC" });
						}
					})
				)
				.style("stroke-opacity", 0.6);

			d3.select(yAxisRef.current)
				.attr("transform", `translate(${dims.marginLeft}, 0)`)
				.call(d3.axisLeft(y).ticks(6).tickFormat(format_tick))
				.style("stroke-opacity", 0.6);

			d3.select(yGridRef.current)
				.attr("transform", `translate(${dims.marginLeft}, 0)`)
				.call(
					d3
						.axisLeft(y)
						.tickSize(-(dims.width - dims.marginRight - dims.marginLeft))
						.tickFormat("")
				)
				.style("stroke-dasharray", "2,2")
				.style("stroke-opacity", 0.24);

			d3.select(xGridRef.current)
				.attr("transform", `translate(0, ${dims.height - dims.marginBottom})`)
				.call(
					d3
						.axisBottom(offset)
						.tickSize(-(dims.height - dims.marginTop - dims.marginBottom))
						.tickFormat("")
				)
				.style("stroke-dasharray", "2,2")
				.style("stroke-opacity", 0.24);

			if (groupedTrendlineData) {
				d3.select(trendlineSeriesRef.current)
					.selectAll("path")
					.data(groupedTrendlineData)
					.join(
						(enter) => enter.append("path").attr("fill", "none").attr("stroke-width", 1.5).attr("d", line),
						(update) => update.attr("d", line).attr("stroke", (d) => c(d.key)),
						(exit) => exit.remove()
					);

				d3.select(trendlineLabelsRef.current)
					.selectAll("text")
					.data(trendline)
					.join("text")
					.attr("text-anchor", "middle")
					.attr("font-size", `${fontSize(categories)}rem`)
					.attr("font-weight", "500")
					.attr("x", (d) => {
						return offset(d.date) + x.bandwidth() * categories * 0.75;
					})
					.attr("y", (d) => {
						const yPos = trendlineY(d.value) + 18;
						if (yPos >= dims.height - dims.marginTop - 18) {
							return yPos - 28;
						}
						return yPos;
					})
					.attr("fill", "black")
					.text((d) => d.value_label || format_value(d.value));

				d3.select(trendlineMarksRef.current)
					.selectAll("circle")
					.data(trendline)
					.join("circle")
					.attr("fill", (d) => c(d.category))
					.attr("stroke", "white")
					.attr("stroke-width", 1)
					.join("circle")
					.attr("cx", (d) => {
						return offset(d.date) + x.bandwidth() * categories * 0.75;
					})
					.attr("cy", (d) => trendlineY(d.value))
					.attr("r", 4);
			}
		}
	}, [labelsRef.current, seriesRef.current, containerRef.current, x, y, c, dims, groupedData, groupedTrendlineData, trendlineY]);

	return (
		<svg
			ref={containerRef}
			width={"100%"}
			height={dims.height}
		>
			<g ref={xAxisRef} />
			<g ref={xGridRef} />
			<g ref={yAxisRef} />
			<g ref={yGridRef} />
			<g ref={seriesRef} />
			<g ref={trendlineSeriesRef} />
			<g ref={trendlineMarksRef} />
			<g ref={labelsRef} />
			<g ref={trendlineLabelsRef} />
		</svg>
	);
}
