Post

Building a Protein Variant Classifier with ESM2 and Multi-GPU Training

Building a protein variant classifier with ESM2: clinical metric selection, a difference-vector architecture, class imbalance, and multi-GPU DDP.

Building a Protein Variant Classifier with ESM2 and Multi-GPU Training

Introduction

In the field of clinical genomics, accurately predicting whether a specific genetic variant is pathogenic (disease-causing) or benign is a critical challenge. Recently, I worked on a project to develop a deep learning model that classifies protein variants as either Gain-of-Function (GOF) or Loss-of-Function (LOF) using ESM2 (Evolutionary Scale Modeling), a state-of-the-art protein language model.

This post covers two related but distinct tasks with different label spaces, so it is worth separating them up front to avoid conflating their labels:

  • Task A — Pathogenic-variant prioritization & metric selection. A binary pathogenic (LABEL=1) vs benign (LABEL=0) problem, used to evaluate and choose among existing pathogenicity predictors. (Covered in Challenge 1.)
  • Task B — GOF/LOF classifier training. A separate binary GOF vs LOF problem over the variants of interest, where we train our own ESM2-based model. (Covered in Challenges 2-4.)

The two tasks share a class-imbalance theme but do not share labels: a “positive” in Task A is a pathogenic variant, whereas a “positive” in Task B is the minority GOF class.

Challenge 1 (Task A): Metric Selection for Clinical Use

Before diving into the model, I had to evaluate existing pathogenicity predictors. The dataset contains 107 patients, each with multiple variants where only a few are pathogenic (LABEL=1).

The Problem: Class Imbalance

The data is highly imbalanced—most variants are benign (LABEL=0), only a few are pathogenic (LABEL=1). This makes metric selection critical.

MetricFormulaProblem with Imbalanced Data
Accuracy(TP+TN) / TotalPredicting all as benign gives high accuracy
AUROCArea under TPR-FPR curveCan look good even with poor precision

Why AUROC Alone Is Not Enough

AUROC measures discrimination across all thresholds. A model with AUROC=0.94 sounds great, but:

  • At what threshold does it achieve good Precision and Recall?
  • In clinical diagnostics, False Negatives are dangerous (missing a pathogenic variant)

Metrics for Clinical Pathogenicity Prediction

For this problem, I focused on both classification metrics and ranking metrics:

Classification Metrics (Binary)

MetricFormulaClinical Importance
Recall (Sensitivity)$\frac{TP}{TP + FN}$Must be high: we cannot miss pathogenic variants
Precision (PPV)$\frac{TP}{TP + FP}$Reduces unnecessary follow-up tests
F1 Score$\frac{2 \times Precision \times Recall}{Precision + Recall}$Balances both for imbalanced data

Ranking Metric (Patient-Centric)

Since each patient has multiple variants and we want the pathogenic variant to be ranked high:

MetricFormulaClinical Importance
Top-K Recall$\frac{\text{# patients with causal variant in top K}}{\text{# total patients}}$Measures how often the pathogenic variant appears in the top K predictions

The formal definition:

\[\text{Top-K Recall} = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1}[\text{rank}(v_i) \leq K]\]

Where:

  • $N$ = number of patients
  • $v_i$ = the pathogenic variant for patient $i$
  • $\text{rank}(v_i)$ = position when variants are sorted by prediction score (descending)
  • $\mathbb{1}[\cdot]$ = indicator function (1 if true, 0 if false)

Why Recall Is Critical

In medical diagnostics, a False Negative (predicting benign when actually pathogenic) means:

  • Patient doesn’t receive treatment
  • Disease progresses undetected

Therefore, Recall must be prioritized, even at the cost of some False Positives.

Evaluation Framework

I evaluated each predictor (A, B, C) with both classification and ranking metrics:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sklearn.metrics import precision_recall_curve, roc_auc_score
import numpy as np
import pandas as pd

def evaluate_predictor(y_true: np.ndarray, y_scores: np.ndarray) -> dict:
    """Evaluate predictor with classification metrics."""
    auroc = roc_auc_score(y_true, y_scores)
    
    precisions, recalls, thresholds = precision_recall_curve(y_true, y_scores)
    f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
    best_idx = np.argmax(f1_scores)
    
    return {
        "auroc": auroc,
        "best_f1": f1_scores[best_idx],
        "recall_at_best_f1": recalls[best_idx],
        "precision_at_best_f1": precisions[best_idx],
    }
# end def

def compute_top_k_recall(df: pd.DataFrame, score_col: str, k: int) -> float:
    """Compute Top-K Recall per patient."""
    hits = 0
    for patient_id, group in df.groupby("Patient_ID"):
        sorted_group = group.sort_values(score_col, ascending=False)
        top_k_labels = sorted_group.head(k)["LABEL"].values
        if 1 in top_k_labels:
            hits += 1
        # end if
    # end for
    return hits / df["Patient_ID"].nunique()
# end def

Results

Classification Metrics

PredictorAUROCBest F1Recall @ Best F1Precision @ Best F1
A0.940.420.650.31
B0.880.580.820.45
C0.910.510.710.40

Ranking Metrics (Patient-Centric)

PredictorTop-1 RecallTop-5 Recall
A12%35%
B24%52%
C18%41%

Key Findings:

  1. Predictor A had the highest AUROC but the worst F1, Recall, and Top-K metrics
  2. Predictor B achieved 82% Recall and 52% Top-5 Recall—meaning it catches more pathogenic variants both in classification and ranking

Decision: For clinical use, Predictor B is preferred because:

  1. Highest Recall (minimizes missed pathogenic variants)
  2. Best F1 Score (balanced performance on imbalanced data)
  3. Best Top-5 Recall (pathogenic variant is in top 5 for 52% of patients)

Lesson: In medical AI with class imbalance, evaluate using multiple metrics that reflect clinical consequences—not just AUROC.

Challenge 2 (Task B): Modeling Protein Variants with ESM2

The core task was to classify variants using esm2_t33_650M_UR50D.

Existing vs. Proposed Approach

A standard approach in this domain often involves feeding the mutant sequence directly into the model to predict its property.

Baseline Architecture Figure 1: Standard Baseline Approach. The model only sees the mutant sequence, making it difficult to learn the specific impact of the mutation relative to the wild-type.

However, simply feeding the mutant sequence isn’t enough. The model needs to understand what changed. I designed the input to explicitly capture the difference:

1
Input = Concat(E_wt, E_mut, E_mut - E_wt)
  • E_wt: Embedding of the Wild-Type sequence
  • E_mut: Embedding of the Mutant sequence
  • Difference: The vector representing the direction of change (Mutant - Wild-Type)

This “Difference Vector” was a key design choice in my experiments for distinguishing between LOF (function loss) and GOF (function gain).

Model Architecture Figure 2: Our Proposed Architecture. By explicitly feeding the difference vector (Mutant - WT), the model can directly focus on the functional shift caused by the variant.

Code Snippet: Model Architecture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from transformers import EsmModel

class ESM2VariantClassifier(nn.Module):
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D"):
        super().__init__()
        self.esm = EsmModel.from_pretrained(model_name)
        # Freeze backbone for efficiency
        for param in self.esm.parameters():
            param.requires_grad = False
            
        hidden_size = self.esm.config.hidden_size
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 3, 512), # 3x input size due to concatenation
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, wt_ids, wt_mask, mut_ids, mut_mask):
        wt_out = self.esm(input_ids=wt_ids, attention_mask=wt_mask)
        mut_out = self.esm(input_ids=mut_ids, attention_mask=mut_mask)
        
        wt_cls = wt_out.last_hidden_state[:, 0, :]
        mut_cls = mut_out.last_hidden_state[:, 0, :]
        
        diff = mut_cls - wt_cls
        combined = torch.cat((wt_cls, mut_cls, diff), dim=1)
        
        return self.classifier(combined)

