본문 바로가기
프로그래밍/Scikit-Learn

[Scikit-Learn] 30. Isolation Forest (feat. IsolationForest)

by 부자 꽁냥이 2023. 5. 21.

Isolation Forest는 이진 탐색 나무를 이용하여 데이터의 이상치 여부를 판단하게 해주는 알고리즘이다. Scikit-Learn(sklearn)에서는 IsolationForest 클래스를 이용하면 Isolation Forest 알고리즘을 수행할 수 있다. 이번 포스팅에서는 IsolationForest의 사용법을 알아본다.

 

Isolation Forest에 대한 개념은 아래 포스팅을 참고하기 바란다.

 

44. Isolation Forest에 대해서 알아보자.

 

44. Isolation Forest에 대해서 알아보자.

이번 포스팅에서는 모델 기반 이상치 탐지 방법인 Isolation Forest에 대해서 알아보고자 한다. - 목차 - 1. Isolation Forest이란 무엇인가? 2. Isolation Forest 알고리즘 3. 예제 3. 장단점 1. Isolation Forest이란

zephyrus1111.tistory.com


   IsolationForest

먼저 시뮬레이션용 데이터를 만들어준다. 원래 Isolation Forest는 label이 없어도 동작하는 알고리즘이지만 여기서는 이상치를 잘 잡아내는지 확인하기 위하여 label을 만들어준 것이다.

 

import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus'] = False
import numpy as np

from sklearn.model_selection import train_test_split

## 데이터 생성
n_samples, n_outliers = 120, 40
rng = np.random.RandomState(0)
covariance = np.array([[0.5, -0.1], [0.7, 0.4]])
cluster_1 = 0.4 * rng.randn(n_samples, 2) @ covariance + np.array([2, 2])  
cluster_2 = 0.3 * rng.randn(n_samples, 2) + np.array([-2, -2])
outliers = rng.uniform(low=-4, high=4, size=(n_outliers, 2))

## X 데이터
X = np.concatenate([cluster_1, cluster_2, outliers]) 
## 라벨 1: 정상, -1: 이상치
labels = np.concatenate([np.ones((2 * n_samples), dtype=int), -np.ones((n_outliers), dtype=int)])

 

데이터를 시각화해 보자.

 

fig = plt.figure(figsize=(8,8))
fig.set_facecolor('white')
ax = fig.add_subplot()
ax.scatter(X[:, 0], X[:, 1], c=labels)
plt.show()

 

이제 IsolationForest의 사용방법을 알아보자. 굉장히 간단하다. 나무 개수 n_estimators, 샘플링 사이즈 max_samples를 설정하고 fit 메서드에 데이터를 넣어주면 끝난다.

 

from sklearn.ensemble import IsolationForest

i_forest = IsolationForest(
                n_estimators=100, ## 트리 개수
                max_samples=256 ## 샘플링 사이즈
            ).fit(X)

 

이제 예측 성능을 살펴보자. 예측은 predict 메서드를 이용하며 이상치 점수는 score_samples 메서드를 이용한다. 이때 score_samples는 음수 부호가 붙어서 나오므로 이를 양수로 바꿔줘야 한다.

 

print('정확도 :', np.mean(labels==i_forest.predict(X)))
print('이상치 점수 상위 5개: ', -np.sort(i_forest.score_samples(X))[:5])

 

 

정확도가 98.5%가 나왔고 상위 이상치 점수가 0.7보다 크다는 것을 알 수 있다.


댓글


맨 위로