Knowledge Vault 2/19 - ICLR 2014-2023
Jorg Bornschein and Yoshua Bengio ICLR 2015 - Reweighted Wake-Sleep
<Resume Image >

Concept Graph & Resume using Claude 3 Opus | Chat GPT4 | Gemini Adv | Llama 3:

graph LR classDef helmholtz fill:#f9d4d4, font-weight:bold, font-size:14px; classDef inference fill:#d4f9d4, font-weight:bold, font-size:14px; classDef sampling fill:#d4d4f9, font-weight:bold, font-size:14px; classDef training fill:#f9f9d4, font-weight:bold, font-size:14px; classDef results fill:#f9d4f9, font-weight:bold, font-size:14px; A[Bornschein, Bengio
ICLR 2015] --> B[Helmholtz machines: directed generative,
intractable inference. 1] B --> C[Inference network Q approximates
latent variables. 2] B --> D[Recent Helmholtz models: VAEs,
DARN, etc. 3] D --> E[Joint objective, inference and
generative networks. 4] E --> F[Naive Q sampling: high
variance estimators. 5] A --> G[Importance sampling, not variational
approximations. 6] G --> H[Q as proposal distribution
for trained models. 7] G --> I[Unbiased likelihood estimator,
Q-dependent variance. 8] I --> J[Minimal variance: Q approximates
P H X . 9] G --> K[Importance sampling for model
parameter updates. 10] K --> L[P gradient: K Q samples,
importance weights. 11] L --> M[Weighted gradient average,
no P backpropagation. 12] L --> N[Layer-wise targets, normalized
importance weights. 13] G --> O[No clear Q training signal. 14] O --> P[Q proposes, affects only
estimator variance. 15] P --> Q[Train Q to minimize
estimator variance. 16] Q --> R[Q trains on real wake
or model data sleep. 17] G --> S[Wake and sleep phase
gradient updates. 18] S --> T[Wake: Q gradients like
P's estimator. 19] S --> U[P and Q updates: proposals,
weights, averaging. 20] U --> V[Q updates minimize KL P Q . 21] V --> W[Variational methods typically
minimize KL Q P . 22] R --> X[K=1, sleep only recovers
classical wake-sleep. 23] A --> Y[5-10 samples improve on
wake-sleep, near SOTA. 24] Y --> Z[Wake and sleep updates
train Q best. 25] Y --> AA[Trains deep nets on
MNIST competitively. 26] Y --> AB[More experiments: sample count
sensitivity. 27] A --> AC[Goal: fast feedforward inference,
generative net. 28] AC --> AD[Competitive with VAEs,
5-10x slower. 29] AC --> AE[2x slower than binary
NVIL baseline MLP. 30] class A,B,D helmholtz; class C,E,F,O,P,Q,R,S,T,U,V,W inference; class G,H,I,J,K,L,M,N sampling; class X,AC,AD,AE training; class Y,Z,AA,AB results;

Resume:

1.-Helmholtz machines fit a directed generative model to data using maximum likelihood, which is generally intractable for models with many latent variables.

2.-An inference network Q is trained to help do approximate inference, running from observed data to latent variables.

3.-Recent work on training Helmholtz machines includes variational autoencoders, stochastic backpropagation, neural variational inference, and DARN models.

4.-These models rely on a variational approximation to obtain a joint objective function containing both the inference and generative networks.

5.-The joint objective cannot be naively trained using samples from Q as it results in high variance estimators.

6.-This work derives parameter update equations using only importance sampling, not variational approximations.

7.-Importance sampling has been used before to evaluate already trained models by interpreting Q as a proposal distribution.

8.-The likelihood estimator obtained is unbiased. Its variance depends on the quality of the proposal distribution Q.

9.-Variance is minimized when Q approximates the true intractable posterior P(H|X). Equality gives a zero variance estimator.

10.-The same importance sampling mechanism is used to derive parameter update rules for the model.

11.-The gradient estimator for the generative model P involves drawing K samples from Q and calculating importance weights.

12.-A weighted average of gradients for each sample is taken using the importance weights. No backpropagation through P is needed.

13.-Layer-wise targets allow each layer of P to get a local gradient. Importance weights are automatically normalized.

14.-No obvious signal exists for how to train the feedforward proposal network Q using this approach.

15.-Q acts only as a proposal and doesn't influence expected values of estimators, likelihood, or gradients, only their variance.

16.-Q is trained to minimize variance of estimators, which means approximating the intractable posterior P(H|X).

17.-Q can be trained on real data (wake phase) or imaginary data from the generative model (sleep phase).

18.-Simple gradient updates are derived for both wake and sleep phases. Sleep phase uses a single sample as Q's target.

19.-_WAKE phase uses the same importance sampling mechanism to derive Q gradients, with the same structure as P's gradient estimator.

20.-To train both P and Q, proposals are drawn using Q, importance weights calculated, and used to average P and Q updates.

21.-The Q updates are equivalent to minimizing KL divergence between the true posterior P(H|X) and proposal Q.

22.-Variational approximations typically yield updates minimizing KL(Q||P) instead. Justifying with importance sampling gives the reverse KL(P||Q).

23.-Using K=1 samples and only sleep phase updates recovers the classical wake-sleep algorithm from the 90s.

24.-Empirical results on various datasets show 5-10 samples significantly improves on classical wake-sleep and approaches state-of-the-art.

25.-Best results come from using both wake and sleep phase updates to train the inference network.

26.-The approach can train relatively deep networks on real-world datasets like binary MNIST, achieving competitive test log likelihood.

27.-More experiments showed sensitivity to number of samples, but were not covered in depth due to time constraints.

28.-The motivation was to combine fast feedforward inference with a generative network, an old idea that makes sense.

29.-Compared to variational methods like VAEs, this is competitive but 5-10x slower due to multiple samples needed.

30.-Compared to methods for binary variables like NVIL, this may be ~2x slower as NVIL needs a second MLP pass to estimate baselines.

Knowledge Vault built byDavid Vivancos 2024