An Image Is Worth 1×16×16 Words (2024)

111Equal contribution.222Research supporting this publication conducted while authors were employed at Insitro.

Yujia Bao*absent{}^{*\dagger}start_FLOATSUPERSCRIPT * † end_FLOATSUPERSCRIPT
Accenture
yujia.bao@accenture.com
&Srinivasan Sivanandan*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT
Insitro
srinivasan@insitro.com
\ANDTheofanis Karaletsos{}^{\dagger}start_FLOATSUPERSCRIPT † end_FLOATSUPERSCRIPT
Chan Zuckerberg Initiative
theofanis@karaletsos.com

Abstract

Vision Transformer (ViT) has emerged as a powerful architecture in the realm of modern computer vision. However, its application in certain imaging fields, such as microscopy and satellite imaging, presents unique challenges. In these domains, images often contain multiple channels, each carrying semantically distinct and independent information. Furthermore, the model must demonstrate robustness to sparsity in input channels, as they may not be densely available during training or testing. In this paper, we propose a modification to the ViT architecture that enhances reasoning across the input channels and introduce Hierarchical Channel Sampling (HCS) as an additional regularization technique to ensure robustness when only partial channels are presented during test time. Our proposed model, ChannelViT, constructs patch tokens independently from each input channel and utilizes a learnable channel embedding that is added to the patch tokens, similar to positional embeddings. We evaluate the performance of ChannelViT on ImageNet, JUMP-CP (microscopy cell imaging), and So2Sat (satellite imaging). Our results show that ChannelViT outperforms ViT on classification tasks and generalizes well, even when a subset of input channels is used during testing. Across our experiments, HCS proves to be a powerful regularizer, independent of the architecture employed, suggesting itself as a straightforward technique for robust ViT training. Lastly, we find that ChannelViT generalizes effectively even when there is limited access to all channels during training, highlighting its potential for multi-channel imaging under real-world conditions with sparse sensors.Our code is available at https://github.com/insitro/ChannelViT.

1 Introduction

Vision Transformers (ViT) have emerged as a crucial architecture in contemporary computer vision, significantly enhancing image analysis. However, application to specific imaging domains, such as microscopy and satellite imaging, poses unique challenges. Images in these fields often comprise multiple channels, each carrying semantically distinct and independent information. The complexity is further compounded by the fact that these input channels may not always be densely available during training or testing, necessitating a model capable of handling such sparsity.

In response to these challenges, we propose a modification to the ViT architecture that bolsters reasoning across the input channels. Our proposed model, ChannelViT, constructs patch tokens independently from each input channel and incorporates a learnable channel embedding that is added to the patch tokens in addition to the location-specific positional embedding. This simple modification enables the model to reason across both locations and channels. Furthermore, by treating the channel dimension as the patch sequence dimension, ChannelViT can seamlessly handle inputs with varying sets of channels.

Despite these advancements, two main challenges persist. While ChannelViT can leverage existing efficient implementations of ViT with minimal modifications, the increase in sequence length introduces additional computational requirements. Moreover, if ChannelViT is consistently trained on the same set of channels, its ability to generalize to unseen channel combinations at test time may be compromised. To address these challenges, we introduce Hierarchical Channel Sampling (HCS), a new regularization technique designed to improve robustness. Unlike channel dropout, which drops out each input channel independently, HCS uses a two-step sampling procedure. It first samples the number of channels and then, based on this, it samples the specific channel configurations. While channel dropout tends to allocate more distribution to combinations with a specific number of channels, HCS assigns a uniform weight to the selection of any number of channels. HCS consistently improves robustness when different channels are utilized during testing in both ViT and ChannelViT. Notably, our evaluation on ImageNet shows that using only the red channel, HCS can increase the validation accuracy from 29.39 to 68.86.

We further evaluate ChannelViT on two real world multi-channel imaging applications: microscopy cell imaging (JUMP-CP) and satellite imaging (So2Sat). In these applications, different channels often correspond to independent information sources. ChannelViT significantly outperforms its ViT counterpart in these datasets, underscoring the importance of reasoning across different channels. Moreover, by treating different channels as distinct input tokens, we demonstrate that ChannelViT can effectively generalize even when there is limited access to all channels in the dataset during training. Lastly, we show that ChannelViT enables additional insights. The learned channel embeddings correspond to meaningful interpretations, and the attention visualization highlights relevant features across spatial and spectral resolution, enhancing interpretability. This highlights the potential of ChannelViT for wide-ranging applications in the field of multi-channel imaging.

An Image Is Worth 1×16×16 Words (1)

2 Related work

Vision transformer and its applications to multi-channel imaging

Vision Transformer (ViT) has demonstrated state-of-the-art performance in various computer vision tasksDosovitskiy etal. ; Touvron etal. (2021); Carion etal. (2020); Zhu etal. (2020b).Recently, researchers have started adopting ViT for multi-spectral imaging. For example, in satellite imaging, Kaselimi etal. (2022) showed that a ViT-based classifier outperforms CNN models, especially on imbalanced classes. Additionally, Tarasiou etal. (2023) proposed acquisition-time-specific temporal positional encodings to model satellite images over time, while Cong etal. (2022) demonstrated the benefits of using distinct spectral positional encodings with ViT. Recently, Nguyen etal. (2023) proposed a modification to the ViT architecture, introducing variable tokenization and variable token aggregation methods to handle heterogeneous input data sources in climate and weather modeling. Moreover, Scheibenreif etal. (2022) found that ViT, when combined with self-supervised pre-training, performs on-par with state-of-the-art benchmarks.

In the field of cell biology, Sivanandan etal. (2023) utilized ViT with self-supervised pre-training to learn representations of cells across multiple fluorescence channels. Furthermore, Hatamizadeh etal. (2022a; b) leveraged ViT for segmenting 3D MRI images. Hussein etal. (2022) proposed to train multiple ViTs, one for each input channel, for epileptic seizure predictions.

In contrast to previous work, we address a practical challenge in multi-channel imaging, where different datasets often have different available channels.111For example (https://github.com/chrieke/awesome-satellite-imagery-datasets), satellite imaging often involves multiple signals such as Sentinel-1 (SAR), Sentinel-2, UAV, etc.To tackle this challenge, we propose ChannelViT, which creates image patches from each individual input channel. This simple modification unifies the modeling across data with different input channels and offers robust performance at test time, even when only a subset of the channels is available.

Robustness for Vision Transformer

Robustness can be defined in different ways. One aspect is the vulnerability to adversarial attacks. Mahmood etal. (2021) found that ViTs are as susceptible to white-box adversarial attacks as CNNs. To improve robustness, Robust ViT incorporates more robust components like global pooling (Mao etal., 2022). Additionally, Chefer etal. (2022) propose regularization of the relevancy map of ViT to enhance robustness. Zhou etal. (2022); Zhang etal. (2021); Song etal. (2022) augments transformers with feature-wise attention to improve robustness and performance. Another approach focuses on generalization over distribution shiftsSagawa etal. (2019); Liu etal. (2021). Bao & Karaletsos (2023) introduces a context token inferred from ViT’s hidden layers to encode group-specific information.

In our work, we specifically focus on improving the generalization performance across different channel combinations, which is a common scenario in multi-channel imaging. We argue that the original ViT is sensitive to changes in input channels, as it computes a single patch token across all channels. In contrast, ChannelViT creates separate patch tokens for each channel, making it inherently more robust to variations in channel availabilities. To further enhance channel robustness, we introduce hierarchical channel sampling (HCS) during training. This methodology draws inspiration from prior studies on channel dropout Srivastava etal. (2014); Tompson etal. (2015); Hou & Wang (2019). However, instead of dropping out intermediate channels, our approach introduces a two-stage sampling algorithm designed to selectively mask out the input channels.

3 Method

ChannelViT is a modification of the original Vision Transformer (ViT) architecture proposed by Dosovitskiy etal. . Unlike the original architecture, which condenses each multi-channel image patch into a single ‘word’ token, ChannelViT segregates channel-specific information into multiple tokens. This simple yet effective modification yields three key advantages:

  1. 1.

    ChannelViT facilitates reasoning across both positions and channels with Transformer;

  2. 2.

    By transforming the channel dimension into the sequence length dimension, ChannelViT can seamlessly manage inputs with varying sets of channels;

  3. 3.

    ChannelViT can utilize existing efficient implementations of ViT.

In the following paragraphs, we explore the architecture and implementation of ChannelViT in detail. Figure1 provides a visual overview of the model.

3.1 Channel Vision Transformer (ChannelViT)

Patch embeddings

Consider an input imagex𝑥xitalic_x with dimensions H×W×C𝐻𝑊𝐶H\times W\times Citalic_H × italic_W × italic_C. Given a patch size of P×P𝑃𝑃P\times Pitalic_P × italic_P, this image can be reshaped into a sequence of non-overlapping patches

[x[c1,p1],,x[c1,pN],x[c2,p1],,x[c2,pN],,x[cC,pN],,x[cC,pN]],\left[x[c_{1},p_{1}],\ldots,x[c_{1},p_{N}],\,x[c_{2},p_{1}],\ldots,x[c_{2},p_{%N}],\quad\ldots\quad,x[c_{C},p_{N}],\ldots,x[c_{C},p_{N}]\right],[ italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] , … , italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] , italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] , … , italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] , … , italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] , … , italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] ,

where x[ci,pn]𝑥subscript𝑐𝑖subscript𝑝𝑛x[c_{i},p_{n}]italic_x [ italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] corresponds to the n𝑛nitalic_n-th P×P𝑃𝑃P\times Pitalic_P × italic_P image patch at channel cisubscript𝑐𝑖c_{i}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and N=HW/P2𝑁𝐻𝑊superscript𝑃2N=HW/P^{2}italic_N = italic_H italic_W / italic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.As the Transformer encoder requires a sequence of one-dimensional vectors, each patch is flattened into a 1D vector. Unlike ViT, which generates a single token for a multi-channel image patch, ChannelViT produces one token from every single-channel image patch.

Tied image filters

We apply a learnable linear projection WP2×D𝑊superscriptsuperscript𝑃2𝐷W\in\mathbb{R}^{P^{2}\times D}italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT × italic_D end_POSTSUPERSCRIPT to the flattened patches.It is important to note that in a regular ViT, each channel has its own weights in the linear projection layer.In ChannelViT, our preliminary experiments suggest that tying the image filters across channels offer superior performance compared to untied image filters (AppendixC.3).Therefore, we tie the learnable projection W𝑊Witalic_W across channels.The intuition behind this is that the low-level image filters can be shared across channels(Ghiasi etal., 2022), and tying the parameters can improve the model’s robustness across channels.

