/**
 * Scatter chart transformation utilities
 */
import { RetrieveAllResponse } from "@services/api/generated/retriever/models/retrieveAllResponse";
import { Aggregations } from "@services/api/generated/webserver/models/aggregations";
import { ColumnDataType } from "@services/api/generated/webserver/models/columnDataType";
import { DataVisualizationType } from "@services/api/generated/webserver/models/dataVisualizationType";
import { ChartConfig } from "@utils/chartTransformations";
import * as Highcharts from "highcharts";

import {
  aggregateValues,
  getColumnValue,
  getTopBreakdownValues,
  MAX_BREAKDOWN_VALUES,
  RowArray,
} from "./commonTransformations";

export interface ScatterChartData {
  xColumn: string;
  yColumn: string;
  series: Highcharts.SeriesScatterOptions[];
  [key: string]: unknown;
}

export interface ScatterDataPoint extends Array<number> {
  0: number; // x value
  1: number; // y value
}

/**
 * Limits the number of data points in scatter chart series
 * @param series Array of scatter chart series
 * @param maxPoints Maximum number of data points to include
 * @returns Limited series array
 */
const limitScatterDataPoints = (
  series: Highcharts.SeriesScatterOptions[],
  maxPoints: number
): Highcharts.SeriesScatterOptions[] => {
  if (!series.length) {
    return series;
  }

  // Calculate the total number of data points across all series
  const totalPoints = series.reduce((sum, s) => sum + (s.data?.length || 0), 0);

  if (totalPoints <= maxPoints) {
    return series;
  }

  // Create a copy of the series to work with
  const newSeries: Highcharts.SeriesScatterOptions[] = [];

  // Ensure each category gets at least some representation
  // Calculate points to allocate per series proportionally
  const seriesCount = series.length;
  const minPointsPerSeries = Math.min(5, Math.floor(maxPoints / seriesCount));
  let remainingPoints = maxPoints;

  // First pass: ensure each series gets minimum representation
  for (const s of series) {
    const sCopy = { ...s };
    const seriesDataLen = s.data?.length || 0;

    // Allocate at least minPointsPerSeries or all points if less than minimum
    const pointsToAllocate = Math.min(seriesDataLen, minPointsPerSeries);
    sCopy.data = s.data?.slice(0, pointsToAllocate);
    newSeries.push(sCopy);
    remainingPoints -= pointsToAllocate;
  }

  // Second pass: distribute remaining points proportionally
  if (remainingPoints > 0 && totalPoints > maxPoints) {
    // Calculate the proportion of remaining data points for each series
    const seriesWithRemainingData = series
      .map((s, index) => {
        const originalLength = s.data?.length || 0;
        const allocatedLength = newSeries[index].data?.length || 0;
        return {
          index,
          remainingLength: originalLength - allocatedLength,
        };
      })
      .filter((s) => s.remainingLength > 0);

    const totalRemainingLength = seriesWithRemainingData.reduce(
      (sum, s) => sum + s.remainingLength,
      0
    );

    // Distribute remaining points proportionally
    for (const { index, remainingLength } of seriesWithRemainingData) {
      if (remainingPoints <= 0) {
        break;
      }

      const proportion = remainingLength / totalRemainingLength;
      let additionalPoints = Math.floor(remainingPoints * proportion);

      // Ensure we don't exceed the remaining points or original data length
      additionalPoints = Math.min(
        additionalPoints,
        remainingPoints,
        (series[index].data?.length || 0) - (newSeries[index].data?.length || 0)
      );

      if (additionalPoints > 0) {
        const startIndex = newSeries[index].data?.length || 0;
        const additionalData = series[index].data?.slice(startIndex, startIndex + additionalPoints);

        if (additionalData && additionalData.length > 0) {
          newSeries[index].data = [...(newSeries[index].data || []), ...additionalData];
          remainingPoints -= additionalPoints;
        }
      }
    }
  }

  return newSeries;
};

/**
 * Creates data points for a scatter series by applying aggregation if needed
 * @param rows Array of data rows
 * @param xColumn X-axis column name
 * @param yColumn Y-axis column name
 * @param columnNames Array of all column names
 * @param aggregation Optional aggregation method
 * @returns Array of scatter data points
 */
const createScatterDataPoints = (
  rows: RowArray[],
  xColumn: string,
  yColumn: string,
  columnNames: string[],
  aggregation?: Aggregations
): ScatterDataPoint[] => {
  // If no aggregation is needed, create data points directly
  if (!aggregation) {
    return rows
      .filter((row) => {
        const xValue = getColumnValue(row, xColumn, columnNames);
        const yValue = getColumnValue(row, yColumn, columnNames);
        return typeof xValue === "number" && typeof yValue === "number";
      })
      .map((row) => {
        const xValue = getColumnValue(row, xColumn, columnNames) as number;
        const yValue = getColumnValue(row, yColumn, columnNames) as number;
        return [xValue, Number(yValue.toFixed(2))] as ScatterDataPoint;
      });
  }

  // Group by x value for aggregation
  const groupedByX = new Map<number, number[]>();

  rows.forEach((row) => {
    const xValue = getColumnValue(row, xColumn, columnNames);
    const yValue = getColumnValue(row, yColumn, columnNames);

    if (
      typeof xValue === "number" &&
      (typeof yValue === "number" || aggregation === Aggregations.count)
    ) {
      if (!groupedByX.has(xValue)) {
        groupedByX.set(xValue, []);
      }

      if (typeof yValue === "number") {
        groupedByX.get(xValue)?.push(yValue);
      } else if (aggregation === Aggregations.count) {
        // For count aggregation, we still want to count rows with null y values
        groupedByX.get(xValue)?.push(0);
      }
    }
  });

  // Apply aggregation to each group
  const dataPoints: ScatterDataPoint[] = [];
  groupedByX.forEach((yValues, xValue) => {
    const aggregatedValue = aggregateValues(yValues, aggregation);
    dataPoints.push([xValue, Number(aggregatedValue.toFixed(2))] as ScatterDataPoint);
  });

  return dataPoints;
};

