Zangwei Zheng, zangwei@u.nus.edu
National University of Singapore
ICLR 2024 Oral
Other version: [arXiv] [Code] [中文]
Discuss on X with the author.
TL;DR
Training for many epochs wastes time on easy, well‑learned samples. InfoBatch speeds things up by dynamically pruning data and rescaling the loss to keep performance. It delivers 20–40% faster training on image classification, semantic segmentation, vision pretraining, diffusion models, and LLM instruction fine‑tuning—without losing accuracy.
How does InfoBatch work?
We provide a plug‑and‑play PyTorch implementation for InfoBatch (under active development). With the three changes shown below, you can plug InfoBatch into your training code.
Here is a brief overview of the InfoBatch algorithm.
- First, InfoBatch randomly drops a fraction of samples whose loss is below the average loss over the batch. The paper discusses more advanced strategies, but this simple rule already works very well.
- Second, for the remaining below‑average‑loss samples, InfoBatch rescales their loss by to keep overall training unbiased.
- Third, at the end of training, InfoBatch runs through all samples once to mitigate forgetting.
The hyperparameter controls the fraction of epochs that perform on‑the‑fly pruning. A good starting point is .
In the code above: (1) the dataset is wrapped and the index order is managed, (2) the InfoBatch sampler is passed to the DataLoader constructor, and (3) the loss is rescaled and the sampler is updated with the loss between the forward and backward pass. For more mathematical discussion and ablations, see the paper. For parallel training, see the code.
Applications
The idea behind InfoBatch is simple but effective across many applications.
- Image classification: 40% speedup with no accuracy drop, unlike prior methods.
- MAE pretraining: 20% time saved for ViT and Swin, with no downstream accuracy loss.
- Semantic segmentation: 40% time saved with no mIoU degradation.
- Diffusion models: 27% time saved with comparable FID.
- LLM instruction fine‑tuning: 20% time saved.