Channel-aware and position-aware patch embeddings

Despite tying the linear filter across channels, it remains essential to preserve channel-specific information, given the distinct characteristics of different channels (AppendixC.4).We introduce learnable channel embeddings [𝚌𝚑𝚗1,,𝚌𝚑𝚗C]subscript𝚌𝚑𝚗1subscript𝚌𝚑𝚗𝐶[\texttt{chn}_{1},\ldots,\texttt{chn}_{C}][ chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , chn start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ], where 𝚌𝚑𝚗cDsubscript𝚌𝚑𝚗𝑐superscript𝐷\texttt{chn}_{c}\in\mathbb{R}^{D}chn start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.In line with the original ViT, we also incorporate learnable positional embeddings to maintain positional information of each patch. We denote the positional embeddings as [𝚙𝚘𝚜1,,𝚙𝚘𝚜N]subscript𝚙𝚘𝚜1subscript𝚙𝚘𝚜𝑁[\texttt{pos}_{1},\ldots,\texttt{pos}_{N}][ pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ], where 𝚙𝚘𝚜nDsubscript𝚙𝚘𝚜𝑛superscript𝐷\texttt{pos}_{n}\in\mathbb{R}^{D}pos start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. It’s worth noting that these position embeddings are also shared across channels, enabling ChannelViT to recognize the same image patch across different channels.Finally, we prepend a learnable classifier token 𝙲𝙻𝚂D𝙲𝙻𝚂superscript𝐷\texttt{CLS}\in\mathbb{R}^{D}CLS ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT to the sequence to encode global image features.The resulting input sequence can be written as

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+𝚌𝚑𝚗1+Wx[c1,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{1}+Wx[c_{1},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗1+Wx[c1,pN],subscript𝚙𝚘𝚜𝑁subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{1}+Wx[c_{1},p_{N}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
,\displaystyle\ldots,\quad… ,𝚙𝚘𝚜1+𝚌𝚑𝚗C+Wx[cC,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗𝐶𝑊𝑥subscript𝑐𝐶subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{C}+Wx[c_{C},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗C+Wx[cC,pN]].\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{C}+Wx[c_{C},p_{N}]%\big{]}.… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] .

Transformer encoder

Following the original VIT, we feed the above input sequence into a Transformer encoder, which captures dependencies between image patches by embedding each patch based on its similarity to othersVaswani etal. (2017). Specifically, the Transformer encoder comprises alternating layers of multiheaded self-attention blocks and MLP blocks. Layer normalization, as proposed by Ba etal. (2016), is performed before each block, and residual connectionsHe etal. (2016) are established after each block. We use the final layer representation of the CLS token to represent the input image.For classification tasks, a linear classifier is employed, followed by a Softmax function, to predict the corresponding label. We utilize the standard cross entropy loss as our training objective.

3.2 Hierarchical channel sampling (HCS)

Training ChannelViT directly presents two challenges: 1) The sequence length becomes proportional to the number of channels, leading to a quadratic surge in the number of attentions required for computation; 2) Training exclusively on all channels may result in the model not being prepared for partial channels at test time, thereby affecting its generalization capability. To mitigate these issues, we propose applying hierarchical channel sampling (HCS) during the training process. Specifically, for an image x𝑥xitalic_x with C𝐶Citalic_C channels, we proceed as follows:

  1. 1.

    First, we sample a random variable m𝑚mitalic_m uniformly from the set {1,2,,C}12𝐶\{1,2,\ldots,C\}{ 1 , 2 , … , italic_C }. This m𝑚mitalic_m represents the number of channels that we will utilize during this training step;

  2. 2.

    Next, we sample a channel combination 𝒞msubscript𝒞𝑚\mathcal{C}_{m}caligraphic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT uniformly from all channel combinations that consist of m𝑚mitalic_m channels;

  3. 3.

    Finally, we return the image with only the sampled channels x[𝒞m]𝑥delimited-[]subscript𝒞𝑚x[\mathcal{C}_{m}]italic_x [ caligraphic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ].

HCS shares similarity to channel dropoutTompson etal. (2015), but it differs in terms of the prior distribution imposed on the sampled channels. In channel dropout, each channel is dropped based on a given probability independently. The probability of having m𝑚mitalic_m channels varies drastically for different m𝑚mitalic_ms, which can negatively impact the final performance (Figure3). In contrast, since m𝑚mitalic_m is sampled uniformly over the total number of channels, HCS ensures that the sampling procedure equally covers each m𝑚mitalic_m.Finally, we note that HCS is only employed during training. At test time, ChannelViT has access to all input channels.

HCS can also be interpreted as simulating test-time distributions during training. Compared to group distributionally robust optimization(Sagawa etal., 2019), HCS minimizes the mean loss rather than the worst-case loss. This approach is logical when considering channel robustness, as having more channels will naturally enhance performance. We don’t want the model to over-focus on the worst-case loss, which typically corresponds to situations when we sample very few channels.

An Image Is Worth 1×16×16 Words (2)

4 Experiments

We evaluate ChannelViT across three image classification benchmarks: ImageNetDeng etal. (2009), JUMP-CPChandrasekaran etal. (2022), and So2SatZhu etal. (2019). In Figure2 (top), we illustrate the correlation among different input channels for each dataset. As observed, ImageNet exhibits a strong correlation among the three RGB channels. For JUMP-CP, while there is a strong correlation within the fluorescence channels and within the brightfield channels, there is minimal to no correlation between the brightfield and the fluorescence channels. A similar group structure among the channels is observed for So2Sat.Due to space constraints, our primary focus in the main paper is on the comparison between ViT and ChannelViT. For additional comparisons with MultiViT(Hussein etal., 2022), please refer to AppendixB.1. Comparisons with FANs(Zhou etal., 2022) can be found in AppendixB.2.

JUMP-CP

The JUMP-CP benchmark, established by the JUMP-Cell Painting Consortium, serves as a microscopy imaging standard. In alignment with the work of Chandrasekaran etal. (2022), we utilize the task of perturbation detection from cell images as a means to evaluate and compare the efficacy of various representation models. It is important to recognize that while perturbation detection is a valuable task, it is not the ultimate objective of cell imaging modeling; rather, it provides an interpretable metric for model assessment.The dataset includes a total of 160 perturbations. We focused on a compound perturbation plate ‘BR00116991’, which contains 127k training images, 45k validation images, and 45k testing images. Each cell image contains 8 channels, comprising both fluorescence information (first five channels) and brightfield information (last three channels).

So2Sat

This satellite imaging benchmark encompasses half a million image patches from Sentinel-1 and Sentinel-2 satellites, distributed across 42 global urban agglomerations. Each image patch incorporates 18 channels, with 8 originating from Sentinel-1 and the remaining 10 from Sentinel-2. The primary objective of this dataset is to facilitate the prediction of the climate zone for each respective image patch, with a total of 17 distinct climate zones being represented.

Implementation details

We utilize the Vision Transformer (ViT) implementation provided by Facebook Research222https://github.com/facebookresearch/dino/blob/main/vision_transformer.py. During training, we minimize the cross entropy loss. To ensure a fair comparison, both ViT and ChannelViT are subjected to identical optimization settings. These settings encompass the use of the Adam optimizer, a learning rate scheduler featuring linear warmup and cosine decay, and a cosine scheduler for the weight decay parameter. For a more detailed description of the hyper-parameter settings, we direct readers to the Appendix.

4.1 ImageNet

Backbone
Use hierarchical
channel sampling?
Val Acc.
on RGB
Val Acc.
on R-only
Val Acc.
on G-only
Val Acc.
on B-only
Models trained on three channels (RGB)
ViT-S/1671.4929.3933.7921.18
ViT-S/1673.0168.8669.7867.59
ChannelViT-S/1674.6469.9070.3068.48
Expert models trained on only one channel
ViT-S/16 (R-only)N/A70.04
ViT-S/16 (G-only)N/A70.61
ViT-S/16 (B-only)N/A69.47

Table1 showcases our results on ImageNet, using ViT small as the representation backbone and a patch size of 16 by 16. We observe that without applying hierarchical channel sampling, ViT-S/16 achieves a validation accuracy of 71.49 using all three channels but fails to generalize when only one channel is provided at test time. Simulating this test-time channel drop during training via hierarchical channel sampling (HCS) significantly improves performance. For instance, the validation accuracy for using only the red channel improves from 29.39 to 68.86, demonstrating the effectiveness of HCS as a regularizer for enforcing channel robustness.Lastly, while there is limited room for improvement due to the strong correlations among the input RGB channels, ChannelViT still consistently outperforms the corresponding ViT baseline (by 1.2 on average), narrowing the gap (1.300.481.300.481.30\rightarrow 0.481.30 → 0.48) to the expert models that are trained using only one channel.

4.2 JUMP-CP: microscopy cell imaging

ViT-S/16ChannelViT-S/16ViT-S/16ChannelViT-S/16ViT-S/8ChannelViT-S/8
Use hierarchical
channel sampling?
Training on 5 fluorescence channels
#channelsfor testing5 channels48.4153.4155.5156.7860.2960.03
4 channels0.8515.1343.5945.9448.8049.34
3 channels1.895.1233.1435.4537.1338.15
2 channels1.461.2225.2426.5727.4027.99
1 channel0.541.2520.4921.4321.3021.58
Training on all 8 channels (5 fluorescence channels & 3 brightfield channels)
#channelsfor testing8 channels52.0666.2256.8768.0966.4474.77
7 channels5.9141.0349.3561.0259.0168.42
6 channels1.8124.5742.3853.4551.2961.26
5 channels2.4614.2035.7845.5043.3953.05
4 channels2.388.5629.8437.3735.6043.87
3 channels2.705.6524.9429.6828.5934.19
2 channels2.633.2421.5423.7723.3225.73
1 channel3.002.0819.9220.8420.4121.20
An Image Is Worth 1×16×16 Words (3)

We present our results on the microscopy cell imaging benchmark, JUMP-CP, in Table2. This benchmark involves a 160-way classification task. Due to computational constraints, we utilize ViT-S as our representation backbone. We consider both the standard resolution with a patch size of 16x16 and a high-resolution model with a patch size of 8x8.

In the first part of our analysis, we train all models using only the five fluorescence channels and evaluate their performance on the test set under various input channel combinations. Our observations are as follows: 1) HCS significantly enhances the channel robustness for both ViT and ChannelViT; 2) High-resolution models consistently outperform their low-resolution counterparts; 3) With the exception of the 5-channel evaluation with a patch size of 8x8, ChannelViT consistently outperforms ViT.

