Перейти к содержанию

Дистилляция знаний LLM

~3 минуты чтения

Предварительно: базовые понятия о квантизации, loss functions


Зачем это нужно

GPT-4 знает всё, но стоит $30/M токенов и требует кластер GPU. Вам нужна модель для одной задачи -- скажем, классификация тикетов. Дистилляция позволяет "перелить" знания из гигантской модели в маленькую, которая работает в 10x дешевле и в 5x быстрее, сохраняя 90-95% качества.

Аналогия: учитель (teacher model) знает весь курс физики. Студент (student model) не копирует учебник -- он учится на объяснениях учителя. Когда учитель говорит "вероятно ответ А, но Б тоже возможен" -- это передает больше информации, чем просто "ответ А". Эти нюансы ("dark knowledge") -- суть дистилляции.

Teacher-Student Framework

graph LR
    T["Teacher<br/>(Large Model)"] -->|"Soft Labels<br/>p=[0.92, 0.05, 0.03]"| S["Student<br/>(Small Model)"]
    GT["Ground Truth"] -->|"Hard Labels<br/>y=[1, 0, 0]"| S
    S -->|"Combined Loss"| L["L = alpha*CE + (1-alpha)*KL"]
    style T fill:#e8eaf6,stroke:#3f51b5
    style S fill:#e8f5e9,stroke:#4caf50
    style L fill:#fff3e0,stroke:#ef6c00

Ключевой инсайт: когда teacher предсказывает "Paris" с 92%, "Lyon" с 5%, "France" с 3% -- эти 5% на "Lyon" передают student'у знание об альтернативах. Hard label "Paris" такого не дает.

KD Loss Function

\[ \mathcal{L}_{KD} = \alpha \cdot \mathcal{L}_{CE}(y, y_s) + (1-\alpha) \cdot \mathcal{L}_{KL}(p_t, p_s) \]
  • \(\mathcal{L}_{CE}\) -- cross-entropy с ground truth (hard labels)
  • \(\mathcal{L}_{KL}\) -- KL divergence между teacher и student (soft labels)
  • \(\alpha\) -- баланс hard vs soft (типично 0.1-0.5)
  • \(p_t, p_s\) -- softened probability distributions

Temperature Scaling

\[ p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} \]
Temperature Effect Use Case
T = 1 Standard softmax Baseline
T = 2-5 Softer distribution Typical KD
T = 10+ Very soft Capturing dark knowledge

Hinton's original paper: T = 20.

5-Step Process

  1. Select Teacher Model -- pre-trained large model, stays frozen
  2. Design Student Architecture -- notably smaller, retains learning capacity
  3. Generate Soft Labels -- run calibration data through teacher
  4. Train with Combined Loss -- \(\alpha \cdot L_{CE} + (1-\alpha) \cdot L_{KL}\)
  5. Validate Performance -- compare: teacher (upper bound), student baseline, student + KD

Types of Knowledge Distillation

A. Response-based (Vanilla KD)

Focus: output logits teacher'а.

\[ \mathcal{L}_{response} = \| p_t(x) - p_s(x) \|^2 \]
graph LR
    TL["Teacher Logits"] -->|"T=4"| SL["Soft Labels"]
    SL --> CL["Combined Loss"]
    HL["Hard Labels"] --> CL
    CL --> ST["Student Update"]
    style SL fill:#e8eaf6,stroke:#3f51b5
    style HL fill:#fff3e0,stroke:#ef6c00

Плюсы: простой и эффективный. Минусы: теряет промежуточные представления.

B. Feature-based (Layer-wise)

Focus: intermediate hidden states.

\[ \mathcal{L}_{feature} = \| h_t^l - W \cdot h_s^l \|^2 \]

где \(W\) -- projection matrix для dimension alignment (teacher: 4096, student: 2048).

graph TD
    TL["Teacher Layer i<br/>h_t = activations"] --> MSE["MSE Loss"]
    SL["Student Layer j<br/>h_s = activations"] -->|"Projection W"| MSE
    MSE --> TOTAL["Total Loss =<br/>L_KD + lambda * SUM(L_layer)"]
    style TL fill:#e8eaf6,stroke:#3f51b5
    style SL fill:#e8f5e9,stroke:#4caf50
    style MSE fill:#fff3e0,stroke:#ef6c00

Плюсы: сохраняет representations. Минусы: нужна обработка dimension mismatch.

