import { RawGreeksDataMap, GREEK_IDX } from "types";
import { TableData, VolatilityDataRow as TableDataRow, OptionLegType, ModelParameters, ProfitLossWithDomain, ProfitLossData, OptionLeg, TimeframeBoundsPoint } from "types/riskEngine";
import BlackScholes, { normCDF } from "./BlackScholes";
import * as d3 from 'd3';

export function getStrikePricesForFirstTimestamp(
  currentGreeksObj: any,
): number[] | undefined {
  if (!currentGreeksObj) {
    return undefined;
  }

  const timestamps = Object.keys(currentGreeksObj);
  if (timestamps.length === 0) {
    return undefined;
  }

  const firstTimestamp = timestamps[0];
  const firstTimestampObj = currentGreeksObj[firstTimestamp];
  const strikePrices = Object.keys(firstTimestampObj).map(Number);
  return strikePrices;
}

// Return DTE, price, column assuming the first timestamp is 0 DTE after getFilteredGreeksData
export function extractTableData(data: RawGreeksDataMap, isCall: boolean, column: number): TableData[] {
  const result: TableData[] = [];

  for (const tte in data) {
    const priceLevel = data[tte];
    const timestampResult: TableData = {
      tte: Number(tte),
      strikePrices: {},
    };

    for (const price in priceLevel) {
      const vol = priceLevel[price][isCall ? 0 : 1][column];
      timestampResult.strikePrices[Number(price)] = vol;
    }

    result.push(timestampResult);
  }

  return result;
}

// Optimized interpolation over a row given a sequence of sorted prices
export function interpolateDataRow(data: TableDataRow[], prices: number[]): number[] {
  let i = 0;

  return prices.map((price) => {
    while (i + 1 < data.length && data[i + 1].strike <= price) {
      ++i;
    }

    if (i + 1 >= data.length) {
      // clamp right
      return data[data.length - 1].value;
    } else if (price < data[0].strike) {
      // clamp left
      return data[0].value;
    }
    
    const fraction = (price - data[i].strike) / (data[i + 1].strike - data[i].strike);
    return d3.interpolateNumber(data[i].value, data[i + 1].value)(fraction);
  });
}

export function interpolateDataTable(data: TableData[], prices: number[], dte: number): number[] {
  const r0 = Math.floor(dte);
  const r1 = r0 + 1;

  if (data.length === 0) {
    return prices.map(() => 0);
  } else if (r0 < 0) {
    const vol0 = data[0];
    const row: TableDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
      strike: Number(strike),
      value: vol0.strikePrices[Number(strike)],
    }));
    return interpolateDataRow(row, prices);
  } else if (r1 >= data.length) {
    const vol0 = data[data.length - 1];
    const row: TableDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
      strike: Number(strike),
      value: vol0.strikePrices[Number(strike)],
    }));
    return interpolateDataRow(row, prices);
  }

  const vol0 = data[r0];
  const vol1 = data[r1];

  const row0: TableDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
    strike: Number(strike),
    value: vol0.strikePrices[Number(strike)],
  }));

  const row1: TableDataRow[] = Object.keys(vol1.strikePrices).map((strike) => ({
    strike: Number(strike),
    value: vol1.strikePrices[Number(strike)],
  }));

  const w1 = dte - r0;
  const w0 = 1 - w1;

  const vol0Interpolated = interpolateDataRow(row0, prices);
  const vol1Interpolated = interpolateDataRow(row1, prices);

  const result: number[] = vol0Interpolated.map((v, i) => v * w0 + vol1Interpolated[i] * w1);

  return result;
}

export function generatePrices(
  minValue: number,
  maxValue: number,
  numSteps: number = 1_000,
): number[] {
  // Calculate minimum required steps to keep stepSize >= 0.01
  const minRequiredSteps = Math.ceil((maxValue - minValue) / 0.01);
  const adjustedNumSteps = Math.min(numSteps, minRequiredSteps);

  const stepSize = (maxValue - minValue) / (adjustedNumSteps - 1);
  return Array.from({ length: adjustedNumSteps }, (_, i) => minValue + i * stepSize);
}

export function isCall(type: OptionLegType): boolean {
  return [OptionLegType.SHORT_CALL, OptionLegType.LONG_CALL].includes(type);
}

export function isLong(type: OptionLegType): boolean {
  return [OptionLegType.LONG_PUT, OptionLegType.LONG_CALL].includes(type);
}

export function calculateOptionLegProfit(
  { legType, strikePrice, shares, premium }: OptionLeg,
  currentStockPrice: number,
  modelParameters: ModelParameters,
  statisticalVolatility: number
): number {
  const simplePrice = isCall(legType)
    ? Math.max(0, currentStockPrice - strikePrice)
    : Math.max(0, strikePrice - currentStockPrice);

  const bsm = new BlackScholes({
    isCall: isCall(legType),
    currentPrice: currentStockPrice,
    strike: strikePrice,
    riskFreeRate: modelParameters.riskFreeRate,
    yte: modelParameters.daysToExpiration / 365.0,
    iVol: modelParameters.volatility ?? statisticalVolatility,
  });

  const price = modelParameters.daysToExpiration > 0 ? bsm.price() : simplePrice;
  const profit = isLong(legType)
    ? price * shares - premium
    : premium - price * shares;

  return profit;
}