In the latter part of our analysis, we utilize all available channels for training, which includes three additional brightfield channels for each image. For ViT, the high-resolution ViT-S/8 model improves from 60.29 to 66.44, demonstrating the importance of the additional brightfield information, while the improvement for ViT-S/16 is marginal (from 55.51 to 56.87). When focusing on ChannelViT, we observe a significant performance boost over its ViT counterpart. ChannelViT-S/16 outperforms ViT-S/16 by 11.22 (68.09 vs 56.87) and ChannelViT-S/8 outperforms ViT-S/8 by 8.33 (74.77 vs. 66.44). These improvements are consistent across different channel combinations. As we have seen in Figure2, fluorescence and brightfield channels provide distinct information. ChannelViT effectively reasons across channels, avoiding the need to collapse all information into a single token at the first layer, thereby enhancing performance.

Lastly, we delve into a comparative analysis between input channel dropout and hierarchical channel sampling, as depicted in Figure3. It is evident from our observations that the ViT model, when trained with HCS, consistently surpasses the performance of those trained with input channel dropout across all channel combinations. Furthermore, we discern a pronounced correlation between the performance of models trained with input channel dropout and the probability distribution of the number of channels sampled during training.

Combine fluorescence-only data and 8-channel data for training

% fluorescence-only data

100%75%50%25%0%

% 8-channel data

0%25%50%75%100%
Evaluating on 5 fluorescence channels
ViT-S/1655.5152.55±2.68plus-or-minus2.68\pm 2.68± 2.6851.65±2.14plus-or-minus2.14\pm 2.14± 2.1449.53±1.39plus-or-minus1.39\pm 1.39± 1.3945.75
ChannelViT-S/1656.7858.01±1.77plus-or-minus1.77\pm 1.77± 1.7758.19±1.49plus-or-minus1.49\pm 1.49± 1.4958.42±1.37plus-or-minus1.37\pm 1.37± 1.3757.60
Evaluating on all 8 channels
ViT-S/1650.29±1.93plus-or-minus1.93\pm 1.93± 1.9352.47±1.82plus-or-minus1.82\pm 1.82± 1.8254.64±1.01plus-or-minus1.01\pm 1.01± 1.0156.87
ChannelViT-S/1657.97±1.36plus-or-minus1.36\pm 1.36± 1.3661.88±0.91plus-or-minus0.91\pm 0.91± 0.9164.80±0.89plus-or-minus0.89\pm 0.89± 0.8968.09

Data Efficiency

In the realm of microscopy imaging, we often encounter situations where not all channels are available for every cell due to varying experiment guidelines and procedures. Despite this, the goal remains to develop a universal model capable of operating on inputs with differing channels. ChannelViT addresses this issue by treating different channels as distinct input tokens, making it particularly useful in scenarios where not all channels are available for all data.Table3 presents a scenario where varying proportions (0%, 25%, 50%, 75%, 100%) of the training data have access to all eight channels, with the remaining data only having access to the five fluorescence channels. The performance of ViT and ChannelViT is evaluated at test time using both the five fluorescence channels (top section) and all eight channels (bottom section).

Our observations are as follows: 1) When only a limited amount of 8-channel data (25%) is available, both ChannelViT and ViT show a decrease in performance when utilizing eight channels at test time compared to five channels; 2) As the availability of 8-channel data increases, the performance of the ViT baseline on the fluorescence evaluation steadily declines (from 55.51 to 45.75), while the performance of ChannelViT sees a slight improvement (from 56.78 to 57.60); 3) When evaluated on all eight channels, ChannelViT significantly outperforms ViT, with an average gap of 9.62.

Channel-specific attention visualization

Attention heatmaps, generated by Vision Transformers (ViTs), have emerged as a valuable tool for interpreting model decisions. For instance, Chefer etal. (2021) introduced a relevance computation method, which assigns local relevance based on the Deep Taylor Decomposition principle and subsequently propagates these relevance scores through the layers. However, a limitation of ViTs is their tendency to amalgamate information across different channels. In the realm of microscopy imaging, discerning the contribution of each fluorescence channel to the predictions is vital due to their distinct biological implications.

Figure4 (right) presents the class-specific relevance visualizations for ViT-S/8 and ChannelViT-S/8. For the top cell labeled KRAS, ChannelViT appears to utilize information from the Mito channel. For the bottom cell labeled KCNH76, ChannelViT seems to utilize information from the ER and RNA channels for its prediction. Compared to ViT, ChannelViT facilitates the examination of contributions made by individual channels.

In Figure4 (left), we further compute the maximum attention score (averaged over 100 cells) for each cell label (perturbed gene) and each input channel. Our observations indicate that ChannelViT focuses on different channels for different labels (corresponding to perturbed genes), with the Mito channel emerging as the most significant information source. This heatmap, which describes the discriminability of different labels over different channels, can also aid in better understanding the relationships between different gene perturbations.

An Image Is Worth 1×16×16 Words (4)

Time efficiency

One limitation of ChannelViT is the additional computational cost incurred when expanding the channel dimension into the sequence length dimension. Implementing ChannelViT without HCS increases the training time from approximately 3 hours to 12 hours. With HCS, the training duration for ChannelViT is reduced to about 10 hours. During inference, ChannelViT requires approximately 1.6 times more time than its ViT counterpart. An interesting future direction would be to combine ChannelViT with more efficient attention mechanisms, such as LinformerWang etal. (2020) and LongNetDing etal. (2023), which scale linearly with sequence length. We direct the reader to AppendixC.1 for a comprehensive analysis of the running times.

4.3 So2Sat: Satellite Imaging

Table 4: Test accuracy of 17-way local climate zone classification on So2Sat. We consider two official splits: random split and city split. Both ViT and ChannelViT are trained on all channels with hierarchical channel sampling. We evaluate their performance on 18 channels (Sentinel 1 & 2) as well as partial channels (Sentinel 1).Sentinel 1(Channel 0-7)Sentinel 1 & 2(Channel 0-17)Random split(Zhu, 2021)ViT-S/850.6297.82ChannelViT-S/859.7599.10City split(Zhu etal., 2019)ViT-S/841.0762.48ChannelViT-S/847.3963.01

Our results on the So2Sat satellite imaging benchmark are presented in Table4. We evaluate two official splits: random split and city split, training both ViT-S/8 and ChannelViT-S/8 models using hierarchical channel sampling across all channels (Sentinel 1 & 2).

Upon evaluation, ChannelViT demonstrates superior performance over its ViT counterpart, with an improvement of 1.28 for the random split and 0.53 for the more challenging city split. In the realm of satellite imaging, Sentinel 1 channels are derived from a Synthetic Aperture Radar operating on the C-band, while Sentinel-2 is a multispectral high-resolution imaging mission. It’s worth noting that Sentinel-2 data can be cloud-affected, underscoring the importance of models that can robustly operate under partial signals using only Sentinel 1. In both random and city splits, ChannelViT significantly outperforms ViT (59.75 vs. 50.62 in random split and 47.39 vs. 41.07 in city split).

Lastly, we explore the efficiency of ChannelViT in combining satellite training data with different signals. As depicted in Figure5, we consider varying proportions (10%, 25%, 50%, 75%, 100%) of the training data with access to both Sentinel 1 & 2 signals, while the remaining data only has access to Sentinel 1 signals. The models are evaluated using all Sentinel 1 & 2 signals. Our observations consistently show ChannelViT outperforming ViT.

Interpreting the channel embeddings learned by ChannelViT

Figure2 presents the correlations between the input channels. It’s noteworthy that the first four channels of Sentinel-1 correspond to: 1) the real part of the VH channel; 2) the imaginary part of the VH channel; 3) the real part of the VV channel; and 4) the imaginary part of the VV channel. These four input channels are uncorrelated, as evidenced by the bottom left corner of the So2Sat visualization heatmap. However, upon examining the correlations between the learned channel embeddings, we observe a high correlation between the real and imaginary parts of both VV and VH channels. This intuitively aligns with the fact that the real and imaginary parts are equivalent in terms of the information they provide. This demonstrates that ChannelViT learns meaningful channel embeddings, which can provide additional insights into the relationships between different input signals.

5 Conclusion

In conclusion, our proposed model, ChannelViT, effectively addresses the unique challenges of multi-channel imaging domains. By enhancing reasoning across input channels and seamlessly handling inputs with varying sets of channels, ChannelViT has consistently outperformed its ViT counterpart in our evaluations on ImageNet and diverse applications such as medical, microscopy cell, and satellite imaging. The introduction of Hierarchical Channel Sampling (HCS) further bolsters the model’s robustness when testing with different channel combinations. Moreover, ChannelViT not only improves data efficiency but also provides additional interpretability, underscoring its potential for broad applications in the field of multi-channel imaging.