C. Relation-based

Focus: relationships between samples.

\[ \mathcal{L}_{relation} = \| \psi(x_i, x_j)_t - \psi(x_i, x_j)_s \|^2 \]

Pros: captures inter-sample relationships. Cons: computational overhead.

D. Attention Transfer

\[ \mathcal{L}_{attn} = \sum_l \| A_t^l - A_s^l \|_F^2 \]

Match attention matrices per layer. Captures how teacher "attends" to input tokens.


LLM-Specific Methods

1. Feature Dynamics Alignment (ACL 2025)

\[ L_{delta} = \|\Delta h_t - \Delta h_s\|^2 \]

где \(\Delta h = h_{layer+1} - h_{layer}\) -- feature transition.

Key insight: match not just features, but how features change across layers.

2. TAID (Temporally Adaptive Interpolated Distillation, 2025)

def taid_loss(student_logits, teacher_logits, labels, t, schedule):
    kd_loss = kl_divergence(student_logits, teacher_logits, T=4)
    ce_loss = cross_entropy(student_logits, labels)
    alpha = schedule(t)  # Changes over training
    return alpha * ce_loss + (1 - alpha) * kd_loss

Adaptively adjusts distillation weight over training.

3. Progressive Distillation

Phase 1: Distill from full teacher
Phase 2: Fine-tune student
Phase 3: Optional iterative refinement

4. LightPAFF (Nature, 2025)

Pruning + KD combined в единый objective:

\[ \mathcal{L}_{total} = \mathcal{L}_{task} + \lambda_{KD} \mathcal{L}_{KD} + \lambda_{prune} \mathcal{L}_{prune} \]

"Efficient self-attention with smart pruning for sustainable LLMs" (22 citations).

5. Regressor-free Intermediate Distillation (OpenReview 2025)

Problem: dimension mismatch в intermediate layers. Solution: pruning teacher to match student dimensions (вместо projection layers).

6. Membership-aware KD (arXiv 2508.07054)

Finding: KD can leak training data membership. Mitigation: regularization during distillation.


Pruning + Distillation Pipeline

NVIDIA NeMo Framework

graph TD
    ORIG["Original Model"] --> P["Step 1: Pruning<br/>Remove unimportant weights/heads"]
    P --> KD["Step 2: Fine-tuning with KD<br/>Original = Teacher, Pruned = Student"]
    KD --> RES["Result: Smaller model<br/>with similar performance"]
    style ORIG fill:#e8eaf6,stroke:#3f51b5
    style P fill:#fce4ec,stroke:#c62828
    style KD fill:#e8f5e9,stroke:#4caf50
    style RES fill:#fff3e0,stroke:#ef6c00

Pruning Types

Type What Removed Recovery with KD
Magnitude pruning Individual weights 85-90%
Structured pruning Neurons/heads/layers 90-95%
Attention head pruning Per head 92-97%
Layer dropping Entire layers 88-93%
Depth pruning Entire layers KD essential
Width pruning Neurons, heads, channels KD + fine-tuning

Teacher-Guided Pruning (2025)

Подход Процесс Результат
Standard Prune -> Accuracy Drop -> KD Recovery Хороший
Guided Prune + Teacher -> Minimal Drop -> KD Polish Лучший

Guided pruning использует teacher во время pruning, а не только после.

KD Recovery Rates

Pruning Level Without KD With KD
25% 92% accuracy 98% accuracy
50% 78% accuracy 94% accuracy
75% 55% accuracy 85% accuracy

P-KD-Q Sequence (Optimal Compression Order)

2025 Research Finding (arXiv 2511.19495):

Pruning -> Distillation -> Quantization (P-KD-Q) performed best among tested sequences, yielding the best balance for compression with preserved capabilities.

Step Purpose
Pruning (P) Remove redundant parameters, establish structural foundation
Distillation (D) Retrain to recover capabilities, optimize remaining parameters
Quantization (Q) Apply to already-optimized architecture

Critical Finding: sequences quantizing before distillation saw perplexity jump by an order of magnitude.

Sequence Performance
P-KD-Q Best balance
KD-P-Q Good
P-Q-KD Poor (perplexity spike)
Q-KD-P Worst (quality collapse)

Four Compression Techniques (Overview)

