1 Introduction

Nowadays, we have witnessed a lot of successes with the application of machine learning techniques in a variety of tasks related to computer vision (Li et al., 2022; Zaidi et al., 2022) and natural language processing (Mridha et al., 2022). Despite many achievements so far, the widely-adopted assumption for most existing methods, i.e., the data are identically and independently distributed in training and testing, may not always hold in actual applications (Zhou et al., 2022; Liu et al., 2022). In the real-world scenario, it is quite common that the distributions between training and testing data may be different, owing to changed environments. For example, acquired histopathological images of breast cancer from different healthcare centers exhibit significant domain gaps (a.k.a., domain shift, see Fig. 1 for more detail) caused by differences in device vendors and staining methods, which may lead to the catastrophic deterioration of the performance (Qi et al., 2020). To address this issue, domain generalization (DG) is developed to learn a model from multiple related yet different domains (a.k.a., source domains) that is able to generalize well on unseen testing domain (a.k.a., target domain).

Fig. 1
figure 1

Histopathological image examples of breast cancer tissue from three different healthcare institutes, including NKI with 626 images, IHC with 645 images, and VGH with 1324 images. There are two different tissue types, including epithelium and stroma. Obvious domain gaps (e.g., the density of tissue and the staining color) can be observed

Recently, researchers proposed several domain generalization approaches, such as data augmentation with randomization (Yue et al., 2019), data generalization with stylization (Verma et al., 2019; Zhou et al., 2021), meta learning-based training schemes (Li et al., 2018; Kim et al., 2021), among which representation learning-based methods are one of the most popular ones. These representation learning-based methods (Balaji et al., 2019) aim to learn domain-invariant feature representation. To be specific, if the discrepancy between source domains in feature space can be minimized, the model is expected to generalize better on unseen target domain, due to learned domain-invariant and transferable feature representation (Ben-David et al., 2006). For instance, an classical contrastive semantic alignment (CSA) loss proposed by Motiian et al. (2017) was to encourage positive sample pairs (with same label) from different domains closer while pulling other negative pairs (with different labels) apart. Dou et al. (2019) introduced the CSA loss which jointly considers local class alignment loss (for point-wise domain alignment) and global class alignment loss (for distribution-wise alignment).

Despite the progress so far, a reliable contrastive semantic loss with point-wise (or local) perspective usually requires sufficient samples on source domains such that diverse sample-to-sample pairs can be constructed (Sohn, 2016; Khosla et al., 2020). For example, Khosla et al. (2020) proposed a supervised contrastive semantic loss with a considerable volume of batch size on large-scale datasets such that decent performance can be guaranteed. Yao et al. (2022) also emphasized the importance of the number of sample-to-sample pairs influenced by data sizes for contrastive-based loss on DG problem. On the other hand, in the eye of distribution-wise (a.k.a., global) alignment between domains (Dou et al., 2019), a consistent distribution measurement (e.g., Kullback–Leibler (KL) divergence) theoretically relies on sufficient samples for the distribution estimation as discussed by (Bu et al., 2018). However, these sufficient samples from multiple source domains may not always be available or accessible in the real world. For example, for the medical imaging data, insufficient sample scenarios either exist in all source domains (e.g., rare diseases inherently have a small volume of data from all healthcare centers (Lee et al., 2022)) or in some source domains (e.g., some specific domains have significantly smaller sample sizes than others, resulting from the differences of the ethnicity (Johnson & Louis, 2022), the demography (Gurdasani et al., 2019), and the privacy-preserving regulation (Can & Ersoy, 2021)). It is therefore necessary to develop reliable and effective semantic alignments from both local and global perspectives in the context of insufficient samples (a.k.a., small-data scenario) based on the source domains, in order to achieve better domain-invariant representations.

In this paper, we propose to learn domain-invariant representation from multiple source domains to tackle the domain generalization problem in the context of insufficient samples. Instead of extracting latent embeddings (i.e., latent points) based on deterministic models (e.g., convolutional neural networks, CNNs), we propose to leverage a probabilistic framework endowed by variational Bayesian inference to map each data point into probabilistic embeddings (i.e., the latent distribution) for domain generalization. Specifically, by following the domain-invariant learning from global (distribution-wise) perspective, we propose to extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD (P-MMD) that can empirically measure the discrepancy between mixture distributions (a.k.a., distributions over distributions), consisted of a serial of latent distributions rather than latent points. From a local perspective, instead of imposing the CSA loss based on pairs of latent points, a novel probabilistic contrastive semantic alignment (P-CSA) loss with kernel mean embedding is proposed to encourage positive probabilistic embedding pairs closer while pulling other negative ones apart. Extensive experimental results on three challenging medical imaging classification tasks, including epithelium stroma classification on insufficient histopathological images, skin lesion classification, and spinal cord gray matter segmentation, show that our proposed method can achieve better cross-domain performance in the context of insufficient data compared with state-of-the-art methods.

2 Related Works

2.1 Domain Generalization with Medical Images

Existing DG methods can be generally categorized into three different streams, namely data augmentation/generation (Yue et al., 2019; Graves, 2011; Zhou et al., 2021), meta-learning (Li et al., 2018; Kim et al., 2021) and feature representation learning (Li et al., 2018; Gong et al., 2019; Xiao et al., 2021). Among these methods, feature representation learning, which aims to explore invariant feature information that can be shared across domains, demonstrates to be a widely adopted method for the problem of DG. For the feature representation learning-based DG method, (Li et al., 2018) proposed to conduct multi-domain alignment in latent space via a multi-domain MMD distance. Gong et al. (2019) leveraged adversarial training to eliminate the domain discrepancy such that domain-invariant representation can be learned in a manifold space. Due to the varieties of imaging protocol (e.g., the choice of image solution for MRI image), device vendors (e.g., Philips or Siemens CT scanners), and patient populations (the race and age group), the acquired imaging data from different medical sites may exist significant domain shift problem (Liu et al., 2021). Dou et al. (2019) proposed a meta-learning framework to perform local and global semantic alignment for medical image classification. A similar design is also adopted by Li et al. (2022) for tissue image classification. Qi et al. (2020) utilized the curriculum learning scheme to transfer the knowledge for histopathological images classification. Li et al. (2020) combined the data augmentation and domain alignment to achieve decent performance on multiple medical data classification tasks. However, these methods may not focus on learning domain-invariant representation on insufficient samples from source domains.

2.2 Probabilistic Neural Networks