References

  • Ba etal. (2016)JimmyLei Ba, JamieRyan Kiros, and GeoffreyE Hinton.Layer normalization.arXiv preprint arXiv:1607.06450, 2016.
  • Bao & Karaletsos (2023)Yujia Bao and Theofanis Karaletsos.Contextual vision transformers for robust representation learning.arXiv preprint arXiv:2305.19402, 2023.
  • Carion etal. (2020)Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, AlexanderKirillov, and Sergey Zagoruyko.End-to-end object detection with transformers.In European conference on computer vision, pp. 213–229.Springer, 2020.
  • Caron etal. (2021)Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal,Piotr Bojanowski, and Armand Joulin.Emerging properties in self-supervised vision transformers.In Proceedings of the IEEE/CVF international conference oncomputer vision, pp. 9650–9660, 2021.
  • Chandrasekaran etal. (2022)SrinivasNiranj Chandrasekaran, BethA Cimini, Amy Goodale, Lisa Miller, MariaKost-Alimova, Nasim Jamali, John Doench, Briana Fritchman, Adam Skepner,Michelle Melanson, etal.Three million images and morphological profiles of cells treated withmatched chemical and genetic perturbations.bioRxiv, pp. 2022–01, 2022.
  • Chefer etal. (2021)Hila Chefer, Shir Gur, and Lior Wolf.Transformer interpretability beyond attention visualization, 2021.
  • Chefer etal. (2022)Hila Chefer, Idan Schwartz, and Lior Wolf.Optimizing relevance maps of vision transformers improves robustness.Advances in Neural Information Processing Systems,35:33618–33632, 2022.
  • Cong etal. (2022)Yezhen Cong, Samar Khanna, Chenlin Meng, Patrick Liu, Erik Rozi, Yutong He,Marshall Burke, David Lobell, and Stefano Ermon.Satmae: Pre-training transformers for temporal and multi-spectralsatellite imagery.Advances in Neural Information Processing Systems,35:197–211, 2022.
  • Deng etal. (2009)Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and LiFei-Fei.Imagenet: A large-scale hierarchical image database.In 2009 IEEE conference on computer vision and patternrecognition, pp. 248–255. Ieee, 2009.
  • Ding etal. (2023)Jiayu Ding, Shuming Ma, LiDong, Xingxing Zhang, Shaohan Huang, Wenhui Wang,and Furu Wei.Longnet: Scaling transformers to 1,000,000,000 tokens.arXiv preprint arXiv:2307.02486, 2023.
  • (11)Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, GeorgHeigold, Sylvain Gelly, etal.An image is worth 16x16 words: Transformers for image recognition atscale.In International Conference on Learning Representations.
  • Gao etal. (2022)Irena Gao, Shiori Sagawa, PangWei Koh, Tatsunori Hashimoto, and Percy Liang.Out-of-distribution robustness via targeted augmentations.In NeurIPS 2022 Workshop on Distribution Shifts: ConnectingMethods and Applications, 2022.
  • Ghiasi etal. (2022)Amin Ghiasi, Hamid Kazemi, Eitan Borgnia, Steven Reich, Manli Shu, MicahGoldblum, AndrewGordon Wilson, and Tom Goldstein.What do vision transformers learn? a visual exploration.arXiv preprint arXiv:2212.06727, 2022.
  • Goyal etal. (2017)Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, LukaszWesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He.Accurate, large minibatch sgd: Training imagenet in 1 hour.arXiv preprint arXiv:1706.02677, 2017.
  • Hatamizadeh etal. (2022a)Ali Hatamizadeh, Yucheng Tang, Vishwesh Nath, Dong Yang, Andriy Myronenko,Bennett Landman, HolgerR Roth, and Daguang Xu.Unetr: Transformers for 3d medical image segmentation.In Proceedings of the IEEE/CVF winter conference onapplications of computer vision, pp. 574–584, 2022a.
  • Hatamizadeh etal. (2022b)Ali Hatamizadeh, Ziyue Xu, Dong Yang, Wenqi Li, Holger Roth, and Daguang Xu.Unetformer: A unified vision transformer model and pre-trainingframework for 3d medical image segmentation.arXiv preprint arXiv:2204.00631, 2022b.
  • He etal. (2016)Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.Deep residual learning for image recognition.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 770–778, 2016.
  • Hou & Wang (2019)Saihui Hou and Zilei Wang.Weighted channel dropout for regularization of deep convolutionalneural network.In Proceedings of the AAAI Conference on ArtificialIntelligence, volume33, pp. 8425–8432, 2019.
  • Hussein etal. (2022)Ramy Hussein, Soojin Lee, and Rabab Ward.Multi-channel vision transformer for epileptic seizure prediction.Biomedicines, 10(7):1551, 2022.
  • Kaselimi etal. (2022)Maria Kaselimi, Athanasios Voulodimos, Ioannis Daskalopoulos, NikolaosDoulamis, and Anastasios Doulamis.A vision transformer model for convolution-free multilabelclassification of satellite imagery in deforestation monitoring.IEEE Transactions on Neural Networks and Learning Systems,2022.
  • Liu etal. (2021)EvanZ Liu, Behzad Haghgoo, AnnieS Chen, Aditi Raghunathan, PangWei Koh,Shiori Sagawa, Percy Liang, and Chelsea Finn.Just train twice: Improving group robustness without training groupinformation.In International Conference on Machine Learning, pp.6781–6792. PMLR, 2021.
  • Loshchilov & Hutter (2019)Ilya Loshchilov and Frank Hutter.Decoupled weight decay regularization.In International Conference on Learning Representations, 2019.URL https://openreview.net/forum?id=Bkg6RiCqY7.
  • Mahmood etal. (2021)Kaleel Mahmood, Rigel Mahmood, and Marten VanDijk.On the robustness of vision transformers to adversarial examples.In Proceedings of the IEEE/CVF International Conference onComputer Vision, pp. 7838–7847, 2021.
  • Mao etal. (2022)Xiaofeng Mao, Gege Qi, Yuefeng Chen, Xiaodan Li, Ranjie Duan, Shaokai Ye, YuanHe, and Hui Xue.Towards robust vision transformer.In Proceedings of the IEEE/CVF conference on Computer Visionand Pattern Recognition, pp. 12042–12051, 2022.
  • Nguyen etal. (2023)Tung Nguyen, Johannes Brandstetter, Ashish Kapoor, JayeshK. Gupta, and AdityaGrover.Climax: A foundation model for weather and climate, 2023.
  • Sagawa etal. (2019)Shiori Sagawa, PangWei Koh, TatsunoriB Hashimoto, and Percy Liang.Distributionally robust neural networks.In International Conference on Learning Representations, 2019.
  • Scheibenreif etal. (2022)Linus Scheibenreif, Joëlle Hanna, Michael Mommert, and Damian Borth.Self-supervised vision transformers for land-cover segmentation andclassification.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition, pp. 1422–1431, 2022.
  • Sivanandan etal. (2023)Srinivasan Sivanandan, Bobby Leitmann, Eric Lubeck, MohammadMuneeb Sultan,Panagiotis Stanitsas, Navpreet Ranu, Alexis Ewer, JordanE Mancuso, ZacharyFPhillips, Albert Kim, etal.A pooled cell painting crispr screening platform enables de novoinference of gene function by self-supervised deep learning.bioRxiv, pp. 2023–08, 2023.
  • Song etal. (2022)QiSong, Jie Li, Chenghong Li, Hao Guo, and Rui Huang.Fully attentional network for semantic segmentation.In Proceedings of the AAAI Conference on ArtificialIntelligence, volume36, pp. 2280–2288, 2022.
  • Srivastava etal. (2014)Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and RuslanSalakhutdinov.Dropout: a simple way to prevent neural networks from overfitting.The journal of machine learning research, 15(1):1929–1958, 2014.
  • Tarasiou etal. (2023)Michail Tarasiou, Erik Chavez, and Stefanos Zafeiriou.Vits for sits: Vision transformers for satellite image time series.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition, pp. 10418–10428, 2023.
  • Tompson etal. (2015)Jonathan Tompson, Ross Goroshin, Arjun Jain, Yann LeCun, and Christoph Bregler.Efficient object localization using convolutional networks.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 648–656, 2015.
  • Touvron etal. (2021)Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, AlexandreSablayrolles, and Hervé Jégou.Training data-efficient image transformers & distillation throughattention.In International conference on machine learning, pp.10347–10357. PMLR, 2021.
  • Vaswani etal. (2017)Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,AidanN Gomez, Łukasz Kaiser, and Illia Polosukhin.Attention is all you need.Advances in neural information processing systems, 30, 2017.
  • Wang etal. (2020)Sinong Wang, BelindaZ Li, Madian Khabsa, Han Fang, and Hao Ma.Linformer: Self-attention with linear complexity.arXiv preprint arXiv:2006.04768, 2020.
  • You etal. (2018)Yang You, Zhao Zhang, Cho-Jui Hsieh, James Demmel, and Kurt Keutzer.Imagenet training in minutes.In Proceedings of the 47th International Conference on ParallelProcessing, pp. 1–10, 2018.
  • Zhang etal. (2021)Xin Zhang, Liangxiu Han, Tam Sobeih, Lewis Lappin, Mark Lee, Andew Howard, andAron Kisdi.The channel-spatial attention-based vision transformer network forautomated, accurate prediction of crop nitrogen status from uav imagery.arXiv e-prints, pp. arXiv–2111, 2021.
  • Zhou etal. (2022)Daquan Zhou, Zhiding Yu, Enze Xie, Chaowei Xiao, Animashree Anandkumar, JiashiFeng, and JoseM Alvarez.Understanding the robustness in vision transformers.In International Conference on Machine Learning, pp.27378–27394. PMLR, 2022.
  • Zhu etal. (2020a)XiaoXiang Zhu, Jingliang Hu, Chunping Qiu, Yilei Shi, Jian Kang, Lichao Mou,Hossein Bagheri, Matthias Haberle, Yuansheng Hua, Rong Huang, Lloyd Hughes,Hao Li, Yao Sun, Guichen Zhang, Shiyao Han, Michael Schmitt, and YuanyuanWang.So2sat lcz42: A benchmark data set for the classification of globallocal climate zones [software and data sets].IEEE Geoscience and Remote Sensing Magazine, 8(3):76–89, 2020a.doi: 10.1109/MGRS.2020.2964708.
  • Zhu (2021)Xiaoxiang Zhu.So2sat lcz42 3 splits, 2021.
  • Zhu etal. (2019)Xiaoxiang Zhu, Jingliang Hu, Chunping Qiu, Yilei Shi, Hossein Bagheri, JianKang, Hao Li, Lichao Mou, Guicheng Zhang, Matthias Häberle, Shiyao Han,Yuansheng Hua, Rong Huang, Lloyd Hughes, Yao Sun, Michael Schmitt, andYuanyuan Wang.So2sat lcz42, 2019.
  • Zhu etal. (2020b)Xizhou Zhu, Weijie Su, Lewei Lu, Bin Li, Xiaogang Wang, and Jifeng Dai.Deformable detr: Deformable transformers for end-to-end objectdetection.arXiv preprint arXiv:2010.04159, 2020b.

Appendix A Implementation Details

This section elucidates the specifics of our implementation and the settings of our hyper-parameters.

A.1 Hierarchical Channel Sampling

In Section3.2, we outlined the channel sampling procedure of HCS. In this subsection, we offer a comprehensive example of HCS in conjunction with ChannelViT and ViT.

Hierarchical Channel Sampling for ChannelViT