Technique What It Does Speedup Quality Impact
Quantization Reduce precision (FP32 -> INT4) Multi-fold Modest
Pruning Remove unnecessary weights Varies Minimal with proper retraining
Distillation Teacher -> Student transfer 2-3x 97% retention
LoRA Freeze base, train adapters Minimal 90-95% of full FT

LoRA (Low-Rank Adaptation)

\[ W' = W + \frac{\alpha}{r} BA \]

\(B \in \mathbb{R}^{d \times r}\), \(A \in \mathbb{R}^{r \times k}\), rank \(r \ll d\).

from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8, lora_alpha=32, lora_dropout=0.1,
    target_modules=["c_attn"]
)
lora_model = get_peft_model(base_model, lora_config)
# Result: >99% reduction in trainable parameters

TinyLlama Case Study

Architecture

Component LLaMA-2-7B (Teacher) TinyLlama-1.1B (Student)
Layers 32 22
Hidden dim 4096 2048
Attention heads 32 32
Parameters 7B 1.1B
Compression ratio -- 6.4x
Training data -- 3T tokens

Results

Metric LLaMA-2-7B TinyLlama TinyLlama + KD
MMLU 45.3 35.1 38.2
HellaSwag 78.6 68.2 71.5
ARC-C 49.2 38.5 41.8

Key Techniques

  1. Pre-training with teacher guidance
  2. Intermediate layer matching
  3. Attention transfer
  4. Large-scale data augmentation

Deployment Scenarios

By Platform

Platform Recommended Approach
Cloud Quantization + semantic caching
Edge/Mobile Distillation + quantization
IoT/Embedded P-KD-Q full stack

Real-World Applications

Use Case Requirements Solution
Real-time chat Low latency Distilled model
Autocomplete Sub-100ms Tiny model + caching
Document processing High throughput Quantized model
On-device AI Memory constrained P-KD-Q stack

Infrastructure Optimization (Three-Layer Stack)

  1. Semantic Caching: reduce redundant LLM calls
  2. Vector Search: RAG context retrieval
  3. Distilled Models: fast inference on cache misses

Redis Benchmark: 90% precision, 200ms median latency, top 100 nearest neighbors, 50 concurrent queries (billion-scale).


Implementation

Basic KD Training Loop (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor,
    temperature: float = 4.0,
    alpha: float = 0.5,
) -> torch.Tensor:
    # Soft labels (KD loss)
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    kd_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean')
    kd_loss *= temperature ** 2  # Scale by T^2

    # Hard labels (CE loss)
    ce_loss = F.cross_entropy(student_logits, labels)

    return alpha * ce_loss + (1 - alpha) * kd_loss


def train_with_distillation(teacher, student, dataloader, optimizer, T=4.0, alpha=0.5):
    teacher.eval()

    for batch in dataloader:
        input_ids, labels = batch

        with torch.no_grad():
            teacher_logits = teacher(input_ids).logits

        student_logits = student(input_ids).logits

        loss = distillation_loss(student_logits, teacher_logits, labels, T, alpha)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Layer-wise Distillation

def layerwise_distillation_loss(
    student_hiddens: list,
    teacher_hiddens: list,
    layer_mapping: dict,  # {student_layer: teacher_layer}
) -> torch.Tensor:
    total_loss = 0.0

    for s_layer, t_layer in layer_mapping.items():
        s_hidden = student_hiddens[s_layer]
        t_hidden = teacher_hiddens[t_layer]

        if s_hidden.size(-1) != t_hidden.size(-1):
            proj = nn.Linear(s_hidden.size(-1), t_hidden.size(-1))
            s_hidden = proj(s_hidden)

        loss = F.mse_loss(s_hidden, t_hidden)
        total_loss += loss

    return total_loss / len(layer_mapping)

HuggingFace Integration

from transformers import AutoModelForCausalLM, Trainer

teacher = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
student = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

class KDTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        student_logits = outputs.logits

        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        loss_kd = self.kd_loss(student_logits, teacher_logits, labels)
        return (loss_kd, outputs) if return_outputs else loss_kd

Best Practices

Hyperparameters

Parameter Recommended Notes
Temperature (T) 2-5 Higher for more soft knowledge
Alpha 0.1-0.5 Lower = more KD influence
Learning rate 1e-5 to 1e-4 Lower than pre-training
Batch size As large as possible More stable gradients
Training steps 10-50K Depends on data size