Compared with deterministic models, probabilistic neural networks turns to learn a distribution over model parameters, which can integrate the uncertainty in predictive modeling (Kingma et al., 2015; Gal & Ghahramani, 2016). When the data is insufficient, probabilistic models usually can achieve better generalized performance due to its probabilistic property (as an implicit regularization) (Blundell et al., 2015). In the context of insufficient samples, Bayesian neural network (Neal, 2012) (BNN) with variational inference, a representative probabilistic model, not only can improve predictive accuracy as a classifier (Wilson & Izmailov, 2020), but also can build up the quality of low-dimensional embeddings of insufficient data (Mallick et al., 2021), which is a crucial motivation for this paper. Meanwhile, modern analytical approximation techniques (e.g., Variational inference (Blei et al., 2017), empirical Bayes (Krishnan et al., 2020)) can efficiently infer the posterior distribution of model parameters with stochastic gradient descent method, which can integrate BNN with deterministic DNN conveniently.

In Xiao et al. (2021), the authors proposed to consider the uncertainty of a generalizable model based on BNN, where the distances of positive probabilistic embedding pairs and class distribution are minimized via KL measure. Despite the effectiveness, the dissimilar pairs (i.e., negative pairs) are ignored, which may not benefit feature representation learning. Moreover, they only focused on sample similarity while the distribution information is ignored. Instead, our proposed method comprehensively considers both positive and negative probabilistic embedding pairs via a novel distribution-based contrastive semantic loss. Last but not the least, our proposed method highlights the benefit of the BNN for building up the quality of latent embeddings under insufficient sample scenarios.

2.3 Probabilistic Embedding

Compared with deterministic point embeddings, probabilistic embeddings aim to characterize the data with a distribution. Due to its high robustness and effective representation (Nguyen et al., 2017), probabilistic representation has been applied to several fields, such as video representation learning (Park et al., 2022), image representation learning (Oh et al., 2018), face recognition (Shi & Jain, 2019; Chang et al., 2020), speaker diarization (Silnova et al., 2020), and human pose estimation (Sun et al., 2020). Recently, some researchers further leveraged the probabilistic embeddings to bridge the gap between data modalities (Chun et al., 2021; Neculai et al., 2022; Chun, 2023). For example, Chun et al. (2021) found that the probabilistic representation can lead to a richer embedding space for the challengeable relation reasoning between the images and their captions. These probabilistic embedding-based approaches either inherit the inherent distribution property of data (e.g., the multiple frames of a video) or tackle the one-to-many correspondences through distributional representation. Instead, our proposed method imposes the Bayesian neural network to generate probabilistic embedding. As such, the representative capacity of the data in the small-data regime can be enhanced. Moreover, we devise a novel probabilistic MMD to measure the discrepancy between mixture probabilistic embeddings for domain-invariant learning.

Fig. 2
figure 2

A visualized computational process for probabilistic MMD (P-MMD) on two source domains. The same color for samples in different domains denotes the same label

3 Methodology

Preliminary Assume that there are K domains from different collected environments. The samples in each domain can be represented as \(\textbf{X}_{l} = \{\textbf{x}_{l_{1}},\ldots ,\textbf{x}_{l_{n_{l}}}\}\), where \(l \in \mathbb {N}^{+}: \{1,\ldots ,K\}\), \(\textbf{x}_{l_{i}} \in \mathbb {R}^{d \times 1}\) denotes a sample with the d dimension vector in the l-th domain. \(n_{l}\) is the total number of samples in the l-th domain. The corresponding labels of samples \(\textbf{X}_{l}\) in each domain can be denoted as \(\textbf{Y}_{l} = \{\textbf{y}_{l_{1}},\ldots ,\textbf{y}_{l_{n_{l}}}\}\), where \(\textbf{y}_{l_{i}} \in \mathbb {R}^{m \times 1}\) is the form of one-hot encoding with m classes in total. For the setting of domain generalization, the source domain data represented as \(\{\textbf{X}_{l}^{S},\textbf{Y}_{l}^{S}\}_{l=1}^{K}\), can be available in the training phase only, whereas the target domain data, denoted by \(\textbf{X}^{T}\), are only seen in test phase.

Overall We provide a framework that can learn better domain-invariant representation when there is insufficient source domain data. The probabilistic neural network is imposed to enable high-quality and powerful feature representation in the context of insufficient samples. To effectively perform global perspective alignment, a novel probabilistic MMD is proposed to empirically measure the discrepancy between distributions over distributions based on reproducing kernel Hilbert space. We also propose a probabilistic contrastive semantic alignment to adapt probabilistic embeddings with local perspective. The details of our proposed method are discussed as below.

Probabilistic Embedding of Insufficient Data Compared with deterministic models, the probabilistic models can learn a distribution over model weights, which has shown a better capacity to represent latent embeddings under insufficient sample scenario (Mallick et al., 2021). In this work, Bayesian neural network (BNN) (Blei et al., 2017) is utilized to extract the low-dimensional embeddings from high-dimensional inputs. By feeding the inputs into BNN with a parameter \(\textbf{W} \sim p(\mathbf {W)}\), the samples \(\textbf{X}_{l} = \{\textbf{x}_{l_{1}},\ldots ,\textbf{x}_{l_{n_{l}}}\}\) of each domain can be represented by a set of probabilistic embeddings (i.e., latent distributions), i.e., \(p(\textbf{Z}|\textbf{X}_{l}) = \{p(\textbf{z}|\textbf{x}_{l_{1}},\textbf{W}),\ldots ,p(\textbf{z}|\textbf{x}_{l_{n_{l}}},\textbf{W})\}\) where \(\textbf{W} \sim p(\mathbf {W)}\) is sampled stochastically. The variational inference is used to approximate the posterior distribution of \(\textbf{W}\) with the evidence lower bound (ELBO) (more details can be found in “Appendix A.1”). By using Monte Carlo (MC) estimators with T stochastic sampling operations from \(\textbf{W}\), the predictive distribution of each \(p(\textbf{z}|\textbf{x})\) can be an unbiased approximation.

3.1 Distribution Alignment via Probabilistic Maximum Mean Discrepancy

In this section, we introduce an approach to learning domain-invariant representation from a global perspective by minimizing the discrepancy among domains. Among various distribution distance metrics, Maximum Mean Discrepancy (MMD) is widely adopted (Long et al., 2017; Li et al., 2018) which aims to measure the distance between two probability distributions in a non-parametric manner. Specifically, assume that latent embeddings \(\textbf{Z}_{l}\)=\(\{\textbf{z}_{l_{1}},\ldots ,\textbf{z}_{l_{n_{l}}}\}\) and \(\textbf{Z}_{t}\)=\(\{\textbf{z}_{t_{1}},\ldots ,\textbf{z}_{t_{n_{t}}}\}\) are drawn from two unknown distributions \(\mathbb {P}_{l}\) and \(\mathbb {P}_{t}\). The probability measure \(\mathbb {P}\) can be mapped into a reproducing kernel Hilbert space (RKHS) \(\mathcal {H}\) as a element by setting,

