Перейти к содержанию

Шпаргалка: выбор модели

~7 минут чтения

Предварительно: Метрики | sklearn

Выбор модели -- это не "какой алгоритм модный", а decision framework по свойствам данных и задачи. Два главных фактора: (1) тип данных (табличные, текст, изображения, временные ряды) и (2) размер датасета. На табличных данных gradient boosting (XGBoost/LightGBM) почти всегда лучше нейросетей, на изображениях -- наоборот.

Дерево выбора модели

graph TD
    START["Какую модель выбрать?"] --> TASK{"Тип задачи?"}

    TASK --> CLS["Классификация"]
    TASK --> REG["Регрессия"]
    TASK --> CLUST["Кластеризация"]
    TASK --> DIM["Снижение размерности"]
    TASK --> ANOM["Детекция аномалий"]

    CLS --> BIN{"Бинарная?"}
    BIN -->|"< 1000 данных"| SVM_KNN["SVM, KNN"]
    BIN -->|"Интерпретируемость"| LOGREG["LogReg, DecisionTree"]
    BIN -->|"Макс. качество"| BOOST["XGBoost, LightGBM"]
    CLS --> MULTI{"Многоклассовая?"}
    MULTI -->|"< 10 классов"| SAME["Те же что для бинарной"]
    MULTI -->|"100+ классов"| NN["Neural Networks"]

    REG --> LIN{"Линейная?"}
    LIN -->|"Да"| LINREG["Ridge, Lasso"]
    LIN -->|"Нет"| NLIN["RF, XGBoost, NN"]
    REG --> OUTLIERS["Выбросы: Huber, Robust"]

    CLUST --> KNOW{"Знаем K?"}
    KNOW -->|"Да"| KMEANS["KMeans, GMM"]
    KNOW -->|"Нет"| DBSCAN["DBSCAN, HDBSCAN"]

    DIM --> LINPCA{"Линейное?"}
    LINPCA -->|"Да"| PCA["PCA, TruncatedSVD"]
    LINPCA -->|"Визуализация"| TSNE["t-SNE, UMAP"]

    ANOM --> ISOLFOR["Isolation Forest, One-Class SVM"]

    style START fill:#e8eaf6,stroke:#3f51b5
    style CLS fill:#e8f5e9,stroke:#4caf50
    style REG fill:#fff3e0,stroke:#ef6c00
    style CLUST fill:#f3e5f5,stroke:#9c27b0
    style DIM fill:#e8eaf6,stroke:#3f51b5
    style ANOM fill:#fce4ec,stroke:#c62828
    style BOOST fill:#e8f5e9,stroke:#4caf50

Classification

Модель Плюсы Минусы Когда использовать
LogisticRegression Быстрая, интерпретируемая, вероятности Только линейные границы Baseline, линейно разделимые данные
DecisionTree Интерпретируемая, feature importance Легко переобучается Когда нужны объяснения
RandomForest Устойчив к переобучению, параллелизм Медленный на больших данных Универсальный выбор
XGBoost/LightGBM Лучшее качество на табличных данных Много гиперпараметров Kaggle, продакшен
SVM Хорош для малых данных, kernel trick Медленный на больших данных Малые/средние данные, нелинейные границы
KNN Простой, нет обучения Медленный inference, curse of dimensionality Малые данные, как baseline
NaiveBayes Очень быстрый, хорош для текста Предположение независимости Текст, спам-фильтры
Neural Networks Любые зависимости Много данных и compute Изображения, текст, большие данные

Regression

Модель Плюсы Минусы Когда использовать
LinearRegression Быстрая, интерпретируемая Только линейные зависимости Baseline, простые зависимости
Ridge Устойчива к мультиколлинеарности Не делает feature selection Много коррелированных признаков
Lasso Feature selection, sparse Нестабильна при коррелированных признаках Когда нужно отобрать признаки
ElasticNet Баланс Ridge + Lasso Два гиперпараметра Много признаков, часть избыточных
RandomForest Нелинейные зависимости Не экстраполирует Нелинейные данные
XGBoost/LightGBM Лучшее качество Не экстраполирует Табличные данные
SVR Kernel trick для нелинейности Медленный, много параметров Малые данные

По размеру данных

