Memory as a lens to understand learning and optimization

Speaker

USC

Host

Justin Chen, Lily Chung, John Kuszmaul
CSAIL, EECS

What is the role of memory in learning and optimization? The optimal convergence rates (measures in terms of the number of oracle queries or samples needed) for various optimization problems are achieved by computationally expensive optimization techniques, such as second-order methods and cutting-plane methods. We will discuss if simpler, faster and memory-limited algorithms such as gradient descent can achieve these optimal convergence rates for the prototypical optimization problem of minimizing a convex function with access to a gradient. Our results hint at a perhaps curious dichotomy---it is not possible to significantly improve on the convergence rate of known memory efficient techniques (which are linear-memory variants of gradient descent for many of these problems) without using substantially more memory (quadratic memory for many of these problems). Therefore memory could be a useful discerning factor to provide a clear separation between 'efficient' and 'expensive' techniques. This perspective can also be applied to understand mechanisms which transformers use to solve certain algorithmic tasks. We will show empirical evidence that transformers learn to achieve second-order convergence rates for solving linear-regression tasks, leveraging some of the theory of optimization and memory-limited learning.

This is mostly based on joint work with Annie Marsden, Aaron Sidford, Gregory Valiant, Deqing Fu, Tian-Qi Chen and Robin Jia.