A meta-learning framework for predicting generalization
This post is a summary of the paper Neural Complexity Measures.
Say we have a neural network \(f_\theta: \mathcal{X} \rightarrow \mathcal{Y}\). We want to predict and/or prevent overfitting, so we are often interested in measuring the complexity of the function \(f_\theta\). How should we measure the complexity of a neural network? Some common approaches are:
Each of these hand-designed measures of complexity fails to capture the behavior of neural networks used in practice. A common feature, attributable to the fact that people designed them, is that they’re simple equations defined in parameter space. I believe that defining a complexity measure in parameter space is the wrong approach. Parameter space is insanely messy and high-dimensional, and a simple equation like the ones above will likely be insufficient to capture this intricacy. Another concern is in calibration between models: changing the architecture or increasing the number of parameters alters the parameter space’s geometry drastically, making a comparison between different models hard.
We
The generalization gap is a direct quantitative measure of the degree of overfitting. While most approaches attempt to find suitable proxies for this quantity, we adopt a meta-learning framework that treats the estimation of the generalization gap as a set-input regression problem.
We call this meta-learned estimator a Neural Complexity (NC) measure. We train NC with the following meta-learning loop:
We continually use NC as a regularizer for new task learning runs and store snapshots of these runs into a memory bank. We train NC using minibatches of snapshots, sampled randomly from the memory bank.
NC is a neural network which takes (data, outputs, labels) for training data and (data, outputs) for held-out validation data to produce a single scalar value:
In the paper, we show proof-of-concept experiments that show that an NC model can learn to predict the generalization gap in synthetic regression and real-world image classification problems. The trained NC models show signs of generalization to out-of-distribution learner architectures. Interestingly, using NC as a regularizer resulted in lower test loss than train loss.
Neural Complexity (NC) is a meta-learning framework for predicting generalization, which has exciting implications for both meta-learning and understanding generalization in deep networks.
We typically use meta-learning for small problems such as few-shot classification.
Because they rely on feeding the entire dataset into the computation graph, previous meta-learning frameworks
This work showed that a meta-learned model could predict the generalization gap reliably enough to be used as a regularizer. We also proposed a simple way of translating regression accuracy to a generalization bound. I think this shows the potential of meta-learning as a tool for understanding generalization in deep networks. Further improvements to the NC framework and integration with theoretical tools for understanding generalization could be a promising way forward in this problem.