Teacher-Student Size Guidelines

Teacher Student Quality Retention
7B 1.1B 85-90%
13B 3B 90-93%
70B 7B 92-95%
70B 1.1B 75-80%

Size Ratio vs Retention

Size Ratio Expected Retention
2:1 95-97%
4:1 90-95%
8:1 85-90%
16:1 75-85%

Common Pitfalls

Pitfall Solution
Teacher too large gap Use intermediate teacher
Temperature too high Start with T=2, increase gradually
Too much KD (alpha too low) Balance with hard labels
Layer mismatch Use projection layers

Для интервью

Q: "Что такое knowledge distillation?"

Transfer знаний от large teacher model к smaller student model. Teacher генерирует soft labels (probability distributions), student учится одновременно на soft labels (KL divergence) и hard labels (CE). Loss = alpha * CE + (1-alpha) * KL. Temperature T > 1 softens distributions, передавая "dark knowledge" -- альтернативные предсказания.

Q: "Какие типы KD существуют?"

(1) Response-based: match output logits (simplest). (2) Feature-based: match intermediate hidden states через MSE + projection для dimension alignment. (3) Relation-based: match inter-sample relationships. (4) Attention Transfer: match attention matrices per layer. Для LLM: Feature Dynamics Alignment (ACL 2025) -- match не features, а их transitions между layers.

Q: "TinyLlama approach?"

Teacher: Llama 2 7B. Student: 1.1B (22 layers, 2048 hidden, compression 6.4x). Pre-train on 3T tokens, fine-tune with KD (intermediate layer matching + attention transfer). MMLU: 45.3 -> 38.2 (84% retention). HellaSwag: 78.6 -> 71.5 (91% retention).

Q: "Optimal compression order?"

P-KD-Q (Pruning -> Distillation -> Quantization). Pruning removes redundant params. Distillation recovers quality. Quantization applied last to optimized model. Quantizing before distillation: perplexity jumps by order of magnitude. arXiv 2511.19495.

Q: "Pruning + KD combination?"

At 50% pruning: без KD = 78% accuracy, с KD = 94% accuracy. Teacher-Guided Pruning (2025): use teacher during pruning (not just after), minimal initial drop, then KD polish. NVIDIA NeMo: prune -> distill -> fine-tune pipeline.

Q: "Design pipeline: compress 70B to 7B"

(1) Structured pruning: remove 50% attention heads + FFN neurons. (2) KD: 70B teacher, T=4, alpha=0.3, 10K-50K steps. Expected: 92-95% retention. (3) Quantize to INT8/FP8. (4) Benchmark: MMLU, HumanEval, HellaSwag. (5) Deploy with vLLM/TensorRT-LLM. Alternatively if budget allows: skip pruning, train student from scratch with KD on 1-3T tokens.

Ключевые числа

Факт Значение
DistilBERT size reduction 40% smaller, 60% faster, 97% accuracy
TinyBERT-4 reduction 86.7% smaller, 89.4% faster
LoRA trainable params >99% reduction
P-KD-Q optimal order arXiv 2511.19495
KD from 7B to 1.1B ~500 GPU-hours
KD from 70B to 7B ~5000 GPU-hours
50% pruning + KD 94% accuracy retention
4:1 size ratio 90-95% quality retention
Temperature for KD T = 2-5 typical
Distillation with <3% data NeurIPS 2024
Relative cost (distilled + quantized) ~0.1x of full model

Formulas Summary

KD Loss

\[\mathcal{L}_{KD} = \alpha \cdot \mathcal{L}_{CE}(y, y_s) + (1-\alpha) \cdot \mathcal{L}_{KL}(p_t, p_s)\]

Soft Labels

\[p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}\]

Feature Distillation

\[\mathcal{L}_{feature} = \| h_t^l - W \cdot h_s^l \|^2\]

Attention Transfer

\[\mathcal{L}_{attn} = \sum_l \| A_t^l - A_s^l \|_F^2\]

Feature Dynamics (ACL 2025)

\[L_{delta} = \|\Delta h_t - \Delta h_s\|^2 \quad \text{where } \Delta h = h_{l+1} - h_l\]

LoRA