Given a three-channel input x𝑥xitalic_x, as per Section3.1, the input sequence for the Transformer encoder can be expressed as

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+𝚌𝚑𝚗1+Wx[c1,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{1}+Wx[c_{1},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗1+Wx[c1,pN],subscript𝚙𝚘𝚜𝑁subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{1}+Wx[c_{1},p_{N}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
𝚙𝚘𝚜1+𝚌𝚑𝚗2+Wx[c2,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗2𝑊𝑥subscript𝑐2subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{2}+Wx[c_{2},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗2+Wx[c2,pN],subscript𝚙𝚘𝚜𝑁subscript𝚌𝚑𝚗2𝑊𝑥subscript𝑐2subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{2}+Wx[c_{2},p_{N}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
𝚙𝚘𝚜1+𝚌𝚑𝚗3+Wx[c3,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗3𝑊𝑥subscript𝑐3subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{3}+Wx[c_{3},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗3+Wx[c3,pN]].\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{3}+Wx[c_{3},p_{N}]%\big{]}.… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] .

Let’s assume that our sampled channel combination from the HCS algorithm is {1,3}13\{1,3\}{ 1 , 3 }. The corresponding input sequence for the Transformer encoder would then be modified accordingly.

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+𝚌𝚑𝚗1+Wx[c1,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{1}+Wx[c_{1},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗1+Wx[c1,pN],subscript𝚙𝚘𝚜𝑁subscript𝚌𝚑𝚗1𝑊𝑥subscript𝑐1subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{1}+Wx[c_{1},p_{N}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
𝚙𝚘𝚜1+𝚌𝚑𝚗3+Wx[c3,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗3𝑊𝑥subscript𝑐3subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{3}+Wx[c_{3},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗3+Wx[c3,pN]].\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{3}+Wx[c_{3},p_{N}]%\big{]}.… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] .

It’s important to note that reducing the number of channels only modifies the sequence length. Furthermore, since we sample the channel combinations for each training step, the channels utilized for each image can vary across different epochs.

Hierarchical Channel Sampling for ViT

Given the identical three-channel input x𝑥xitalic_x, the input sequence for the Transformer encoder can be articulated as

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+W1x[c1,p1]+W2x[c2,p1]+W3x[c3,p1]+b,subscript𝚙𝚘𝚜1subscript𝑊1𝑥subscript𝑐1subscript𝑝1subscript𝑊2𝑥subscript𝑐2subscript𝑝1subscript𝑊3𝑥subscript𝑐3subscript𝑝1𝑏\displaystyle\texttt{pos}_{1}+W_{1}x[c_{1},p_{1}]+W_{2}x[c_{2},p_{1}]+W_{3}x[c%_{3},p_{1}]+b,pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] + italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] + italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] + italic_b ,
,\displaystyle\ldots,\quad… ,𝚙𝚘𝚜n+W1x[c1,pn]+W2x[c2,pn]+W3x[c3,pn]+b].\displaystyle\texttt{pos}_{n}+W_{1}x[c_{1},p_{n}]+W_{2}x[c_{2},p_{n}]+W_{3}x[c%_{3},p_{n}]+b\big{]}.pos start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] + italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] + italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] + italic_b ] .

Here W1,W2,W3subscript𝑊1subscript𝑊2subscript𝑊3W_{1},W_{2},W_{3}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT represent the weights associated with each input channel, and b𝑏bitalic_b is the bias term. Let’s continue with the assumption that our sampled channel combination from the HCS algorithm remains {1,3}13\{1,3\}{ 1 , 3 }. We then adjust the above input sequence as follows:

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+W1x[c1,p1]3/2+W3x[c3,p1]3/2+b,subscript𝚙𝚘𝚜1subscript𝑊1𝑥subscript𝑐1subscript𝑝132subscript𝑊3𝑥subscript𝑐3subscript𝑝132𝑏\displaystyle\texttt{pos}_{1}+W_{1}x[c_{1},p_{1}]3/2+W_{3}x[c_{3},p_{1}]3/2+b,pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] 3 / 2 + italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] 3 / 2 + italic_b ,
,\displaystyle\ldots,\quad… ,𝚙𝚘𝚜n+W1x[c1,pn]3/2+W3x[c3,pn]3/2+b].\displaystyle\texttt{pos}_{n}+W_{1}x[c_{1},p_{n}]3/2+W_{3}x[c_{3},p_{n}]3/2+b%\big{]}.pos start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] 3 / 2 + italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] 3 / 2 + italic_b ] .

It’s noteworthy that, in addition to masking the input from the second channel, we also rescale the remaining channels by a factor of 3/2323/23 / 2.This is akin to the approach of Srivastava etal. (2014), and is done to ensure that the output of the linear patch layer maintains the same scale, despite the reduction in input channels.

A.2 Training with ViT and ChannelViT

Backbone

For the vision transformer backbone, we employ the PyTorch implementation provided by Facebook Research333https://github.com/facebookresearch/dino/blob/main/vision_transformer.py. Due to computational constraints, we primarily utilize the ‘vit-small‘ architecture, which has an embedding dimension of 386, a depth of 12, 6 heads, an MLP hidden dimension of 4×386=1544438615444\times 386=15444 × 386 = 1544 and pre layer normalization. We also briefly experiment ‘vit-base‘ which increases the embedding dimension to 768, the number of heads to 12, and the MLP hidden dimension to 4×768=3072476830724\times 768=30724 × 768 = 3072.For ChannelViT, we retain the same parameter settings as its ViT counterparts for the Transformer encoder. Note that ChannelViT has a marginally smaller number of parameters, as the first linear projection layer is now shared across channels.

Objective

We employ the standard cross-entropy loss for both ViT and ChannelViT across the four image classification benchmarks. Specifically, we utilize the Transformer encoder’s representation for the CLS token at the final layer, and append a linear layer, followed by a Softmax function, to predict the probability of each class.

Optimization

For optimization, we employ the AdamW optimizer(Loshchilov & Hutter, 2019). The learning rate is warmed up for the initial 10 epochs, peaking at 0.0005(Goyal etal., 2017), after which it gradually decays to 106superscript10610^{-6}10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT following a cosine scheduler. To mitigate overfitting, we apply weight decay to the weight parameters, excluding the bias and normalization terms. The weight decay starts at 0.04 and incrementally increases during training, following a cosine scheduler, up to a maximum of 0.4. Each model is trained for 100 epochs with a batch size of 256. The training is conducted on an AWS p4d.24xlarge instance equipped with 8 A100 GPUs.

A.3 Training on datasets with varying channel availability

In Table3 and Figure5, we investigated scenarios where our training datasets exhibited varying channel availability. This section provides a detailed description of the training settings we employed and presents additional results for an alternative setting.

ChannelViT and ViT

Despite the different channel combinations in the training datasets, we utilize a consistent approach (as detailed in Appendix) to encode the images for both ChannelViT and ViT. For ChannelViT, this entails having varying sequence lengths for images with different numbers of channels. For ViT, this involves masking out the unavailable channels and rescaling the remaining ones.

Objective

We continue to use the cross-entropy loss. However, in this instance, there are two potential methods for data sampling.

  1. 1.

    Sampling a random batch from each dataset and minimizing their average loss. This approach will assign more weight to datasets with fewer examples. Mathematically, it optimizes

    upsample=|D1|+|D2|2|D1|D1+|D1|+|D2|2|D2|D2,subscriptupsamplesubscript𝐷1subscript𝐷22subscript𝐷1subscriptsubscript𝐷1subscript𝐷1subscript𝐷22subscript𝐷2subscriptsubscript𝐷2\mathcal{L}_{\text{upsample}}=\frac{|D_{1}|+|D_{2}|}{2|D_{1}|}\mathcal{L}_{D_{%1}}+\frac{|D_{1}|+|D_{2}|}{2|D_{2}|}\mathcal{L}_{D_{2}},caligraphic_L start_POSTSUBSCRIPT upsample end_POSTSUBSCRIPT = divide start_ARG | italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | + | italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG start_ARG 2 | italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | end_ARG caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG | italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | + | italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG start_ARG 2 | italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,

    where we assume D1subscript𝐷1D_{1}italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and D2subscript𝐷2D_{2}italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are the two training datasets with different channels.

  2. 2.

    Concatenate the two datasets and draw a batch from the combined datasets. This approach simply minimizes the average loss

    average=D1+D2.subscriptaveragesubscriptsubscript𝐷1subscriptsubscript𝐷2\mathcal{L}_{\text{average}}=\mathcal{L}_{D_{1}}+\mathcal{L}_{D_{2}}.caligraphic_L start_POSTSUBSCRIPT average end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT .

Our preliminary experiments indicate that the second method consistently outperformed the first. For instance, in JUMP-CP when training with 25% 8-channel data, ChannelViT-S/16 achieves 57.97% when training with averagesubscriptaverage\mathcal{L}_{\text{average}}caligraphic_L start_POSTSUBSCRIPT average end_POSTSUBSCRIPT but only reachs 45.52% when training with upsamplesubscriptupsample\mathcal{L}_{\text{upsample}}caligraphic_L start_POSTSUBSCRIPT upsample end_POSTSUBSCRIPT. Similarly, ViT-S/16 achieves 50.29% when training with averagesubscriptaverage\mathcal{L}_{\text{average}}caligraphic_L start_POSTSUBSCRIPT average end_POSTSUBSCRIPT but only scores 42.58% when training with upsamplesubscriptupsample\mathcal{L}_{\text{upsample}}caligraphic_L start_POSTSUBSCRIPT upsample end_POSTSUBSCRIPT. We hypothesize that models exhibit overfitting when trained using the upsampling loss. Therefore, we report the numbers for the normal average loss averagesubscriptaverage\mathcal{L}_{\text{average}}caligraphic_L start_POSTSUBSCRIPT average end_POSTSUBSCRIPT in Table3 and Figure5.

A.4 Evaluation across all channel combinations

To assess the channel robustness of the trained models, we enumerate all possible channel combinations and report the corresponding accuracy for each.

For instance, in Table2, we have considered two training scenarios: the top section pertains to training on 5 fluorescence channels, while the bottom section pertains to training on all 8 channels. For the top section, we can evaluate the models for all subsets of the 5 fluorescence channels. This includes

  • Combinations with 5 channels: there is only one C55=1superscriptsubscript𝐶551C_{5}^{5}=1italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT = 1 combination;

  • Combinations with 4 channels: there are C54=5superscriptsubscript𝐶545C_{5}^{4}=5italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT = 5 combinations;

  • Combinations with 3 channels: there are C53=10superscriptsubscript𝐶5310C_{5}^{3}=10italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT = 10 combinations;

  • Combinations with 2 channels: there are C52=10superscriptsubscript𝐶5210C_{5}^{2}=10italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 10 combinations;

  • Combinations with 1 channels: there are C51=5superscriptsubscript𝐶515C_{5}^{1}=5italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = 5 combinations.

Consequently, we evaluate a total of 1+5+10+10+5=311510105311+5+10+10+5=311 + 5 + 10 + 10 + 5 = 31 channel combinations.Given a specific channel combination, we mask out the testing images accordingly (as described in AppendixA.1) and compute the corresponding testing accuracy. We then report the average accuracy over combinations that have the same number of channels. As one might intuitively expect, models tend to perform better when provided with more channels.

A.5 Dataset details

In this section, we provide a detailed description of our datasets and their corresponding input channels.

JUMP-CP, 160-way classification