function getBounds(timeframes: TimeframeBoundsPoint[], daysToExpiry: number): number[] {
  if (!timeframes || timeframes.length === 0) {
    return [0, 0]
  }

  if (daysToExpiry <= timeframes[0].daysToExpiry) {
    return timeframes[0]['bounds-timeframe'];
  }

  for (let i = 0; i < timeframes.length - 1; i++) {
    if (daysToExpiry >= timeframes[i].daysToExpiry && daysToExpiry <= timeframes[i + 1].daysToExpiry) {
      const w = (daysToExpiry - timeframes[i].daysToExpiry) / (timeframes[i + 1].daysToExpiry - timeframes[i].daysToExpiry);

      const b0 = timeframes[i]['bounds-timeframe'][0] + (timeframes[i + 1]['bounds-timeframe'][0] - timeframes[i]['bounds-timeframe'][0]) * w;
      const b1 = timeframes[i]['bounds-timeframe'][1] + (timeframes[i + 1]['bounds-timeframe'][1] - timeframes[i]['bounds-timeframe'][1]) * w;

      return [Math.min(b0, b1), Math.max(b0, b1)];
    }
  }

  return timeframes[timeframes.length - 1]['bounds-timeframe'];
}

export function calculateProfitLoss(optionLegs: OptionLeg[], volDataCall: TableData[], volDataPut: TableData[], modelParameters: ModelParameters, priceDomain: number[] | null, timeframes: TimeframeBoundsPoint[] | null): ProfitLossWithDomain {
  if (optionLegs.length === 0) {
    return {
      profitLossData: [],
      domain: [0, 0],
    };
  }

  const maxStrikePrice = Math.max(
    ...optionLegs.map((option) => option.strikePrice),
  );

  const prices = priceDomain ? generatePrices(priceDomain[0], priceDomain[1]) : generatePrices(0, 2 * maxStrikePrice);
  const priceVolCall = interpolateDataTable(volDataCall, prices, modelParameters.daysToExpiration);
  const priceVolPut = interpolateDataTable(volDataPut, prices, modelParameters.daysToExpiration);

  const ivBounds = timeframes ? getBounds(timeframes, modelParameters.daysToExpiration) : null;

  const allProfitLossData: ProfitLossData[] = prices.map((price, index) => {
    let profit: number = 0.0;
    optionLegs.forEach((optionLeg) => {
      profit += calculateOptionLegProfit(optionLeg, price, modelParameters, (isCall(optionLeg.legType) ? priceVolCall : priceVolPut)[index]);
    });

    let profitBounds = [0.0, 0.0];
    if (ivBounds) {
      optionLegs.forEach((optionLeg) => {
        profitBounds[0] += calculateOptionLegProfit(optionLeg, price, { ...modelParameters, volatility: ivBounds[0] }, (isCall(optionLeg.legType) ? priceVolCall : priceVolPut)[index]);
        profitBounds[1] += calculateOptionLegProfit(optionLeg, price, { ...modelParameters, volatility: ivBounds[1] }, (isCall(optionLeg.legType) ? priceVolCall : priceVolPut)[index]);
      });
    }
    else {
      profitBounds[0] = profit;
      profitBounds[1] = profit;
    }

    return {
      price,
      profit,
      volatility: modelParameters.volatility ?? priceVolCall[index],
      profitBounds
    };
  });

  // Domain is defined as the region where the PnL line is profit->loss or loss->profit.
  const getProfitDomain = (profitLossData: ProfitLossData[]): number[] => {
    let minDomain: number | null = null;
    let maxDomain: number | null = null;

    for (let i = 0; i < profitLossData.length - 1; i++) {
      if (
        (profitLossData[i].profit >= 0 && profitLossData[i + 1].profit < 0) ||
        (profitLossData[i].profit < 0 && profitLossData[i + 1].profit >= 0)
      ) {
        if (minDomain === null) {
          minDomain = profitLossData[i].price;
        } else {
          maxDomain = profitLossData[i + 1].price;
        }
      }
    }

    if (minDomain === null || maxDomain === null) {
      return [
        0,
        profitLossData.length > 0 ? profitLossData[profitLossData.length - 1].price : 0,
      ];
    }

    let domainWindow = maxDomain - minDomain;
    domainWindow = Math.max(1, domainWindow * 0.15); // MINIMUM_CHART_X_DOMAIN placeholder
    minDomain -= domainWindow;
    maxDomain += domainWindow;

    minDomain = Math.floor(minDomain);
    maxDomain = Math.ceil(maxDomain);

    if (minDomain < 0) {
      minDomain = 0;
    }

    return [minDomain, maxDomain];
  };

  const [minPriceDomain, maxPriceDomain] = getProfitDomain(allProfitLossData);
  const chartProfitLossData = allProfitLossData.filter(
    (data) => data.price >= minPriceDomain && data.price <= maxPriceDomain,
  );

  return {
    profitLossData: chartProfitLossData,
    domain: [minPriceDomain, maxPriceDomain],
  };
}