// Copyright 2024 Luminary Cloud, Inc. All Rights Reserved.
import * as clusterconfigpb from '../proto/clusterconfig/clusterconfig_pb';
import * as workflowpb from '../proto/workflow/workflow_pb';

import { copyJobConfig } from './proto';

/**
 * Update the "gpu_type" and "n_pu" fields in workflow.Config.job_config_ template. It creates a new
 * proto object without modifying the old ones.
 * @param oldConfig
 * @param gpuType
 * @param npu
 * @returns
 */
export function updateGpuPref(
  oldConfig: workflowpb.Config,
  gpuType: clusterconfigpb.GPUType,
  nPu: number,
): workflowpb.Config {
  const newConfig = copyJobConfig(oldConfig);
  const gpuPrefs: workflowpb.GPUPreference[] = [];
  if (gpuType !== clusterconfigpb.GPUType.UNSPECIFIED) {
    gpuPrefs.push(new workflowpb.GPUPreference({ gpuType, nPu }));
  }
  newConfig.jobConfigTemplate!.gpuPref = gpuPrefs;
  return newConfig;
}

/**
 * List of GPU types supported by the backend. Generated from analyzer.ServerInfoReply proto.
 */
interface GPUType {
  text: string; // displayed in the dialog
  tooltip: string; // tooltip help
  gpuType: clusterconfigpb.GPUType; // clusterconfig.GPUType. "T4", "V100", etc.
  puCount: number[]; // job sizes (#nodes) supported.
}

export function getGPUType(allTypes: GPUType[], typ: clusterconfigpb.GPUType): GPUType {
  const gpu = allTypes.find((gpuType) => gpuType.gpuType === typ);
  if (!gpu) {
    throw Error(`GPU type ${typ} not found`);
  }
  return gpu;
}

export function pickNearestValue(value: number, allowed: number[]): number {
  let nearest = allowed[0];
  let minDiff = Math.abs(value - nearest);
  for (let i = 1; i < allowed.length; i += 1) {
    const diff = Math.abs(value - allowed[i]);
    if (diff < minDiff) {
      nearest = allowed[i];
      minDiff = diff;
    }
  }
  return nearest;
}

/**
 * Convert analyzer.ServerInfoReply.node_pool to GPUType[].
 * @param nodePools
 * @returns
 */
export function collateGpuTypes(nodePools: clusterconfigpb.NodePool[]): GPUType[] {
  const gpus: GPUType[] = [{
    text: 'Auto',
    tooltip: 'Auto-select',
    gpuType: clusterconfigpb.GPUType.UNSPECIFIED,
    puCount: [],
  }];

  const registerGPU = (pool: clusterconfigpb.NodePool): void => {
    // Find or create the GPUType for this pool type.
    const { description, gpuType, nCpuPerNode, nGpuPerNode, maxNodesPerJob } = pool;

    const found = gpus.find((item) => item.gpuType === gpuType);
    const gpu = found || {
      text: description,
      tooltip: description,
      gpuType,
      puCount: [],
    };
    if (!found) {
      gpus.push(gpu);
    }

    // These are coming from the reserved nodes. Allow selecting an arbitrary number of PUs since
    // we no longer have constraints about the mount of PUs per node being used by a pod.
    if (gpu.gpuType === clusterconfigpb.GPUType.H100) {
      const maxGpus = nGpuPerNode * maxNodesPerJob;
      for (let i = 1; i <= maxGpus; i += 1) {
        gpu.puCount.push(i);
      }
    } else {
      const nPuPerNode = gpu.gpuType === clusterconfigpb.GPUType.CPU ? nCpuPerNode : nGpuPerNode;
      for (let i = 1; i <= maxNodesPerJob; i += 1) {
        gpu.puCount.push(nPuPerNode * i);
      }
    }
  };
  nodePools.forEach(registerGPU);

  // Sort the job sizes in ascending order and remove duplicates.
  gpus.forEach((gpu) => {
    const newPuCount = new Set(gpu.puCount.sort((a, b) => a - b));
    gpu.puCount = [...newPuCount];
  });
  return gpus;
}