$$\begin{aligned} \mu _{\mathbb {P}}:= \mathbb {E}_{\textbf{z}\sim \mathbb {P}}[\phi (\textbf{z})]= \int _{\mathcal {Z}}k(\textbf{z},\cdot )d\mathbb {P}= \mathbb {E}_{\textbf{z}\sim \mathbb {P}}[k(\textbf{z},\cdot )], \end{aligned}$$
(1)

where a reproducing kernel \(k: \mathcal {X} \times \mathcal {X} \rightarrow \mathbb {R}\) and corresponding feature map \(\phi : \mathcal {X} \rightarrow \mathcal {H}\) are defined. Let the kernel k is characteristic such that the map \(\mu :\mathbb {P}\rightarrow \mu _{\mathbb {P}}\) is injective. In this case the MMD can be defined as the distance \(\Vert \mu _{\mathbb {P}_l}-\mu _{\mathbb {P}_k}\Vert _{\mathcal {H}}\) in \(\mathcal {H}\) between mean embeddings and it can be used as a measure of distance between the distributions \(\mathbb {P}_{l}\) and \(\mathbb {P}_{t}\) (Borgwardt et al., 2006; Gretton et al., 2012). The explicit computation of MMD can be derived by unbiased empirical estimation of mean map (Gretton et al., 2012), i.e.,

$$\begin{aligned} \begin{aligned}&{\text {MMD}}\left( \mathbb {P}_{l}, \mathbb {P}_{t}\right) ^{2}= \Vert \frac{1}{n_{l}} \sum _{i=1}^{n_{l}} \phi \left( \textbf{z}_{l_{i}}\right) -\frac{1}{n_{t}} \sum _{j=1}^{n_{t}} \phi \left( \textbf{z}_{t_{j}}\right) \Vert _{\mathcal {H}}^{2} \end{aligned} \end{aligned}$$
(2)

The idea of using MMD for domain generalization has been explored in several works (e.g., (Li et al., 2018; Hu et al., 2020)).

In the probabilistic framework, instead of the individual latent embeddings \(\textbf{z}_{l_{1},\ldots }\), we have latent probabilistic embeddings \(\Pi _{l_{1}}:= p(\textbf{z}|\textbf{x}_{l_{1}},\textbf{W}),\ldots \). For a source domain \(D_l\), we have the associated distribution over distributions \(\mathbb {P}_{l} = \{\Pi _{l_{1}},\ldots ,\Pi _{l_{n_{l}}}\}\). For this scenario, we propose to extend the existing point-based empirical MMD estimate to a distribution-based empirical probability MMD (P-MMD) estimate. P-MMD utilizes empirical estimation by kernels on distributions to measure the discrepancy between mixture distributions \(\mathbb {P}_{l}\) and \(\mathbb {P}_{t}\) under the probabilistic framework.

Specifically, we first represent latent probabilistic embeddings as elements in RKHS \(\mathcal {H}_k\) using the kernel k, that we coin a level-1 kernel in the sequel, e.g., \(\mu _{\Pi _{l_{1}}}:= \mathbb {E}_{\textbf{z}\sim \Pi _{l_{1}}}[\phi (\textbf{z})]=\mathbb {E}_{\textbf{z}\sim \Pi _{l_{1}}}[k(\textbf{z},\cdot )]\), which is an analog to the Eq. (1). The kernel mean embedding \(\mu _{\Pi _{l_{1}}}\) can be regarded as a new feature map for a variety of tasks (Yoshikawa et al., 2014). Here, to enable non-linear learning on distributions, we introduce a level-2 kernel K (Muandet et al., 2012). Consider a level-1 kernel \(\kappa \) on \(\mathcal {H}\) and its reproducing kernel Hilbert space (RKHS) \(\mathcal {H}_{\kappa }\). Define K as

$$\begin{aligned} K(\Pi _{l_{i}}, \Pi _{t_{j}}) = \kappa (\mu _{\Pi _{l_{i}}},\mu _{\Pi _{t_{j}}}) = \langle \psi (\mu _{\Pi _{l_{i}}}),\psi (\mu _{\Pi _{t_{j}}})\rangle _{\mathcal {H}_{\kappa }},\nonumber \\ \end{aligned}$$
(3)

where K and its explicit form on kernel mean embeddings \(\kappa \) are p.d. kernels (Berlinet & Thomas-Agnan, 2011). We define a novel probabilistic MMD (P-MMD) empirical estimation method using the level-2 kernel K:

$$\begin{aligned}&{\text {P-MMD}}(\mathbb {P}_{l}, \mathbb {P}_{t})^{2} = \Vert \frac{1}{n_{l}} \sum _{i=1}^{n_{l}} \psi (\mu _{\Pi _{l_{i}}})-\frac{1}{n_{t}} \sum _{j=1}^{n_{t}} \psi (\mu _{\Pi _{t_{j}}})\Vert _{\mathcal {H_{\kappa }}}^{2}\nonumber \\ {}&= \frac{1}{n_{l}^{2}} \sum _{i=1}^{n_{l}} \sum _{i^{\prime }=1}^{n_{l}} K(\Pi _{l_{i}}, \Pi _{l_{i}^{\prime }}) + \frac{1}{n_{t}^{2}} \sum _{j=1}^{n_{t}} \sum _{j^{\prime }=1}^{n_{t}} K(\Pi _{t_{j}}, \Pi _{t_{j}^{\prime }})\nonumber \\&\quad -\frac{2}{n_{l} n_{t}} \sum _{i=1}^{n_{l}} \sum _{j=1}^{n_{t}} K(\Pi _{l_{i}}, \Pi _{t_{j}}). \end{aligned}$$
(4)

In this work, the level-1 and level-2 kernels, k and K, are both Gaussian RBF kernel due to its impressive performance on a limited amount of distribution data (Muandet et al., 2012). Namely, \(K = K_{Gau}(\Pi _{l_{i}}, \Pi _{t_{j}}) = \kappa (\mu _{\Pi _{l_{i}}}, \mu _{\Pi _{t_{j}}})\) can be represented as

