- Project and Probe: Sample-Efficient Domain Adaptation by Interpolating Orthogonal FeaturesAnnie S. Chen*, Yoonho Lee*, Amrith Setlur, Sergey Levine, Chelsea Finnpreprint [abstract] [paper]
TLDR: We propose Project and Probe, a lightweight, sample-efficient approach that learns a diverse set of predictive features and adapts to a target distribution by interpolating among them with a small target dataset.
Conventional approaches to robustness try to learn a model based on causal features. However, identifying maximally robust or causal features may be difficult in some scenarios, and in others, non-causal "shortcut" features may actually be more predictive. We propose a lightweight, sample-efficient approach that learns a diverse set of features and adapts to a target distribution by interpolating these features with a small target dataset. Our approach, Project and Probe (Pro2), first learns a linear projection that maps a pre-trained embedding onto orthogonal directions while being predictive of labels in the source dataset. The goal of this step is to learn a variety of predictive features, so that at least some of them remain useful after distribution shift. Pro2 then learns a linear classifier on top of these projected features using a small target dataset. We theoretically show that Pro2 learns a projection matrix that is optimal for classification in an information-theoretic sense, resulting in better generalization due to a favorable bias-variance tradeoff. Our experiments on four datasets, with multiple distribution shift settings for each, show that Pro2 improves performance by 5-15% when given limited target data compared to prior methods such as standard linear probing.
- DetectGPT: Zero-Shot Machine-Generated Text Detection using Probability CurvatureEric Mitchell, Yoonho Lee, Alexander Khazatsky, Christopher D Manning, Chelsea Finnpreprint [abstract] [paper] [website]
TLDR: We develop a method that can detect if a passage is generated by a particular language model. Our method is based on the hypothesis that a passage is likely model-generated if it is near a local maximum in the model’s predictive probability space.
The fluency and factual knowledge of large language models (LLMs) heightens the need for corresponding systems to detect whether a piece of text is machine-written. For example, students may use LLMs to complete written assignments, leaving instructors unable to accurately assess student learning. In this paper, we first demonstrate that text sampled from an LLM tends to occupy negative curvature regions of the model’s log probability function. Leveraging this observation, we then define a new curvature-based criterion for judging if a passage is generated from a given LLM. This approach, which we call DetectGPT, does not require training a separate classifier, collecting a dataset of real or generated passages, or explicitly watermarking generated text. It uses only log probabilities computed by the model of interest and random perturbations of the passage from another generic pre-trained language model (e.g, T5). We find DetectGPT is more discriminative than existing zero-shot methods for model sample detection, notably improving detection of fake news articles generated by 20B parameter GPT-NeoX from 0.81 AUROC for the strongest zero-shot baseline to 0.95 AUROC for DetectGPT.
- ICLRSurgical Fine-Tuning Improves Adaptation to Distribution ShiftsYoonho Lee*, Annie S. Chen*, Fahim Tajwar, Ananya Kumar, Huaxiu Yao, Percy Liang, Chelsea FinnICLR 2023
NeurIPS 2022 Workshop on Distribution Shifts
NeurIPS 2022 I Can’t Believe It’s Not Better Workshop [abstract] [paper]
TLDR: The best layer to fine-tune reflects the nature of the distribution shift.
A common approach to transfer learning under distribution shift is to fine-tune the last few layers of a pre-trained model, preserving learned features while also adapting to the new task. This paper shows that in such settings, selectively fine-tuning a subset of layers (which we term surgical fine-tuning) matches or outperforms commonly used fine-tuning approaches. Moreover, the type of distribution shift influences which subset is more effective to tune: for example, for image corruptions, fine-tuning only the first few layers works best. We validate our findings systematically across seven real-world data tasks spanning three types of distribution shifts. Theoretically, we prove that for two-layer neural networks in an idealized setting, first-layer tuning can outperform fine-tuning all layers. Intuitively, fine-tuning more parameters on a small target dataset can cause information learned during pre-training to be forgotten, and the relevant information depends on the type of shift.
- ICLRDiversify and Disambiguate: Out-of-Distribution Robustness via DisagreementYoonho Lee, Huaxiu Yao, Chelsea FinnICLR 2023
ICML 2022 Principles of Distribution Shift Workshop
ICML 2022 Spurious Correlations, Invariance, and Stability Workshop [abstract] [paper] [website] [code]
TLDR: Given underspecified data, (1) find a diverse set of solutions and (2) choose the best one.
Many datasets are underspecified: there exist multiple equally viable solutions to a given task. Underspecification can be problematic for methods that learn a single hypothesis because different functions that achieve low training loss can focus on different predictive features and thus produce widely varying predictions on out-of-distribution data. We propose DivDis, a simple two-stage framework that first learns a diverse collection of hypotheses for a task by leveraging unlabeled data from the test distribution. We then disambiguate by selecting one of the discovered hypotheses using minimal additional supervision, in the form of additional labels or inspection of function visualization. We demonstrate the ability of DivDis to find hypotheses that use robust features in image classification and natural language processing problems with underspecification.
- NeurIPS-WRelaxing the Kolmogorov Structure Function for Realistic Computational ConstraintsYoonho Lee, Chelsea Finn, Stefano ErmonNeurIPS 2022 Workshop on Information-Theoretic Principles in Cognitive Systems [abstract] [paper]
TLDR: An efficient relaxation of the Kolmogorov Structure Function that can leverage neural networks.
The degree to which a task is learnable given different computational constraints shows the amount of usable information at different scales. An instantiation of this idea is the Kolmogorov Structure Function (KSF), which shows how the fit of an optimal k-bit description of a given string improves for increasing values of k. While conceptually appealing, computing the KSF is infeasible in practice due to the exponentially large search space of all descriptions of a given length, in addition to the unbounded time complexity. This paper proposes the Constrained Structure Function (CSF), a generalization of the KSF that can be computed efficiently by taking into account realistic computational constraints. In addition to being feasible to compute, the CSF of a dataset can be expressed as the sum of datapoint-wise functions which reflect the degree to which each datapoint is typical in the context of the dataset. Empirically, we demonstrate that the CSF can be used for detecting individual datapoints with characteristics such as being easy, mislabeled, or belonging to a hidden subgroup.
- NeurIPSWild-Time: A Benchmark of in-the-Wild Distribution Shift over TimeHuaxiu Yao*, Caroline Choi*, Bochuan Cao, Yoonho Lee, Pang Wei Koh, Chelsea FinnNeurIPS 2022 Datasets & Benchmarks Track
ICML 2022 Shift Happens Workshop [abstract] [paper] [code]
TLDR: A benchmark of distribution shifts over time.
Distribution shifts occur when the test distribution differs from the training distribution, and can considerably degrade performance of machine learning models deployed in the real world. While recent works have studied robustness to distribution shifts, distribution shifts arising from the passage of time have the additional structure of timestamp metadata. Real-world examples of such shifts are underexplored, and it is unclear whether existing models can leverage trends in past distribution shifts to reliably extrapolate into the future. To address this gap, we curate Wild-Time, a benchmark of 7 datasets that reflect temporal distribution shifts arising in a variety of real-world applications. On these datasets, we systematically benchmark 9 approaches with various inductive biases. Our experiments demonstrate that existing methods are limited in tackling temporal distribution shift: across all settings, we observe an average performance drop of 21% from in-distribution to out-of-distribution data.
- NeurIPSOn Divergence Measures for Bayesian PseudocoresetsBalhae Kim, Jungwon Choi, Seanie Lee, Yoonho Lee, Jung-Woo Ha, Juho LeeNeurIPS 2022 [abstract] [paper] [code]
TLDR: An exploration of the choice of divergence for learning a Bayesian pseudocoreset.
A Bayesian pseudocoreset is a small synthetic dataset for which the posterior over parameters approximates that of the original dataset. While promising, the scalability of Bayesian pseudocoresets is not yet validated in large-scale problems such as image classification with deep neural networks. On the other hand, dataset distillation methods similarly construct a small dataset such that the optimization with the synthetic dataset converges to a solution similar to optimization with full data. Although dataset distillation has been empirically verified in large-scale settings, the framework is restricted to point estimates, and their adaptation to Bayesian inference has not been explored. This paper casts two representative dataset distillation algorithms as approximations to methods for constructing pseudocoresets by minimizing specific divergence measures: reverse KL divergence and Wasserstein distance. Furthermore, we provide a unifying view of such divergence measures in Bayesian pseudocoreset construction. Finally, we propose a novel Bayesian pseudocoreset algorithm based on minimizing forward KL divergence. Our empirical results demonstrate that the pseudocoresets constructed from these methods reflect the true posterior even in large-scale Bayesian inference problems.
- EntropyDiscrete Infomax Codes for Supervised Representation LearningYoonho Lee, Wonjae Kim, Wonpyo Park, Seungjin ChoiEntropy Special Issue "Theory and Applications of Information Processing Algorithms" [abstract] [paper]
TLDR: Regularizing few-shot classification using compact discrete codes.
Learning compact discrete representations of data is a key task on its own or for facilitating subsequent processing of data. In this paper we present a model that produces Discrete InfoMax COdes (DIMCO); we learn a probabilistic encoder that yields k-way d-dimensional codes associated with input data. Our model’s learning objective is to maximize the mutual information between codes and labels with a regularization, which enforces entries of a codeword to be as independent as possible. We show that the infomax principle also justiﬁes previous loss functions (e.g., cross-entropy) as its special cases. Our analysis also shows that using shorter codes, as DIMCO does, reduces overﬁtting in the conext of few-shot classiﬁcation. Through experiments in various domains, we observe this implicit meta-regularization effect of DIMCO. Furthermore, we show that the codes learned by DIMCO are efﬁcient in terms of both memory and retrieval time compared to previous methods.
- NeurIPSDiversity Matters When Learning From EnsemblesGiung Nam*, Jongmin Yoon*, Yoonho Lee, Juho LeeNeurIPS 2021 [abstract] [paper] [code]
TLDR: To distill from deep ensembles, use inputs that ensemble members disagree on.
Deep ensembles excel in large-scale image classification tasks both in terms of prediction accuracy and calibration. Despite being simple to train, the computation and memory cost of deep ensembles limits their practicability. While some recent works propose to distill an ensemble model into a single model to reduce such costs, there is still a performance gap between the ensemble and distilled models. We propose a simple approach for reducing this gap, i.e., making the distilled performance close to the full ensemble. Our key assumption is that a distilled model should absorb as much function diversity inside the ensemble as possible. We first empirically show that the typical distillation procedure does not effectively transfer such diversity, especially for complex models that achieve near-zero training error. To fix this, we propose an augmentation-based distillation strategy that reveals diversity by seeking inputs for which ensemble member outputs disagree. We empirically show that a model distilled with such augmented samples indeed exhibits enhanced diversity, leading to improved performance.
- Amortized Probabilistic Detection of Communities in GraphsYueqi Wang*, Yoonho Lee*, Pallab Basu, Juho Lee, Yee Whye Teh, Liam Paninski, Ari Pakman[abstract] [paper] [code]
TLDR: An attention-based method for probabilistically detecting communities within graphs.
Learning community structures in graphs has broad applications across scientific domains. While graph neural networks (GNNs) have been successful in encoding graph structures, existing GNN-based methods for community detection are limited by requiring knowledge of the number of communities in advance, in addition to lacking a proper probabilistic formulation to handle uncertainty. We propose a simple framework for amortized community detection, which addresses both of these issues by combining the expressive power of GNNs with recent methods for amortized clustering. Our models consist of a graph representation backbone that extracts structural information and an amortized clustering network that naturally handles variable numbers of clusters. Both components combine into well-defined models of the posterior distribution of graph communities and are jointly optimized given labeled graphs. At inference time, the models yield parallel samples from the posterior of community labels, quantifying uncertainty in a principled way. We evaluate several models from our framework on synthetic and real datasets and demonstrate superior performance to previous methods. As a separate contribution, we extend recent amortized probabilistic clustering architectures by adding attention modules, which yield further improvements on community detection tasks.
- UAIOn the Distribution of Penultimate Activations of Classification NetworksMinkyo Seo*, Yoonho Lee*, Suha KwakUAI 2021 [abstract] [paper]
TLDR: Final FC layer weights contain information about class relations.
This paper studies the probability distributions of penultimate activations of classification networks. Specifically, we show that, when a classification network is trained with the cross-entropy loss, its final classification layer forms a Generative-Discriminative pair with a generative classifier based on a specific distribution of penultimate activations. More importantly, the distribution is parameterized by the weights of the final fully-connected layer, and can be considered as a generative model that synthesizes the penultimate activations without feeding input data. We empirically demonstrate that this generative model enables stable knowledge distillation in the presence of domain shift, and can also transfer knowledge from a classifier to variational autoencoders and generative adversarial networks for class-conditional image generation.
- NeurIPSBootstrapping Neural ProcessesJuho Lee*, Yoonho Lee*, Jungtaek Kim, Eunho Yang, Sung Ju Hwang, Yee Whye TehNeurIPS 2020 [abstract] [paper] [video] [code]
TLDR: Improved uncertainty estimates in Neural Processes using bootstrapping.
Unlike in the traditional statistical modeling for which a user typically hand-specify a prior, Neural Processes (NPs) implicitly define a broad class of stochastic processes with neural networks. Given a data stream, NP learns a stochastic process that best describes the data. While this "data-driven" way of learning stochastic processes has proven to handle various types of data, NPs still rely on an assumption that uncertainty in stochastic processes is modeled by a single latent variable, which potentially limits the flexibility. To this end, we propose the Boostrapping Neural Process (BNP), a novel extension of the NP family using the bootstrap. The bootstrap is a classical data-driven technique for estimating uncertainty, which allows BNP to learn the stochasticity in NPs without assuming a particular form. We demonstrate the efficacy of BNP on various types of data and its robustness in the presence of model-data mismatch.
- NeurIPSNeural Complexity MeasuresYoonho Lee, Juho Lee, Sung Ju Hwang, Eunho Yang, Seungjin ChoiNeurIPS 2020 [abstract] [paper] [blog] [video] [code]
TLDR: A meta-learning framework for predicting generalization.
While various complexity measures for deep neural networks exist, specifying an appropriate measure capable of predicting and explaining generalization in deep networks has proven challenging. We propose Neural Complexity (NC), a meta-learning framework for predicting generalization. Our model learns a scalar complexity measure through interactions with many heterogeneous tasks in a data-driven way. The trained NC model can be added to the standard training loss to regularize any task learner in a standard supervised learning scenario. We contrast NC’s approach against existing manually-designed complexity measures and other meta-learning models, and we validate NC’s performance on multiple regression and classification tasks.
- NeurIPS-WDeep Amortized ClusteringJuho Lee, Yoonho Lee, Yee Whye TehNeurIPS 2019 Sets and Parts Workshop (oral) [abstract] [paper]
TLDR: Learning to cluster by identifying one cluster at a time.
We propose Deep Amortized Clustering (DAC), a framework in which a neural network learns to cluster datasets efficiently using a few forward passes through a deep neural network. DAC implicitly learns what makes a cluster, how to group data points into clusters, and how to count the number of clusters in datasets. DAC is meta-learned in a data-driven way, using only clustered datasets and their partitions. This framework differs from traditional clustering algorithms, which usually require user-specified prior knowledge about the shape or structure of clusters. We empirically show on both synthetic and image data that DAC can efficiently and accurately cluster novel datasets.
- NeurIPSLearning Dynamics of Attention: Human Prior for Interpretable Machine ReasoningWonjae Kim, Yoonho LeeNeurIPS 2019 [abstract] [paper] [code]
TLDR: Smooth and interpretable attention using Neural ODEs.
Without relevant human priors, neural networks may learn uninterpretable features. We propose Dynamics of Attention for Focus Transition (DAFT) as a human prior for machine reasoning. DAFT is a novel method that regularizes attention-based reasoning by modelling it as a continuous dynamical system using neural ordinary differential equations. As a proof of concept, we augment a state-of-the-art visual reasoning model with DAFT. Our experiments reveal that applying DAFT yields similar performance to the original model while using fewer reasoning steps, showing that it implicitly learns to skip unnecessary steps. We also propose a new metric, Total Length of Transition (TLT), which represents the effective reasoning step size by quantifying how much a given model’s focus drifts while reasoning about a question. We show that adding DAFT results in lower TLT, demonstrating that our method indeed obeys the human prior towards shorter reasoning paths in addition to producing more interpretable attention maps.
- ICMLSet Transformer: A Framework for Attention-based Permutation-Invariant Neural NetworksJuho Lee, Yoonho Lee, Jungtaek Kim, Adam Kosiorek, Seungjin Choi, Yee Whye TehICML 2019 [abstract] [paper] [code]
TLDR: Self-attention for sets using inducing points. O(N) feedforward complexity.
Many machine learning tasks such as multiple instance learning, 3D shape recognition, and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the order of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces the computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating the state-of-the-art performance compared to recent methods for set-structured data.
- ICMLGradient-based Meta-learning with Learned Layerwise Metric and SubspaceYoonho Lee, Seungjin ChoiICML 2018 [abstract] [paper] [video] [code]
TLDR: Improving MAML by fixing some weights during task adaptation.
Gradient-based meta-learning methods leverage gradient descent to learn the commonalities among various tasks. While previous such methods have been successful in meta-learning tasks, they resort to simple gradient descent during meta-testing. Our primary contribution is the MT-net, which enables the meta-learner to learn on each layer’s activation space a subspace that the task-specific learner performs gradient descent on. Additionally, a task-specific learner of an MT-net performs gradient descent with respect to a meta-learned distance metric, which warps the activation space to be more sensitive to task identity. We demonstrate that the dimension of this learned subspace reflects the complexity of the task-specific learner’s adaptation task, and also that our model is less sensitive to the choice of initial learning rates than previous gradient-based meta-learning methods. Our method achieves state-of-the-art or comparable performance on few-shot classification and regression tasks.