A meta-learning framework for predicting generalization

This post is a summary of the paper Neural Complexity Measures.

Say we have a neural network \(f_\theta: \mathcal{X} \rightarrow \mathcal{Y}\).
We want to predict and/or prevent overfitting, so we are often interested in measuring the complexity of the function \(f_\theta\).
**How should we measure the complexity of a neural network?**
Some common approaches are:

- number of parameters in \(\theta\)
- norm of parameter vector \(\Vert \theta \Vert_p\)
- distance to initialization \(\Vert \theta - \theta_0 \Vert_p\)
- flatness of minima \(\Vert \nabla^2 \mathcal{L} (\theta) \Vert\)

Each of these hand-designed measures of complexity fails to capture the behavior of neural networks used in practice. A common feature, attributable to the fact that people designed them, is that they’re simple equations defined in parameter space. I believe that defining a complexity measure in parameter space is the wrong approach. Parameter space is insanely messy and high-dimensional, and a simple equation like the ones above will likely be insufficient to capture this intricacy. Another concern is in calibration between models: changing the architecture or increasing the number of parameters alters the parameter space’s geometry drastically, making a comparison between different models hard.

We **we defined a complexity measure in function space**, and (2) **we learned this measure in a data-driven way**.
More concretely, for any given neural network, a meta-learned model predicts its generalization gap:

The generalization gap is a direct quantitative measure of the degree of overfitting. While most approaches attempt to find suitable proxies for this quantity, we adopt a meta-learning framework that treats the estimation of the generalization gap as a set-input regression problem.

We call this meta-learned estimator a Neural Complexity (NC) measure. We train NC with the following meta-learning loop:

We continually use NC as a regularizer for new task learning runs and store snapshots of these runs into a memory bank. We train NC using minibatches of snapshots, sampled randomly from the memory bank.

NC is a neural network which takes (data, outputs, labels) for training data and (data, outputs) for held-out validation data to produce a single scalar value:

In the paper, we show proof-of-concept experiments that show that an NC model can learn to predict the generalization gap in synthetic regression and real-world image classification problems. The trained NC models show signs of generalization to out-of-distribution learner architectures. Interestingly, using NC as a regularizer resulted in lower test loss than train loss.

Neural Complexity (NC) is a meta-learning framework for predicting generalization, which has exciting implications for both meta-learning and understanding generalization in deep networks.

We typically use meta-learning for small problems such as few-shot classification.
Because they rely on feeding the entire dataset into the computation graph, previous meta-learning frameworks

This work showed that a meta-learned model could predict the generalization gap reliably enough to be used as a regularizer. We also proposed a simple way of translating regression accuracy to a generalization bound. I think this shows the potential of meta-learning as a tool for understanding generalization in deep networks. Further improvements to the NC framework and integration with theoretical tools for understanding generalization could be a promising way forward in this problem.