Knowledge Vault 6 /67 - ICML 2021
Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies
Paul Vicol · Luke Metz · Jascha Sohl-Dickstein
< Resume Image >

Concept Graph & Resume using Claude 3.5 Sonnet | Chat GPT4o | Llama 3:

graph LR classDef graphs fill:#f9d4d4, font-weight:bold, font-size:14px classDef strategies fill:#d4f9d4, font-weight:bold, font-size:14px classDef tasks fill:#d4d4f9, font-weight:bold, font-size:14px classDef applications fill:#f9f9d4, font-weight:bold, font-size:14px A[Unbiased Gradient Estimation
in Unrolled Computation
Graphs with Persistent
Evolution Strategies] --> B[Computation
Graphs] A --> C[Evolution
Strategies] A --> D[Tasks
and
Applications] A --> E[PES
Implementation] B --> B1[Dynamical systems
with parameters. 1] B --> B2[Traditional ML
optimization
methods. 2] B --> B3[Optimization challenges
in long
unrolls. 3] B --> B4[Gaussian-smoothed, memory-efficient
optimization. 4] B --> B5[Gradient estimate
from parameter
sequence. 6] B --> B6[Sequential gradient
estimate
summation. 7] C --> C1[Unbiased computation
graph
splitting. 5] C --> C2[ES with
particle
tracking. 8] C --> C3[Correlation-dependent
gradient
variance. 9] C --> C4[Partial unrolls,
unbiased
estimates. 16] C --> C5[PES navigates
chaotic
landscapes. 17] C --> C6[Easily
parallelizable PES. 18] D --> D1[PESs unbiasedness
demonstration. 10] D --> D2[PES in
chaotic
regions. 11] D --> D3[Optimal region
convergence. 12] D --> D4[PES for MLP
on
MNIST. 13] D --> D5[Consistent, lower
losses
with PES. 14] D --> D6[PES efficiency
in swimmer
task. 15] E --> E1[Positive, negative
perturbation
pairs. 22] E --> E2[PES example
using
JAX. 23] E --> E3[Particle state
tracking
difference. 24] E --> E4[PES in chaotic
hyperparameter
tasks. 25] E --> E5[PES in MLP
training. 26] E --> E6[Full sequence
correction
terms. 28] class A,B,B1,B2,B3,B4,B5,B6 graphs class C,C1,C2,C3,C4,C5,C6 strategies class D,D1,D2,D3,D4,D5,D6 tasks class E,E1,E2,E3,E4,E5,E6 applications

Resume:

1.- Unrolled computation graphs: Represent dynamical systems with parameters governing state evolution over time, used in various machine learning applications.

2.- Classic optimization approaches: Backpropagation through time, truncated backprop, real-time recurrent learning (RTRL), and approximations, each with limitations.

3.- Chaotic loss landscapes: Long unrolls can lead to chaotic or poorly conditioned loss landscapes, making optimization challenging.

4.- Evolution Strategies (ES): Optimizes Gaussian-smoothed meta-objective, doesn't require backprop, memory-efficient, can optimize blackbox functions, scalable on parallel compute.

5.- Persistent Evolution Strategies (PES): Unbiased approach splitting computation graph into truncated unrolls, accumulating correction terms over full sequence.

6.- PES derivation: Uses shift in notation, considering loss as function of entire parameter sequence, deriving gradient estimate.

7.- PES decomposition: Breaks down into sum of sequential gradient estimates, accumulating perturbations over multiple unrolls.

8.- PES implementation: Similar to ES but with particle state tracking and perturbation accumulation.

9.- PES variance: Depends on correlation between gradients at each unroll, can decrease with more unrolls under certain conditions.

10.- Synthetic influence balancing task: Demonstrates PES's unbiasedness, converging to correct solutions unlike truncated methods.

11.- Hyperparameter optimization: PES outperforms truncated methods on toy 2D regression task with chaotic regions.

12.- MNIST learning rate schedule: PES converges to optimal region for both differentiable and non-differentiable objectives.

13.- Multi-hyperparameter tuning: PES outperforms truncated ES and random search for tuning 20 hyperparameters of MLP on MNIST.

14.- Learned optimizer training: PES achieves lower losses and more consistency than ES when meta-training MLP-based optimizer.

15.- Continuous control policy learning: PES more efficient than ES on full episodes for swimmer task, while truncated ES fails.

16.- Unbiased gradient estimation: PES provides unbiased estimates from partial unrolls, unlike truncated methods.

17.- Loss surface smoothing: PES inherits this useful characteristic from ES, helping navigate chaotic landscapes.

18.- Parallelizability: PES is easily parallelizable, inheriting this advantage from ES.

19.- Non-differentiable objectives: PES can work with non-differentiable functions like accuracy instead of loss.

20.- Tractable compute and memory cost: PES achieves this while providing unbiased estimates from partial unrolls.

21.- Applications: PES applicable to hyperparameter optimization, training learned optimizers, and reinforcement learning.

22.- Antithetic sampling: Used in practice for PES, sampling pairs of positive and negative perturbations at each time step.

23.- JAX implementation: Example of PES estimator implementation using JAX, demonstrating simplicity and parallelization.

24.- Comparison with truncated ES: PES differs in tracking particle states and accumulating perturbations over time.

25.- Meta-loss surface visualization: Illustrates chaotic regions in hyperparameter optimization tasks where PES excels.

26.- CIFAR-10 experiment: PES outperforms ES in meta-training learned optimizer for training MLP on CIFAR-10.

27.- Mujoco swimmer task: Demonstrates PES's efficiency in learning continuous control policies using partial unrolls.

28.- Bias elimination: PES eliminates bias from truncations by accumulating correction terms over full sequence of unrolls.

29.- Frequent parameter updates: PES allows for more frequent updates compared to full-unroll ES, improving efficiency.

30.- Easy implementation: PES is described as an easy-to-implement modification of ES, making it accessible for various applications.

Knowledge Vault built byDavid Vivancos 2024