Here I pool each sequence using the CLS token (last_hidden_state[:, 0, :]); a common alternative is mean-pooling the hidden states over the non-special (non-CLS/EOS/padding) tokens, which can yield a more stable whole-sequence representation when the CLS token is not specifically trained as a summary.

Challenge 3 (Task B): Extreme Class Imbalance

This imbalance is on Task B’s GOF/LOF label space (distinct from the pathogenic/benign labels of Task A). The dataset had a 9:1 imbalance (90% LOF, 10% GOF). A standard model would simply predict “LOF” for everything and achieve 90% accuracy, which is useless.

Solution: Weighted Loss

I used CrossEntropyLoss with class weights inversely proportional to the class frequencies.

1
2
3
4
5
6
7
8
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# LOF (0): 90%, GOF (1): 10%
class_weights = torch.tensor([0.1, 0.9]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

This forces the model to pay 9x more attention to the minority GOF class, preventing it from being ignored.

Challenge 4 (Task B): Distributed Training on A100s

To utilize 4x NVIDIA A100 GPUs, I used PyTorch’s DistributedDataParallel (DDP).

Key implementation details:

  1. DistributedSampler: Ensures each GPU gets a different slice of data.
  2. init_process_group: Sets up communication between GPUs.
  3. torchrun: The launcher utility to manage processes.

One interesting hurdle was verifying DDP logic on a single local GPU. I learned that you can use the gloo backend with torchrun --nproc_per_node=1 to simulate the distributed environment locally before deploying to the expensive cluster.

A caveat on backends: PyTorch recommends nccl as the default backend for distributed GPU training, while gloo is intended for CPU (or a local single-process smoke test). So the local gloo command below is purely for logic verification, not the real 4xA100 run—which should use nccl.

1
2
# Local verification command
torchrun --nproc_per_node=1 train_script.py --backend gloo

What Didn’t Work / Limitations

This was a small-scale study, so the results are a proof of concept rather than a validated clinical tool:

  • Tiny evaluation set. Task A’s predictor comparison uses 107 patients with only a few pathogenic variants each, so the AUROC/F1/Top-K gaps between predictors A–C carry wide uncertainty — no confidence intervals or significance tests are reported.
  • Frozen backbone. ESM2 is used purely as a feature extractor (backbone frozen, only the head trained), which caps how much variant-specific signal the model can capture; fine-tuning or LoRA was not compared.
  • Static class weighting only. The 9:1 GOF/LOF imbalance is handled with fixed inverse-frequency weights; resampling, focal loss, and threshold calibration were not benchmarked against it.
  • Pooling not ablated. CLS-token pooling is used; as noted above, mean-pooling may give a more stable representation, but the two were not compared head-to-head.
  • No external validation. Generalization to other cohorts and a leakage-safe held-out split (so variants from the same patient don’t span train/test) are not established here.

Conclusion

This project reinforced the importance of domain-specific feature engineering (Difference Vector) and robust engineering practices (DDP, Weighted Loss) when working with biological data. By combining pre-trained PLMs with thoughtful architecture, we can build powerful tools for genomic analysis.

Setup (for reproducibility). Model: facebook/esm2_t33_650M_UR50D (HuggingFace transformers, backbone frozen). Hardware: 4× NVIDIA A100, PyTorch DistributedDataParallel (nccl) launched with torchrun. Data: 107 patients (Task A); ~9:1 LOF/GOF split (Task B). This was a private coursework project, so the code isn’t public — the stack above and the in-post snippets are the reproduction cues.

This post is licensed under CC BY 4.0 by the author.