$$\begin{aligned}&\kappa (\mu _{\Pi _{l_{i}}}, \mu _{\Pi _{t_{j}}}) = \exp (-\frac{\lambda }{2}\Vert \mu _{\Pi _{l_{i}}} - \mu _{\Pi _{t_{j}}}\Vert _{\mathcal {H}_{\kappa }}^{2}) \nonumber \\&\quad = \exp (-\frac{\lambda }{2}(\langle \mu _{\Pi _{l_{i}}},\mu _{\Pi _{l_{i}}}\rangle _{\mathcal {H}_{\kappa }})- 2\langle \mu _{\Pi _{l_{i}}},\mu _{\Pi _{t_{j}}}\rangle _{\mathcal {H}_{\kappa }} \nonumber \\&\qquad + \langle \mu _{\Pi _{t_{j}}},\mu _{\Pi _{t_{j}}}\rangle _{\mathcal {H}_{\kappa }})) \nonumber \\&\quad = \exp (-\frac{\lambda }{2}(\frac{1}{m_{l}^{2}} \sum _{i=1}^{m_{l}} \sum _{i^{\prime }=1}^{m_{l}} k(\textbf{z}_{l_{i}},\textbf{z}_{l_{i}^{\prime }}) \nonumber \\&\qquad - \frac{2}{m_{l} m_{t}} \sum _{i=1}^{m_{l}} \sum _{j=1}^{m_{t}} k(\textbf{z}_{l_{i}},\textbf{z}_{t_{j}})) + \frac{1}{m_{t}^{2}} \sum _{j=1}^{m_{t}} \sum _{j^{\prime }=1}^{m_{t}} k(\textbf{z}_{t_{j}},\textbf{z}_{t_{j}^{\prime }}), \end{aligned}$$
(5)

where \(m_{l}\) and \(m_{t}\) are determined by sampling times T. The kernel mean embedding using the level-1 kernel k creates distributions \(\mu (\mathbb {P}_1),\dots ,\mu (\mathbb {P}_N)\) represented by the samples \(\{\mu _{\Pi _{l_1}},\ldots ,\mu _{\Pi _{l_n}}\}\) for \(l=1,\ldots ,N\) respectively in the RKHS \(\mathcal {H}_k\). The underlying strategy of P-MMD is to apply the classic MMD to these distributions (with respect to the kernel \(\kappa \)). To access the effect that the minimization of P-MMD has on the original latent probability distributions across different domains, we recall the following:

Theorem 1

Muandet et al. (2012) Let \(\mathbb {P}_1,\ldots ,\mathbb {P}_N\) be probability distributions and \(\hat{\mathbb {P}}:=\frac{1}{N}\sum _{i=1}^{N}\mathbb {P}_i\). Then the distributional variance given by \(\frac{1}{N}\sum \Vert \mu _{\mathbb {P}_i}-\mu _{\hat{\mathbb {P}}}\Vert \) is 0 iff \(\mathbb {P}_1=\mathbb {P}_2=\ldots =\mathbb {P}_N\).

Corollary 3.1

Li et al. (2018) The upper bound of the distributional variance can be written as

$$\begin{aligned} \frac{1}{K^{2}} \sum _{1 \le i,j \le K} {\text {MMD}}(\mathbb {P}_{i}, \mathbb {P}_{j})^{2}. \end{aligned}$$

In our setting Theorem 1 and Corollary 3.1 along with the fact that k is a characteristic kernel imply the following:

Corollary 3.2

iff all moments of latent distributions \(\Pi _{l}\) associated to points of domain \(D_l\) for \(l=1,\ldots ,N\) are distributed identically across domains, \( \frac{1}{K^{2}} \sum _{1 \le i,j \le K} {\text {P-MMD}}(\mathbb {P}_{i}, \mathbb {P}_{j})^{2}=0 \) holds.

Following Corollary 3.2 we define the following loss function:

$$\begin{aligned} \mathcal {L}_{global} = \frac{1}{K^{2}} \sum _{1 \le i,j \le K} {\text {P-MMD}}(\mathbb {P}_{i}, \mathbb {P}_{j})^{2}. \end{aligned}$$
(6)

Corollary 3.2 implies that as Eq. 6 tends to 0 so does the distance between the distributions of means, variances and higher moments of the distributions \(\Pi _l\) associated to points of different domains.

Remark 1

In Sect. 4.5, we compare the P-MMD approach to simply taking the mean (i.e., first moment) of latent probabilistic embeddings \(\Pi _l\) i.e. taking \({\Pi }_{l}\rightarrow \textbf{m}_{\Pi _{l}}=\mathbb {E}_{\textbf{x}\sim \Pi _{l}[\textbf{x}]}\), and then minimizing the associated “vanilla” MMD. Although this scheme is more efficient computationally over our proposed method, it discards most information about high-level statistics as discussed by Muandet et al. (2017). We empirically verify that our approach has better performance across domains. The visualized computation of P-MMD is shown in Fig. 2.

Although we focus on the scenario of insufficient samples, the computational consumption from Eqs. (4) and (5) may be still prohibitive as the calculation of MMD distance between distributions can scale at least quadratically with the increasing of sample size (especially for image segmentation task), i.e., \(O(n^{2})\) in a domain. Here, by following the linear statistic theory of MMD, the unbiased estimate can be derived by drawing pairs from two domains with replacement, i.e., \({\text {P-MMD}}(\mathbb {P}_{l}, \mathbb {P}_{t})^{2} \approx \frac{2}{n_{l}}\sum _{i=1}^{\frac{2}{n_{l}}}[K(\Pi _{l_{2i}}, \Pi _{l_{2i+1}^{\prime }})+K(\Pi _{t_{2i}}, \Pi _{t_{2i+1}^{\prime }})-K(\Pi _{l_{2i}}, \Pi _{t_{2i+1}})-K(\Pi _{l_{2i+1}}, \Pi _{t_{2i}})\), where assuming \(n_{l}=n_{t}\) for simplicity. Borgwardt et al. (2006) gives proof about the unbiased property of the linear statistic of MMD and shows that statistic power does not be sacrificed too much.

Table 1 Experiment results of epithelium stroma classification of histopathological images

3.2 Probabilistic Contrastive Semantic Alignment

To learn domain-invariant representation from a local perspective, a popular idea is to encourage positive pairs with same label closer together, while pulling other negative ones with different labels further apart (Motiian et al., 2017; Dou et al., 2019). These methods usually measure the Euclidean distance between samples in the embedding space. However, this scheme may not satisfy our probabilistic framework due to its probabilistic embeddings.

To this end, we propose a probabilistic contrastive semantic alignment (P-CSA) loss that can utilize the empirical MMD to measure the discrepancy between probabilistic embeddings. The proposed P-CSA loss \(\mathcal {L}_{local}\) consists of two components, including the positive probabilistic contrastive loss and negative probabilistic contrastive loss. The former aims to minimize the distance between the intra-class distributions from different domains, i.e.,

$$\begin{aligned} \begin{aligned}&\mathcal {L}_{local}^{pos} =\frac{1}{2}\Vert \frac{1}{T} \sum _{i=1}^{T} \phi \left( M_{\Theta }(\textbf{z}_{n_{i}})\right) -\frac{1}{T} \sum _{j=1}^{T} \phi \left( M_{\Theta }(\textbf{z}_{q_{j}})\right) \Vert _{\mathcal {H}}^{2}, \end{aligned}\nonumber \\ \end{aligned}$$
(7)

where \(M_{\Theta }(\cdot )\) denotes the embedding network of metric learning, which will contribute to learn the distance between features better (Dou et al., 2019). Note that \(\textbf{y}_{n} = \textbf{y}_{q}\) needs to be satisfied. Then, the negative probabilistic contrastive loss is denoted by

$$\begin{aligned} \mathcal {L}_{local}^{neg}&= \frac{1}{2} \max [0, \xi - {\text {MMD}}(\Pi _{n},\Pi _{q})^{2}] = \frac{1}{2} \max [0, \xi \nonumber \\&\quad - \Vert \frac{1}{T} \sum _{i=1}^{T} \phi \left( M_{\Theta }(\textbf{z}_{n_{i}})\right) -\frac{1}{T} \sum _{j=1}^{T} \phi \left( M_{\Theta }(\textbf{z}_{q_{j}})\right) \Vert _{\mathcal {H}}^{2}], \end{aligned}$$
(8)

where \(\xi \) is a distance margin that can guarantee an appropriate repulsion range. Note that \(\textbf{y}_{n} \ne \textbf{y}_{q}\) needs to be satisfied.

Model Training Our proposed framework consists of three modules, a BNN-based probabilistic extractor \(Q_{\phi }\), a BNN-based classifier \(C_{\omega }\), and a metric network \(M_{\Theta }(\cdot )\). For the \(Q_{\phi }\), we only add a Bayesian layer with ReLU layer on the bottom of a pretrained deterministic model (e.g., ResNet18 by removing fully-connected layers) by following Xiao et al. (2021). For the \(C_{\omega }\), a Bayesian layer is also introduced to adapt the classification on insufficient sample better. More implement details of BNN can be found in “Appendix A.1”. The structure of \(M_{\Theta }\) is the same as Dou et al. (2019). The images \(\mathcal {X}=\{\textbf{x}_{l_{i}}\}\) conduct T stochastic forward passes on the \(Q_{\phi }\) and \(C_{\omega }\) by MC sampling to obtain probabilistic predicts \(\{\hat{y}_{l_{i}}^{j}\}_{j=1}^{T}\), where the outputs (i.e., probabilistic embeddings) of \(Q_{\phi }\) serve as the inputs for the calculations of \(\mathcal {L}_{global}\) and \(\mathcal {L}_{local}\). The final predicts \(\{\hat{y}_{l_{i}}^{j}\}\) are the expectation of \(\{\hat{y}_{l_{i}}^{j}\}_{j=1}^{T}\). The total objectives can be summarized as follows,

$$\begin{aligned} \mathcal {L}_{total}&= \sum _{l,i}\mathcal {L}_{c}(\hat{y}_{l_{i}},y_{l_{i}}) + {\text {KL}}[q_{\theta }(Q_{\phi })\Vert p(Q_{\phi })] \nonumber \\&\quad + {\text {KL}}[q_{\theta }(C_{\omega })\Vert p(C_{\omega })]+ \beta _{1}\mathcal {L}_{local} + \beta _{2}\mathcal {L} _{global}. \end{aligned}$$
(9)
Table 2 Domain generalization results on skin lesion classification

Discussion The rationale that our proposed method can benefit DG performance on small-data scenario can come from two aspects. First, BNN can be adaptive to insufficient data well compared with deterministic models (Graves, 2011; Mallick et al., 2021). For our proposed method, BNN is introduced to the DG problem in the context of insufficient data, where BNN-based feature extractor and classification layers can take both consistent improvements (see 2nd and 3rd columns in Table 4). More importantly, domain-invariant representation learning under this probabilistic framework from global and local perspectives contributes to more robust cross-domain performance (see 4th and 5th columns in Table 4; Fig. 3 in Sect. 4.5 for the effectiveness of P-MMD).

4 Experiments

We evaluate our proposed method on three medical imaging tasks: (1) epithelium stroma classification, (2) skin lesion classification, (3) spinal cord gray matter segmentation. The used datasets in these tasks are collected from different healthcare institutes and suffer from the domain shift problem in the context of insufficient samples, i.e. insufficient sample scenarios exist either in all or some source domains.

4.1 Epithelium Stroma Classification

Epithelium stroma classification is a fundamental step for the prognostic analysis of the tumor. The public histopathological image datasets for binary classification (epithelium or stroma) are collected from three healthcare centers with different staining types and tissue densitiesFootnote 1: IHC, NKI, and VGH. After the patching operation, IHC, NKI and VGH datasets respectively have 1342,1230, and 1376 patches, which means that the insufficient sample problem exist in all source domains compared with large-scale natural images. We randomly split the data of each source domain into a training set (80%) and a test set (20%) and adopt the leave-one-domain-out strategy for evaluation. The pretrained ResNet18 is introduced as the backbone. The structure of Bayesian layer in \(Q_{\phi }\) is a fully-connected-based BNN with \(512 \times 512\). The structure of Bayesian layer in \(C_{\omega }\) is also a fully-connected-based BNN with \(512 \times 2\). We utilize Adam optimizer with learning rate as \(5 \times 10^{-5}\) for training. The batch size is 32 for each source domain with 4000 iterations. The hyperparameters are selected in a wide range on the validation set, where the \(\beta _{1}\) and \(\beta _{2}\) are 0.1 and 0.7 for the \(\mathcal {L}_{local}\) and the \(\mathcal {L}_{global}\), respectively. For the P-MMD, level-1 and level-2 kernels are the Gaussian RBF kernels (the kernel bandwidth is empirically set to 1 for all kernels) by following Muandet et al. (2012). For the P-CSA loss, the distance margin \(\xi \) is set to 1. By balancing the performance and computational efficiency, the number of MC sampling in each Bayesian layer (a.k.a., T), is set to 10. We also discuss the influence of using different T in the Sect. 4.5. We report the results based on average value and standard deviation in each target domain by running the experiment for five different times.

Table 3 Domain generalization results on gray matter segmentation task

Results Table 1 shows the epithelium stroma classification results on different target domains. We compare with five different approaches. All methods are constructed via a SWAD-based (Cha et al., 2021) framework (where a pretrained ResNet18 as the backbone) with the same training schedule. DeepAll is a deterministic version of our proposed method that is trained on all source domains without any DG strategy in the sequel. Some observations can be summarized as following. First, both our proposed method and BDIL (Xiao et al., 2021) achieve promising results on both IHC and VGH tasks, which may benefit from the positive impact of probabilistic framework on insufficient samples. However, BDIL has an obvious performance drop on the more challenging NKI domain (which has a lower average accuracy compared with other target domains) compared with our proposed method (i.e., 71.89% v.s. 76.71%). This may be due to the introduction of global alignment (i.e., P-MMD, as used by ours), which is more powerful for learning domain-invariant representations than only using local negative pairs (as used by BDIL). Second, compared with contrastive semantic alignment-based method (i.e., MSFA (Dou et al., 2019)), our proposed method achieves a significantly better performance (80.45% v.s. 88.82% on IHC domain), due to a more reliable distribution-based contrastive learning manner on insufficient samples from all source domains. Finally, our proposed method achieves the best average performance with a clear margin compared with other approaches (i.e., LDDG (Li et al., 2020), KDDG (Wang et al., 2021), DNA (Chu et al., 2022), MIRO (Cha et al., 2022), and DSU (Chu et al., 2022)).

4.2 Skin Lesion Classification

Seven public skin lesion datasetsFootnote 2 for seven classes of lesions are collected from various institutes using different dermatoscope types: HAM10000 with 10015 images, UDA with 601 images, SON with 9251 images, DMF with 1212 images, MSK with 3551 images, D7P with 1926 images, and PH2 with 200 images. We can observe that the insufficient sample problem exists in some source domains, especially in PH2 and UDA domains. Following previous work (Li et al., 2020), each domain is randomly split into a 50% training set, 30% test set, and 20% validate set, respectively. As adopted in Li et al. (2020), one domain from DMF, D7P, MSK, PH2, SON and UDA is as target domain and the remaining domains together with HAM10000 as source domains. The pretrained ResNet18 is introduced as the backbone. The structure of Bayesian layer in \(Q_{\phi }\) is a fully-connected-based BNN with \(512 \times 512\). The structure of Bayesian layer in \(C_{\omega }\) is also a fully-connected-based BNN with \(512 \times 7\). The Adam optimizer is employed with learning rate of \(5 \times 10^{-5}\) for 2000 iterations. The hyperparameters are selected in a wide range on the validation set, where the \(\beta _{1}\) and \(\beta _{2}\) are 0.1 and 0.7 for the \(\mathcal {L}_{local}\) and the \(\mathcal {L}_{global}\), respectively. The batch size is 32 for each source domain. The remaining settings are the same with epithelium stroma classification. For the results, the average value and standard deviation are reported by running five times.

Table 4 Ablation study on each component of our proposed method for spinal cord gray matter segmentation task (where “site2” is as the target domain)

Results Table 2 shows the skin lesion classification accuracies on different target domains. Six different methods are utilized for comparison. All approaches are implemented using the SWAD-based framework and a pretrained ResNet18 model as the backbone. One has some observations as following. First, compared with contrastive semantic alignment-based method (e.g., MASF), our proposed method provides more reliable distribution-based pairs, leading to a better performance on insufficient samples from some source domains (MASF:0.2692 v.s. Ours: 0.3781 on DMF domain, where PH2 and UDA as the parts of source domain). Second, although BDIL slightly outperforms our proposed method on D7P domain (0.6204 v.s. 0.6120), our proposed method has a significantly better performance on the challenging DMF domain (0.2985 v.s. 0.3781) and the average results (0.7049 v.s. 0.7328) and other domains. Third, it seems that other baseline methods (e.g., SWAD, DNS, and DSU) impose respective schemes to relieve the impact of insufficient samples from some source domains. For example, DSU achieves the best performance on DMF domain. This may be due to the positive impact of straightforward domain randomization. Yet, it is difficult for DSU to realize consistently better results in multiple domains compared with our proposed method, owing to the lack of explicit domain alignment.

Fig. 3
figure 3

The loss curve of iteration on skin lesion and epothelial-stromal classficaiton tasks. a Global alignment loss. b Local alignment loss

4.3 Spinal Cord Gray Matter Segmentation

The spinal cord gray matter (GM) segmentation (pixel-level classification) is an emergent task for predicting disability via evaluating the atrophy of GM area. The acquired magnetic resonance imaging (MRI) data are collected from four healthcare centersFootnote 3: “site1” with 30 slices, “site2” with 113 slices, “site3” with 246 slices, and “site4” with 122 slices. One can observe that the insufficient sample problem exists in some source domains, especially in “site1” domain. By following previous work (Li et al., 2020), we randomly split the data of each source domain into a training set (80%) and a test set (20%) and adopt the leave-one-domain-out strategy for evaluation. The hyperparameters are selected in a wide range on the validation set, where the \(\beta _{1}\) and \(\beta _{2}\) are 0.01 and 0.001 for the \(\mathcal {L}_{local}\) and the \(\mathcal {L}_{global}\), respectively. The 2D-Unet (Ronneberger et al., 2015) is leveraged as the backbone for all methods. The structures of \(Q_{\phi }\) and \(C_{\omega }\) as well as more experimental details can be found in “Appendix B”. By following Li et al. (2020), the average results in each target domain are reported by running three times.

Table 5 Domain generalization results on MSK dataset by randomly picking same proportion of samples from each source domain
Table 6 Domain generalization results on MSK dataset by randomly picking same number of samples from each class in each domain

Results Table 3 shows the spinal cord GM segmentation results on different target domains. Dice Similarity Coefficient (DSC), Jaccard Index (JI), and Conformity Coefficient (CC) are used to measure the accuracy of obtained segmentation results. True Positive Rate (TPR) and Average Surface Distance (ASD) are introduced from statistical and distance-based perspectives. We compare with five different methods (which have effective segmentation performance). Some observations can be found as follows. First, our proposed method outperforms all baseline methods in terms of average results for five quantitative metrics. Second, compared with the contrastive semantic alignment-based method (e.g., MASF), reliable pixel-level pairs constructed by our proposed method also contribute to improving performance in scenarios with insufficient samples from some source domains. Third, as we can see, our proposed method and DSU achieve the best and second-best performance compared with other baselines, which may be reasonable as they can benefit from the modeling of the data uncertainty in the small-data regime using the BNN or the distribution modeling.

4.4 Ablation Analysis

The spinal cord gray matter segmentation task is utilized to explore the effectiveness of each component for our proposed method, due to its various quantitative metrics. The results can be shown in Table 4. First, we observe that better performance can be achieved by introducing a probabilistic layer compared with the results that using Unet, which reflects the superiority of probabilistic models. Secondly, we observe that by either introducing local or global alignment for domain-invariant information learning, better performance can be achieved compared with the results of only using the probabilistic layer, which shows the effectiveness of the introduced probabilistic feature regularization term. Last but not least, by imposing domain-invariant learning with both local and global views, the performances are further improved, which justifies the effectiveness of our proposed method by jointly considering local and global alignment.

Effectiveness of domain-invariant loss We are also interested in the impacts of domain-invariant losses on different tasks. Thus, we visualize the learning curves of different task in Fig. 3. As we can observe, for the skin lesion (on DMF) and epithelium-stroma (on IHC) classification tasks, the loss curves with iterations reflect the global discrepancy converges faster than the local discrepancy, while the more challenging cross-domain task converges more slowly on global alignment.

Fig. 4
figure 4

The performance comparison between mean embedding method and kernel mean embedding method with different Monte Carlo samples T. For each sub-figure, we use only one alignment operation. a Local alignment. Mean Embedding The mean embedding operation with Euclidean distance is utilized between probabilistic embedding pairs. Kernel Mean Embedding The kernel mean embedding with MMD distance is utilized between probabilistic embedding pairs. b Global alignment. Mean Embedding The mean embedding operation with MMD distance is utilized between domains (as distributions). Kernel Mean Embedding The kernel mean embedding with P-MMD distance is utilized between domains (as distributions over distributions)

4.5 Further Analyses of Our Proposed Method

Results on Different Fractions of Training Data We are interested in how different fractions of training samples influence the final performance based on small-data scenario. To this end, we adopt skin lesion classification task for evaluation since the number of samples in each domain turns out to be imbalance (some domains such as HAM10000 contain sufficient data while the number of samples in some other domains such as PH2 and UDA is insufficient). As such, we can better simulate the scenario that the issue of insufficient samples either exists in all source domains or in partial source domains. We choose MSK dataset with the second-highest number of samples as the target domain and the remaining datasets as the source domains (including PH2 and UDA). Table 5 shows the accuracies. As we can see, compared with DNA (with second-best performance on 100% of samples), our proposed method and BDIL have better average results and lower performance attenuation as the decrease of sample number, due to their probabilistic gain under small data scenarios. Compared with BDIL, our proposed method shows a more robust performance on this challenging small data scenario, due to the integration of local (P-CSA) and global (P-MMD) alignments.

Results on Different Numbers of Samples for Training Data We are also interested in how different numbers of samples per class influence the final performance based on small-data scenarios. To this end, we also adopt skin lesion classification task for evaluation. Specifically, we randomly draw T samples from each class in a source domain to represent this domain for training. Here, we set T to 20,30, and 40, respectively, in different experiments.

The results can be found in Table 6. As we can see, our proposed method achieved the best performance among all settings compared with all baseline methods. Meanwhile, it seems that the Bayesian-based DG approaches (e.g., our proposed method and BIDL) have better performance compared with other methods, which is reasonable as the BNN can be adaptive to the small data scenario well. Especially, our proposed method has around 5% improvements compared with the second-best method when T is set to smaller, i.e., 20.

Kernel Mean Embedding (i.e., P-MMD) v.s. Mean Embedding We explore the effect of different schemes for probabilistic embeddings. A straightforward method is to represent probabilistic embeddings with the expectation, which is called the “Mean Embedding”. Then, a probabilistic embedding can be considered as a latent point and the MMD can be used to measure the discrepancy between distributions consisting of latent points.

Table 7 Out-of-domain accuracies (%) on PACS based on ResNet50
Table 8 Out-of-domain accuracies (%) on OfficeHome based on ResNet50
Table 9 Out-of-domain accuracies (%) on VLCS based on ResNet50

For the mean embedding-based \(\mathcal {L}_{global}\), the computational process of this scheme for MMD distance can be formulated as

$$\begin{aligned} {\text {MMD}}(\mathbb {P}_{l}, \mathbb {P}_{t})^{2}= & {} \left\| \frac{1}{n_{l}} \sum _{i=1}^{n_{l}} \varphi (\mathbb {E}[\Pi _{l_{i}}])-\frac{1}{n_{t}} \sum _{j=1}^{n_{t}} \varphi (\mathbb {E}[\Pi _{t_{j}}])\right\| _{\mathcal {H}}^{2}. \qquad \end{aligned}$$
(10)

Eqation (10) can be further constructed a global alignment loss \(\mathcal {L}_{global}\). For the local alignment loss \(\mathcal {L}_{local}\), the Euclidean distance can be used to compute the distance between latent points, which is similar to the original CAS loss in Motiian et al. (2017). For the positive pairs with the same label, the mean embedding-based positive contrastive loss can be represented as

$$\begin{aligned} \mathcal {L}_{local}^{pos} = \frac{1}{2} \left\| \frac{1}{T} \sum _{i=1}^{T} \mathbb {E}\left[ M_{\Theta }(\textbf{z}_{n_{i}})]\right) -\frac{1}{T} \sum _{j=1}^{T} \mathbb {E}\left[ M_{\Theta }(\textbf{z}_{q_{j}})]\right) \right\| _{2}^{2}. \nonumber \\ \end{aligned}$$
(11)