We use the processed version of JUMP-CP released by Bao & Karaletsos (2023)444https://github.com/insitro/ContextViT. Each image consists of a single masked cell and includes five fluorescence channels: AGP, DNA, ER, Mito, RNA, as well as three brightfield channels: HighZBF (Brightfield-1), LowZBF (Brightfield-2), and Brightfield (Brightfield-3). Each cell has been perturbed by a chemical compound, and the goal is to identify the gene target of the chemical perturbation.

So2Sat, 17-way classification

We use the processed version So2Sat released by the original authorsZhu etal. (2020a)555https://github.com/zhu-xlab/So2Sat-LCZ42. Each image patch consists of 8 channels from Sentinel-1:

  1. 1.

    the real part of the unfiltered VH channel;

  2. 2.

    the imaginary part of the unfiltered VH channel;

  3. 3.

    the real part of the unfiltered VV channel;

  4. 4.

    the imaginary part of the unfiltered VV channel;

  5. 5.

    the intensity of the refined LEE filtered VH channel;

  6. 6.

    the intensity of the refined LEE filtered VV channel;

  7. 7.

    the real part of the refined LEE filtered covariance matrix off-diagonal element;

  8. 8.

    the imaginary part of the refined LEE filtered covariance matrix off-diagonal element.

and 10 channels from Sentinel-2: Band B2, Band B3, Band B4, Band B5, Band B6, Band B7, Band B8, Band B8a, Band B11 and Band B12.The task is to predict the climate zone for each respective image patch, with a total of 17 distinct climate zones being represented.

Appendix B Additional baselines

ViT
S/16
MultiViT
S/16
ChannelViT
S/16
ViT
S/16
MultiViT
S/16
ChannelViT
S/16
Use hierarchical
channel sampling?
#channelsfor testing8 channels52.0649.0666.2256.8730.2568.09
7 channels5.9134.1041.0349.3529.0461.02
6 channels1.8123.7724.5742.3827.4453.45
5 channels2.4617.0914.2035.7825.6945.50
4 channels2.3812.988.5629.8423.9637.37
3 channels2.7010.585.6524.9422.3429.68
2 channels2.639.613.2421.5420.8923.77
1 channel3.007.972.0819.9219.8520.84

B.1 Baseline: Concatenating Features from Multiple Single-Channel ViTs

Hussein etal. (2022) utilized ViTs for epileptic seizure predictions, proposing a method to train multiple ViTs, one for each input channel. The final image representation is derived by aggregating the output CLS tokens across all single-channel ViTs. An MLP is then attached to these aggregated features to predict the image label. In this section, we implement this baseline based on the paper, termed MultiViT, and evaluate its performance both with and without HCS.

Table5 presents our results on JUMP-CP when training using all eight channels. Without HCS, MultiViT underperforms compared to ViT when evaluated on all channels, despite having eight times more parameters. This underscores the importance of parameter sharing across different channels to combat overfitting. However, when testing on a subset of channels, MultiViT outperforms ViT, as each ViT operates on a single channel, thereby improving robustness to changes in the input channels. Interestingly, MultiViT does not perform well with HCS. While the accuracy improves when testing on a subset of channels, the accuracy significantly decreases (from 49.06 to 30.25) when using all eight channels. We hypothesize that this is due to the channel-wise feature aggregation being performed after the single-channel ViTs, preventing the model from conditioning the representation based on the input channel availability.

We find that ChannelViT significantly outperforms MultiViT. There are three key differences between the two models:

  1. 1.

    ChannelViT learns a single ViT across all channels, rather than one ViT for each channel;

  2. 2.

    ChannelViT is aware of the input channel availability at the input patch sequence, while the single-channel ViTs in MultiViT operate independently;

  3. 3.

    ChannelViT allows cross-channel cross-location attention, while MultiViT only permits cross-location attention.

Without HCSWith HCS
#channels
for testing
ViT
S/16
FAN
S/16
(conv
patch)
FAN
S/16
(linear
patch)
ChannelViT
S/16
ViT
S/16
FAN
S/16
(conv
patch)
FAN
S/16
(linear
patch)
ChannelViT
S/16
852.0665.1365.4266.2256.873.4920.3168.09
75.911.243.6341.0349.353.8820.5261.02
61.810.644.8224.5742.383.9617.4653.45
52.462.116.6214.2035.783.1515.1745.50
42.383.806.688.5629.843.9211.7437.37
32.705.036.035.6524.944.549.4229.68
22.634.365.973.2421.542.216.6523.77
13.002.682.922.0819.922.902.5220.84

B.2 Baseline: Fully Attentional Networks (FANs)

Zhou etal. (2022) introduced a family of Fully Attentional Networks (FANs) that combine channel-wise attention with the MLP in a transformer encoder layer. Notably, the channels in this context extend beyond the input channels. FANs aggregate feature channels with high correlation values across the transformer encoder layers and isolate outlier features with low correlation values.

We adopted the implementation provided at https://github.com/NVlabs/FAN/blob/master/models/fan.py and evaluated the FAN small with a patch size of 16 by 16. It’s worth noting that FAN, by default, employs four stacks of 3 by 3 convolution layers (each followed by GELU activations) to construct the input patch tokens, whereas ViT and ChannelViT use a single linear layer over the 16 by 16 input patches. We refer to this FAN baseline as FAN S/16 (conv patch). We also experimented with replacing these convolution layers with the same linear projection used in the regular ViT, terming this modified version of FAN as FAN S/16 (linear patch).

Table6 presents our results on FANs. Without HCS, the default FAN-S/16 (conv patch) significantly outperforms ViT (65.13 vs 52.06), demonstrating the effectiveness of cross-channel attention. However, it still falls short of ChannelViT (65.13 vs. 66.22). Furthermore, when evaluated using a subset of channels at test time, its performance significantly declines (1.24 vs. 41.03 on 7 channels). Interestingly, we observed that the FAN with a linear patch embedding layer performs slightly better than the default FAN with convolution patch embeddings.

We also investigated training FANs with HCS. We discovered that FAN with convolution patch embeddings struggled to learn a meaningful classifier. Replacing the convolution layers with a simple linear transformation improved the performance, and we observed that when trained with HCS, FAN-S/16 (linear patch) outperforms its counterpart without HCS when evaluated on a subset of channels. However, the performance is still significantly lower than the regular ViT-S/16. We hypothesize that since FANs explicitly leverage the correlation between different hidden channels to build its representations, it becomes more sensitive to channel perturbations at test time.

In conclusion, we highlight the key differences between ChannelViT and FANs:

  1. 1.

    ChannelViT performs cross-channel and cross-location attention jointly, meaning that each patch token can attend to a different channel at a different location.

  2. 2.

    ChannelViT maintains the distinction of different input channels throughout the transformer encoder and tie the transformer encoder across channels, which we argue enhances robustness to channel changes.

ModelResNet50ResNet152ViT-S/8ChannelViT-S/8
#parameters25M60M22M22M
Use hierarchical
channel sampling?
#channelsfor testing8 Channels65.9666.5466.4474.77
7 Channels2.393.0559.0168.42
6 Channels2.224.2951.2961.26
5 Channels1.575.3543.3953.05
4 Channels1.195.9135.6043.87
3 Channels0.785.6928.5934.19
2 Channels0.564.2923.3225.73
1 Channels0.512.6620.4121.20

B.3 Baseline: Convolutional neural networks (CNNs)

In this section, we further compare the Vision Transformer (ViT) and ChannelViT with conventional Convolutional Neural Networks (CNNs). We use the widely-adopted ResNet-50 and ResNet-152 as our baseline models, as described by He etal. (2016). We present the number of parameters for each model and their corresponding performance on the JUMPCP dataset in Table7.

Initially, we trained the ResNet baselines using Hierarchical Channel Sampling (HCS), but this approach led to training instability, with the top-1 accuracies of both models converging to approximately 5% by the end of training. Without HCS, ResNet-50 and ResNet-152 exhibit performance comparable to the ViT-S/8 baseline. Despite having three times more parameters, ResNet-152 achieves only a slight improvement over ResNet-50. When compared with ChannelViT-S/8, there still remains a significant performance gap.

We hypothesize that the parameter sharing within ChannelViT enables the efficient and robust construction of channel-invariant filters. Conversely, the explicit cross-channel attention in ChannelViT effectively facilitates the model’s ability to infer relationships across related channels.

An Image Is Worth 1×16×16 Words (5)

Appendix C Additional analysis

C.1 Running time analysis

Model#parametersTraining timeInference time8-channel accuracy
ResNet5025M3.9 hours65.8 sec65.96
ResNet15260M4.4 hours81.8 sec66.54
ViT-S/1622M2.8 hours54.5 sec56.87
ChannelViT-S/16 w/o HCS22M12.1 hours91.0 sec66.22
ChannelViT-S/16 w/ HCS22M10.2 hours90.7 sec68.09
Tied linear
projection weights?
#channelsfor testing5 channels54.7856.78
4 channels43.8845.94
3 channels33.6735.45
2 channels25.5726.57
1 channel21.0721.43

Our proposed ChannelViT model, which expands the channel dimension into the sequence length dimension, introduces an inherent increase in computational cost. As shown in Table8, the training duration for the ChannelViT-S/16 model on the JUMP-CP dataset, utilizing all eight channels, is significantly longer without the application of Hierarchical Channel Sampling (HCS). However, the integration of HCS results in a 15% reduction in training time, decreasing from 12 hours and 6 minutes to 10 hours and 17 minutes. This demonstrates that HCS not only bolsters the model’s robustness but also markedly enhances training efficiency.

In terms of inference cost, ChannelViT exhibits a 1.6 times increase in processing time compared to its ViT counterpart, yet it achieves an 11.22% higher accuracy (on an absolute scale). When measured against the better performing ResNet-152 baseline, Channel ViT’s inference time is only 1.1 times longer.

In this paper, we have explored the ChannelViT utilizing the standard quadratic attention mechanism. Looking ahead, it would be intriguing to investigate the integration of ChannelViT with more efficient algorithms, such as LinformerWang etal. (2020) and LongNetDing etal. (2023), which scale linearly with sequence length. Such combinations could potentially yield further improvements in both performance and computational efficiency.

C.2 Attention visualization for ImageNet

Figure7 illustrates the attention heatmaps for ViT and ChannelViT models trained on the ImageNet dataset. For each image, we generate the rolled out attention scores for two distinct classes—espresso and wine for the top image, and elephant and zebra for the bottom image—following Chefer etal. (2021). We observe that ChannelViT precisely focuses its attention on the relevant channels, such as the red channel when predicting red wine. In scenarios where the contrast pattern, such as black and white for a zebra, is distributed across all channels, ChannelViT effectively utilizes all channels to inform its prediction.

