// Copyright 2020-2024 Luminary Cloud, Inc. All Rights Reserved.
import React, { useEffect, useReducer, useState } from 'react';

import { useNavigate } from 'react-router-dom';

import { getAdValue } from '../../../lib/adUtils';
import assert from '../../../lib/assert';
import { colors } from '../../../lib/designSystem';
import { varSpecToParamState } from '../../../lib/explorationUtils';
import { workflowLink } from '../../../lib/navigation';
import { fromBigInt } from '../../../lib/number';
import { computeOutputsReply } from '../../../lib/output/computeOutputs';
import { createOutputs, getOutputNodeWarnings } from '../../../lib/outputNodeUtils';
import { getReferenceValues } from '../../../lib/referenceValueUtils';
import { disabledOutputCategories, getDerivative, getYRange } from '../../../lib/sensitivityUtils';
import * as basepb from '../../../proto/base/base_pb';
import * as explorationpb from '../../../proto/exploration/exploration_pb';
import * as frontendpb from '../../../proto/frontend/frontend_pb';
import * as feoutputpb from '../../../proto/frontend/output/output_pb';
import * as outputpb from '../../../proto/output/output_pb';
import { useEntityGroupData } from '../../../recoil/entityGroupState';
import { useGeometryTags } from '../../../recoil/geometry/geometryTagsState';
import { useJobState } from '../../../recoil/jobState';
import { useOutputNodes } from '../../../recoil/outputNodes';
import { useEnabledExperiments } from '../../../recoil/useExperimentConfig';
import { useStaticVolumes } from '../../../recoil/volumes';
import { useCurrentConfig } from '../../../recoil/workflowConfig';
import { useWorkflowState } from '../../../recoil/workflowState';
import { useSimulationParam } from '../../../state/external/project/simulation/param';
import { useSimulationParamScope } from '../../../state/external/project/simulation/paramScope';
import { useIsStaff } from '../../../state/external/user/frontendRole';
import { ActionButton } from '../../Button/ActionButton';
import EmptyState from '../../EmptyState';
import { JobDataCategory } from '../../JobPanel/JobDataCategory';
import JobDataPanel from '../../JobPanel/JobDataPanel';
import { getJobStatusData } from '../../JobPanel/JobStatus';
import JobVerticalDataTable from '../../JobPanel/JobVerticalDataTable';
import ResidualChart from '../../OutputChart/ResidualChart';
import { createStyles, makeStyles } from '../../Theme';
import { useProjectContext } from '../../context/ProjectContext';

import InputColumn from './InputColumn';
import OutputRowHeader from './OutputRowHeader';

const useStyles = makeStyles(
  () => createStyles({
    page: {
      height: '100%',
      backgroundColor: colors.surfaceDark3,
    },
    root: {
      padding: '13px 0 0 8px',
    },
    jobStatus: {
      background: colors.surfaceMedium2,
      borderRadius: '4px',
      padding: '1px 0 8px 0',
    },
    columns: {
      display: 'flex',
      justifyContent: 'flex-start',
      alignItems: 'stretch',
      gap: '16px',
      position: 'relative',
    },
    leftColumn: {
      lineHeight: '24px',
      width: '407px',
      display: 'flex',
      flexDirection: 'column',
      justifyContent: 'space-between',
      alignItems: 'stretch',
      gap: '8px',
    },
    outputsPanel: {
      background: colors.surfaceMedium2,
      borderRadius: '4px',
      padding: '16px',
    },
    residualFlyOut: {
      position: 'absolute',
      border: `1px solid ${colors.neutral300}`,
      borderRadius: '4px',
      zIndex: 1,
      background: 'rgba(47, 48, 52, 0.9)',
      width: '250px',
      boxShadow: '5px 5px 8px rgba(0, 0, 0, 0.3)',
      top: '13px',
      left: '423px',
    },
    arrow: {
      top: '66px',
      left: '407px',
      border: '10px solid',
      position: 'absolute',
      borderColor: `transparent ${colors.neutral250} transparent transparent`,
      zIndex: 3,
    },
    arrowBorder: {
      top: '66px',
      left: '404px',
      border: '10px solid',
      position: 'absolute',
      borderColor: `transparent ${colors.neutral300} transparent transparent`,
      zIndex: 2,
    },
  }),
  { name: 'SensitivityAnalysisResults' },
);

// An action for updating the computed outputs.
interface computedOutputsAction {
  // The indices to the inputs that should be updated based on the reply.
  inputIndices: number[];
  // The total number of inputs.
  numInputs: number;
  // The reply containing the data to update the computed outputs.
  reply: frontendpb.ComputeOutputReply;
}

