ss2
ss2
Abstract— Identification of protein-protein interactions (PPIs) have generated vast amounts of validated interaction data.
helps derive cellular mechanistic understanding, particularly in These experimentally determined interactions have been
the context of complex conditions such as neurodegenerative systematically collected in comprehensive databases such as
disorders, metabolic syndromes, and cancer. Large Language
STRING [3], BioGRID [4], and IntAct [5], creating valuable
arXiv:2502.06173v1 [cs.LG] 10 Feb 2025
Uncertainty-aware LoRA
Fig. 1. Illustration of our uncertainty-aware low-rank adaptation approach for pre-trained LLMs in protein-protein interaction prediction.
the groundwork for safer, more reliable, and more informative This approach improves generalizability and model calibration
computational tools in precision medicine. by incorporating parameter uncertainty into predictions.
III. M ETHODOLOGY strength. As a result, parameters from any previously trained
A. LoRA Ensemble model that used a reasonable weight decay setting (for
example, via AdamW with its weight decay) can be directly
We employ an ensemble of LoRA models – LoRA Ensemble
reused.
[50], [51] as an efficient strategy for uncertainty quantification
Next, to obtain an approximate posterior around θMAP ,
in LLMs. Traditional deep ensembles yield better predictive
Laplace method proceeds with a second-order Taylor expan-
performance and uncertainty estimation by training multiple
sion of the log-joint L(D, θ) = log p(y, X, θ) around θMAP .
models independently, but applying this directly to LLMs
Hence, by ignoring the higher-order terms, this yields
is often infeasible due to high memory and computational
costs. 1
L(D, θ) ≈ L(D, θMAP ) + (θ − θMAP )⊤ H (θ − θMAP ),
To circumvent these issues, each LoRA Ensemble member 2
fine-tunes the same pre-trained backbone W0 with a low- where the first-order term zeros out due to the zero gradient
rank trainable modification ∆Wm = Bm Am , where Bm ∈ at θMAP and H is the Hessian of the log-joint at θMAP ,
Rd1 ×r and Am ∈ Rr×d2 have significantly fewer parameters ∇2θ L(D, θ)|θMAP . Under this quadratic approximation,
than the full model, rm ≪ min(d1 , d2 ). These adapters
p(θ | D) ≈ N θ|θMAP , H −1 .
are trained independently and in parallel, ensuring diverse (1)
solutions–{W1 , W2 , . . . , WM }. The ensemble prediction is
computed by averaging outputs across M ensemble members. Hence, Laplace approximation turns out to be post-hoc
For a given input xnew , if ynew m
represents the prediction Bayesian inference method which requires the additional
from the m-th ensemble member, the final ensemble output step of computing the H −1 matrix at θMAP . In practice,
(for continuous outcomes) is given by: computing the full Hessian H can be expensive, especially
for large models due to quadratic complexity with respect to
M
1 X m the number of model parameters. We use the positive semi-
pens (ynew |xnew ) = p(ynew |xnew , Wm ).
M m=1 definite Fisher information matrix to circumvent the issue of
the potentially indefinite Hessian, which arises when local
This approach retains the benefits of ensembling–improved convexity conditions fail to hold in large machine learning
accuracy, calibration, and robustness–while preserving effi- models. Accordingly, the Fisher information is defined by
ciency by reusing the frozen backbone and only training
N
lightweight LoRA adapters. X
Eŷ∼P(y|fθ (xn )) GG⊤
F (θ) =
B. Bayesian Low-Rank Adaptation n
Despite the availability of scalable posterior inference where G = ∇θ P(ŷ|fθ (xn )) represents the gradient and the
methods like variational inference [38], a fully Bayesian expectation above is over the model’s output distribution.
treatment of LLMs remains computationally prohibitive. Next, in order to estimate the Fisher information in a
Instead, limiting Bayesian inference to LoRA parameters manner that is both tractable and memory-efficient, we
offers a more tractable means of capturing uncertainty in employ a Kronecker-Factored Approximate Curvature (K-
model predictions. However, even Markov chain Monte FAC) approach similar to [46]. In K-FAC, we treat Fisher as
Carlo approaches can become excessively costly for inferring a block-diagonal matrix for each linear layer and factorize
posteriors over the millions of LoRA parameters involved each block into two smaller matrices. For the l-th linear layer,
in large-scale models. As a practical compromise, Bayesian we compute Fisher block Fl using that layer’s input activations
LoRA [46] employs the Laplace approximation to estimate al−1 and log-likelihood gradients with respect to layer’s pre-
the posterior over these low-rank parameters, centered around activation output sl denoted by Gsl = ∇sl log P(y|X, θ).
their maximum a posteriori (MAP) estimate together with Hence the expression is
covariance equaling the Fisher information matrix [52].
N
To this end, let θ denote the trainable LoRA parameters X
EP(y|fθ (xn )) al−1 a⊤ ⊤
with a prior distribution of N (0, λ−1 I). The Laplace approx- Fl = l−1 ⊗ EP(y|fθ (xn )) Gsl Gsl
n=1
imation first calculates MAP estimate which is equivalent to (2)
maximizing the log-joint, log P(y, X, θ) This approach avoids storing the full, dense Hessian, thereby
θMAP = argmax log P(y, X, θ) reducing computational overhead. By applying K-FAC to the
θ LoRA parameters, we maintain a compact representation of
= argmax log P(y|X, θ) + log P(θ) uncertainty while keeping the overhead similar to standard
θ
training. However, in Equation (2), the first expectation grows
λ
= argmax log P(y|X, θ) + ||θ||22 + const with the square of the layer’s input width, while the second
θ 2 grows with the square of the output width. Because LoRA
where X represents the model inputs. The term associated adapters alternate between wide-input-narrow-output configu-
with log of the prior distribution provides us L2 -regularization ration and vice versa, one of these expectations can become
on the trainable parameters. We can incorporate this in especially large. To address this, we use an incremental SVD
frequentist model training via weight decay term with λ/2 to factorize the large matrix into two new low-rank factors
thereby saving memory. Further mathematical details are by prompting models to assess whether proteins interact
provided in Appendix E of [46]. in the corresponding conditions and collectively contribute
Once we infer the approximate posterior which is Gaussian to advancing computational models for predicting PPIs
as per Equation 1, we can linearize the model predictions across different disease contexts, enhancing our understanding
around the MAP estimate θMAP [53]. For a test input xnew , of disease-specific interaction networks. These tasks are
⊤ formalized into binary (True/False) classification problems
fθ (xnew ) ≈ fθMAP (xnew ) + ∇θ fθ (xnew ) θ θ − θMAP . as illustrated in Fig. 1. Furthermore, each dataset is divided
MAP
Because this expression is linear in θ, integrating out the into 80% for training and 20% for testing, with all models
Gaussian posterior over θ yields a Gaussian predictive evaluated on the fixed test set in each PPI prediction task. We
distribution for the logits: refer readers to [31], for additional details and exploratory
analyses of these datasets.
fθ (xnew ) ∼ N y|fθMAP (xnew ), Λ ,
Implementation Details. In all experiments, we construct
where Λ = ∇θ fθMAP (xnew )⊤ H −1 ∇θ fθMAP (xnew ). a LoRA ensemble using three individually fine-tuned LoRA
Finally to efficiently sample from this predictive posterior, learners. The LoRA matrices B are initialized to zero, while
we use the Cholesky decomposition of Λ = LL⊤ . Then, the entries of A follow a Kaiming Uniform initialization
[57]. Optimization is performed using the AdamW optimizer
ŷ = fθ (xnew ) = fθMAP (xnew ) + Lz, z ∼ N 0, I . with a learning rate of 1 × 10−4 , default hyperparameters,
and a total of four training epochs. The batch size is set to
This linearized predictive step, combined with a Gaussian
4 for the ND-PPI and M-PPI cases and 16 for the C-PPI
approximate posterior, yields efficient uncertainty estimates
case, following [31]. For Bayesian LoRA, the prior precision
in Bayesian LoRA approach for downstream tasks.
λ is fixed at 0.1. Lastly, LoRA is applied to the queries,
IV. E XPERIMENTAL R ESULTS values, and output layer across all methods, with specific
In this section, we assess the performance of two hyperparameters set to r = 16, α = 32, a dropout rate of
uncertainty-aware LoRA adaptations—LoRA Ensemble and 0.05, and a maximum sequence length of 50.
Bayesian LoRA—applied to LLaMA-3-8B and BioMedGPT- Results. The results for ND-PPI, M-PPI, and C-PPI tasks are
LM-7B models on publicly available protein-protein interac- summarized in Tables I, II, and III, respectively.
tion datasets. As a baseline, we include a single LoRA model In the ND-PPI prediction task (Table I), we demonstrate
trained in a deterministic manner. All LoRA-based approaches that the LoRA ensemble achieves the highest predictive
were implemented using the PEFT library [54], with each accuracy among all models in both LLM settings and has the
configuration run three times using different random seeds. lowest NLL in the LLaMA-3 fine-tuning case. Conversely,
We evaluate model performance and robustness by accuracy Bayesian LoRA demonstrates the best calibration in both
(Acc), negative log-likelihood (NLL), and expected calibration scenarios, exhibiting the lowest ECE and achieving the lowest
error (ECE) on the test sets. Additional details on the NLL and NLL in the BioMedGPT fine-tuning case. Lastly, the LoRA
ECE metrics can be found in Appendix V-A. Furthermore, we ensemble reports the highest values for specificity, precision,
report Matthews Correlation Coefficient (MCC), specificity F1-score, MCC, and AUROC among all models. In the M-PPI
(Spec.), precision (Prec.), F1-score, and Area under Receiver prediction task (Table II), we show that the LoRA ensemble
Operating Characteristic curve (AUROC) over test sets for a achieves the highest predictive accuracy and lowest NLL in
comprehensive view of predictive capabilities. Final metrics both LLM scenarios, while also attaining the lowest ECE
are summarized by the mean and standard deviation across in the LLaMA-3 case. Conversely, Bayesian LoRA achieves
three independent runs. the best calibration in the BioMedGPT case and the highest
PPI Datasets. The datasets analyzed here explore PPIs related specificity in both scenarios. Finally, the LoRA ensemble
to various diseases, providing valuable insights into their outperforms all the models by achieving best precision, F1-
underlying mechanisms. The Neurodegenerative diseases PPI score, MCC, and AUROC values.
(ND-PPI) dataset, sourced from the study [55], focuses on In the C-PPI prediction task (Table III), we demonstrate that
neurodegenerative diseases and examines a network of 820 the LoRA ensemble once again achieves the highest predictive
proteins forming 11,762 interactions, evenly split between accuracy and lowest NLL in both settings, while also attaining
positive and negative pairs. The dataset is structured to the lowest ECE in the BioMedGPT scenario. Bayesian LoRA
assess whether specific protein pairs interact in the presence matches the best predictive accuracy in the BioMedGPT case
of neurodegenerative conditions. Similarly, the metabolic and achieves the lowest ECE in the LLaMA-3 case. In the
disorders PPI (M-PPI) dataset, also from [55], investigates LLaMA-3 setting, the LoRA ensemble reports the highest
metabolic disorders, encompassing 1,063 proteins and a values for specificity, precision, F1-score, MCC, and AUROC
total of 10,262 interactions. The cancer PPI (C-PPI) dataset, among all models. Additionally, it achieves the best specificity
derived from the study [56], consisted of 933 positive and in the BioMedGPT case. Notably, both Bayesian LoRA and
1,308 negative interactions. To ensure balanced representation, the LoRA ensemble attain the best precision, F1-score, and
this dataset was curated to create an equal-sized collection of MCC values in the BioMedGPT case. Lastly, all three models
1,866 total interactions. These datasets have been evaluated yield identical AUROC values in the BioMedGPT case.
TABLE I
ND-PPI P REDICTION : T HE BEST RESULTS AMONG ALL COMPARED METHODS FOR A GIVEN LLM PRE - TRAINED MODEL ARE HIGHLIGHTED IN BOLD .
A LL METRICS ARE REPORTED AS MEANS WITH STANDARD DEVIATIONS IN SUBSCRIPT, BASED ON THREE INDEPENDENT RUNS .
LLM Model Methods Acc (↑) NLL (↓) ECE (↓) Spec. (↑) Prec. (↑) F1 (↑) MCC (↑) AUROC (↑)
Single LoRA 86.510.54 0.3620.036 0.0950.011 0.9660.005 0.8810.005 0.8630.006 0.7450.011 0.9530.004
Llama-3 LoRA Ensemble 88.700.62 0.3020.002 0.0880.005 0.9730.001 0.8990.004 0.8860.006 0.7850.011 0.9640.001
Bayesian LoRA 86.510.20 0.3170.003 0.0520.021 0.8480.033 0.8660.002 0.8650.002 0.7320.003 0.9440.004
Single LoRA 85.442.16 0.5390.053 0.1190.025 0.9630.016 0.8740.011 0.8520.023 0.7260.034 0.9440.005
BioMedGPT LoRA Ensemble 88.001.19 0.3630.049 0.0870.022 0.9650.008 0.8920.007 0.8790.013 0.7710.019 0.9560.001
Bayesian LoRA 86.820.60 0.3200.012 0.0330.007 0.8690.015 0.8680.006 0.8680.006 0.7370.012 0.9370.003
TABLE II
M-PPI P REDICTION : T HE BEST RESULTS AMONG ALL COMPARED METHODS FOR A GIVEN LLM PRE - TRAINED MODEL ARE HIGHLIGHTED IN BOLD . A LL
METRICS ARE REPORTED AS MEANS WITH STANDARD DEVIATIONS IN SUBSCRIPT, BASED ON THREE INDEPENDENT RUNS .
LLM Model Methods Acc (↑) NLL (↓) ECE (↓) Spec. (↑) Prec. (↑) F1 (↑) MCC (↑) AUROC (↑)
Single LoRA 85.820.26 0.3980.016 0.0840.006 0.9080.036 0.8630.007 0.8580.002 0.7210.009 0.9370.002
Llama-3 LoRA Ensemble 87.450.16 0.3080.013 0.0510.010 0.9220.016 0.8780.003 0.8740.002 0.7520.004 0.9500.002
Bayesian LoRA 83.411.17 0.3740.005 0.0710.018 0.9320.038 0.8500.003 0.8320.013 0.6830.013 0.9250.004
Single LoRA 83.680.54 0.5420.026 0.1130.009 0.7990.047 0.8400.002 0.8370.006 0.6780.007 0.9250.005
BioMedGPT LoRA Ensemble 87.141.39 0.3540.028 0.0620.010 0.8640.036 0.8720.013 0.8710.014 0.7440.027 0.9410.005
Bayesian LoRA 83.290.57 0.3850.015 0.0370.018 0.8880.033 0.8380.003 0.8320.007 0.6710.007 0.9050.004
TABLE III
C-PPI P REDICTION : T HE BEST RESULTS AMONG ALL COMPARED METHODS FOR A GIVEN LLM PRE - TRAINED MODEL ARE HIGHLIGHTED IN BOLD . A LL
METRICS ARE REPORTED AS MEANS WITH STANDARD DEVIATIONS IN SUBSCRIPT, BASED ON THREE INDEPENDENT RUNS .
LLM Model Methods Acc (↑) NLL (↓) ECE (↓) Spec. (↑) Prec. (↑) F1 (↑) MCC (↑) AUROC (↑)
Single LoRA 96.620.62 0.0940.011 0.0330.002 0.9730.016 0.9670.007 0.9660.006 0.9320.013 0.9960.002
Llama-3 LoRA Ensemble 97.860.00 0.0660.005 0.0290.005 0.9800.000 0.9790.000 0.9790.000 0.9570.000 0.9970.000
Bayesian LoRA 96.971.24 0.0850.020 0.0270.002 0.9630.012 0.9700.012 0.9700.012 0.9400.025 0.9960.002
Single LoRA 97.680.82 0.0590.011 0.0250.011 0.9760.006 0.9770.008 0.9770.008 0.9540.016 0.9980.001
BioMedGPT LoRA Ensemble 98.400.54 0.0520.000 0.0210.002 0.9800.000 0.9840.005 0.9840.005 0.9680.011 0.9980.000
Bayesian LoRA 98.400.54 0.0640.005 0.0310.001 0.9760.006 0.9840.005 0.9840.005 0.9680.011 0.9980.001