An Image Is Worth 1×16×16 Words (6)

C.3 Ablation: tied vs. untied image filters for ChannelViT

An Image Is Worth 1×16×16 Words (7)

In the main paper, we introduced ChannelViT with a linear projection layer tied across various channels. This section delves into the exploration of flexible weights for each channel (Figure8). The input sequence to the Transformer encoder can be represented as follows:

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+𝚌𝚑𝚗1+W1x[c1,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗1subscript𝑊1𝑥subscript𝑐1subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{1}+W_{1}x[c_{1},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗1+W1x[c1,pN],subscript𝚙𝚘𝚜𝑁subscript𝚌𝚑𝚗1subscript𝑊1𝑥subscript𝑐1subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{1}+W_{1}x[c_{1},p_{N%}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
,\displaystyle\ldots,\quad… ,𝚙𝚘𝚜1+𝚌𝚑𝚗C+WCx[cC,p1],subscript𝚙𝚘𝚜1subscript𝚌𝚑𝚗𝐶subscript𝑊𝐶𝑥subscript𝑐𝐶subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}_{C}+W_{C}x[c_{C},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗C+WCx[cC,pN]],\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}_{C}+W_{C}x[c_{C},p_{N%}]\big{]},… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] ,

where W1,,WCsubscript𝑊1subscript𝑊𝐶W_{1},\ldots,W_{C}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT denote the linear transformations associated with the input channels.Table9 showcases our findings on JUMP-CP. It is observed that ChannelViT, when trained with tied image filter weights, consistently outperforms its untied counterpart. We hypothesize that the first layer filters are generally shareable across channels, and tying the parameters can prevent overfitting, thereby enhancing the model’s robustness.

C.4 Ablation: shared vs. unshared channel embeddings

An Image Is Worth 1×16×16 Words (8)
Shared channelembedding?ChannelViT-S/16 on JUMP-CPChannelViT-S/16 on So2Sat
fluorescence
(5 channels)
fluorescence & brightfield
(8 channels)
Sentinel-1
(8 channels)
Sentinel-1 & -2
(18 channels)
1.262.4910.4452.13
57.6068.0947.3963.01

In this section, we conduct an ablation study to investigate the impact of channel embeddings on the performance of ChannelViT models.Specifically, we consider the following simplification of ChannelViT where we have a shared channel embedding across all channels:

[𝙲𝙻𝚂,\displaystyle\big{[}\texttt{CLS},\quad[ CLS ,𝚙𝚘𝚜1+𝚌𝚑𝚗+Wx[c1,p1],subscript𝚙𝚘𝚜1𝚌𝚑𝚗𝑊𝑥subscript𝑐1subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}+Wx[c_{1},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗+Wx[c1,pN],subscript𝚙𝚘𝚜𝑁𝚌𝚑𝚗𝑊𝑥subscript𝑐1subscript𝑝𝑁\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}+Wx[c_{1},p_{N}],… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn + italic_W italic_x [ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ,
,\displaystyle\ldots,\quad… ,𝚙𝚘𝚜1+𝚌𝚑𝚗+Wx[cC,p1],subscript𝚙𝚘𝚜1𝚌𝚑𝚗𝑊𝑥subscript𝑐𝐶subscript𝑝1\displaystyle\texttt{pos}_{1}+\texttt{chn}+Wx[c_{C},p_{1}],pos start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + chn + italic_W italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ,,𝚙𝚘𝚜N+𝚌𝚑𝚗+Wx[cC,pN]].\displaystyle\ldots\ ,\quad\texttt{pos}_{N}+\texttt{chn}+Wx[c_{C},p_{N}]\big{]}.… , pos start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT + chn + italic_W italic_x [ italic_c start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ] .

We consider ChannelViTs trained on both JUMP-CP and So2Sat. A natural way to define chn is to set it as the mean embeddings of the learned channel embeddings:

𝚌𝚑𝚗=1Cc𝚌𝚑𝚗c.𝚌𝚑𝚗1𝐶subscript𝑐subscript𝚌𝚑𝚗𝑐\texttt{chn}=\frac{1}{C}\sum_{c}\texttt{chn}_{c}.chn = divide start_ARG 1 end_ARG start_ARG italic_C end_ARG ∑ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT chn start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT .

We present our ablation study in Table10. We observe that this modification significantly harms the performance, underscoring the importance of maintaining the original channel embeddings. Interestingly, the ChannelViT model demonstrates a higher degree of sensitivity to alterations in channel embedding on the JUMP-CP dataset as compared to the So2Sat dataset. This suggests that the specific characteristics of the dataset can influence the model’s reliance on channel embeddings.

Additional features
besides last-layer ‘[CLS]’
Classifier
on top of the features
Classifier shared across
channel combinations?
Accuracy
Informing the final classifier of the sampled channel combination
NoneLinear26.56
Embeddings for each channel comb.MLP61.98
One-hot encoding of each channel comb.MLP66.86
ChannelViT
NoneLinear68.09

C.5 Investigation: do we need a separate classifier for each channel combination?

The application of hierarchical channel sampling results in the model receiving a variety of input channel combinations, leading to significant changes in the input distribution. This prompts an investigation into whether it’s necessary to further condition the final classifier based on the sampled channel combinations.Table11 presents our ablation analysis, where we consider three methods for incorporating the information of the input channels into the final classifier:

  1. 1.

    The first baseline involves learning a separate linear classifier on top of the ViT embeddings for each channel combination.

  2. 2.

    The second baseline learns an embedding vector for each channel and constructs the representation for the sampled channel combination by summing up all the embeddings for the selected channels. This representation is then concatenated with the ViT representation and fed to a shared MLP with one hidden layer.

  3. 3.

    The third method is similar to the second baseline, but uses one-hot encoding as the representation for the sampled channel combination.

Our observations indicate that all three methods underperform when compared to the basic ChannelViT, which uses a shared linear classifier across all channel combinations. We hypothesize that the shared linear classifier regularizes the ViT to embed inputs with different channel combinations into the same space. This bias appears to enhance robustness and performance.

An Image Is Worth 1×16×16 Words (9)

C.6 Breaking down the performance gain on JUMP-CP for each gene target

In Figure10, we delve into a comparative analysis of the performance between ChannelViT-S/8 and ViT-S/8 across each cell label (gene target). Our figure reveals that ChannelViT surpasses ViT in 90% of the gene targets, while underperforming in the remaining targets. It’s important to note that the gain is computed from a 160-way classification task, where the models are trained to optimize the average loss across all gene targets. If we reframe the problem using a multi-task learning objective, the distribution of gains per gene could potentially differ, and we expect the improvements of ChannelViT to be more consistent.

C.7 Backbone: Small vs. Base vs. Large

In the main body of the paper, we explored the ViT and ChannelViT across different resolutions (16x16 and 8x8). To provide a comprehensive analysis, we extend our investigation to include various backbone sizes. Adhering to the conventions established by Dosovitskiy etal. , we evaluated the performance of the ViT and ChannelViT models with small, base, and large backbones. The specific configurations of these models are detailed in Table12.

Performance metrics for the different model sizes are presented in Table13. We note a trend of incremental performance improvements in both ViT and ChannelViT as the number of parameters increases. Concurrently, the performance disparity between ViT and ChannelViT remains consistent and significant.

sizeembed dimdepthnum headsMLP hidden dim
Small3841261536
Base76812123072
Large102424164096
ViT (trained with HCS)ChannelViT (trained with HCS)
Small/16Base/16Large/16Small/16Base/16Large/16
#channelsfor testing8 channels56.8757.8557.9668.0968.5368.87
7 channels49.3550.3550.9361.0261.5662.01
6 channels42.3843.9843.9853.4553.5353.59
5 channels35.7837.2637.2645.5045.9646.54
4 channels29.8430.8230.8237.3737.9038.09
3 channels24.9425.3725.3729.6829.9630.06
2 channels21.5421.7321.7323.7723.7823.62
1 channel19.9220.0420.0420.8421.6121.06

C.8 Performance variations across different channel combinations

Table14 presents an analysis of the standard deviation for both ViT and ChannelViT on the JUMP-CP dataset, considering all channel combinations. We report the mean accuracies for groups categorized by the same amount of channels. It is important to note that, despite maintaining a constant number of channels, the informational content of different channel combinations can differ markedly, which is reflected in the substantial standard deviations observed in the table.

To further dissect and comprehend this variance, we examined the performance gains of ChannelViT over ViT for each individual channel combination. The mean improvements, along with their standard deviations, are presented in Table15. Our analysis substantiates that the performance enhancements attributed to ChannelViT are not only consistent across various combinations but also notably significant.

ViT-S/16ChannelViT-S/16ViT-S/8ChannelViT-S/8
Use hierarchical
channel sampling?
Training on all 8 channels (5 fluorescence channels & 3 brightfield channels)
#channelsfor testing8 channels (C88=1)C^{8}_{8}=1)italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT = 1 )56.8768.0966.4474.77
7 channels (C78=8)C^{8}_{7}=8)italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT = 8 )49.35±9.38plus-or-minus9.38\pm 9.38± 9.3861.02±9.78plus-or-minus9.78\pm 9.78± 9.7859.01±10.07plus-or-minus10.07\pm 10.07± 10.0768.42±9.11plus-or-minus9.11\pm 9.11± 9.11
6 channels (C68=28subscriptsuperscript𝐶8628C^{8}_{6}=28italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT = 28)42.38±10.64plus-or-minus10.64\pm 10.64± 10.6453.45±12.40plus-or-minus12.40\pm 12.40± 12.4051.29±12.47plus-or-minus12.47\pm 12.47± 12.4761.26±11.91plus-or-minus11.91\pm 11.91± 11.91
5 channels (C58=56subscriptsuperscript𝐶8556C^{8}_{5}=56italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT = 56)35.78±10.18plus-or-minus10.18\pm 10.18± 10.1845.50±13.23plus-or-minus13.23\pm 13.23± 13.2343.39±12.89plus-or-minus12.89\pm 12.89± 12.8953.05±13.41plus-or-minus13.41\pm 13.41± 13.41
4 channels (C48=70subscriptsuperscript𝐶8470C^{8}_{4}=70italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = 70)29.84±8.32plus-or-minus8.32\pm 8.32± 8.3237.37±12.25plus-or-minus12.25\pm 12.25± 12.2535.60±11.55plus-or-minus11.55\pm 11.55± 11.5543.87±13.36plus-or-minus13.36\pm 13.36± 13.36
3 channels (C38=56subscriptsuperscript𝐶8356C^{8}_{3}=56italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = 56)24.94±5.43plus-or-minus5.43\pm 5.43± 5.4329.68±9.22plus-or-minus9.22\pm 9.22± 9.2228.59±8.38plus-or-minus8.38\pm 8.38± 8.3834.19±11.10plus-or-minus11.10\pm 11.10± 11.10
2 channels (C28=28subscriptsuperscript𝐶8228C^{8}_{2}=28italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 28)21.54±2.37plus-or-minus2.37\pm 2.37± 2.3723.77±4.89plus-or-minus4.89\pm 4.89± 4.8923.32±4.27plus-or-minus4.27\pm 4.27± 4.2725.73±6.57plus-or-minus6.57\pm 6.57± 6.57
1 channels (C18=8subscriptsuperscript𝐶818C^{8}_{1}=8italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 8)19.92±0.51plus-or-minus0.51\pm 0.51± 0.5120.84±1.64plus-or-minus1.64\pm 1.64± 1.6420.41±1.26plus-or-minus1.26\pm 1.26± 1.2621.20±2.17plus-or-minus2.17\pm 2.17± 2.17
ChannelViT-S/16
over ViT-S/16
ChannelViT-S/16
over ViT-S/16
ChannelViT-S/8
over ViT-S/8
Use hierarchical
channel sampling?
Training on all 8 channels (5 fluorescence channels & 3 brightfield channels)
#channelsfor testing8 channels (C88=1)C^{8}_{8}=1)italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT = 1 )14.1614.1614.1614.1611.2211.2211.2211.228.328.328.328.32
7 channels (C78=8)C^{8}_{7}=8)italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT = 8 )35.13±18.37plus-or-minus35.1318.3735.13\pm 18.3735.13 ± 18.3711.67±1.17plus-or-minus11.671.1711.67\pm 1.1711.67 ± 1.179.41±1.80plus-or-minus9.411.809.41\pm 1.809.41 ± 1.80
6 channels (C68=28subscriptsuperscript𝐶8628C^{8}_{6}=28italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT = 28)22.76±18.64plus-or-minus22.7618.6422.76\pm 18.6422.76 ± 18.6411.07±2.22plus-or-minus11.072.2211.07\pm 2.2211.07 ± 2.229.96±1.90plus-or-minus9.961.909.96\pm 1.909.96 ± 1.90
5 channels (C58=56subscriptsuperscript𝐶8556C^{8}_{5}=56italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT = 56)11.75±11.33plus-or-minus11.7511.3311.75\pm 11.3311.75 ± 11.339.72±3.46plus-or-minus9.723.469.72\pm 3.469.72 ± 3.469.66±2.30plus-or-minus9.662.309.66\pm 2.309.66 ± 2.30
4 channels (C48=70subscriptsuperscript𝐶8470C^{8}_{4}=70italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = 70)6.18±6.86plus-or-minus6.186.866.18\pm 6.866.18 ± 6.867.52±4.19plus-or-minus7.524.197.52\pm 4.197.52 ± 4.198.27±3.08plus-or-minus8.273.088.27\pm 3.088.27 ± 3.08
3 channels (C38=56subscriptsuperscript𝐶8356C^{8}_{3}=56italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = 56)2.95±5.27plus-or-minus2.955.272.95\pm 5.272.95 ± 5.274.74±3.96plus-or-minus4.743.964.74\pm 3.964.74 ± 3.965.60±3.58plus-or-minus5.603.585.60\pm 3.585.60 ± 3.58
2 channels (C28=28subscriptsuperscript𝐶8228C^{8}_{2}=28italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 28)0.61±3.97plus-or-minus0.613.970.61\pm 3.970.61 ± 3.972.24±2.62plus-or-minus2.242.622.24\pm 2.622.24 ± 2.622.41±2.82plus-or-minus2.412.822.41\pm 2.822.41 ± 2.82
1 channels (C18=8subscriptsuperscript𝐶818C^{8}_{1}=8italic_C start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 8)0.92±7.49plus-or-minus0.927.49-0.92\pm 7.49- 0.92 ± 7.490.93±1.15plus-or-minus0.931.150.93\pm 1.150.93 ± 1.150.79±0.95plus-or-minus0.790.950.79\pm 0.950.79 ± 0.95

