Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies

Paul Vicol · Luke Metz · Jascha Sohl-Dickstein

**Concept Graph & Resume using Claude 3.5 Sonnet | Chat GPT4o | Llama 3:**

graph LR
classDef graphs fill:#f9d4d4, font-weight:bold, font-size:14px
classDef strategies fill:#d4f9d4, font-weight:bold, font-size:14px
classDef tasks fill:#d4d4f9, font-weight:bold, font-size:14px
classDef applications fill:#f9f9d4, font-weight:bold, font-size:14px
A[Unbiased Gradient Estimation

in Unrolled Computation

Graphs with Persistent

Evolution Strategies] --> B[Computation

Graphs] A --> C[Evolution

Strategies] A --> D[Tasks

and

Applications] A --> E[PES

Implementation] B --> B1[Dynamical systems

with parameters. 1] B --> B2[Traditional ML

optimization

methods. 2] B --> B3[Optimization challenges

in long

unrolls. 3] B --> B4[Gaussian-smoothed, memory-efficient

optimization. 4] B --> B5[Gradient estimate

from parameter

sequence. 6] B --> B6[Sequential gradient

estimate

summation. 7] C --> C1[Unbiased computation

graph

splitting. 5] C --> C2[ES with

particle

tracking. 8] C --> C3[Correlation-dependent

gradient

variance. 9] C --> C4[Partial unrolls,

unbiased

estimates. 16] C --> C5[PES navigates

chaotic

landscapes. 17] C --> C6[Easily

parallelizable PES. 18] D --> D1[PESs unbiasedness

demonstration. 10] D --> D2[PES in

chaotic

regions. 11] D --> D3[Optimal region

convergence. 12] D --> D4[PES for MLP

on

MNIST. 13] D --> D5[Consistent, lower

losses

with PES. 14] D --> D6[PES efficiency

in swimmer

task. 15] E --> E1[Positive, negative

perturbation

pairs. 22] E --> E2[PES example

using

JAX. 23] E --> E3[Particle state

tracking

difference. 24] E --> E4[PES in chaotic

hyperparameter

tasks. 25] E --> E5[PES in MLP

training. 26] E --> E6[Full sequence

correction

terms. 28] class A,B,B1,B2,B3,B4,B5,B6 graphs class C,C1,C2,C3,C4,C5,C6 strategies class D,D1,D2,D3,D4,D5,D6 tasks class E,E1,E2,E3,E4,E5,E6 applications

in Unrolled Computation

Graphs with Persistent

Evolution Strategies] --> B[Computation

Graphs] A --> C[Evolution

Strategies] A --> D[Tasks

and

Applications] A --> E[PES

Implementation] B --> B1[Dynamical systems

with parameters. 1] B --> B2[Traditional ML

optimization

methods. 2] B --> B3[Optimization challenges

in long

unrolls. 3] B --> B4[Gaussian-smoothed, memory-efficient

optimization. 4] B --> B5[Gradient estimate

from parameter

sequence. 6] B --> B6[Sequential gradient

estimate

summation. 7] C --> C1[Unbiased computation

graph

splitting. 5] C --> C2[ES with

particle

tracking. 8] C --> C3[Correlation-dependent

gradient

variance. 9] C --> C4[Partial unrolls,

unbiased

estimates. 16] C --> C5[PES navigates

chaotic

landscapes. 17] C --> C6[Easily

parallelizable PES. 18] D --> D1[PESs unbiasedness

demonstration. 10] D --> D2[PES in

chaotic

regions. 11] D --> D3[Optimal region

convergence. 12] D --> D4[PES for MLP

on

MNIST. 13] D --> D5[Consistent, lower

losses

with PES. 14] D --> D6[PES efficiency

in swimmer

task. 15] E --> E1[Positive, negative

perturbation

pairs. 22] E --> E2[PES example

using

JAX. 23] E --> E3[Particle state

tracking

difference. 24] E --> E4[PES in chaotic

hyperparameter

tasks. 25] E --> E5[PES in MLP

training. 26] E --> E6[Full sequence

correction

terms. 28] class A,B,B1,B2,B3,B4,B5,B6 graphs class C,C1,C2,C3,C4,C5,C6 strategies class D,D1,D2,D3,D4,D5,D6 tasks class E,E1,E2,E3,E4,E5,E6 applications

