Jorg Bornschein and Yoshua Bengio ICLR 2015 - Reweighted Wake-Sleep

**Concept Graph & Resume using Claude 3 Opus | Chat GPT4 | Gemini Adv | Llama 3:**

graph LR
classDef helmholtz fill:#f9d4d4, font-weight:bold, font-size:14px;
classDef inference fill:#d4f9d4, font-weight:bold, font-size:14px;
classDef sampling fill:#d4d4f9, font-weight:bold, font-size:14px;
classDef training fill:#f9f9d4, font-weight:bold, font-size:14px;
classDef results fill:#f9d4f9, font-weight:bold, font-size:14px;
A[Bornschein, Bengio

ICLR 2015] --> B[Helmholtz machines: directed generative,

intractable inference. 1] B --> C[Inference network Q approximates

latent variables. 2] B --> D[Recent Helmholtz models: VAEs,

DARN, etc. 3] D --> E[Joint objective, inference and

generative networks. 4] E --> F[Naive Q sampling: high

variance estimators. 5] A --> G[Importance sampling, not variational

approximations. 6] G --> H[Q as proposal distribution

for trained models. 7] G --> I[Unbiased likelihood estimator,

Q-dependent variance. 8] I --> J[Minimal variance: Q approximates

P H X . 9] G --> K[Importance sampling for model

parameter updates. 10] K --> L[P gradient: K Q samples,

importance weights. 11] L --> M[Weighted gradient average,

no P backpropagation. 12] L --> N[Layer-wise targets, normalized

importance weights. 13] G --> O[No clear Q training signal. 14] O --> P[Q proposes, affects only

estimator variance. 15] P --> Q[Train Q to minimize

estimator variance. 16] Q --> R[Q trains on real wake

or model data sleep. 17] G --> S[Wake and sleep phase

gradient updates. 18] S --> T[Wake: Q gradients like

P's estimator. 19] S --> U[P and Q updates: proposals,

weights, averaging. 20] U --> V[Q updates minimize KL P Q . 21] V --> W[Variational methods typically

minimize KL Q P . 22] R --> X[K=1, sleep only recovers

classical wake-sleep. 23] A --> Y[5-10 samples improve on

wake-sleep, near SOTA. 24] Y --> Z[Wake and sleep updates

train Q best. 25] Y --> AA[Trains deep nets on

MNIST competitively. 26] Y --> AB[More experiments: sample count

sensitivity. 27] A --> AC[Goal: fast feedforward inference,

generative net. 28] AC --> AD[Competitive with VAEs,

5-10x slower. 29] AC --> AE[2x slower than binary

NVIL baseline MLP. 30] class A,B,D helmholtz; class C,E,F,O,P,Q,R,S,T,U,V,W inference; class G,H,I,J,K,L,M,N sampling; class X,AC,AD,AE training; class Y,Z,AA,AB results;

ICLR 2015] --> B[Helmholtz machines: directed generative,

intractable inference. 1] B --> C[Inference network Q approximates

latent variables. 2] B --> D[Recent Helmholtz models: VAEs,

DARN, etc. 3] D --> E[Joint objective, inference and

generative networks. 4] E --> F[Naive Q sampling: high

variance estimators. 5] A --> G[Importance sampling, not variational

approximations. 6] G --> H[Q as proposal distribution

for trained models. 7] G --> I[Unbiased likelihood estimator,

Q-dependent variance. 8] I --> J[Minimal variance: Q approximates

P H X . 9] G --> K[Importance sampling for model

parameter updates. 10] K --> L[P gradient: K Q samples,

importance weights. 11] L --> M[Weighted gradient average,

no P backpropagation. 12] L --> N[Layer-wise targets, normalized

importance weights. 13] G --> O[No clear Q training signal. 14] O --> P[Q proposes, affects only

estimator variance. 15] P --> Q[Train Q to minimize

estimator variance. 16] Q --> R[Q trains on real wake

or model data sleep. 17] G --> S[Wake and sleep phase

gradient updates. 18] S --> T[Wake: Q gradients like

P's estimator. 19] S --> U[P and Q updates: proposals,

weights, averaging. 20] U --> V[Q updates minimize KL P Q . 21] V --> W[Variational methods typically

minimize KL Q P . 22] R --> X[K=1, sleep only recovers

classical wake-sleep. 23] A --> Y[5-10 samples improve on

wake-sleep, near SOTA. 24] Y --> Z[Wake and sleep updates

train Q best. 25] Y --> AA[Trains deep nets on

MNIST competitively. 26] Y --> AB[More experiments: sample count

sensitivity. 27] A --> AC[Goal: fast feedforward inference,

generative net. 28] AC --> AD[Competitive with VAEs,

5-10x slower. 29] AC --> AE[2x slower than binary

NVIL baseline MLP. 30] class A,B,D helmholtz; class C,E,F,O,P,Q,R,S,T,U,V,W inference; class G,H,I,J,K,L,M,N sampling; class X,AC,AD,AE training; class Y,Z,AA,AB results;