\(M_{\Theta }(\cdot )\) denotes the embedding network of metric learning. For the negative pairs with different labels, the negative contrastive loss \(\mathcal {L}_{local}^{neg}\) is denoted by

$$\begin{aligned} \frac{1}{2} \max \left[ 0, \xi - \left\| \frac{1}{T} \sum _{i=1}^{T} \mathbb {E}[M_{\Theta }(\textbf{z}_{n_{i}})])-\frac{1}{T} \sum _{j=1}^{T} \mathbb {E}[M_{\Theta }(\textbf{z}_{q_{j}})])\right\| _{2}^{2}\right] .\nonumber \\ \end{aligned}$$
(12)

As a result, a mean embedding-based contrastive loss with the view of local alignment can be calculated as

$$\begin{aligned} \mathcal {L}_{local} = \mathcal {L}_{local}^{pos} + \mathcal {L} _{local}^{neg}. \end{aligned}$$
(13)

Instead, we can observe from Fig. 2 that our proposed method induces a level-2 kernel-based MMD with empirical estimation for probabilistic embeddings. Specifically, our proposed scheme can preserve higher moments of a probabilistic embedding via nonlinear level-1 kernel (see the fourth component in Fig. 2). Moreover, by introducing a level-2 kernel, the similarities between probabilistic embeddings also can be measured based on their own moment information (see the last component in Fig. 2). Benefiting from these virtues, the proposed probabilistic MMD can accurately capture the discrepancy between mixture distributions via an extended empirical MMD fashion.

