Assignment 1 / Text / Model Backbone

Model Backbone

Three models built on pre-trained BERT — a fully fine-tuned BERT classifier, a BiLSTM with frozen BERT embeddings, and a learned ensemble meta-learner.

Assignment 1 / Text / Model Backbone

Backbone Visualizations

The diagrams below illustrate the two main feature extractors used in the text pipeline.

BERT Backbone Architecture

Diagram of the BERT backbone architecture
The BERT classifier uses bert-base-uncased with WordPiece tokenization, input embeddings, and 12 stacked Transformer encoder layers. The pooled [CLS] representation is passed to a task-specific classifier head.

BiLSTM + Frozen BERT Embeddings

Diagram of the BiLSTM with frozen BERT embeddings
This branch freezes the BERT embedding module, sends the 768-dimensional token features through a two-layer bidirectional LSTM, applies learned attention pooling over time, and finishes with a compact classifier head.

Architecture Overview

Advanced

BERT Classifier

  • Backbone: bert-base-uncased — 12-layer Transformer, d_model=768, 12 heads
  • Loading: AutoModelForSequenceClassification.from_pretrained() — full fine-tuning
  • Classifier head: BERT pooler output → Dropout(0.3) → Linear(768→20)
  • Tokenization: WordPiece, max_length=256, [CLS]/[SEP]/[PAD] tokens
  • All 109.5 M parameters are updated during fine-tuning (no freezing)
  • AMP (FP16): enabled with torch.amp.GradScaler for RTX GPU training
  • Gradient accumulation: 2 steps → effective batch = 32
109,497,620 parameters (fully trainable)
Baseline

BiLSTM + Frozen BERT Embeddings

  • Embedding: full BERT embedding module (bert.embeddings) — 768-dim, completely frozen
  • Frozen params: 23,837,184 (word + position + token-type embeddings + LayerNorm)
  • Recurrent layer: 2-layer bidirectional nn.LSTM, hidden=384 → 768-dim output
  • Variable-length input: pack_padded_sequence respects attention mask
  • Attention pooling: learned scalar attention over LSTM timesteps → softmax-weighted sum
  • Classifier: LayerNorm → Linear(768→384) → GELU → Dropout(0.3) → Linear(384→20)
  • AMP off — BERT LayerNorm uses eps=1e-12 (too small for FP16); trained in FP32
7,395,477 trainable · 23,837,184 frozen
Ensemble

LearnedEnsemble (meta-MLP)

  • Backbones: frozen BERT classifier + frozen BiLSTM — logits pre-computed, no gradient flows into them
  • Meta-input: concatenated softmax logits from both models → shape (B, 40)
  • Meta-MLP: Linear(40→80) → GELU → Dropout(0.15) → Linear(80→40) → GELU → Dropout(0.10) → Linear(40→20)
  • Training on cached logits: both backbone logits pre-computed on GPU and stored on CPU — each ensemble epoch takes < 1 s
  • Trainable params: 7,340 (meta-MLP only) — checkpoint is < 100 KB
7,340 trainable (meta only)

Architecture Comparison

Property BERT Classifier BiLSTM LearnedEnsemble
Total parameters 109,497,620 31,232,661 140,730,281
Trainable params (own stage) 109,497,620 7,395,477 7,340
Pre-trained backbone bert-base-uncased (full) bert-base-uncased (embeddings only, frozen)
Sequence modelling 12-layer self-attention (parallel) 2-layer BiLSTM (sequential)
Pooling strategy BERT [CLS] pooler output Learned attention-weighted sum Direct logit fusion
Mixed precision AMP FP16 ✓ FP32 (LayerNorm eps stability) FP32
Training speed (per epoch) ~160 s (RTX 3060 6 GB) ~74 s (RTX 3060 6 GB) < 1 s (cached logits)
Best test accuracy 70.23% 68.00% 69.48%

Why These Architectures?

  • Fine-tuned BERT is the natural "advanced" choice — pre-trained contextual representations and a shared tokenizer give a massive head-start over random initialisation.
  • BiLSTM with frozen BERT embeddings is the "baseline" — it reuses high-quality 768-dim token representations without the cost of fine-tuning the full Transformer stack; only the LSTM and classifier are learned.
  • Ensemble meta-MLP tests whether logit-level complementarity between BERT's attention-based and the BiLSTM's sequential view can be exploited with minimal extra compute.

Key Engineering Decisions

  • AMP disabled for BiLSTM — BERT's embedding LayerNorm uses eps=1e-12, which underflows FP16 (~6e-8), producing NaN gradients. Training in FP32 is safe and fast enough within 6 GB VRAM.
  • Gradient accumulation (2 steps) for BERT keeps effective batch at 32 while fitting the 16-sample per-step budget within VRAM.
  • Pre-computed logits for ensemble — both backbones are moved to GPU one at a time (BERT then BiLSTM) to collect logits, then moved back to CPU. The meta-MLP trains on CPU-cached tensors — no VRAM needed for ensemble training.
  • Separate weight-decay groups for BERT — bias/LayerNorm params are excluded from L2 penalty (standard BERT fine-tuning practice).