Размер данных Рекомендуемые модели Советы
< 1,000 (мало) SVM (RBF kernel), KNN, LogReg/Ridge с регуляризацией, деревья (max_depth=3-5) CV (k=5-10), сильная регуляризация, feature engineering важнее модели, аугментация если возможно
1K -- 100K (средние) Random Forest, XGBoost / LightGBM / CatBoost, SVM, небольшие NN Ensemble методы работают хорошо, можно подбирать гиперпараметры, CV обязательна
> 100K (большие) LightGBM (быстрее XGBoost), Neural Networks, SGDClassifier (online), Spark MLlib (distributed) Holdout вместо CV, сэмплирование для подбора гиперпараметров, feature selection для ускорения

По типу признаков

Тип признаков Модели Нюансы
Числовые Любые модели Нормализация обязательна для SVM, KNN, NN. Log-transform для скошенных распределений
Категориальные CatBoost (нативно), LightGBM (categorical_feature), RF (после encoding) One-hot для малого числа категорий, Target encoding для большого, Label encoding для деревьев
Текст TF-IDF + LogReg/SVM, BoW + NaiveBayes, BERT/RoBERTa (лучшее качество) Deep Learning: BERT > CNN > LSTM/GRU. Традиционные: TF-IDF + SVM -- сильный baseline
Изображения CNN (ResNet, EfficientNet, ConvNeXt), ViT Только Deep Learning. Transfer learning с ImageNet весами обязателен
Временные ряды ARIMA/SARIMA, Prophet, XGBoost + lag features, LSTM/GRU, Temporal Fusion Transformer Традиционные хороши для одномерных; ML -- для многомерных с экзогенными переменными

По требованиям

Требование Рекомендуемые модели Избегать
Интерпретируемость LogReg (коэффициенты), DecisionTree (правила), Linear + SHAP. Для любой модели: SHAP, LIME, Permutation importance Глубокие NN, сложные ансамбли без explainability
Скорость inference LogReg, DecisionTree (неглубокое), Quantized NN, Distilled models KNN, SVM (RBF) на больших данных, Large Ensembles
Калиброванные вероятности LogReg (изначально), CalibratedClassifierCV, NN + Temperature Scaling RF, XGBoost, SVM -- некалиброваны, нужен CalibratedClassifierCV
Устойчивость к выбросам Tree-based (разбиение по порогам), RobustScaler + модель, Huber loss Linear Regression, KNN, PCA -- чувствительны к выбросам

Baseline стратегия

Шаг Classification Regression
1. Baseline (начни просто) LogisticRegression Ridge
2. Усложнение RandomForest -> XGBoost RandomForest -> XGBoost
3. Fine-tune Подбор гиперпараметров лучшей модели Подбор гиперпараметров лучшей модели
4. Ensemble (опционально) Stacking / averaging лучших Stacking / averaging лучших

Сравнительная таблица

Критерий LogReg RF XGBoost SVM NN
Скорость обучения ★★★★★ ★★★☆☆ ★★☆☆☆ ★★☆☆☆ ★☆☆☆☆
Скорость inference ★★★★★ ★★★☆☆ ★★★☆☆ ★★☆☆☆ ★★★★☆
Качество на таблицах ★★★☆☆ ★★★★☆ ★★★★★ ★★★☆☆ ★★★☆☆
Качество на изображениях ★☆☆☆☆ ★☆☆☆☆ ★☆☆☆☆ ★★☆☆☆ ★★★★★
Интерпретируемость ★★★★★ ★★★☆☆ ★★☆☆☆ ★★☆☆☆ ★☆☆☆☆
Малые данные ★★★★☆ ★★★☆☆ ★★☆☆☆ ★★★★★ ★☆☆☆☆
Большие данные ★★★★★ ★★★☆☆ ★★★★☆ ★☆☆☆☆ ★★★★★
Категории ★★☆☆☆ ★★★★☆ ★★★★★ ★★☆☆☆ ★★★☆☆

Типичные ошибки

Нейросети на табличных данных

Начинать с нейросетей на табличных данных -- ошибка. XGBoost/LightGBM почти всегда лучше на структурированных данных. Исключения: очень большие данные (миллионы строк) или специфичные задачи (entity embeddings). Benchmark: TabNet бьет NN на табличных, но проигрывает XGBoost.

Accuracy на несбалансированных данных

При 95% негативных примеров модель-константа ("всегда нет") дает accuracy 95%. Используй F1-score (баланс precision/recall), ROC-AUC (порог-инвариантно), или PR-AUC (при сильном дисбалансе). На собеседовании: "какую метрику выбрать?" -- это проверка на понимание дисбаланса.

