


const EP = "https://ctlzr-api.bitgate.workers.dev";

export interface Dataset {
    id: string,
    time: string,
    parameters: Record<string, any>,
}

interface RawDataset {
    id: string,
    time: string,
    parameters: string,
}

export async function listDatasets(): Promise<Dataset[]> {
    const r = await fetch(`${EP}/api/v1/datasets`);
    const entries: RawDataset[] = await r.json();
    console.log(entries.map(e => new Date(e.time)));

    return entries.map(e => ({
        id: e.id,
        time: e.time,
        parameters: JSON.parse(e.parameters),
    })).sort((a, b) => new Date(b.time).getTime() - new Date(a.time).getTime());
}

export async function getDataset(key: string): Promise<Dataset | null> {
    const r = await fetch(`${EP}/api/v1/datasets/${key}`);
    if (r.status === 404) {
        return null;
    }
    const entry: RawDataset = await r.json();
    return {
        id: entry.id,
        time: entry.time,
        parameters: JSON.parse(entry.parameters),
    };
}

export async function createDataset(key: string, data: object): Promise<Response> {
    return await fetch(`${EP}/api/v1/datasets`, {
        method: 'POST',
        headers: {
            'Content-Type': 'application/json',
        },
        body: JSON.stringify({ key, data }),
    });
}

export async function deleteDataset(key: string): Promise<Response> {
    return await fetch(`${EP}/api/v1/datasets/${key}`, {
        method: 'DELETE',
    });
}

export async function updateDataset(key: string, data: object): Promise<Response> {
    return await fetch(`${EP}/api/v1/datasets/${key}`, {
        method: 'PUT',
        headers: {
            'Content-Type': 'application/json',
        },
        body: JSON.stringify({ data }),
    });
}

(window as any).listDatasets = listDatasets;


export const HIDDEN = [
    'abort',
    'allow_tf32',
    'auto_loss_weights',
    'auto_loss_weights_last_n',
    'balance_losses',
    'barrier_breaking',
    'beta3',
    'compile_unet',
    'crop_square',
    'no_vae_cache',
    'use_predefined_original_size',
    'use_predefined_target_size',
    'multi_aspect_training',
    'batch_preprocessing',
    'classifier_zeroing',
    'crop_coords_multiple',
    'prescale_into_bucket',
    'zero_target_size',
    'new_crop_calc',
    'overwrite_original_size',
    'freeze_text_encoder_after',
    'inference_scheduler',
    'lora_fan_in_out',
    'lora_bias',
    'loss_mask_blur_strength',
    'loss_mask_empty_weight',
    'loss_mask_from_white',
    'loss_weighting_file',
    'lr_power',
    'lr_scheduler',
    'preserve_count',
    'p_preserve',
    'preservation_lr_multiplier',
    'resolution',
    'text_rank',
    'unet_as_eval',
    'use_8bit_adam',
    'inference_seed',
    'learning_rate_text',
    'logging_dir',
    'loss_clip',
    'lr_num_cycles',
    'merge_params',
    'mixed_precision',
    // 'mse_reduction',
    'num_validation_images',
    'output_dir',
    'param_exclude',
    'param_include',
    'patch_mlp',
    'pretrained_model_name_or_path',
    'pretrained_vae_model_name_or_path',
    // 'resolution',
    'scale_lr',
    'seed',
    'timestep_weights_file',
    'train_text_encoder',
];


export const OBSOLETE = [
    'accumulate_text',
    'adam_weight_decay_text',
    'cast_vae',
    'center_crop',
    'checkpoints_total_limit',
    'class_data_dir',
    'class_prompt',
    'crops_coords_top_left_h',
    'crops_coords_top_left_w',
    'dataloader_num_workers',
    'enable_xformers_memory_efficient_attention',
    'autocaption',
    'force_input_grad',
    'gradient_checkpointing',
    'hflip',
    'hub_model_id',
    'hub_token',
    'infer_original_size_from_caption',
    'local_rank',
    'num_class_images',
    'num_train_epochs',
    'pre_offset',
    'prior_generation_precision',
    'prior_loss_weight',
    'push_to_hub',
    'report_to',
    'resume_from_checkpoint',
    'revision',
    'sample_batch_size',
    'validation_original_size',
    'with_prior_preservation'
];


