Post

Building a Protein Variant Classifier with ESM2 and Multi-GPU Training

Building a Protein Variant Classifier with ESM2 and Multi-GPU Training

Introduction

In the field of clinical genomics, accurately predicting whether a specific genetic variant is pathogenic (disease-causing) or benign is a critical challenge. Recently, I worked on a project to develop a deep learning model that classifies protein variants as either Gain-of-Function (GOF) or Loss-of-Function (LOF) using ESM2 (Evolutionary Scale Modeling), a state-of-the-art protein language model.

Challenge 1: Metric Selection for Clinical Use

Before diving into the model, I had to evaluate existing pathogenicity predictors. The dataset contains 107 patients, each with multiple variants where only a few are pathogenic (LABEL=1).

The Problem: Class Imbalance

The data is highly imbalanced—most variants are benign (LABEL=0), only a few are pathogenic (LABEL=1). This makes metric selection critical.

MetricFormulaProblem with Imbalanced Data
Accuracy(TP+TN) / TotalPredicting all as benign gives high accuracy
AUROCArea under TPR-FPR curveCan look good even with poor precision

Why AUROC Alone Is Not Enough

AUROC measures discrimination across all thresholds. A model with AUROC=0.94 sounds great, but:

  • At what threshold does it achieve good Precision and Recall?
  • In clinical diagnostics, False Negatives are dangerous (missing a pathogenic variant)

Metrics for Clinical Pathogenicity Prediction

For this problem, I focused on both classification metrics and ranking metrics:

Classification Metrics (Binary)

MetricFormulaClinical Importance
Recall (Sensitivity)$\frac{TP}{TP + FN}$Must be high: we cannot miss pathogenic variants
Precision (PPV)$\frac{TP}{TP + FP}$Reduces unnecessary follow-up tests
F1 Score$\frac{2 \times Precision \times Recall}{Precision + Recall}$Balances both for imbalanced data

Ranking Metric (Patient-Centric)

Since each patient has multiple variants and we want the pathogenic variant to be ranked high:

MetricFormulaClinical Importance
Top-K Recall$\frac{\text{# patients with causal variant in top K}}{\text{# total patients}}$Measures how often the pathogenic variant appears in the top K predictions

The formal definition:

\[\text{Top-K Recall} = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1}[\text{rank}(v_i) \leq K]\]

Where:

  • $N$ = number of patients
  • $v_i$ = the pathogenic variant for patient $i$
  • $\text{rank}(v_i)$ = position when variants are sorted by prediction score (descending)
  • $\mathbb{1}[\cdot]$ = indicator function (1 if true, 0 if false)

Why Recall Is Critical

In medical diagnostics, a False Negative (predicting benign when actually pathogenic) means:

  • Patient doesn’t receive treatment
  • Disease progresses undetected

Therefore, Recall must be prioritized, even at the cost of some False Positives.

Evaluation Framework

I evaluated each predictor (A, B, C) with both classification and ranking metrics:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sklearn.metrics import precision_recall_curve, roc_auc_score
import numpy as np

def evaluate_predictor(y_true: np.ndarray, y_scores: np.ndarray) -> dict:
    """Evaluate predictor with classification metrics."""
    auroc = roc_auc_score(y_true, y_scores)
    
    precisions, recalls, thresholds = precision_recall_curve(y_true, y_scores)
    f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
    best_idx = np.argmax(f1_scores)
    
    return {
        "auroc": auroc,
        "best_f1": f1_scores[best_idx],
        "recall_at_best_f1": recalls[best_idx],
        "precision_at_best_f1": precisions[best_idx],
    }
# end def

def compute_top_k_recall(df: pd.DataFrame, score_col: str, k: int) -> float:
    """Compute Top-K Recall per patient."""
    hits = 0
    for patient_id, group in df.groupby("Patient_ID"):
        sorted_group = group.sort_values(score_col, ascending=False)
        top_k_labels = sorted_group.head(k)["LABEL"].values
        if 1 in top_k_labels:
            hits += 1
        # end if
    # end for
    return hits / df["Patient_ID"].nunique()
# end def

Results

Classification Metrics

PredictorAUROCBest F1Recall @ Best F1Precision @ Best F1
A0.940.420.650.31
B0.880.580.820.45
C0.910.510.710.40

Ranking Metrics (Patient-Centric)

PredictorTop-1 RecallTop-5 Recall
A12%35%
B24%52%
C18%41%

Key Findings:

  1. Predictor A had the highest AUROC but the worst F1, Recall, and Top-K metrics
  2. Predictor B achieved 82% Recall and 52% Top-5 Recall—meaning it catches more pathogenic variants both in classification and ranking

Decision: For clinical use, Predictor B is preferred because:

  1. Highest Recall (minimizes missed pathogenic variants)
  2. Best F1 Score (balanced performance on imbalanced data)
  3. Best Top-5 Recall (pathogenic variant is in top 5 for 52% of patients)

Lesson: In medical AI with class imbalance, evaluate using multiple metrics that reflect clinical consequences—not just AUROC.

Challenge 2: Modeling Protein Variants with ESM2

The core task was to classify variants using esm2_t33_650M_UR50D.

Existing vs. Proposed Approach

A standard approach in this domain often involves feeding the mutant sequence directly into the model to predict its property.

Baseline Architecture Figure 1: Standard Baseline Approach. The model only sees the mutant sequence, making it difficult to learn the specific impact of the mutation relative to the wild-type.

However, simply feeding the mutant sequence isn’t enough. The model needs to understand what changed. I designed the input to explicitly capture the difference:

1
Input = Concat(E_wt, E_mut, E_mut - E_wt)
  • E_wt: Embedding of the Wild-Type sequence
  • E_mut: Embedding of the Mutant sequence
  • Difference: The vector representing the direction of change (Mutant - Wild-Type)

This “Difference Vector” proved crucial for distinguishing between LOF (function loss) and GOF (function gain).

Model Architecture Figure 2: Our Proposed Architecture. By explicitly feeding the difference vector (Mutant - WT), the model can directly focus on the functional shift caused by the variant.

Code Snippet: Model Architecture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class ESM2VariantClassifier(nn.Module):
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D"):
        super().__init__()
        self.esm = EsmModel.from_pretrained(model_name)
        # Freeze backbone for efficiency
        for param in self.esm.parameters():
            param.requires_grad = False
            
        hidden_size = self.esm.config.hidden_size
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 3, 512), # 3x input size due to concatenation
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, wt_ids, wt_mask, mut_ids, mut_mask):
        wt_out = self.esm(input_ids=wt_ids, attention_mask=wt_mask)
        mut_out = self.esm(input_ids=mut_ids, attention_mask=mut_mask)
        
        wt_cls = wt_out.last_hidden_state[:, 0, :]
        mut_cls = mut_out.last_hidden_state[:, 0, :]
        
        diff = mut_cls - wt_cls
        combined = torch.cat((wt_cls, mut_cls, diff), dim=1)
        
        return self.classifier(combined)

Challenge 3: Extreme Class Imbalance

The dataset had a 9:1 imbalance (90% LOF, 10% GOF). A standard model would simply predict “LOF” for everything and achieve 90% accuracy, which is useless.

Solution: Weighted Loss

I used CrossEntropyLoss with class weights inversely proportional to the class frequencies.

1
2
3
# LOF (0): 90%, GOF (1): 10%
class_weights = torch.tensor([0.1, 0.9]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

This forces the model to pay 9x more attention to the minority GOF class, preventing it from being ignored.

Challenge 4: Distributed Training on A100s

To utilize 4x NVIDIA A100 GPUs, I used PyTorch’s DistributedDataParallel (DDP).

Key implementation details:

  1. DistributedSampler: Ensures each GPU gets a different slice of data.
  2. init_process_group: Sets up communication between GPUs.
  3. torchrun: The launcher utility to manage processes.

One interesting hurdle was verifying DDP logic on a single local GPU. I learned that you can use the gloo backend (CPU-based) with torchrun --nproc_per_node=1 to simulate the distributed environment locally before deploying to the expensive cluster.

1
2
# Local verification command
torchrun --nproc_per_node=1 train_script.py --backend gloo

Conclusion

This project reinforced the importance of domain-specific feature engineering (Difference Vector) and robust engineering practices (DDP, Weighted Loss) when working with biological data. By combining pre-trained PLMs with thoughtful architecture, we can build powerful tools for genomic analysis.

This post is licensed under CC BY 4.0 by the author.