Transformer Models: Not Just for Text Anymore
Zian (Andy) Wang
Since their debut in 2017 with the influential paper "Attention is All You Need," Transformers, or self-attention-based deep learning models, have made waves throughout the machine learning community. The world of natural language processing (NLP) has been particularly transformed, as models like GPT and BERT demonstrate remarkable proficiency in tasks ranging from language generation to sentiment analysis.
The recent excitement surrounding GPT-4 and Google’s Bard further emphasizes the impact of Transformer-based models. While these achievements are noteworthy, Transformers have also extended their reach into other machine learning domains.
Computer vision is another field where Transformers have made a significant impact. The innovative architecture, with its self-attention mechanisms and residual connections, has transcended the boundaries of NLP and brought about a new approach to analyzing and processing visual data. Vision Transformers (ViT) and DETR are now emerging as strong competitors in image classification, object detection, and segmentation tasks.
Furthermore, diffusion models combine NLP with computer vision, generating images prompted by humans. Models such as MidJourney and Stable Diffusion have made a considerable stir not only in the ML community but also among artists and photographers.
The adaptable nature of the Transformer architecture, owing to its ingenious self-attention and residual stream design, allows for its application in a multitude of tasks with promising performance and research potential. In this article, we will explore some lesser-known applications of Transformer models, which have been overshadowed by the NLP breakthroughs and the acclaim for ViT and diffusion models.
Transformers for Tabular Data
Compared to images, NLP tasks more closely resemble tabular data. Since each token is represented by a 1-D array with no spatial patterns in a second or a third dimension. There’s nothing special about self-attention and sequence modeling, self-attention is a generic algorithm that models relationships between various vectors. In a sense, self-attention “frees” the contiguous schematics of NLP by directly relating information in a non-sequential manner (since tokens later on in the sequence can correlate to earlier tokens).
Therefore, we can consider each token as a feature in a tabular dataset, and the embedding dimension of the token as the number of samples fed into the model. To adapt the model for tabular data, the number of tokens is fixed, unlike in Transformers for language encoding where each token is a single "row" of data. The embedding dimension in this case depends on the number of samples to be processed. We can then attach a classification, regression, or any linear layer to post-process the raw, sequentially-shaped output to fit the dataset labels.
The method described would be the most basic application of Transformer models to tabular data.
The TabTransformer, proposed in the paper “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” from Xin Huang et al., employs a somewhat similar approach to applying Transformer-based models to tabular data without heavily modifying the original Transformer architecture and pipeline. Nonetheless, a slight modification is required in encoding categorical features.
The sparseness of such features could potentially harm– or at least create difficulties in– modeling, but this is not to say that they have no potential to convey critical information. The authors of the TabTransformer overcome this by using column embeddings on categorical data, much in the same way one-hot encoded language tokens are transformed into contextual embeddings.
Here are the nitty gritty details of what’s going on under the hood:
Each categorical feature in a data sample is embedded into a vector of shape (d, 1), where d is a predetermined embedding dimension. This produces an embedding matrix of shape (n, c, d), where n is the number of samples, c is the number of categorical features, and d is the embedding dimension. These embedding vectors are then fed into a decoder-only Transformer.
Note, this is all done without the use of continuous features. The output of the Transformer is then concatenated with the continuous features and fed into an additional MLP to generate the final predictions.
In other words, categorical features are embedded into the latent feature space.
The Self-Attention and Intersample Transformer (SAINT) model adopts a similar embedding technique compared to the TabTransformer. However, instead of embedding only the categorical features, both continuous and categorical features are embedded. SAINT also includes the addition of an intersample self-attention block after the normal self-attention is applied. The intersample self-attention relies on the correlation between different data points, or “rows” of the dataset instead of finding relationships between the columns.
The intersample self-attention functions exactly how it sounds: it performs computation across the entire batch by treating each sample as a “token” in the attention matrix while the number of samples in the batch is the sequence length. The authors observed slight improvements to multiple datasets over models that only used self-attention.
Other than models utilizing self-attention and the Transformer architecture straight out-of-the-box, there have been many papers focusing on creating attention modules specifically designed for tabular data. For example, TabNet, which performs automatic feature selection of continuous and categorical features by using an attention mechanism based on the sparsemax function.
TabNet became a popular choice in Kaggle competitions after its first success in the MoA competition. On the other hand, ARM-Net emphasizes cross-feature correlations by transforming features into an exponential space and determining an “interaction order” and “interaction weight” using a sparse attention mechanism. The potential of attention-based and Transformer-based tabular data modeling is fascinating. In a way, LLMs or Transformers used for NLP in general inspired researchers to analyze and decode tabular data in a similar manner.
Transformers for Recommender Systems
As it turns out, transformer models can be readily adapted to recommendation systems. This is because the features, or the inputs to recommender systems are typically already in a sequential format. For example, when modeling user behaviors, systems can track the placement of the cursor or the items visited on the website in a chronological order. Traditional recommendation algorithms such as collaborative filtering typically ignore the hidden sequential dynamics of user interactions. Fortunately, as we know, Transformers are designed to capture sequential patterns.
One example of this adaptation is the BERT4Rec model proposed by Sun et al. (2019), which employs the BERT architecture to build a powerful sequential recommendation system. BERT4Rec processes user-item interaction sequences as input, leveraging the masked self-attention mechanism to capture both short-term and long-term patterns. During training, a percentage of items in the input sequence are randomly masked, and the model is trained to predict these masked items based on the surrounding context.
After pre-training, BERT4Rec is fine-tuned on the recommendation task, predicting the next item in the sequence given the user's interaction history. This allows the model to learn better item embeddings and understand the underlying sequential patterns, resulting in more accurate and personalized recommendations compared to traditional methods like collaborative filtering.
More recently, Nvidia’s Transformers4Rec library enables users to implement a Transformer-based recommender system with models available on Hugging Face. The library provides fully customizable architectures while keeping the implementation simple and striaght-forward.
“Classical” Recommender System approaches lend themselves more into computational based algorithms rather than “full-on” ML models. The emergence of Transformers is turning out to be a pivotal change in the world of Recommender Systems, as the original intended input space of Transformers fits almost perfectly with the problem space characteristic of most Recommendation Systems.
Transformers For Reinforcement Learning
Another surprising use of Transformer-based models has been in the field of Reinforcement Learning (RL).Researchers are able to fully integrate Transformer-based models into the RL pipeline.
For those unfamiliar, reinforcement learning usually involves an agent (imagine a player in a video game) responding, or taking actions based on the states (think of the attributes of the player at a specific time) that it's in, and observing the environment around it. The environment then changes based on the actions that the agent performs. The cycle of action and response repeats until the game ends or some goal is reached.
Some of the early work to incorporate Transformers into RL involves reasoning over observations that contain multiple entities. Any complex Reinforcement Learning environment may contain multiple “things,” in one set of observations where each individual thing (think of the different cars driving on a busy road) has its own unique properties. The paper “Relational deep reinforcement learning” from Zambaldi et al. utilizes self-attention to capture relationships between entities, which was later implemented in StarCraft II, a 3-dimensional, multi-agent science fiction video game.
The intended sequence-to-sequence nature of Transformers allows them to be directly applied to temporal sequences. In other words, the Transformer model acts as a “memory” of the agent, and every action taken by the agent is based on a series of past states, inputted into the network as a chronological sequence. The paper “A Simple Neural Attentive Learner” proved that vanilla Transformers is not able to accurately model temporal sequences. Gated Transformer-XL was the first model capable of accurately capturing temporal sequences.
The Gated Transformer-XL modifies the Transformer-XL architecture by adding a skip connection from the input to the output layer. The authors showed that the modification improves stability in the beginning stages of training.
TrMRL (Transformers for Meta-Reinforcement Learning), a meta-RL agent, uses transformer architecture to create an episodic memory by linking recent memories, aiding in contextualizing the policy (a function that outputs actions based on states). Its self-attention mechanism generates consensus representation at each layer, minimizing Bayes Risk and enhancing action computation. Demonstrating comparable or superior performance, sample efficiency, and out-of-distribution generalization, TrMRL outperforms many meta-RL baselines.
Using transformer models with reinforcement learning offers multiple advantages. They excel at handling long-term dependencies, making them suitable for reinforcement learning tasks where actions and states are interconnected over time. Their scalability and support for parallel computing, thanks to the self-attention mechanism, ensure efficient handling of large state-action spaces and faster learning.
Transfer learning, as exemplified by models like TrMRL, allows knowledge from one task to enhance performance on related tasks. Additionally, transformers' inherent interpretability offers insights into the decision-making process of learning agents. In short, this combination leads to efficient, generalizable, strategic, and interpretable AI learning.
Self-Attention: Is It Really All You Need?
We’ve seen the potential and the current power of Transformer-based models, or models incorporating self-attention in general. With the introduction of GPT-4 and its impressive multi-modal abilities to process image and text, the research on Transformer-based models is growing by the day. However, it’s important to remember that self-attention is NOT all you need.
For example, in ViTs, competition is computed globally and in a 1-dimensional context. In contrast, most CNNs perform computations locally, whereby filters move across images and only focus on a particular area at any given time, dictated by the size of the convolution kernel. However, the locality of CNNs does present advantages over typical fully-connected neural networks in a number of respects. Firstly, it is much computationally cheaper to adopt convolution layers for images than fully-connected networks. Second, the two dimensional nature of convolutions allow them to process input in a way that is more similar to how we humans process visuals. Fully-connected networks fail to capture the same spatial connections.
In the paper “TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation”, the authors report the finding that a hybrid CNN-Transformer encoder performs better than a pure Transformer encoder . This suggests that Transformers alone can benefit from the addition of convolution blocks for image segmentation.
In a more direct example, the paper from Yihe Dong et al. “Attention is not all you need: pure attention loses rank doubly exponentially with depth”, they directly analyze NLP Transformers and present the potential harm that self-attention can inflict on models. The authors discover that self-attention blocks pose a strong inductive bias to produce uniformed token outputs. Specifically, without the help of residual connections and MLP blocks, self-attention layers converge exponentially to outputting a rank-1 matrix as the network depth increases. Essentially, attention layers may limit the expressive power of the model and result in a loss of information as the network becomes deeper.
Transformers have lent major contributions to machine learning, and we are only beginning to realize their full potential. They are not, however, the end-all-be-all. Although they can be adapted to a wide-- and growing-- variety of tasks, self-attention modules are no magic bullet that can be tacked on to any model. Let this article be a caution against losing ourselves in the hype of whatever's newest and flashiest. In short, self-attention is not all you need.