Here, we validate the effectiveness of different schemes on the NKI task of Epithelium Stroma classification in each aligned view. The experimental settings are similar for different methods. The experimental results can be found in Fig. 4. As we can see, our proposed method achieves consistent improvements in each alignment method with different Monte Carlo samples, which may be reasonable as the kernel mean representation can preserve many statistical components due to the injective property. Second, when the number of MC samples is 10, we can observe an obvious margin in global alignment, which refers to the computation between mixture distributions.

Fig. 5
figure 5

The performance of our proposed model on the NKI task of Epithelium Stroma classification with different Monte Carlo samples T

Influence of Different Number of MC Samples It is much important to balance the number of Monte Carlos samples and the computational efficiency. On the one hand, the property of probabilistic embeddings can be affected by the Monte Carlos sampling. On the other hand, too many Monte Carlos samples may suffer from the heavy computational cost. Xiao et al. (2021) suggested that the distributional property and computational cost are both acceptable for the computation of the KL divergence when the number of Monte Carlos samples is chosen appropriately, the practical performance of our proposed method needs to be explored. We conduct the experiments on the NKI task of Epithelium-Stromal classification with different Monte Carlo samples T.

The results are shown in Fig. 5. As we can see, if the number of Monte Carlo samples is too small, it is difficult to capture the property of distribution for probabilistic embeddings. As the increase of T, there is an obvious improvement for our proposed method. Interestingly, the performance is gradually saturated. As a result, by balancing the number of Monte Carlos samples and the computational efficiency, the number of Monte Carlos samples T in each Bayesian layer is set to 10.