**Resume: **

**1.-**Helmholtz machines fit a directed generative model to data using maximum likelihood, which is generally intractable for models with many latent variables.

**2.-**An inference network Q is trained to help do approximate inference, running from observed data to latent variables.

**3.-**Recent work on training Helmholtz machines includes variational autoencoders, stochastic backpropagation, neural variational inference, and DARN models.

**4.-**These models rely on a variational approximation to obtain a joint objective function containing both the inference and generative networks.

**5.-**The joint objective cannot be naively trained using samples from Q as it results in high variance estimators.

**6.-**This work derives parameter update equations using only importance sampling, not variational approximations.

**7.-**Importance sampling has been used before to evaluate already trained models by interpreting Q as a proposal distribution.

**8.-**The likelihood estimator obtained is unbiased. Its variance depends on the quality of the proposal distribution Q.

**9.-**Variance is minimized when Q approximates the true intractable posterior P(H|X). Equality gives a zero variance estimator.

**10.-**The same importance sampling mechanism is used to derive parameter update rules for the model.

**11.-**The gradient estimator for the generative model P involves drawing K samples from Q and calculating importance weights.

**12.-**A weighted average of gradients for each sample is taken using the importance weights. No backpropagation through P is needed.

**13.-**Layer-wise targets allow each layer of P to get a local gradient. Importance weights are automatically normalized.

**14.-**No obvious signal exists for how to train the feedforward proposal network Q using this approach.

**15.-**Q acts only as a proposal and doesn't influence expected values of estimators, likelihood, or gradients, only their variance.

**16.-**Q is trained to minimize variance of estimators, which means approximating the intractable posterior P(H|X).

**17.-**Q can be trained on real data (wake phase) or imaginary data from the generative model (sleep phase).

**18.-**Simple gradient updates are derived for both wake and sleep phases. Sleep phase uses a single sample as Q's target.

**19.-**_WAKE phase uses the same importance sampling mechanism to derive Q gradients, with the same structure as P's gradient estimator.

**20.-**To train both P and Q, proposals are drawn using Q, importance weights calculated, and used to average P and Q updates.

**21.-**The Q updates are equivalent to minimizing KL divergence between the true posterior P(H|X) and proposal Q.

**22.-**Variational approximations typically yield updates minimizing KL(Q||P) instead. Justifying with importance sampling gives the reverse KL(P||Q).

**23.-**Using K=1 samples and only sleep phase updates recovers the classical wake-sleep algorithm from the 90s.

**24.-**Empirical results on various datasets show 5-10 samples significantly improves on classical wake-sleep and approaches state-of-the-art.

**25.-**Best results come from using both wake and sleep phase updates to train the inference network.

**26.-**The approach can train relatively deep networks on real-world datasets like binary MNIST, achieving competitive test log likelihood.

**27.-**More experiments showed sensitivity to number of samples, but were not covered in depth due to time constraints.

**28.-**The motivation was to combine fast feedforward inference with a generative network, an old idea that makes sense.

**29.-**Compared to variational methods like VAEs, this is competitive but 5-10x slower due to multiple samples needed.

**30.-**Compared to methods for binary variables like NVIL, this may be ~2x slower as NVIL needs a second MLP pass to estimate baselines.

Knowledge Vault built byDavid Vivancos 2024