\[W' = W + \frac{\alpha}{r} BA\]

Interview Questions

Conceptual:

  1. "Зачем нужна температура T в knowledge distillation?" -- T=1 (стандартный softmax) даёт peaked distribution: модель уверена в top-1 классе. T>1 "размягчает" distribution, раскрывая dark knowledge -- отношения между классами. Кошка чуть похожа на тигра, почти не похожа на грузовик. При T=20 эта информация видна student'у.
  2. "Чем logit distillation отличается от feature distillation?" -- Logit: копируем output distribution (что teacher думает). Feature: копируем intermediate representations (как teacher думает). Feature distillation глубже, но требует alignment размерностей (проекционные слои).
  3. "Оптимальный порядок compression: P→KD→Q. Почему не KD→P→Q?" -- Pruning сначала: убирает избыточные веса, студент учится на "чистой" структуре. KD потом: передаёт знания компактной модели. Quantization последним: минимальная потеря на уже оптимизированной модели. Обратный порядок теряет 2-5% accuracy.

Practical:

  1. "DistilBERT vs TinyBERT: когда что?" -- DistilBERT: 66M params (40% reduction), простая distillation, 97% accuracy BERT. TinyBERT: 14.5M params (93% reduction), data augmentation + task-specific distillation, 96.8% accuracy. TinyBERT для edge/mobile, DistilBERT для серверов с budget constraints.

Частые ошибки

"Knowledge distillation = просто train на soft labels" -- Полная loss функция: \(L = \alpha \cdot L_{hard} + (1-\alpha) \cdot T^2 \cdot L_{soft}\). Множитель \(T^2\) компенсирует масштаб градиентов при высокой температуре. Без него student недообучается.

"Больше температура = лучше distillation" -- T слишком высокая (>20) делает distribution uniform -- dark knowledge размывается в шуме. Оптимум обычно T=3-10 в зависимости от числа классов.

"Student должен быть той же архитектуры что teacher" -- Нет. Cross-architecture distillation работает: Transformer teacher → CNN student, GPT → LSTM. Ключ -- matching output distributions, не архитектуры.


Источники

  1. Hinton et al. -- "Distilling the Knowledge in a Neural Network" (2015)
  2. arXiv:1910.01108 -- DistilBERT
  3. arXiv:1909.10351 -- TinyBERT
  4. arXiv:2501.16937 -- TAID: Temporally Adaptive Interpolated Distillation
  5. arXiv:2511.19495 -- Compression Ordering Study (P-KD-Q)
  6. ACL 2025 -- "Aligning Feature Dynamics for Knowledge Distillation"
  7. Nature 2025 -- LightPAFF: "Efficient self-attention with smart pruning"
  8. OpenReview 2025 -- "Regressor-free Intermediate Distillation via Teacher Pruning"
  9. arXiv:2508.07054 -- "Membership and Memorization in LLM KD"
  10. NVIDIA NeMo -- "LLM Model Pruning and Knowledge Distillation"
  11. PyTorch Torchtune -- Llama Distillation Guide
  12. Redis Blog -- "Model Distillation for LLMs: Cut Costs & Boost Speed in 2026"
  13. Sebastian Raschka -- "Understanding Knowledge Distillation"
  14. TinyLlama Project -- Training recipe and benchmarks
  15. Analytics Vidhya -- "4 LLM Compression Techniques" (2025)
  16. ICML 2025 -- Agent-style benchmark quantization study

Distilled модель наследует bias и ошибки teacher'а

Дистилляция копирует не только знания, но и ошибки. Если GPT-4 (teacher) галлюцинирует на медицинские вопросы -- student модель будет галлюцинировать ТАК ЖЕ, но без disclaimers. arXiv 2508.07054 показывает: student models memorize training data из teacher outputs. Правило: (1) curate teacher outputs (фильтруй hallucinations до дистилляции), (2) добавь safety fine-tuning ПОСЛЕ дистилляции, (3) evaluate student ОТДЕЛЬНО от teacher на edge cases.


See Also

  • Квантизация LLM -- альтернатива: уменьшить существующую модель вместо обучения маленькой, оптимальный pipeline = P-KD-Q
  • Прунинг LLM -- удаление весов как предшествующий шаг перед дистилляцией (prune -> distill -> quantize)
  • LoRA варианты файнтюнинга -- parameter-efficient fine-tuning student модели после дистилляции
  • Методы Alignment -- RLHF/DPO для safety после дистилляции, чтобы student не наследовал bias
  • Оптимизация инференса -- дистилляция как часть общей стратегии оптимизации