Перейти к содержанию

Непрерывное обучение и катастрофическое забывание

~5 минут чтения

Предварительно: Техники файнтюнинга LLM | LoRA-варианты файнтюнинга

Fine-tune GPT на задаче A, затем на задаче B -- и accuracy на A падает на 40-60%. Это catastrophic forgetting, и это не теоретическая проблема: каждая компания, обновляющая LLM на новых данных, сталкивается с ней. Исследования 2025-2026 показывают, что даже 6.2% интерливинга старых данных в новый батч предотвращает катастрофическое забывание, а self-distillation снижает forgetting до <5%. Цена игнорирования -- полное переобучение с нуля вместо инкрементального апдейта, что для 70B модели стоит $100K+ за run.


Обзор

Catastrophic Forgetting — потеря ранее выученных знаний при обучении новым задачам.

The Stability-Plasticity Dilemma

\[ \text{Challenge: } \underbrace{\text{Learn new tasks}}_{\text{Plasticity}} \leftrightarrow \underbrace{\text{Retain old knowledge}}_{\text{Stability}} \]

1. Why Catastrophic Forgetting Happens

Root Causes

Cause Description
Weight interference New updates overwrite useful weights
Gradient conflict Task gradients point in different directions
Representational drift Hidden representations shift
Capacity saturation Limited model capacity

Mathematical Formulation

For sequential tasks \(\mathcal{T}_1, \mathcal{T}_2, \ldots, \mathcal{T}_n\):

\[ \theta^* = \arg\min_\theta \sum_{i=1}^n \mathcal{L}_i(\theta) \]

Problem: Optimizing \(\mathcal{L}_n\) may increase \(\mathcal{L}_1, \ldots, \mathcal{L}_{n-1}\)


2. Mitigation Strategies

A. Replay-Based Methods

Core idea: Rehearse old data while learning new tasks

\[ \mathcal{L}_{total} = \mathcal{L}_{new} + \lambda \cdot \mathcal{L}_{replay} \]

Experience Replay

# Store exemplars from old tasks
buffer = MemoryBuffer(capacity=1000)

