This repository contains the code for paper "OATS: Online Data Augmentation for Time Series Foundation Models".
OATS introduces a novel online data augmentation framework specifically designed to enhance the training of time series foundation models (TSFM). Unlike traditional offline augmentation methods that pre-generate synthetic data, OATS generates synthetic data by using training samples with high data attribution scores as guiding signals.
OATS consists of three key components:
- Time-series Influence Scores (TSIS) integrate data attribution with time series-specific knowledge to dynamically assess the quality of each training sample, creating a generation guiding signal.
- High-quality Guided Data Augmentation leverages the guiding signal to condition a diffusion model trained on a small subset of the TSFM training data for synthetic data generation.
- Explore-Exploit Mechanism reduces computational overhead and effectively balances between leveraging calculated scores and exploring new samples. The influence scores are stochastically re-evaluated to incorporate model training dynamics ("explore") while preserving previously identified high-quality data ("exploit").
Download dataset for TSFM from here. The directory organization structure is as follows:
- dataset_train
|- Lotsa16B
|- Lotsa1B
|- Lotsa100M
|- Lotsa10M
- dataset_test
|- Lotsa16B
|- Lotsa1B
|- Lotsa100M
|- Lotsa10M
|- LSF
|- Monash
Extracte dataset from for diffusion model. The dataset is extracted from the Lotsa100M dataset with a sampling rate 5% of the dataset in 20 selected subdatasets.
python extract_data_generation.py -cp cli/conf/pretrain\
-cn default_ddp_val_enc\
model=encoder_10M\
model.enable_influence_scoring=true\
data=lotsa100M_weighted\
trainer.max_epochs=0\
model.num_warmup_steps=0The directory organization structure is as follows:
extracted_label_patches_australian_electricity_demand.npy
extracted_label_patches_azure_vm_traces_2017.npy
extracted_label_patches_buildings_900k.npy
extracted_label_patches_CloudOpsTSF_dataset.npy
extracted_label_patches_CMIP6_dataset.npy
...# Clone the repository
git clone https://github.com/microsoft/TimeCraft.git
cd TimeCraft/OATS
# Create and activate conda environment
conda env create -f environment.yml
conda activate oatsStep 1. Train a time series generation model with the extracted sampled data.
cd models/gen_model
python main_train.py --base configs/multi_domain_timedp_local.yaml --gpus 0, --logdir ./logs/ -sl 320 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0Step 2. Train the time series foundation model
python -m cli.train_val\
-cp conf/pretrain\
-cn default_ddp_val_enc\
model=encoder\
model.enable_influence_scoring=true\
data=lotsa100M_weighted\
val_data=all\
trainer.logger.project=TSFM_PRETRAIN\
run_name=encoder10M_etth1_develop\
model.generate_after_epoch=0\
model.influence_filter_ratio=1.0\
model.select_from_generated=falseOutputs: The results can be found in wandb log and ./outputs/pretrain/