export const DATASET_KEYS = [
    'autocaption',
    'caption_prefix',
    'caption_suffix',
    'captions_file',
    'dropout_caption',
    'instance_data_dir',
    'instance_prompt',
    'negative_validation_prompt',
    'nsfw',
    'validation_prompt'
];


export const DEFAULT_ARGS = {
    abort: false,
    accumulate_text: false,
    adam_beta1: 0.9,
    adam_beta2: 0.99,
    adam_epsilon: 1e-8,
    adam_weight_decay: 0,
    adam_weight_decay_text: null,
    allow_tf32: true,
    alpha: 16,
    auto_loss_weights: true,
    auto_loss_weights_last_n: 1000,
    autocaption: false,
    balance_losses: false,
    barrier_breaking: false,
    beta3: 0.999499874937461,
    bucketing: true,
    caption_prefix: "pixel art: ",
    caption_suffix: " (with a grey background)",
    captions_file: "1024x1024/coa-bosses-again-south-gpt4.txt",
    cast_vae: false,
    center_crop: false,
    checkpointing_steps: 250,
    checkpoints_total_limit: null,
    class_data_dir: null,
    class_prompt: null,
    compel_no_truncate: false,
    compile_unet: false,
    compute_mean_std: false,
    crop_square: false,
    no_vae_cache: true,
    use_predefined_original_size: false,
    use_predefined_target_size: false,
    multi_aspect_training: true,
    batch_preprocessing: true,
    use_v2_transforms: false,
    classifier_zeroing: true,
    one_timestep_value: false,
    crop_coords_multiple: 8,
    prescale_into_bucket: true,
    zero_target_size: false,
    new_crop_calc: false,
    overwrite_original_size: true,
    crops_coords_top_left_h: 0,
    crops_coords_top_left_w: 0,
    dataloader_num_workers: 0,
    dataset_mean: "[0.5]",
    dataset_std: "[0.5]",
    denoiser_steps: 20,
    dropout_caption: "",
    dropout_set_zero: true,
    // dropout_set_time_ids_zero: true,
    enable_xformers_memory_efficient_attention: false,
    force_input_grad: false,
    freeze_text_encoder_after: null,
    gradient_accumulation_steps: 1,
    gradient_checkpointing: false,
    guidance_scale: 7.5,
    hflip: 0,
    hub_model_id: null,
    hub_token: null,
    infer_original_size_from_caption: false,
    inference_height: 1024,
    inference_scheduler: "eulera",
    inference_seed: 979242106,
    inference_width: 1024,
    instance_data_dir: "1024x1024",
    instance_prompt: "pixel art: A sprite asset showing a friendly blacksmith facing south.|pixel art: a game level with a beautiful lush forest and a bandit encampment",
    learning_rate: 0.000125,
    learning_rate_text: 0.0005,
    local_rank: -1,
    logging_dir: "logs",
    lora_dropout: 0.0,
    lora_init: true,
    lora_fan_in_out: false,
    lora_bias: "none", // "none", "all", "lora_only"
    lora_module_regex: "FINETUNE",
    loss_clip: 200000,
    loss_mask_blur_strength: 0,
    loss_mask_empty_weight: 0,
    loss_mask_from_white: false,
    loss_weighting_file: null,
    lr_multiplier: 1,
    lr_num_cycles: 1,
    lr_power: 1,
    lr_scheduler: "constant",
    lr_warmup_discard: false,
    lr_warmup_steps: 1000,
    max_grad_norm: 0,
    max_train_steps: 10000,
    merge_params: false,
    mixed_precision: "no",
    mse_reduction: "mean",
    negative_validation_prompt: "ugly",
    no_lora: false,
    noise_offset: 0.05,
    p_noise_offset: 1.0,
    noise_offset_exclude_dropout: false,
    noise_offset_randomize: false,
    noise_offset_decay: false,
    noise_offset_decay_inv: false,
    noise_offset_retain_target: false,
    noise_dims: 1,
    nonlinear_cutoff: 0,
    nonlinear_inverse: false,
    nonlinear_timesteps: false,
    nsfw: false,
    num_class_images: 100,
    num_train_epochs: 35,
    num_validation_images: 8,
    output_dir: "projects/pixels-randomized-coa2/24-07-19/11h33m04s",
    p_desaturate: 0.0,
    desaturate_min: 0.45,
    p_darken: 0.0,
    darken_min: 0.85,
    p_caption_dropout: 0.0,
    dropout_desaturation: 0.45,
    dropout_darken: 0.80,
    p_default_timestepping: 0.2,
    p_preserve: 0.25,
    preserve_count: 2,
    preservation_lr_multiplier: 1,
    param_exclude: null,
    param_include: null,
    patch_mlp: false,
    pre_offset: false,
    precompute_embeddings: true,
    precompute_latents: false,
    pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0",
    pretrained_vae_model_name_or_path: null,
    prior_generation_precision: null,
    prior_loss_weight: 1,
    prodigy: false,
    prodigy_bias_correction: false,
    prodigy_d0: 0.000001,
    prodigy_d_coef: 1,
    prodigy_max_growth_rate: null,
    prodigy_safeguard_warmup: true,
    push_to_hub: false,
    rank: 16,
    report_to: "tensorboard",
    resolution: 1024,
    resume_from_checkpoint: null,
    revision: null,
    sample_batch_size: 4,
    scale_lr: false,
    seed: 69069799,
    single_timestep_range: null,
    single_timestep_value: null,
    text_rank: 128,
    timestep_weights_file: null,
    train_batch_size: 4,
    train_text_encoder: false,
    unet_as_eval: false,
    use_8bit_adam: false,
    use_adagrad: false,
    use_compel: true,
    use_rslora: false,
    use_sgd: false,
    vae_no_sample: false,
    validation_epochs: 250,
    validation_original_size: null,
    validation_prompt: "pixel art: A sprite asset showing a friendly blacksmith facing south.|pixel art: a game level with a beautiful lush forest and a bandit encampment",
    with_prior_preservation: false
};