// Updates the computed outputs using the data in computedOutputsAction. First,
// we update the number of columns and rows since these may change. Then we use
// the compute output reply to fill in the data for a column of values.
function computedOutputsReducer(
  prevComputedOutputs: basepb.AdFloatType[][],
  action: computedOutputsAction,
): basepb.AdFloatType[][] {
  let computedOutputs = prevComputedOutputs.slice();
  const resultList = action.reply.result;
  if (!resultList) {
    return computedOutputs;
  }

  // Update the number of columns if they have changed.
  if (computedOutputs.length !== action.numInputs) {
    computedOutputs = [];
    for (let i = 0; i < action.numInputs; i += 1) {
      computedOutputs.push([]);
    }
  }

  // Update the number of rows to match the number of outputs if they changed.
  const rowsChange = resultList.length - computedOutputs[0].length;
  if (rowsChange > 0) {
    computedOutputs.forEach((column) => {
      for (let i = 0; i < rowsChange; i += 1) {
        column.push(new basepb.AdFloatType());
      }
    });
  }
  if (rowsChange < 0) {
    computedOutputs.forEach((column) => {
      for (let i = 0; i < rowsChange; i += 1) {
        column.pop();
      }
    });
  }
  // Update one column at a time.
  for (let i = 0; i < action.inputIndices.length; i += 1) {
    const inputIndex = action.inputIndices[i];
    for (let outputIndex = 0; outputIndex < resultList.length; outputIndex += 1) {
      [computedOutputs[inputIndex][outputIndex]] = action.reply.result[outputIndex].values;
    }
  }
  return computedOutputs;
}