**Resume: **

**1.-** Unrolled computation graphs: Represent dynamical systems with parameters governing state evolution over time, used in various machine learning applications.

**2.-** Classic optimization approaches: Backpropagation through time, truncated backprop, real-time recurrent learning (RTRL), and approximations, each with limitations.

**3.-** Chaotic loss landscapes: Long unrolls can lead to chaotic or poorly conditioned loss landscapes, making optimization challenging.

**4.-** Evolution Strategies (ES): Optimizes Gaussian-smoothed meta-objective, doesn't require backprop, memory-efficient, can optimize blackbox functions, scalable on parallel compute.

**5.-** Persistent Evolution Strategies (PES): Unbiased approach splitting computation graph into truncated unrolls, accumulating correction terms over full sequence.

**6.-** PES derivation: Uses shift in notation, considering loss as function of entire parameter sequence, deriving gradient estimate.

**7.-** PES decomposition: Breaks down into sum of sequential gradient estimates, accumulating perturbations over multiple unrolls.

**8.-** PES implementation: Similar to ES but with particle state tracking and perturbation accumulation.

**9.-** PES variance: Depends on correlation between gradients at each unroll, can decrease with more unrolls under certain conditions.

**10.-** Synthetic influence balancing task: Demonstrates PES's unbiasedness, converging to correct solutions unlike truncated methods.

**11.-** Hyperparameter optimization: PES outperforms truncated methods on toy 2D regression task with chaotic regions.

**12.-** MNIST learning rate schedule: PES converges to optimal region for both differentiable and non-differentiable objectives.

**13.-** Multi-hyperparameter tuning: PES outperforms truncated ES and random search for tuning 20 hyperparameters of MLP on MNIST.

**14.-** Learned optimizer training: PES achieves lower losses and more consistency than ES when meta-training MLP-based optimizer.

**15.-** Continuous control policy learning: PES more efficient than ES on full episodes for swimmer task, while truncated ES fails.

**16.-** Unbiased gradient estimation: PES provides unbiased estimates from partial unrolls, unlike truncated methods.

**17.-** Loss surface smoothing: PES inherits this useful characteristic from ES, helping navigate chaotic landscapes.

**18.-** Parallelizability: PES is easily parallelizable, inheriting this advantage from ES.

**19.-** Non-differentiable objectives: PES can work with non-differentiable functions like accuracy instead of loss.

**20.-** Tractable compute and memory cost: PES achieves this while providing unbiased estimates from partial unrolls.

**21.-** Applications: PES applicable to hyperparameter optimization, training learned optimizers, and reinforcement learning.

**22.-** Antithetic sampling: Used in practice for PES, sampling pairs of positive and negative perturbations at each time step.

**23.-** JAX implementation: Example of PES estimator implementation using JAX, demonstrating simplicity and parallelization.

**24.-** Comparison with truncated ES: PES differs in tracking particle states and accumulating perturbations over time.

**25.-** Meta-loss surface visualization: Illustrates chaotic regions in hyperparameter optimization tasks where PES excels.

**26.-** CIFAR-10 experiment: PES outperforms ES in meta-training learned optimizer for training MLP on CIFAR-10.

**27.-** Mujoco swimmer task: Demonstrates PES's efficiency in learning continuous control policies using partial unrolls.

**28.-** Bias elimination: PES eliminates bias from truncations by accumulating correction terms over full sequence of unrolls.

**29.-** Frequent parameter updates: PES allows for more frequent updates compared to full-unroll ES, improving efficiency.

**30.-** Easy implementation: PES is described as an easy-to-implement modification of ES, making it accessible for various applications.

Knowledge Vault built byDavid Vivancos 2024