Last year, Meta’s Chief AI Scientist Yann LeCun proposed a new architecture intended to overcome key limitations of even the most advanced AI systems today. His vision is to create machines that can learn internal models of how the world works so that they can learn much more quickly, plan how to accomplish complex tasks, and readily adapt to unfamiliar situations.
We’re excited to introduce the first AI model based on a key component of LeCun’s vision. This model, the Image Joint Embedding Predictive Architecture (I-JEPA), learns by creating an internal model of the outside world, which compares abstract representations of images (rather than comparing the pixels themselves). I-JEPA delivers strong performance on multiple computer vision tasks, and it’s much more computationally efficient than other widely used computer vision models. The representations learned by I-JEPA can also be used for many different applications without needing extensive fine tuning. For example, we train a 632M parameter visual transformer model using 16 A100 GPUs in under 72 hours, and it achieves state-of-the-art performance for low-shot classification on ImageNet, with only 12 labeled examples per class. Other methods typically take two to 10 times more GPU-hours and achieve worse error rates when trained with the same amount of data.
Our paper on I-JEPA will be presented at CVPR 2023 next week, and we’re also open-sourcing the training code and model checkpoints today.
Capturing common-sense knowledge with self-supervised learning
Our work on I-JEPA (and Joint Embedding Predictive Architecture (JEPA) models more generally) is grounded in the fact that humans learn an enormous amount of background knowledge about the world just by passively observing it. It has been hypothesized that this common sense information is key to enable intelligent behavior such as sample-efficient acquisition of new concepts, grounding, and planning.
AI researchers have tried to devise learning algorithms that capture common sense background knowledge about the world and then encode it into a digital representation the algorithm can access later. To be effective, the system must learn these representations in a self-supervised manner – that is to say, directly from unlabeled data such as images or sounds, rather than from manually assembled labeled datasets.
At a high level, the JEPA aims to predict the representation of part of an input (such as an image or piece of text) from the representation of other parts of the same input. Because it does not involve collapsing representations from multiple views/augmentations of an image to a single point, the hope is for the JEPA to avoid the biases and issues associated with another widely used method called invariance-based pretraining.
At the same time, by predicting representations at a high level of abstraction rather than predicting pixel values directly, the hope is to learn directly useful representations that also avoid the limitations of generative approaches, which underlie the large language models that have generated so much recent excitement.
In contrast, generative architectures learn by removing or distorting portions of the input to the model – for example, erasing part of a photo or hiding some of the words in a text passage. They then try to predict the corrupted or missing pixels or words. One significant shortcoming of generative methods, however, is that the model tries to fill-in every bit of missing information, even though the world is inherently unpredictable. As a result, generative methods may be prone to mistakes a person would never make because they focus too much on irrelevant details instead of capturing high-level predictable concepts. For example, it is notoriously difficult for generative models to generate human hands accurately. (They often add extra digits or make other glaring errors.)
Common architectures for self-supervised learning, in which the system learns to capture the relationships between its inputs. The objective is to assign a high energy to incompatible inputs, and to assign a low energy to compatible inputs. (a) Joint-Embedding (invariant) Architectures learn to output similar embeddings for compatible inputs x, y and dissimilar embeddings for incompatible inputs. (b) Generative Architectures learn to directly reconstruct a signal y from a compatible signal x, using a decoder network that is conditioned on additional (possibly latent) variables z to facilitate reconstruction. (c) Joint-Embedding Predictive Architectures learn to predict the embeddings of a signal y from a compatible signal x, using a predictor network that is conditioned on additional (possibly latent) variables z to facilitate prediction.
A first step toward a broadly capable joint-embedding predictive architecture
The idea behind I-JEPA is to predict missing information in an abstract representation that’s more akin to the general understanding people have. Compared to generative methods that predict in pixel/token space, I-JEPA uses abstract prediction targets for which unnecessary pixel-level details are potentially eliminated, thereby leading the model to learn more semantic features. Another core design choice to guide I-JEPA towards producing semantic representations is the proposed multi-block masking strategy. Specifically, we demonstrate the importance of predicting large blocks containing semantic information (with sufficiently large scale), using an informative (spatially distributed) context.
The Image-based Joint-Embedding Predictive Architecture (I-JEPA) uses a single context block to predict the representations of various target blocks originating from the same image. The context encoder is a Vision Transformer (ViT) that only processes the visible context patches. The predictor is a narrow ViT that takes the context encoder output and predicts the representations of a target block at a specific location, conditioned on positional tokens of the target (shown in color). The target representations correspond to the outputs of the target-encoder, the weights of which are updated at each iteration via an exponential moving average of the context encoder weights.
The predictor in I-JEPA can be seen as a primitive (and restricted) world-model that’s able to model spatial uncertainty in a static image from a partially observable context. What’s more, this world model is semantic in the sense that it predicts high-level information about unseen regions in the image, rather than pixel-level details.
Illustrating how the predictor learns to model the semantics of the world. For each image, the portion outside of the blue box is encoded and given to the predictor as context. The predictor outputs a representation for what it expects to be in the region within the blue box. To visualize the prediction, we train a generative model that produces a sketch of the contents represented by the predictor output, and we show a sample output within the blue box. Clearly the predictor recognizes the semantics of what parts should be filled in (the top of the dog’s head, the bird’s leg, the wolf’s legs, the other side of the building).
To understand what the model is capturing, we trained a stochastic decoder that maps the I-JEPA predicted representations back in pixel space, which shows the model’s outputs when probed to make predictions inside the blue box. This qualitative evaluation shows that the model correctly captures positional uncertainty and produces high-level object parts with the correct pose (e.g., dog’s head, wolf’s front legs). In short, I-JEPA is able to learn high-level representations of object parts without discarding their localized positional information in the image.
Greater efficiency and strong performance
I-JEPA pretraining is also computationally efficient. It doesn’t involve any overhead associated with applying more computationally intensive data augmentations to produce multiple views. Only one view of the image needs to be processed by the target encoder, and only the context blocks need to be processed by the context encoder.
Empirically, we find that I-JEPA learns strong off-the-shelf semantic representations without the use of hand-crafted view augmentations – see the figure below. It also outperforms pixel and token-reconstruction methods on ImageNet-1K linear probing and semi-supervised evaluation.
Linear evaluation performance on ImageNet-1k as a function of GPU hours for pretraining.
I-JEPA is also competitive with previous pretraining approaches that rely on hand-crafted data augmentations on semantic tasks. Compared to these methods, I-JEPA achieves better performance on low-level vision tasks such as object counting and depth prediction. By using a simpler model with less rigid inductive bias, I-JEPA is applicable to a wider set of tasks.
Low Shot Classification Accuracy: Semi-supervised evaluation on ImageNet-1k with 1% of the labels (approximately 12 labeled images per class).
A step closer to human-level intelligence in AI
I-JEPA demonstrates the potential of architectures for learning competitive off-the-shelf image representations without the need for extra knowledge encoded through hand-crafted image transformations. It would be particularly interesting to advance JEPAs to learn more general world-models from richer modalities, e.g., enabling one to make long-range spatial and temporal predictions about future events in a video from a short context, and conditioning these predictions on audio or textual prompts.
We look forward to working to extend the JEPA approach to other domains, like image-text paired data and video data. In the future, JEPA models could have exciting applications for tasks like video understanding. This is an important step towards applying and scaling self-supervised methods for learning a general model of the world.