import { useEffect } from "react";

/**
 * React hook to call a function at a specified interval. This hook will automatically clear the 
 * timer reference when the component unmounts,
 * 
 * @param {TimerHandler} timer The timer function to be called once per interval. Warning: to avoid reference 
 * mismatch (and thus expensive rerenders), wrap the timer inside useCallback().
 * @param {number} interval The interval in milliseconds to call the timer function.
 * @param {boolean | undefined} skipInitial If true, the timer will only start after the first interval.
 */
function useInterval(timer, interval, skipInitial, deps = []) {
    useEffect(() => {
        const callback = () => {
            if (!skipInitial)
                timer();
        };

        callback();

        let key = setInterval(callback, interval);
        return () => clearInterval(key);
    }, [skipInitial, interval, ...deps]);
}

const RENAME = { 'text_rank': 'rank_text' };

const IGNORED = ['alpha', 'cast_vae', 'crops_coords_top_left_h', 'crops_coords_top_left_w',
    'instance_data_dir', 'local_rank', 'logging_dir', 'negative_validation_prompt', 'num_class_images',
    'num_validation_images', 'output_dir', 'pre_offset', 'pretrained_model_name_or_path',
    'pretrained_vae_model_name_or_path', 'prior_generation_precision', 'prior_loss_weight', 'push_to_hub',
    'report_to', 'resolution', 'resume_from_checkpoint', 'revision', 'train_text_encoder',
    'use_8bit_adam', 'with_prior_preservation', 'validation_epochs', 'enable_xformers_memory_efficient_attention',
    'dataloader_num_workers', 'allow_tf32', 'autocaption', 'center_crop', 'checkpointing_steps',
    'checkpoints_total_limit', 'class_data_dir', 'class_prompt'];

const sortArgs = (args) => {
    const sorted = {};

    Object.keys(args)
        .map((k) => k)
        .sort()
        // .filter((k) => !IGNORED.includes(k))
        .forEach((k) => (sorted[k] = args[k]));

    return sorted;
};


const recursiveObjectSort = obj => obj ? Object.keys(obj)
    .sort()
    .reduce((a, v) => {
        a[v] = obj[v];

        // If the value is an object, sort it too
        if (typeof obj[v] === 'object') {
            a[v] = recursiveObjectSort(obj[v]);
        }

        return a;
    }, {}) : obj;

function prepareArgsForRender(data) {
    const args = {};

    Object.keys(data).forEach((k) => {
        const kdest = RENAME.hasOwnProperty(k) ? RENAME[k] : k;
        args[kdest] = data[k];

        if (kdest === 'learning_rate_text' && !args['train_text_encoder']) {
            args[kdest] = "no";
            return;
        }

        if (kdest === 'learning_rate' || kdest === 'learning_rate_text') {
            args[kdest] = Number.parseFloat(data[k]).toExponential(1);
        }
    });

    return args;
}



const RELEVANT = ['accumulate_text', 'adam_weight_decay', 'gradient_accumulation_steps', 'learning_rate',
    'learning_rate_text', 'lr_warmup_steps', 'merge_params', 'mixed_precision', 'mse_reduction', 'no_lora',
    'nonlinear_timesteps', 'num_train_epochs', 'patch_mlp', 'rank', 'text_rank', 'train_batch_size',
    'nonlinear_cutoff', 'adam_beta1', 'adam_beta2', 'adam_epsilon', 'prodigy_d0', 'prodigy_max_growth_rate',
    'prodigy_safeguard_warmup', 'prodigy_bias_correction', 'single_timestep_value', 'prodigy_d_coef', 'single_timestep_range'];


const FRIENDLY_NAMES = {
    'accumulate_text': 'Text accum.',
    'adam_weight_decay': 'Weight decay',
    'gradient_accumulation_steps': 'Grad. steps',
    'learning_rate': 'LR',
    'learning_rate_text': 'Text LR',
    'lr_warmup_steps': 'Warmup',
    'merge_params': 'Merge params',
    'mixed_precision': 'Float type',
    'mse_reduction': 'MSE mode',
    'no_lora': 'Disable LoRA',
    'nonlinear_timesteps': 'Nonlinear steps',
    'num_train_epochs': 'Epochs',
    'patch_mlp': 'Patch MLP',
    'rank': 'Rank',
    'text_rank': 'Text rank',
    'train_batch_size': 'Batch Size',
};

const friendlyName = (k) => FRIENDLY_NAMES[k] || k;


// Returns -1 if the parameter decreased, 1 if it increased, and 0 if it remained the same.
const paramDelta = (changes, p) => {
    if (changes.length === 0) {
        return 0;
    }

    const change = changes.find(c => c.param === p);
    if (change) {
        // Check if it is a numeric value, and compare them numerically.
        if (typeof change.from === 'number') {
            return change.from < change.to ? 1 : -1;
        }

        return 0;
    }

    return 0;
};

// Returns 'inc' if the parameter increased, 'dec' if it decreased, and '' if it remained the same.
const paramClass = (changes, p) => {
    const d = paramDelta(changes, p);
    if (d === 1) {
        return 'inc';
    }
    if (d === -1) {
        return 'dec';
    }
    return '';
};

export {
    useInterval,
    sortArgs,
    prepareArgsForRender,
    friendlyName,
    paramDelta, paramClass,
    recursiveObjectSort,
    RELEVANT
}