const SensitivityAnalysisResults = () => {
  // == Contexts
  const { projectId, workflowId, jobId } = useProjectContext();

  // == Recoil
  const config = useCurrentConfig(projectId, workflowId, jobId);
  const simParam = useSimulationParam(projectId, workflowId, jobId);
  const [outputNodes] = useOutputNodes(projectId, '', '');
  const workflowState = useWorkflowState(projectId, workflowId);
  const geometryTags = useGeometryTags(projectId, workflowId, jobId);
  assert(!!workflowState, 'no workflow');

  // == Hooks
  const classes = useStyles();

  // The outputs without the output coefficients.
  const experiment = config.exploration!;
  // This code assumes we have read workflow state already.
  const selectedJob = workflowState.job[jobId];
  const selectedJobs = selectedJob ? [selectedJob] : [];
  const workflowIds = selectedJob ? [workflowId] : [];
  const params = selectedJob ? [simParam] : [];
  const lastIters = selectedJobs.map((jobA) => jobA.latestIter);
  const [residualsOpen, setResidualsOpen] = useState(false);
  const [computedOutputs, dispatch] = useReducer(computedOutputsReducer, []);
  const isStaff = useIsStaff();
  const jobStatusData = getJobStatusData(
    projectId,
    workflowIds,
    selectedJobs,
    params,
    false,
    isStaff,
    false,
  );
  const jobState = useJobState(projectId, workflowId, jobId);
  const entityGroupData = useEntityGroupData(projectId, workflowId, jobId);
  const staticVolumes = useStaticVolumes(projectId, workflowId, jobId);
  const navigate = useNavigate();
  const experimentConfig = useEnabledExperiments();
  const paramScope = useSimulationParamScope(projectId, workflowId, jobId);
  // The initial central value for each input.
  const [xCenters, setXCenters] = useState<number[]>([]);
  // The adjusted value for each input.
  const [xAdjusted, setXAdjusted] = useState<number[]>([]);

  // Add in the output coefficients to the initial list.
  const outputList: outputpb.Output[] = [];

  const varList = experiment.var.filter((variable) => !variable.synthetic);
  useEffect(() => {
    if (!lastIters[0] || outputList.length === 0) {
      return;
    }
    const refValues =
      getReferenceValues(outputNodes, simParam, experimentConfig, geometryTags, staticVolumes);
    for (let i = 0; i < varList.length; i += 1) {
      // Find the Job ID for the i-th variable. It will have a tangent value of
      // 1 for the i-th experiment value.
      let currJobId = '';
      let iter: number = 0;
      Object.values(workflowState.job).forEach((job) => {
        const value = job.explorationValues!.value[i];
        const realValue = value.typ.case === 'real' ? value.typ.value : null;
        const firstOrder = realValue?.adTypes.case === 'firstOrder' ?
          realValue.adTypes.value : null;
        if (firstOrder && firstOrder.tangent[0] === 1) {
          currJobId = job.jobId;
          iter = fromBigInt(job.latestIter);
        }
      });
      if (currJobId && iter) {
        outputList.forEach((output: outputpb.Output, index: number) => {
          const outputNode = outputList[index];
          outputNode.range = new outputpb.IterationRange({ end: iter });
        });

        computeOutputsReply(
          projectId,
          currJobId,
          outputList,
          refValues,
        ).then((reply) => {
          dispatch({
            inputIndices: [i],
            numInputs: varList.length,
            reply,
          });
        }).catch((error) => {
          throw Error(`Error occured while computing outputs: ${error}`);
        });
      } else {
        throw Error(
          `Job for variable ${varList[i].spec!.text} not found.`,
        );
      }
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [experiment, outputNodes, lastIters[0], entityGroupData, geometryTags, staticVolumes]);

  useEffect(() => {
    const newXCenters = varList.map((variable: explorationpb.Var) => {
      // TODO: Get this to work for AdVector3 types.
      const paramDef =
        varSpecToParamState(variable.spec!, simParam, paramScope, geometryTags, staticVolumes)!;
      return getAdValue(paramDef.value);
    });
    setXCenters(newXCenters);
    setXAdjusted(newXCenters.slice());
  }, [experiment, geometryTags, staticVolumes]); // eslint-disable-line react-hooks/exhaustive-deps

  const outputDataCategories: JobDataCategory[] = [];

  const outputCell = (name: string, outputIndex: number, marginTop: number) => {
    const yRange = getYRange(outputIndex, computedOutputs);
    let yAdjusted = yRange.center;
    const numVars = Math.min(varList.length, xAdjusted.length, xCenters.length);
    for (let inputIndex = 0; inputIndex < numVars; inputIndex += 1) {
      const derivative = getDerivative(inputIndex, outputIndex, computedOutputs);
      yAdjusted += derivative * (xAdjusted[inputIndex] - xCenters[inputIndex]);
    }
    return (
      <OutputRowHeader
        key={name}
        marginTop={marginTop}
        name={name}
        yAdjusted={yAdjusted}
        yRange={yRange}
      />
    );
  };

  let outputIndex = 0;
  const createCategory = (
    outputNode: feoutputpb.OutputNode,
    categoryName: string,
  ): JobDataCategory => {
    const newCategory: JobDataCategory = {
      nameLines: [`${categoryName.toUpperCase()} `],
      subcategories: [],
      values: [],
    };

    const outputNodeWarnings = simParam && paramScope ? getOutputNodeWarnings(
      outputNode,
      outputNodes,
      simParam,
      entityGroupData,
      paramScope,
      staticVolumes,
      geometryTags,
      outputNodes.referenceValues?.referenceValueType,
    ) : [];

    if (!outputNodeWarnings.length) {
      const { outputList: outputs } = createOutputs(
        outputNode,
        outputNodes,
        simParam,
        entityGroupData,
        false,
        false,
      );
      outputs.forEach((output: outputpb.Output, index: number) => {
        const marginTop = index === 0 ? 12 : 24;
        const newOutput = output.clone();
        outputList.push(newOutput);
        newCategory.subcategories.push(
          {
            nameLines: [newOutput.name],
            subcategories: [],
            values: [outputCell(newOutput.name, outputIndex, marginTop)],
            showName: false,
            marginTop,
          },
        );
        outputIndex += 1;
      });
    }
    return newCategory;
  };

  outputNodes.nodes.forEach((node) => {
    if (!disabledOutputCategories.includes(node.nodeProps.case) && node.type) {
      outputDataCategories.push(createCategory(node, node.name));
    }
  });

  const residualGraph = (jobState ? (
    <ResidualChart
      height={175}
      jobState={jobState}
    />
  ) :
    <EmptyState title="No solution found." />
  );

  const inputColumns = varList.map((variable: explorationpb.Var, inputIndex: number) => {
    // TODO: Get this to work for AdVector3 types.
    const xCenter = xCenters[inputIndex] || 0;
    const xRange = xCenter !== 0 ? {
      min: 0.9 * xCenter,
      center: xCenter,
      max: 1.1 * xCenter,
    } : {
      min: -10,
      center: 0,
      max: 10,
    };
    if (xAdjusted.length <= inputIndex) {
      return null;
    }
    return (
      <InputColumn
        adjustedValue={xAdjusted[inputIndex] || 0}
        computedOutputs={computedOutputs}
        inputIndex={inputIndex}
        key={variable.spec!.text}
        name={variable.spec!.text}
        outputCategories={outputDataCategories}
        setAdjustedValue={(newValue: number) => {
          const newValues = xAdjusted.slice();
          newValues[inputIndex] = newValue;
          setXAdjusted(newValues);
        }}
        xRange={xRange}
      />
    );
  });

  return (
    <div className={classes.page}>
      <div className={classes.root}>
        <div className={classes.columns}>
          <div className={classes.leftColumn}>
            <div className={classes.jobStatus}>
              <div style={{
                margin: '11px 15px',
                display: 'flex',
                justifyContent: 'space-between',
                alignItems: 'flex-start',
              }}>
                Exploration
                <ActionButton
                  kind="secondary"
                  onClick={() => {
                    navigate(workflowLink(projectId, workflowId, false));
                  }}>
                  View Simulation
                </ActionButton>
              </div>
              <JobVerticalDataTable
                categories={jobStatusData}
                extraPadding
                maxRows={4}
                name="results"
              />
            </div>
            {residualsOpen && (
              <>
                <div className={classes.arrow} />
                <div className={classes.arrowBorder} />
                <div className={classes.residualFlyOut}>
                  <JobDataPanel
                    body={residualGraph}
                    onClose={() => setResidualsOpen(false)}
                    title={<span>Residuals</span>}
                  />
                </div>
              </>
            )}
            <div className={classes.outputsPanel}>
              <div>Initial & Estimated Outputs</div>
              <JobVerticalDataTable
                categories={outputDataCategories}
                name="output"
              />
            </div>
          </div>
          {inputColumns}
        </div>
      </div>
    </div>
  );
};

export default SensitivityAnalysisResults;