/**
 * Creates scatter chart series for a specific category
 * @param rows Filtered rows for the category
 * @param xColumn X-axis column name
 * @param yColumn Y-axis column name
 * @param columnNames Array of column names
 * @param seriesName Name for the series
 * @param aggregation Optional aggregation method
 * @returns Scatter chart series options
 */
const createScatterSeries = (
  rows: RowArray[],
  xColumn: string,
  yColumn: string,
  columnNames: string[],
  seriesName: string,
  aggregation?: Aggregations
): Highcharts.SeriesScatterOptions => {
  const dataPoints = createScatterDataPoints(rows, xColumn, yColumn, columnNames, aggregation);

  return {
    type: DataVisualizationType.SCATTER.toLowerCase() as Highcharts.SeriesScatterOptions["type"],
    name: seriesName,
    data: dataPoints,
  };
};

/**
 * Transforms raw data into a format suitable for scatter charts
 * @param data Raw data from API or test data
 * @param chartConfig Chart configuration containing columns,
 * column_types, column_aggregations, and column_grouping
 * @returns Transformed data ready for chart rendering
 */
const transformForScatterChart = (
  data: RetrieveAllResponse,
  chartConfig: ChartConfig
): ScatterChartData => {
  // Ensure data has the expected structure
  if (!data || !Array.isArray(data.rows) || !Array.isArray(data.column_names)) {
    console.error("Invalid data format for scatter chart transformation", data);
    return { xColumn: "", yColumn: "", series: [] };
  }

  // eslint-disable-next-line @typescript-eslint/naming-convention
  const { rows, column_names } = data;
  const columnNames = column_names.map((name) => name.toLowerCase());

  const selectedColumns = chartConfig.columns;

  // Get all numerical columns
  const numericalColumns = Object.entries(chartConfig.column_types || {})
    .filter(([_, type]) => type === ColumnDataType.numerical)
    .map(([col]) => col)
    .filter((col) => selectedColumns.includes(col));

  // We need at least two numerical columns for a scatter chart
  if (numericalColumns.length < 2) {
    console.error("At least two numerical columns are required for scatter chart", chartConfig);
    return { xColumn: "", yColumn: "", series: [] };
  }

  // Identify xColumn and yColumn based on the client-side logic
  // Emulate client-side logic: exclude the second column for xColumn selection
  const excludedColumns = chartConfig.columns.length > 1 ? [chartConfig.columns[1]] : [];
  const xColumn =
    numericalColumns.find((col) => !excludedColumns.includes(col)) || numericalColumns[0];
  const yColumn = numericalColumns.find((col) => col !== xColumn) || numericalColumns[1];

  // Identify categoryColumn: first categorical column
  const categoryColumn = Object.entries(chartConfig.column_types || {})
    .filter(
      ([col, type]) =>
        (type === ColumnDataType.categorical || type === ColumnDataType.temporal) &&
        selectedColumns.includes(col)
    )
    .map(([col]) => col)[0];

  // Get the aggregation method if specified
  const aggregation = chartConfig.column_aggregations?.[yColumn] as Aggregations;

  // Initialize the transformed data
  const transformedData: ScatterChartData = {
    xColumn,
    yColumn,
    series: [],
  };

  if (categoryColumn) {
    // If a categorical column exists, use getTopBreakdownValues to limit categories
    const { values: topCategories } = getTopBreakdownValues(
      rows,
      categoryColumn,
      columnNames,
      MAX_BREAKDOWN_VALUES
    );

    // Create a series for each top category
    topCategories.forEach((category) => {
      // Filter data for the current category
      const categoryRows = rows.filter((row) => {
        const categoryValue = getColumnValue(row, categoryColumn, columnNames);
        return String(categoryValue) === category;
      });

      // Create and add the series
      const series = createScatterSeries(
        categoryRows,
        xColumn,
        yColumn,
        columnNames,
        category,
        aggregation
      );

      if (series.data && series.data.length > 0) {
        transformedData.series.push(series);
      }
    });

    // Create an "Other" series for categories not in the top N
    const otherRows = rows.filter((row) => {
      const categoryValue = getColumnValue(row, categoryColumn, columnNames);
      return !topCategories.includes(String(categoryValue));
    });

    if (otherRows.length > 0) {
      // Create and add the "Other" series
      const otherSeries = createScatterSeries(
        otherRows,
        xColumn,
        yColumn,
        columnNames,
        "Other",
        aggregation
      );

      if (otherSeries.data && otherSeries.data.length > 0) {
        transformedData.series.push(otherSeries);
      }
    }
  } else {
    // No categorical column; single series with all data points
    const singleSeries = createScatterSeries(
      rows,
      xColumn,
      yColumn,
      columnNames,
      yColumn,
      aggregation
    );

    if (singleSeries.data && singleSeries.data.length > 0) {
      transformedData.series.push(singleSeries);
    }
  }

  // If no valid series was created, add an empty default series
  // This ensures we always have at least one series for consistent results
  if (transformedData.series.length === 0) {
    transformedData.series.push({
      type: DataVisualizationType.SCATTER.toLowerCase() as Highcharts.SeriesScatterOptions["type"],
      name: yColumn,
      data: [],
    });
  }

  // Limit the total number of data points to avoid performance issues
  transformedData.series = limitScatterDataPoints(transformedData.series, 500);

  return transformedData;
};

export { transformForScatterChart };