export const FILTERED_ARGS = Object.fromEntries(Object.entries(DEFAULT_ARGS).filter(([k, v]) => !OBSOLETE.includes(k) && !HIDDEN.includes(k) && !DATASET_KEYS.includes(k)));
export const OBSOLETE_ARGS = Object.fromEntries(Object.entries(DEFAULT_ARGS).filter(([k, v]) => OBSOLETE.includes(k)));
export const HIDDEN_ARGS = Object.fromEntries(Object.entries(DEFAULT_ARGS).filter(([k, v]) => HIDDEN.includes(k)));
export const DATASET_ARGS = Object.fromEntries(Object.entries(DEFAULT_ARGS).filter(([k, v]) => DATASET_KEYS.includes(k)));

const DATA_TYPES = ['int', 'float', 'str', 'bool', 'list', 'dict'];

// Mapping of key to value, e.g. "abort": "bool"

export const ARG_TYPE_MAP = {
    // "abort": "bool",
    "accumulate_text": "bool",
    "adam_beta1": "float",
    "adam_beta2": "float",
    "adam_epsilon": "float",
    "adam_weight_decay": "float",
    "adam_weight_decay_text": "float",
    "allow_tf32": "bool",
    "alpha": "int",
    "auto_loss_weights": "bool",
    "auto_loss_weights_last_n": "int",
    "autocaption": "bool",
    "balance_losses": "bool",
    "barrier_breaking": "bool",
    "beta3": "float",
    "caption_prefix": "str",
    "caption_suffix": "str",
    "captions_file": "str",
    "cast_vae": "bool",
    "center_crop": "bool",
    "checkpointing_steps": "int",
    "checkpoints_total_limit": "int",
    // "class_data_dir": "str",
    // "class_prompt": "str",
    "compile_unet": "bool",
    "compute_mean_std": "bool",
    // "crops_coords_top_left_h": "int",
    // "crops_coords_top_left_w": "int",
    // "dataloader_num_workers": "int",
    "dataset_mean": "list",
    "dataset_std": "list",
    "denoiser_steps": "int",
    "dropout_caption": "str",
    "dropout_set_time_ids_zero": "bool",
    "dropout_set_zero": "bool",
    // "enable_xformers_memory_efficient_attention": "bool",
    "force_input_grad": "bool",
    "freeze_text_encoder_after": "int",
    "gradient_accumulation_steps": "int",
    "gradient_checkpointing": "bool",
    "guidance_scale": "float",
    "hflip": "float",
    // "hub_model_id": "str",
    // "hub_token": "str",
    "infer_original_size_from_caption": "bool",
    "inference_height": "int",
    "inference_scheduler": "str",
    "inference_seed": "int",
    "inference_width": "int",
    "instance_data_dir": "str",
    "instance_prompt": "str",
    "learning_rate": "float",
    "learning_rate_text": "float",
    "local_rank": "int",
    // "logging_dir": "str",
    "lora_module_regex": "str",
    "loss_clip": "float",
    "loss_mask_blur_strength": "float",
    "loss_mask_empty_weight": "float",
    "loss_mask_from_white": "bool",
    "loss_weighting_file": "str",
    "lr_multiplier": "float",
    "lr_num_cycles": "int",
    "lr_power": "float",
    "lr_scheduler": "str",
    "lr_warmup_discard": "bool",
    "lr_warmup_steps": "int",
    "max_grad_norm": "int",
    "max_train_steps": "int",
    "merge_params": "bool",
    // "mixed_precision": "str",
    "mse_reduction": "str",
    "negative_validation_prompt": "str",
    "no_lora": "bool",
    "noise_offset": "float",
    "nonlinear_cutoff": "int",
    "nonlinear_inverse": "bool",
    "nonlinear_timesteps": "bool",
    "nsfw": "bool",
    // "num_class_images": "int",
    // "num_train_epochs": "int",
    "num_validation_images": "int",
    // "output_dir": "str",
    "p_caption_dropout": "float",
    "p_default_timestepping": "float",
    "param_exclude": "str",
    "param_include": "str",
    "patch_mlp": "bool",
    "pre_offset": "bool",
    "precompute_embeddings": "bool",
    "precompute_latents": "bool",
    "pretrained_model_name_or_path": "str",
    "pretrained_vae_model_name_or_path": "str",
    // "prior_generation_precision": "str",
    // "prior_loss_weight": "float",
    "prodigy": "bool",
    "prodigy_bias_correction": "bool",
    "prodigy_d0": "float",
    "prodigy_d_coef": "int",
    "prodigy_max_growth_rate": "int",
    "prodigy_safeguard_warmup": "bool",
    // "push_to_hub": "bool",
    "rank": "int",
    // "report_to": "str",
    // "resolution": "int",
    // "resume_from_checkpoint": "str",
    // "revision": "str",
    "sample_batch_size": "int",
    "scale_lr": "bool",
    "seed": "int",
    "single_timestep_range": "str",
    "single_timestep_value": "str",
    "text_rank": "int",
    "timestep_weights_file": "str",
    "train_batch_size": "int",
    "train_text_encoder": "bool",
    "unet_as_eval": "bool",
    "use_8bit_adam": "bool",
    "use_adagrad": "bool",
    "use_compel": "bool",
    "use_rslora": "bool",
    "use_sgd": "bool",
    "vae_no_sample": "bool",
    "validation_epochs": "int",
    "validation_original_size": "str",
    "validation_prompt": "str",
    // "with_prior_preservation": "bool"
};

export const DEFAULT_CODE = [
    {
        filename: "train-lora.py",
        url: "https://gist.githubusercontent.com/aristotaloss/afaaf43dfefde86344ab482df6ab279d/raw"
    },
    {
        filename: "monitoring.py",
        url: "https://gist.github.com/aristotaloss/d3e05c6e05ddadb5bfeedcb74bb37421/raw"
    },
    {
        filename: "decelerate.py",
        url: "https://gist.githubusercontent.com/aristotaloss/19e590b0868f928594b2b828fdf15ff2/raw"
    },
    {
        filename: "loss_weight.py",
        url: "https://gist.github.com/aristotaloss/a3b4ea87aefc6c79af3de72910117fda/raw"
    },
    {
        filename: "loss-weights.json",
        url: "https://gist.github.com/aristotaloss/a009fe86c1988f80eca8fa7ed9660ffa/raw"
    }
];