본문 바로가기
Machine Learning/Model

7. CatBoost

by 베짱이28호 2024. 10. 2.

7. CatBoost


1. CatBoost

CatBoost는 Yandex에서 개발한 그래디언트 부스팅 라이브러리로, 범주형 변수를 효과적으로 처리하고 높은 성능을 제공하는 알고리즘이다.

 

장점

  • 범주형 변수의 자동 처리
  • 과적합에 강한 내성
  • 기본 파라미터로도 좋은 성능
  • GPU 학습 지원
  • 결측치 자동 처리

단점

  • 다른 부스팅 알고리즘에 비해 학습 속도가 느림
  • 메모리 사용량이 큼
  • 대규모 데이터셋에서 시간이 많이 소요

2. CatBoost의 주요 특징

  • 순서형 부스팅(Ordered Boosting): 데이터 순서에 따른 편향을 줄이는 알고리즘
  • 대칭 트리(Symmetric Trees): 같은 분할 조건을 가진 균형 잡힌 트리 생성
  • 범주형 변수 인코딩: 타겟 통계량 기반의 자동 인코딩
  • GPU 지원: 병렬 학습을 통한 성능 향상

3. 코드 실습

import numpy as np
import pandas as pd
from catboost import CatBoostClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# 데이터 생성
X, y = make_classification(n_samples=1000, n_features=20,
                         n_informative=15, n_redundant=5,
                         random_state=42)

# 학습/테스트 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.2,
                                                    random_state=42)

# CatBoost 모델 생성
cat_model = CatBoostClassifier(
    iterations=100,
    learning_rate=0.1,
    depth=5,
    l2_leaf_reg=3,
    random_seed=42,
    verbose=False
)

# 모델 학습
cat_model.fit(X_train, y_train)

# 특성 중요도 시각화
plt.figure(figsize=(10, 6))
feature_importance = cat_model.feature_importances_
feature_names = [f'feature_{i}' for i in range(X.shape)]
plt.bar(feature_names, feature_importance)
plt.xticks(rotation=45)
plt.title('CatBoost Feature Importance')
plt.tight_layout()
plt.show()

# 예측 및 성능 평가
from sklearn.metrics import accuracy_score, classification_report

y_pred = cat_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"모델 정확도: {accuracy:.4f}")
print("\n분류 보고서:")
print(classification_report(y_test, y_pred))

'Machine Learning > Model' 카테고리의 다른 글

6. LightGBM  (0) 2024.04.16
5. XGBoost  (0) 2024.02.07
4. Random Forest  (0) 2024.01.29
3. SVM  (0) 2024.01.09
2. Logistic Regression  (0) 2023.12.25
1. K-NN (K-Nearest Neighbor) 알고리즘  (0) 2023.12.18

댓글