Подбор гиперпараметров на тесте

Если подбираешь гиперпараметры, глядя на тестовый скор -- это data leakage. Тест "видел" модель через гиперпараметры. Используй train/val/test split или nested cross-validation: внешний loop для оценки, внутренний для подбора.

Другие частые ошибки:

  • Забыть нормализацию для SVM/KNN -- эти модели чувствительны к масштабу признаков
  • Забыть про feature engineering -- хорошие фичи важнее сложной модели
  • LinearRegression без проверки допущений (линейность, гомоскедастичность, нормальность остатков)
  • Переусложнение: простая модель с хорошими фичами > сложная модель с сырыми данными

Вопросы для собеседования

Табличные данные, 50K строк, 200 признаков, задача бинарной классификации. Какую модель выберете и почему?

❌ «Нейросеть, потому что много признаков» -- на табличных данных NN почти всегда проигрывает boosting.

✅ Baseline: LogisticRegression -- быстрый, интерпретируемый, покажет linear separability. Основная модель: XGBoost или LightGBM -- лучшее качество на табличных данных (подтверждено бенчмарками: Grinsztajn et al. 2022, "Why do tree-based models still outperform deep learning on tabular data?"). Порядок: (1) LogReg baseline, (2) RF для sanity check, (3) XGBoost/LightGBM с подбором гиперпараметров. 200 признаков -- возможно нужен feature selection (Lasso, permutation importance). 50K строк -- достаточно для cross-validation (5-fold stratified).

Когда нейросети лучше gradient boosting? Приведите 3 примера.

❌ «Нейросети всегда лучше если достаточно данных» -- не верно для табличных данных.

✅ (1) Изображения -- CNN/ViT используют spatial structure (convolutions, patch embeddings), деревья не могут. (2) Текст -- BERT и трансформеры используют contextual embeddings и attention, TF-IDF + XGBoost проигрывает на сложном NLU. (3) Multimodal -- когда нужно объединить текст + изображения + таблицы в одной модели (нейросеть может иметь разные ветки). Дополнительно: sequence-to-sequence задачи (перевод, суммаризация), reinforcement learning. На табличных данных нейросети проигрывают из-за: inductive bias деревьев лучше подходит для нерегулярных функций, деревья не требуют feature scaling, boosting менее чувствителен к гиперпараметрам.

Клиент требует объяснимость модели. Как совместить качество и интерпретируемость?

❌ «Используйте DecisionTree -- он интерпретируемый» -- жертвует качеством без необходимости.

✅ Подход зависит от уровня объяснимости: (1) Глобальная (какие факторы важны): обучить XGBoost для качества + SHAP summary plot для объяснения feature importance. (2) Локальная (почему конкретное решение): SHAP waterfall / LIME для объяснения каждого prediction. (3) Regulatory (полная прозрачность): EBM (Explainable Boosting Machine от InterpretML) -- качество ~XGBoost, полная интерпретируемость. Или Lasso + feature engineering. Антипаттерн: учить простую модель "для объяснения" и сложную "для решений" -- модели могут давать разные ответы.

Самопроверка

  1. Выбор модели: Задача -- предсказание оттока клиентов. 10K строк, 30 признаков (числовые + категориальные), дисбаланс 5% churned. Выберите модель, обоснуйте, назовите 3 ключевых гиперпараметра для тюнинга.

  2. Обоснование baseline: Почему baseline для классификации -- LogisticRegression, а не Random Forest? Назовите 3 причины почему простой baseline критически важен.

  3. Trade-off: Вам предлагают выбор: (A) XGBoost с ROC-AUC 0.92, но inference 50ms; (B) LogReg с ROC-AUC 0.87, но inference 1ms. Система обрабатывает 10K запросов/сек. Какую модель выберете и при каких условиях измените решение?


Источники

  1. Grinsztajn et al. -- "Why do tree-based models still outperform deep learning on tabular data?" (NeurIPS 2022)
  2. scikit-learn -- Choosing the right estimator
  3. Nori et al. -- "InterpretML: A Unified Framework for Machine Learning Interpretability" (arXiv:1909.09223)
  4. Fernandez-Delgado et al. -- "Do we Need Hundreds of Classifiers?" (JMLR 2014) -- сравнение 179 классификаторов на 121 датасете

See Also