Дистилляция знаний LLM¶
~3 минуты чтения
Предварительно: базовые понятия о квантизации, loss functions
Зачем это нужно¶
GPT-4 знает всё, но стоит $30/M токенов и требует кластер GPU. Вам нужна модель для одной задачи -- скажем, классификация тикетов. Дистилляция позволяет "перелить" знания из гигантской модели в маленькую, которая работает в 10x дешевле и в 5x быстрее, сохраняя 90-95% качества.
Аналогия: учитель (teacher model) знает весь курс физики. Студент (student model) не копирует учебник -- он учится на объяснениях учителя. Когда учитель говорит "вероятно ответ А, но Б тоже возможен" -- это передает больше информации, чем просто "ответ А". Эти нюансы ("dark knowledge") -- суть дистилляции.
Teacher-Student Framework¶
graph LR
T["Teacher<br/>(Large Model)"] -->|"Soft Labels<br/>p=[0.92, 0.05, 0.03]"| S["Student<br/>(Small Model)"]
GT["Ground Truth"] -->|"Hard Labels<br/>y=[1, 0, 0]"| S
S -->|"Combined Loss"| L["L = alpha*CE + (1-alpha)*KL"]
style T fill:#e8eaf6,stroke:#3f51b5
style S fill:#e8f5e9,stroke:#4caf50
style L fill:#fff3e0,stroke:#ef6c00
Ключевой инсайт: когда teacher предсказывает "Paris" с 92%, "Lyon" с 5%, "France" с 3% -- эти 5% на "Lyon" передают student'у знание об альтернативах. Hard label "Paris" такого не дает.
KD Loss Function¶
- \(\mathcal{L}_{CE}\) -- cross-entropy с ground truth (hard labels)
- \(\mathcal{L}_{KL}\) -- KL divergence между teacher и student (soft labels)
- \(\alpha\) -- баланс hard vs soft (типично 0.1-0.5)
- \(p_t, p_s\) -- softened probability distributions
Temperature Scaling¶
| Temperature | Effect | Use Case |
|---|---|---|
| T = 1 | Standard softmax | Baseline |
| T = 2-5 | Softer distribution | Typical KD |
| T = 10+ | Very soft | Capturing dark knowledge |
Hinton's original paper: T = 20.
5-Step Process¶
- Select Teacher Model -- pre-trained large model, stays frozen
- Design Student Architecture -- notably smaller, retains learning capacity
- Generate Soft Labels -- run calibration data through teacher
- Train with Combined Loss -- \(\alpha \cdot L_{CE} + (1-\alpha) \cdot L_{KL}\)
- Validate Performance -- compare: teacher (upper bound), student baseline, student + KD
Types of Knowledge Distillation¶
A. Response-based (Vanilla KD)¶
Focus: output logits teacher'а.
graph LR
TL["Teacher Logits"] -->|"T=4"| SL["Soft Labels"]
SL --> CL["Combined Loss"]
HL["Hard Labels"] --> CL
CL --> ST["Student Update"]
style SL fill:#e8eaf6,stroke:#3f51b5
style HL fill:#fff3e0,stroke:#ef6c00
Плюсы: простой и эффективный. Минусы: теряет промежуточные представления.
B. Feature-based (Layer-wise)¶
Focus: intermediate hidden states.
где \(W\) -- projection matrix для dimension alignment (teacher: 4096, student: 2048).
graph TD
TL["Teacher Layer i<br/>h_t = activations"] --> MSE["MSE Loss"]
SL["Student Layer j<br/>h_s = activations"] -->|"Projection W"| MSE
MSE --> TOTAL["Total Loss =<br/>L_KD + lambda * SUM(L_layer)"]
style TL fill:#e8eaf6,stroke:#3f51b5
style SL fill:#e8f5e9,stroke:#4caf50
style MSE fill:#fff3e0,stroke:#ef6c00
Плюсы: сохраняет representations. Минусы: нужна обработка dimension mismatch.
C. Relation-based¶
Focus: relationships between samples.
Pros: captures inter-sample relationships. Cons: computational overhead.
D. Attention Transfer¶
Match attention matrices per layer. Captures how teacher "attends" to input tokens.
LLM-Specific Methods¶
1. Feature Dynamics Alignment (ACL 2025)¶
где \(\Delta h = h_{layer+1} - h_{layer}\) -- feature transition.
Key insight: match not just features, but how features change across layers.
2. TAID (Temporally Adaptive Interpolated Distillation, 2025)¶
def taid_loss(student_logits, teacher_logits, labels, t, schedule):
kd_loss = kl_divergence(student_logits, teacher_logits, T=4)
ce_loss = cross_entropy(student_logits, labels)
alpha = schedule(t) # Changes over training
return alpha * ce_loss + (1 - alpha) * kd_loss
Adaptively adjusts distillation weight over training.
3. Progressive Distillation¶
Phase 1: Distill from full teacher
Phase 2: Fine-tune student
Phase 3: Optional iterative refinement
4. LightPAFF (Nature, 2025)¶
Pruning + KD combined в единый objective:
"Efficient self-attention with smart pruning for sustainable LLMs" (22 citations).
5. Regressor-free Intermediate Distillation (OpenReview 2025)¶
Problem: dimension mismatch в intermediate layers. Solution: pruning teacher to match student dimensions (вместо projection layers).
6. Membership-aware KD (arXiv 2508.07054)¶
Finding: KD can leak training data membership. Mitigation: regularization during distillation.
Pruning + Distillation Pipeline¶
NVIDIA NeMo Framework¶
graph TD
ORIG["Original Model"] --> P["Step 1: Pruning<br/>Remove unimportant weights/heads"]
P --> KD["Step 2: Fine-tuning with KD<br/>Original = Teacher, Pruned = Student"]
KD --> RES["Result: Smaller model<br/>with similar performance"]
style ORIG fill:#e8eaf6,stroke:#3f51b5
style P fill:#fce4ec,stroke:#c62828
style KD fill:#e8f5e9,stroke:#4caf50
style RES fill:#fff3e0,stroke:#ef6c00
Pruning Types¶
| Type | What Removed | Recovery with KD |
|---|---|---|
| Magnitude pruning | Individual weights | 85-90% |
| Structured pruning | Neurons/heads/layers | 90-95% |
| Attention head pruning | Per head | 92-97% |
| Layer dropping | Entire layers | 88-93% |
| Depth pruning | Entire layers | KD essential |
| Width pruning | Neurons, heads, channels | KD + fine-tuning |
Teacher-Guided Pruning (2025)¶
| Подход | Процесс | Результат |
|---|---|---|
| Standard | Prune -> Accuracy Drop -> KD Recovery | Хороший |
| Guided | Prune + Teacher -> Minimal Drop -> KD Polish | Лучший |
Guided pruning использует teacher во время pruning, а не только после.
KD Recovery Rates¶
| Pruning Level | Without KD | With KD |
|---|---|---|
| 25% | 92% accuracy | 98% accuracy |
| 50% | 78% accuracy | 94% accuracy |
| 75% | 55% accuracy | 85% accuracy |
P-KD-Q Sequence (Optimal Compression Order)¶
2025 Research Finding (arXiv 2511.19495):
Pruning -> Distillation -> Quantization (P-KD-Q) performed best among tested sequences, yielding the best balance for compression with preserved capabilities.
| Step | Purpose |
|---|---|
| Pruning (P) | Remove redundant parameters, establish structural foundation |
| Distillation (D) | Retrain to recover capabilities, optimize remaining parameters |
| Quantization (Q) | Apply to already-optimized architecture |
Critical Finding: sequences quantizing before distillation saw perplexity jump by an order of magnitude.
| Sequence | Performance |
|---|---|
| P-KD-Q | Best balance |
| KD-P-Q | Good |
| P-Q-KD | Poor (perplexity spike) |
| Q-KD-P | Worst (quality collapse) |
Four Compression Techniques (Overview)¶
| Technique | What It Does | Speedup | Quality Impact |
|---|---|---|---|
| Quantization | Reduce precision (FP32 -> INT4) | Multi-fold | Modest |
| Pruning | Remove unnecessary weights | Varies | Minimal with proper retraining |
| Distillation | Teacher -> Student transfer | 2-3x | 97% retention |
| LoRA | Freeze base, train adapters | Minimal | 90-95% of full FT |
LoRA (Low-Rank Adaptation)¶
\(B \in \mathbb{R}^{d \times r}\), \(A \in \mathbb{R}^{r \times k}\), rank \(r \ll d\).
from peft import get_peft_model, LoraConfig, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, lora_alpha=32, lora_dropout=0.1,
target_modules=["c_attn"]
)
lora_model = get_peft_model(base_model, lora_config)
# Result: >99% reduction in trainable parameters
TinyLlama Case Study¶
Architecture¶
| Component | LLaMA-2-7B (Teacher) | TinyLlama-1.1B (Student) |
|---|---|---|
| Layers | 32 | 22 |
| Hidden dim | 4096 | 2048 |
| Attention heads | 32 | 32 |
| Parameters | 7B | 1.1B |
| Compression ratio | -- | 6.4x |
| Training data | -- | 3T tokens |
Results¶
| Metric | LLaMA-2-7B | TinyLlama | TinyLlama + KD |
|---|---|---|---|
| MMLU | 45.3 | 35.1 | 38.2 |
| HellaSwag | 78.6 | 68.2 | 71.5 |
| ARC-C | 49.2 | 38.5 | 41.8 |
Key Techniques¶
- Pre-training with teacher guidance
- Intermediate layer matching
- Attention transfer
- Large-scale data augmentation
Deployment Scenarios¶
By Platform¶
| Platform | Recommended Approach |
|---|---|
| Cloud | Quantization + semantic caching |
| Edge/Mobile | Distillation + quantization |
| IoT/Embedded | P-KD-Q full stack |
Real-World Applications¶
| Use Case | Requirements | Solution |
|---|---|---|
| Real-time chat | Low latency | Distilled model |
| Autocomplete | Sub-100ms | Tiny model + caching |
| Document processing | High throughput | Quantized model |
| On-device AI | Memory constrained | P-KD-Q stack |
Infrastructure Optimization (Three-Layer Stack)¶
- Semantic Caching: reduce redundant LLM calls
- Vector Search: RAG context retrieval
- Distilled Models: fast inference on cache misses
Redis Benchmark: 90% precision, 200ms median latency, top 100 nearest neighbors, 50 concurrent queries (billion-scale).
Implementation¶
Basic KD Training Loop (PyTorch)¶
import torch
import torch.nn as nn
import torch.nn.functional as F
def distillation_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 4.0,
alpha: float = 0.5,
) -> torch.Tensor:
# Soft labels (KD loss)
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
kd_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean')
kd_loss *= temperature ** 2 # Scale by T^2
# Hard labels (CE loss)
ce_loss = F.cross_entropy(student_logits, labels)
return alpha * ce_loss + (1 - alpha) * kd_loss
def train_with_distillation(teacher, student, dataloader, optimizer, T=4.0, alpha=0.5):
teacher.eval()
for batch in dataloader:
input_ids, labels = batch
with torch.no_grad():
teacher_logits = teacher(input_ids).logits
student_logits = student(input_ids).logits
loss = distillation_loss(student_logits, teacher_logits, labels, T, alpha)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Layer-wise Distillation¶
def layerwise_distillation_loss(
student_hiddens: list,
teacher_hiddens: list,
layer_mapping: dict, # {student_layer: teacher_layer}
) -> torch.Tensor:
total_loss = 0.0
for s_layer, t_layer in layer_mapping.items():
s_hidden = student_hiddens[s_layer]
t_hidden = teacher_hiddens[t_layer]
if s_hidden.size(-1) != t_hidden.size(-1):
proj = nn.Linear(s_hidden.size(-1), t_hidden.size(-1))
s_hidden = proj(s_hidden)
loss = F.mse_loss(s_hidden, t_hidden)
total_loss += loss
return total_loss / len(layer_mapping)
HuggingFace Integration¶
from transformers import AutoModelForCausalLM, Trainer
teacher = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
student = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
class KDTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
student_logits = outputs.logits
with torch.no_grad():
teacher_outputs = self.teacher(**inputs)
teacher_logits = teacher_outputs.logits
loss_kd = self.kd_loss(student_logits, teacher_logits, labels)
return (loss_kd, outputs) if return_outputs else loss_kd
Best Practices¶
Hyperparameters¶
| Parameter | Recommended | Notes |
|---|---|---|
| Temperature (T) | 2-5 | Higher for more soft knowledge |
| Alpha | 0.1-0.5 | Lower = more KD influence |
| Learning rate | 1e-5 to 1e-4 | Lower than pre-training |
| Batch size | As large as possible | More stable gradients |
| Training steps | 10-50K | Depends on data size |
Teacher-Student Size Guidelines¶
| Teacher | Student | Quality Retention |
|---|---|---|
| 7B | 1.1B | 85-90% |
| 13B | 3B | 90-93% |
| 70B | 7B | 92-95% |
| 70B | 1.1B | 75-80% |
Size Ratio vs Retention¶
| Size Ratio | Expected Retention |
|---|---|
| 2:1 | 95-97% |
| 4:1 | 90-95% |
| 8:1 | 85-90% |
| 16:1 | 75-85% |
Common Pitfalls¶
| Pitfall | Solution |
|---|---|
| Teacher too large gap | Use intermediate teacher |
| Temperature too high | Start with T=2, increase gradually |
| Too much KD (alpha too low) | Balance with hard labels |
| Layer mismatch | Use projection layers |
Для интервью¶
Q: "Что такое knowledge distillation?"¶
Transfer знаний от large teacher model к smaller student model. Teacher генерирует soft labels (probability distributions), student учится одновременно на soft labels (KL divergence) и hard labels (CE). Loss = alpha * CE + (1-alpha) * KL. Temperature T > 1 softens distributions, передавая "dark knowledge" -- альтернативные предсказания.
Q: "Какие типы KD существуют?"¶
(1) Response-based: match output logits (simplest). (2) Feature-based: match intermediate hidden states через MSE + projection для dimension alignment. (3) Relation-based: match inter-sample relationships. (4) Attention Transfer: match attention matrices per layer. Для LLM: Feature Dynamics Alignment (ACL 2025) -- match не features, а их transitions между layers.
Q: "TinyLlama approach?"¶
Teacher: Llama 2 7B. Student: 1.1B (22 layers, 2048 hidden, compression 6.4x). Pre-train on 3T tokens, fine-tune with KD (intermediate layer matching + attention transfer). MMLU: 45.3 -> 38.2 (84% retention). HellaSwag: 78.6 -> 71.5 (91% retention).
Q: "Optimal compression order?"¶
P-KD-Q (Pruning -> Distillation -> Quantization). Pruning removes redundant params. Distillation recovers quality. Quantization applied last to optimized model. Quantizing before distillation: perplexity jumps by order of magnitude. arXiv 2511.19495.
Q: "Pruning + KD combination?"¶
At 50% pruning: без KD = 78% accuracy, с KD = 94% accuracy. Teacher-Guided Pruning (2025): use teacher during pruning (not just after), minimal initial drop, then KD polish. NVIDIA NeMo: prune -> distill -> fine-tune pipeline.
Q: "Design pipeline: compress 70B to 7B"¶
(1) Structured pruning: remove 50% attention heads + FFN neurons. (2) KD: 70B teacher, T=4, alpha=0.3, 10K-50K steps. Expected: 92-95% retention. (3) Quantize to INT8/FP8. (4) Benchmark: MMLU, HumanEval, HellaSwag. (5) Deploy with vLLM/TensorRT-LLM. Alternatively if budget allows: skip pruning, train student from scratch with KD on 1-3T tokens.
Ключевые числа¶
| Факт | Значение |
|---|---|
| DistilBERT size reduction | 40% smaller, 60% faster, 97% accuracy |
| TinyBERT-4 reduction | 86.7% smaller, 89.4% faster |
| LoRA trainable params | >99% reduction |
| P-KD-Q optimal order | arXiv 2511.19495 |
| KD from 7B to 1.1B | ~500 GPU-hours |
| KD from 70B to 7B | ~5000 GPU-hours |
| 50% pruning + KD | 94% accuracy retention |
| 4:1 size ratio | 90-95% quality retention |
| Temperature for KD | T = 2-5 typical |
| Distillation with <3% data | NeurIPS 2024 |
| Relative cost (distilled + quantized) | ~0.1x of full model |
Formulas Summary¶
KD Loss¶
Soft Labels¶
Feature Distillation¶
Attention Transfer¶
Feature Dynamics (ACL 2025)¶
LoRA¶
Interview Questions¶
Conceptual:
- "Зачем нужна температура T в knowledge distillation?" -- T=1 (стандартный softmax) даёт peaked distribution: модель уверена в top-1 классе. T>1 "размягчает" distribution, раскрывая dark knowledge -- отношения между классами. Кошка чуть похожа на тигра, почти не похожа на грузовик. При T=20 эта информация видна student'у.
- "Чем logit distillation отличается от feature distillation?" -- Logit: копируем output distribution (что teacher думает). Feature: копируем intermediate representations (как teacher думает). Feature distillation глубже, но требует alignment размерностей (проекционные слои).
- "Оптимальный порядок compression: P→KD→Q. Почему не KD→P→Q?" -- Pruning сначала: убирает избыточные веса, студент учится на "чистой" структуре. KD потом: передаёт знания компактной модели. Quantization последним: минимальная потеря на уже оптимизированной модели. Обратный порядок теряет 2-5% accuracy.
Practical:
- "DistilBERT vs TinyBERT: когда что?" -- DistilBERT: 66M params (40% reduction), простая distillation, 97% accuracy BERT. TinyBERT: 14.5M params (93% reduction), data augmentation + task-specific distillation, 96.8% accuracy. TinyBERT для edge/mobile, DistilBERT для серверов с budget constraints.
Частые ошибки
"Knowledge distillation = просто train на soft labels" -- Полная loss функция: \(L = \alpha \cdot L_{hard} + (1-\alpha) \cdot T^2 \cdot L_{soft}\). Множитель \(T^2\) компенсирует масштаб градиентов при высокой температуре. Без него student недообучается.
"Больше температура = лучше distillation" -- T слишком высокая (>20) делает distribution uniform -- dark knowledge размывается в шуме. Оптимум обычно T=3-10 в зависимости от числа классов.
"Student должен быть той же архитектуры что teacher" -- Нет. Cross-architecture distillation работает: Transformer teacher → CNN student, GPT → LSTM. Ключ -- matching output distributions, не архитектуры.
Источники¶
- Hinton et al. -- "Distilling the Knowledge in a Neural Network" (2015)
- arXiv:1910.01108 -- DistilBERT
- arXiv:1909.10351 -- TinyBERT
- arXiv:2501.16937 -- TAID: Temporally Adaptive Interpolated Distillation
- arXiv:2511.19495 -- Compression Ordering Study (P-KD-Q)
- ACL 2025 -- "Aligning Feature Dynamics for Knowledge Distillation"
- Nature 2025 -- LightPAFF: "Efficient self-attention with smart pruning"
- OpenReview 2025 -- "Regressor-free Intermediate Distillation via Teacher Pruning"
- arXiv:2508.07054 -- "Membership and Memorization in LLM KD"
- NVIDIA NeMo -- "LLM Model Pruning and Knowledge Distillation"
- PyTorch Torchtune -- Llama Distillation Guide
- Redis Blog -- "Model Distillation for LLMs: Cut Costs & Boost Speed in 2026"
- Sebastian Raschka -- "Understanding Knowledge Distillation"
- TinyLlama Project -- Training recipe and benchmarks
- Analytics Vidhya -- "4 LLM Compression Techniques" (2025)
- ICML 2025 -- Agent-style benchmark quantization study
Distilled модель наследует bias и ошибки teacher'а
Дистилляция копирует не только знания, но и ошибки. Если GPT-4 (teacher) галлюцинирует на медицинские вопросы -- student модель будет галлюцинировать ТАК ЖЕ, но без disclaimers. arXiv 2508.07054 показывает: student models memorize training data из teacher outputs. Правило: (1) curate teacher outputs (фильтруй hallucinations до дистилляции), (2) добавь safety fine-tuning ПОСЛЕ дистилляции, (3) evaluate student ОТДЕЛЬНО от teacher на edge cases.
See Also¶
- Квантизация LLM -- альтернатива: уменьшить существующую модель вместо обучения маленькой, оптимальный pipeline = P-KD-Q
- Прунинг LLM -- удаление весов как предшествующий шаг перед дистилляцией (prune -> distill -> quantize)
- LoRA варианты файнтюнинга -- parameter-efficient fine-tuning student модели после дистилляции
- Методы Alignment -- RLHF/DPO для safety после дистилляции, чтобы student не наследовал bias
- Оптимизация инференса -- дистилляция как часть общей стратегии оптимизации