Appendix D Camelyon17-WILDS: Medical Imaging for Histopathology

ViT-S/8ViT-S/8ViT-B/8ChannelViT-S/8ChannelViT-S/8ChannelViT-B/8
Tied weights
across channels?
Evaluation on in-distribution hospitals
3 channels99.1498.4698.2898.9898.9999.13
2 channels98.6598.4298.2298.5198.6698.73
1 channel97.5998.2497.9897.7198.1498.11
Evaluation on out-of-distribution hospitals
3 channels83.0289.1488.5789.9692.6791.39
2 channels85.1288.7888.3288.1188.2587.17
1 channel87.9787.1986.9387.0488.3087.60

In this section, we introduce another dataset, Camelyon17-WILDS, which was not included in the main paper due to space limitations.

D.1 Dataset

The Camelyon17-WILDS dataset encompasses 455k labeled images from five hospitals. The task involves predicting the presence of tumor tissue in the central region of an image. Although the dataset employs standard RGB channels, these are derived from the hematoxylin and eosin staining procedure, which can vary across hospitals.We adopt the processed version from the WILDS benchmark666https://wilds.stanford.edu.

D.2 Results

Table16 presents our results for Camelyon17, a medical imaging benchmark for histopathology. Given the smaller image size (96 by 96), we employ a patch size of 8 by 8 for the ViT backbone.

Starting with the standard ViT-S/8 (first column), we note that it achieves an accuracy of 99.14 for the in-distribution hospitals. With HCS, it also attains an accuracy of over 97 when using only two or one channels for predictions. However, when evaluated on out-of-distribution hospitals, its 3-channel accuracy drops to 83.02. This is not only lower than its in-distribution performance, but also lower than the accuracy achieved when using only one channel for evaluation in the out-of-distribution hospitals (87.97).We hypothesize that this discrepancy is due to the staining shift across hospitalsGao etal. (2022). The mismatch in color distributions results in out-of-distribution inputs for the first linear patch embedding layer. To test this hypothesis, we experiment with tying the parameters across different channels for the first linear patch embedding layer. As seen in the second column, ViT-S/8 with tied weights, while performing slightly worse in the in-distribution hospitals, performs significantly better in the out-of-distribution setting. We also explore ViT-B/8 but found it exhibited overfitting.

By default, we share the first linear patch embedding layer across different channels for ChannelViT. On the out-of-distribution hospital, ChannelViT-S/8 significantly outperforms ViT-S/8 (92.67 vs. 89.14). We also observe that if we untie the weights for different channels in ChannelViT, the generalization performance degrades.

Backbone
Val Acc.
on RGB
Val Acc.
on R-only
Val Acc.
on G-only
Val Acc.
on B-only
Models trained on three channels (RGB)
Supervised ViT-S/1671.4929.3933.7921.18
DINO + ViT-S/16 + LinearProb72.6264.3465.4661.12
DINO + ChannelViT-S/16 + LinearProb74.3867.4467.8565.97
Expert DINO models pre-trained on only one channel
DINO + ViT-S/16 (R-only) + LinearProb67.76
DINO + ViT-S/16 (G-only) + LinearProb68.09
DINO + ViT-S/16 (B-only) + LinearProb66.65

Appendix E Self-supervised pre-training with ChannelViT

This section delves into the integration of self-supervised learning with ChannelViT.

E.1 DINO

We use the DINO algorithm(Caron etal., 2021) for self-supervised learning. It involves a self-distillation process where the student model, provided with local views of the input image, has to learn from the teacher model which has the global views of the same input image.

We follow most of the the configuration suggested by DINO repository777https://github.com/facebookresearch/dino. Specifically, we pre-train DINO with ViT-S/16 and ChannelViT-S/16 for a total of 100 epochs on ImageNet with a batch size of 256.The AdamW optimizer(Loshchilov & Hutter, 2019) is employed, and the learning rate warm-up phase is set for the first 10 epochs.Given our batch size, the maximum learning rate is set to 0.0005, in line with recommendations from You etal. (2018).The learning rate is subsequently decayed using a cosine learning rate scheduler, with a target learning rate of 106superscript10610^{-6}10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT.Weight decay is applied to all parameters, excluding the biases. The initial weight decay is set to 0.04 and is gradually increased to 0.4 using a cosine learning rate scheduler towards the end of training.The DINO projection head utilized has 65536 dimensions, and batch normalization is not employed in the projection head.The output temperature of the teacher network is initially set to 0.04 and is linearly increased to 0.07 within the first 30 epochs. The temperature is maintained at 0.07 for the remainder of the training. To enhance training stability, the parameters of the output layer are frozen during the first epoch.

E.2 Linear Probing

Upon the completion of the pre-training phase, the parameters of both ViT and ChannelViT are frozen. In alignment with the methodology proposed by Caron etal. (2021), the final four layers of the CLS representation are concatenated to represent the image. Subsequently, a linear classifier is trained on this image representation. The training of the linear classifier is conducted using SGD, with a learning rate of 0.005 and a momentum value of 0.9. The learning rate is decayed in accordance with a cosine annealing scheduler. We train the linear classifier for 100 epochs using the ImageNet training split. Once training is done, we report its Top-1 accuracy on the validation split.

E.3 Results

Table17 showcases our results. It is noteworthy that hierarchical channel sampling is not used during DINO pre-training due to its potential to introduce additional instability to the self-distillation objective. However, we observe that DINO-pretrained ViT inherently provides superior channel robustness. Compared to the supervised ViT-S/16, it achieves 64.34 on the red-only evaluation, which is 34.95 better than its supervised version. Furthermore, the integration of DINO-pretraining with ChannelViT consistently enhances performance across all evaluations, bridging the gap towards the expert DINO model that is pre-trained on each individual channel.

An Image Is Worth 1×16×16 Words (2024)
Top Articles
Latest Posts
Article information

Author: Velia Krajcik

Last Updated:

Views: 6246

Rating: 4.3 / 5 (54 voted)

Reviews: 85% of readers found this page helpful

Author information

Name: Velia Krajcik

Birthday: 1996-07-27

Address: 520 Balistreri Mount, South Armand, OR 60528

Phone: +466880739437

Job: Future Retail Associate

Hobby: Polo, Scouting, Worldbuilding, Cosplaying, Photography, Rowing, Nordic skating

Introduction: My name is Velia Krajcik, I am a handsome, clean, lucky, gleaming, magnificent, proud, glorious person who loves writing and wants to share my knowledge and understanding with you.