Table 10 Ablation study on epithelium stroma classification in histopathological images

4.6 Scalability to Benchmark Datasets

Here, we introduce three DG benchmarks, namely PACS (Art: 2048 images, Cartoon: 2344 images, Photo: 1670 images, Sketch: 3929 images), OfficeHome (has 15588 samples with 65 classes from four domains) and VLCS (has 10729 samples with 5 classes from four domains), for comparison. Compared with some large-scale benchmarks (e.g., DomainNet and Wilds), these three datasets are more appropriate to explore the effectiveness of different DG models under the scenario of insufficient samples. performance. We adopt pretrained ResNet50 as the backbone for all benchmarks. The structure of the overall framework is similar with the model mentioned in lesion skin classification. Our proposed method as well as baseline methods are all based on DomainBed, where the holdout fraction (the proportion of validation set) rate for DomainBed is set to 0.2 for all methods. A domain is the target domain and the remaining domains are the source domain for training. The testing is on the overall data of a target domain.

Here, our proposed method is optimized by Adam optimizer with learning rate as \(5\times 10^{-5}\). The batch size for each source domain is 32. The training steps are set to 20000 for PACS and OfficeHome, and 2000 for VLCS. By following the SWAD framework, the training process will be stopped for our proposed method when the validation loss increases significantly. The hyperparameters are selected in a wide range on the validation set. For the \(\mathcal {L}_{local}\) and \(\mathcal {L}_{global}\), the \(\beta _{1}\) and the \(\beta _{2}\) are set to 0.1 and 1 for all benchmark datasets. Other hyperparameters such as kernel function, kernel bandwidth and distance margin are similar to the settings mentioned before.

