Skip to content

Commit e0ff3e9

Browse files
committed
Merge branch 'main' of github.com:wgrathwohl/VERA into main
2 parents c29c3a8 + 45d7932 commit e0ff3e9

1 file changed

Lines changed: 16 additions & 7 deletions

File tree

README.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Code for the paper:
1010
<img align="middle" src="./assets/fig1.png" width="500" />
1111
</p>
1212

13-
Code for implementing **V**ariational **E**ntropy **R**egularized **A**pproximate maximum likelihood (VERA). Contains scripts for training VERA and using VERA for [JEM](https://github.com/wgrathwohl/JEM) training. Code is also available for training semi-supervised models on tabular data, mode counting experiments, and tractable likelihood models.
13+
Code for implementing **V**ariational **E**ntropy **R**egularized **A**pproximate maximum likelihood (VERA). Contains scripts for training VERA and using VERA for [JEM](https://github.com/wgrathwohl/JEM) training. Code is also available for training semi-supervised models on tabular data, mode counting experiments, and tractable likelihood models experiments.
1414

1515
For more info on me and my work please checkout my [website](http://www.cs.toronto.edu/~wgrathwohl/), [twitter](https://twitter.com/wgrathwohl), or [Google Scholar](https://scholar.google.ca/citations?user=ZbClz98AAAAJ&hl=en).
1616

@@ -33,20 +33,29 @@ tqdm
3333
### Hyperparameters
3434

3535
A brief explanation of hyperparameters that can be set from flags and their names in the paper.
36-
- `--clf_weight` Classification weight (`\alpha` in the paper)
37-
- `--pg_control` Gradient norm penalty (`\gamma` in the paper)
38-
- `--ent_weight` Entropy regularization weight (`\lambda` in the paper)
39-
- `--clf_ent_weight` Classification entropy (`\beta` in the paper)
36+
- `--clf_weight` Classification weight (`\alpha`)
37+
- `--pg_control` Gradient norm penalty (`\gamma`)
38+
- `--ent_weight` Entropy regularization weight (`\lambda`)
39+
- `--clf_ent_weight` Classification entropy (`\beta`)
4040

4141
### Training
4242

43-
An explanation of flags for different modes of training
43+
An explanation of flags for different modes of training. Without any of these flags, an unsupervised VERA model will be trained.
4444

4545
- `--clf_only` For training a classifier on its own, i.e. without an EBM as in JEM.
4646
- `--jem` Do JEM training.
4747
- `--labels_per_class` If this is greater than zero, use this many labels per class for semi-supervised learning. If zero (default), do full-label training.
4848

49-
For example, to train a CIFAR10 JEM model: # TODO
49+
To train a CIFAR10/CIFAR100 JEM model as in the paper, run:
50+
51+
```markdown
52+
python train.py --dataset DATASET # cifar10 or cifar100
53+
--ent_weight 0.0001 --noise_dim 128 \
54+
--viz_every 1000 --save_dir /YOUR/SAVE/DIR --data_aug --dropout .3 --thicc_resnet \
55+
--ckpt_path /PATH/TO/YOUR/MODEL.pt --generator_type vera --n_epochs 200 --print_every 100 \
56+
--lr .00003 --glr .00006 --post_lr .00003 --batch_size 40 --pg_control .1 \
57+
--decay_epochs 150 175 --jem --warmup_iters 2500 --clf_weight 100. --g_feats 256
58+
```
5059

5160
### Evaluation
5261

0 commit comments

Comments
 (0)