// Copyright 2024 Luminary Cloud, Inc. All Rights Reserved.
import { CallbackInterface } from 'recoil';

import { createResNode, updateOutputNodes } from '../lib/outputNodeUtils';
import { getPhysicsId, getPhysicsName } from '../lib/physicsUtils';
import { DEFAULT_STOP_COND } from '../lib/stoppingCondsUtils';
import * as simulationpb from '../proto/client/simulation_pb';
import { simulationParamState } from '../state/external/project/simulation/param';

import { geometryTagsState } from './geometry/geometryTagsState';
import { outputNodesState } from './outputNodes';
import { enabledExperimentsState } from './useExperimentConfig';
import { stoppingConditionsSelectorUpdate } from './useStoppingConditions';
import { staticVolumesState } from './volumes';

/**
 * Returns a callback function that adds a new residual node and associated stopping condition for
 * a physics.
 * @param cbInterface recoil callback interface
 * @returns async callback function
 */
export function addPhysicsResidualsCallback(cbInterface: CallbackInterface) {
  const { snapshot: { getPromise }, set } = cbInterface;
  return async (
    newPhysics: simulationpb.Physics[],
    projectId: string,
    workflowId: string,
    jobId: string,
  ) => {
    const jobKey = { projectId, workflowId, jobId };
    const simParam = await getPromise(simulationParamState(jobKey));

    const outputNodes = await getPromise(outputNodesState(jobKey));
    const experimentConfig = await getPromise(enabledExperimentsState);
    const stopConds = await getPromise(stoppingConditionsSelectorUpdate(jobKey));
    const geometryTags = await getPromise(geometryTagsState({ projectId, workflowId, jobId }));
    const staticVolumes = await getPromise(staticVolumesState({ projectId, workflowId, jobId }));

    const newOutputNodes = outputNodes.clone();
    const newStopConds = stopConds.clone();
    newPhysics.forEach((physics) => {
      // Add a new residual node for each physics
      const resNode = createResNode(simParam, experimentConfig, geometryTags, staticVolumes);
      resNode.name = `Residuals for ${getPhysicsName(physics, simParam)}`;
      if (resNode.nodeProps.case === 'residual') {
        resNode.nodeProps.value.physicsId = getPhysicsId(physics);
        newOutputNodes.nodes.push(resNode);

        // Add a stopping condition
        const newCond = DEFAULT_STOP_COND.clone();
        newCond.node = resNode.clone();
        newStopConds.cond.push(newCond);
      }
    });

    updateOutputNodes(newOutputNodes, simParam, experimentConfig, geometryTags, staticVolumes);

    set(outputNodesState({ projectId, workflowId, jobId }), newOutputNodes);
    set(stoppingConditionsSelectorUpdate({ projectId, workflowId, jobId }), newStopConds);
  };
}