The experimental results on PACS, OfficeHome and VLCS can be shown in Tables 7, 8, and 9. As we can see, compared with domain-invariant-based approaches (e.g., DANN), our proposed method has a significant improvement due to the introduction of a probabilistic framework. Our proposed method also outperforms the data augmentation-based approach (e.g., Mixstyle, Manifold Mixup, and CutMix). Although data generation methods (e.g., MixStyle) can effectively tackle the insufficient sample problem via additional generative samples, the lack of effective domain-invariant learning may hamper the improvement of the performance. Our proposed method achieves better performances compared with feature disentanglement-based (e.g., pAdaI). Last but not least, compared with the state-of-the-art domain generalization method (e.g., MIRO), our proposed method achieves comparable performance on small-scale benchmark datasets, which demonstrates the scalability of our proposed method in a small-data regime.

5 Discussion

Limitation (1) MC sampling Our probabilistic embeddings derive from the distribution over the model weights. Similar to widely-used BNN-based models (Blundell et al., 2015; Mallick et al., 2021; Xiao et al., 2021), the predictive distribution of probabilistic embeddings needs to be approximated by MC sampling. This may result in more computational consumption compared with the original deterministic model-based scheme. Although heavy computational consumption can be alleviated in our settings (i.e., DG problem in the context of small data), the computational cost needs to be reduced more on very large-scale datasets in the future. (2) Approximate posterior Although the mean-field variational inference (MFVI) (we used) is effective in rendering an approximate posterior of BNN. The potential amortization gap due to the fully factorized Gaussian assumption of MFVI would limit further performance improvement on very challenging datasets (Cremer et al., 2018). This limitation will be explored by replacing it with more expressive approximate posteriors (such as normalizing flows) in the future.

Tradeoff In this paper, we carefully make a tradeoff between effectiveness and computational complexity. Specifically, due to the introduction of probabilistic embeddings for the DG problem in the context of insufficient data, the probabilistic MMD is proposed based on the level-2 kernel for more high-level statistics, which may be more complex than other baseline methods with “vanilla” MMD. However, it shows improved effectiveness in addressing the problem of global semantic alignment between domains, consisting of a series of probabilistic embeddings. Other simpler methods may not achieve the same level of performance using just the first moment (i.e., the mean embedding). Moreover, we adopt some strategies, such as the unbiased estimate of MMD with linear complexity and the number selection of appropriate MC sampling, to offset the extra computational complexity compared with other baseline methods.

Usage of our proposed method Compared with a deterministic framework, where a source domain as a distribution consists of a set of point embeddings, a source domain generated by our probabilistic framework includes a set of distributions, \(\textit{i.e.,}\) the so-called distribution over distributions or the probability of probability. The vanilla MMD may not directly cope with this situation. Instead, our proposed P-MMD leverages the level-2 kernel to measure the discrepancy between mixture distributions based on the empirical MMD framework and preserve most information about high-level statistics.

In this paper, we focus on the DG problem in the context of insufficient data. Especially, for the medical imaging data, insufficient sample scenarios either exist in all source domains or in some source domains. If the number of training samples from all source domains is sufficient, it is reasonable to choose another DG method. For instance, the Domain-Net datasets for natural images have six source domains, whereas the source domain (“clipart”) with the smallest number of training samples still has 50,000 training samples. Moreover, our proposed method shows good scalability on benchmark datasets with a larger number of samples, although it is designed based on the DG scenario of insufficient data.

Gains of Probabilistic Embeddings To further show the gain of introducing probabilistic embeddings for insufficient data, we conducted an ablation study on the Epithelium Stroma Classification and Skin Lesion Classification. We only add a Bayesian layer (for probabilistic embeddings) with ReLU layer on the bottom of a deterministic baseline model (pre-trained ResNet-18).

The results can be found in Table 10. As we can see, consistent improvements can be observed in these small-data tasks due to the introduction of probabilistic embeddings.

6 Conclusion

In this work, we address the DG problem in the context of insufficient data, which can occur in all or some source domains. To this end, we introduce a probabilistic framework into the DG problem to derive probabilistic embeddings (which can be adaptive to insufficient samples better compared with deterministic models) for domain-invariant learning. Under this probabilistic framework, an extension of MMD called P-MMD is proposed for measuring the distribution over distributions. Moreover, a probabilistic CSA loss is proposed for local alignment. Extensive experiments on insufficient cross-domain medical imaging data show the effectiveness of this method.