# During training
batch_new = sample_new_data()
batch_old = buffer.sample(batch_size // 2)
loss = loss_fn(model, batch_new) + loss_fn(model, batch_old)

Generative Replay

  • Train generative model on old data
  • Generate pseudo-samples for replay
  • No storage needed

B. Regularization Methods

Elastic Weight Consolidation (EWC)

\[ \mathcal{L} = \mathcal{L}_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2 \]

where \(F_i\) = Fisher information for parameter \(i\)

Intuition: Important parameters get penalized more

Synaptic Intelligence (SI)

\[ \mathcal{L} = \mathcal{L}_{new} + \sum_i c_i (\theta_i - \theta_i^*)^2 \]

where \(c_i\) = importance measure accumulated during training

C. Architecture Methods

Progressive Networks

  • Add new columns for new tasks
  • No forgetting (lateral connections)
  • Drawback: Linear growth

PackNet

  • Prune and freeze weights per task
  • Fixed capacity
  • Requires task identity

Sparse LLM Continual Learning (GCL, 2026)

\[ \text{Mask}_i = \text{Prune}(\theta, \text{task}_i) \]

Challenge: Mask accumulation leads to interference

D. Parameter Isolation

Core idea: Different parameters for different tasks

Method Approach Memory
Adapters Task-specific modules Low
LoRA Low-rank per task Very low
Prefix tuning Soft prompts Minimal

3. LLM-Specific Continual Learning

Fine-tuning Challenge

Pre-trained LLM → Fine-tune on Task A → Fine-tune on Task B
                        ↓                        ↓
                  Good on A                 Forgets A!

LLM Continual Learning Approaches (2025-2026)

1. Self-Distillation (arXiv 2601.19897, Jan 2026)

Paper: "Self-Distillation Enables Continual Learning"

Method: Use model's own outputs as targets

\[ \mathcal{L} = \mathcal{L}_{task} + \lambda \cdot D_{KL}(p_{old} \| p_{new}) \]

Result: Reduces catastrophic forgetting in on-policy RL updates

2. Nested Learning (Google Research, Nov 2025)

Key insight: Hierarchical task organization

Level 0: General knowledge (frozen)
Level 1: Domain knowledge (partially frozen)
Level 2: Task-specific (fully trainable)

Formula: $$ \theta = \theta_{frozen} \cup \theta_{domain} \cup \theta_{task} $$

3. Mixed Training (2025 Research)

Key finding: Simple data mixing prevents forgetting

\[ \text{Batch} = \alpha \cdot \text{new\_data} + (1-\alpha) \cdot \text{old\_data} \]

Result: 1:1 ratio = zero forgetting

Even 6.2% interleaved old data prevents catastrophic forgetting!

4. RL-Based Continual Learning

Paper: "Continual Learning with RL for LLMs" (Cameron Wolfe)

Key insight: On-policy RL naturally mitigates forgetting

  • PPO updates preserve KL to reference
  • Inherent regularization
  • Suitable for LLM alignment

4. Evaluation Metrics

Forgetting Measure

\[ F_j = \max_{i < j} (a_{i,i} - a_{i,j}) \]

where \(a_{i,j}\) = accuracy on task \(i\) after training task \(j\)

Backward Transfer (BWT)

\[ \text{BWT} = \frac{1}{n-1} \sum_{i=1}^{n-1} (a_{i,n} - a_{i,i}) \]

Negative BWT = forgetting

Forward Transfer (FWT)

\[ \text{FWT} = \frac{1}{n-1} \sum_{i=2}^{n} (a_{i,i} - b_i) \]

where \(b_i\) = random baseline


5. Benchmarks & Results

Continual Learning Benchmarks (2025)

Benchmark Tasks Focus
CLiMB 15 Vision-language
CGLB 12 Graph learning
LAMBADA 10 Language tasks
STREAM 100+ Large-scale

LLM Continual Learning Results

Method Forgetting Performance
Naive fine-tuning 40-60% High (new task)
EWC 20-30% Medium
Replay (10%) 5-10% High
Mixed training (50%) ~0% High
Self-distillation <5% High

6. Practical Implementation

Continual Fine-tuning with Replay

class ContinualLLM:
    def __init__(self, model, buffer_size=1000):
        self.model = model
        self.buffer = ReplayBuffer(buffer_size)

    def learn_task(self, new_data, epochs=3, replay_ratio=0.5):
        for epoch in range(epochs):
            for batch in new_data:
                # Mix new and old data
                if len(self.buffer) > 0 and random.random() < replay_ratio:
                    old_batch = self.buffer.sample()
                    batch = mix_batches(batch, old_batch)

                # Forward + backward
                loss = self.model(batch)
                loss.backward()
                optimizer.step()

            # Store exemplars
            self.buffer.add(new_data.sample(100))

LoRA-Based Continual Learning

class ContinualLoRA:
    def __init__(self, model):
        self.adapters = {}  # task_id -> LoRA weights

    def add_task(self, task_id, rank=8):
        # Add new LoRA adapter for task
        self.adapters[task_id] = create_lora(rank)

    def forward(self, x, task_id):
        # Route to appropriate adapter
        return self.base_model(x) + self.adapters[task_id](x)

7. Interview Questions

Basic

  1. "What is catastrophic forgetting and why does it occur?"
  2. "Explain the stability-plasticity dilemma"
  3. "What are the main approaches to mitigate forgetting?"

Advanced

  1. "Compare replay-based vs regularization-based methods"
  2. "How does EWC calculate parameter importance?"
  3. "Explain why mixed training prevents catastrophic forgetting"
  4. "How can LoRA adapters be used for continual learning?"

System Design

  1. "Design a system for continuously updating a customer service LLM"
  2. "How would you implement continual learning for a multi-tenant platform?"
  3. "Design evaluation pipeline for continual learning systems"


Частые заблуждения

Заблуждение: EWC решает catastrophic forgetting

EWC (Elastic Weight Consolidation) снижает forgetting с 40-60% до 20-30%, но это далеко от решения. Причина: Fisher information matrix -- грубая аппроксимация parameter importance, и она не учитывает взаимодействие между параметрами. На практике простой replay (10% старых данных) дает 5-10% forgetting -- лучше EWC при меньшей сложности. Mixed training (50:50) дает ~0% forgetting.

Заблуждение: LoRA-адаптеры автоматически решают проблему continual learning

LoRA-per-task изолирует параметры -- да, base model не забывает. Но (1) нужно знать task identity во время инференса (routing), (2) количество адаптеров растет линейно с числом задач, (3) нет transfer между задачами (каждый адаптер учится с нуля). Это workaround, а не continual learning -- модель не "учится лучше со временем", а хранит отдельные снимки.

Заблуждение: replay buffer в 1000 примеров достаточен для LLM

Для табличных моделей 1000 примеров -- разумный buffer. Для LLM с триллионами токенов pre-training данных 1000 примеров -- капля в море. Исследования показывают, что нужно минимум 6.2% от объема новых данных для предотвращения forgetting. При fine-tune на 100K примеров -- это 6.2K exemplars. Качество выборки важнее размера: используйте stratified sampling по задачам и difficulty.


Вопросы для собеседования

Вы fine-tune LLM на новом домене, и модель забывает общие знания. Как решить?

❌ "Увеличим learning rate, чтобы модель быстрее выучила новое" -- высокий LR усиливает forgetting, а не решает его.

✅ Сильный ответ: Несколько стратегий от простого к сложному: (1) Mixed training -- интерливинг 50% новых и 50% старых данных, дает ~0% forgetting. Если нет доступа к старым данным -- (2) Self-distillation: KL-divergence между output'ами новой и старой модели как регуляризация (L = L_task + lambda * D_KL(p_old || p_new)). (3) LoRA per task -- если задачи четко разделены. (4) Маленький learning rate (1e-5 вместо 1e-4) + cosine schedule -- уменьшает magnitude обновлений. Выбор зависит от: доступность старых данных, число задач, compute budget.

Сравните replay-based и regularization-based подходы к continual learning.

❌ "Replay лучше, потому что использует реальные данные" -- однобокий ответ без trade-offs.

✅ Сильный ответ: Replay-based (Experience Replay, Mixed Training): хранит и переигрывает старые примеры. Плюсы -- простота, высокое качество (5-10% forgetting при 10% replay, ~0% при 50%). Минусы -- storage (нужно хранить данные), privacy (нельзя хранить чувствительные данные), copyright. Regularization-based (EWC, SI): штрафует изменение важных весов. Плюсы -- не нужны старые данные, privacy-friendly. Минусы -- worse retention (20-30%), вычислительно дороже (Fisher matrix), хрупкость при большом числе задач. На практике: если есть доступ к данным -- replay (mixed training). Если нет -- self-distillation (не нужны данные, только старая модель) или LoRA isolation.

Как оценить степень catastrophic forgetting после обновления модели?

❌ "Проверим accuracy на тестовом наборе" -- какой именно? Недостаточно конкретно.

✅ Сильный ответ: Три метрики: (1) Forgetting Measure F_j = max(a_{i,i} - a_{i,j}) -- максимальная потеря по каждой старой задаче. (2) Backward Transfer BWT = (1/(n-1)) * sum(a_{i,n} - a_{i,i}) -- средняя деградация на старых задачах, negative BWT = forgetting. (3) Forward Transfer FWT -- улучшение на новых задачах благодаря старым. Практически: (a) держим held-out sets для каждой задачи; (b) после каждого апдейта оцениваем все задачи; © строим матрицу a_{i,j} (performance на задаче i после обучения задаче j); (d) alert при BWT < -5%.


Cross-references

См. также: прогресс-rlhf — RL-based continual learning (PPO regularization)


8. Formulas Summary

EWC Loss

\[\mathcal{L} = \mathcal{L}_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2\]

Forgetting Measure

\[F_j = \max_{i < j} (a_{i,i} - a_{i,j})\]

Backward Transfer

\[\text{BWT} = \frac{1}{n-1} \sum_{i=1}^{n-1} (a_{i,n} - a_{i,i})\]

Mixed Training

\[\mathcal{L} = \alpha \cdot \mathcal{L}_{new} + (1-\alpha) \cdot \mathcal{L}_{old}\]

9. Sources & Further Reading

Papers (2025-2026)

  1. [Self-Distillation] "Self-Distillation Enables Continual Learning" (arXiv 2601.19897, Jan 2026)
  2. [Nested Learning] Google Research Blog (Nov 2025)
  3. [GCL] "Group-shared Continual Learning for Sparse LLMs" (Neurocomputing, 2026)
  4. [Survey] "Understanding Catastrophic Forgetting" (ResearchGate, 2025)
  5. [LLM CL] "Continual Learning of Large Language Models" (ACM, 2025)

Blogs

  • The Neuron: "Four Breakthroughs Reshaping LLM Learning" (Jan 2026)
  • Cameron Wolfe: "Continual Learning with RL for LLMs"