본문 바로가기
AI기초

sklearn 라이브러리 설치, iris DesicionTree Classification 실습

by AI독학 2024. 1. 6.

 

오늘은 iris data를 활용하여, 분류를 해보겠다. 

iris data는 scikit-learn에서 가져올 수 있다. 

 

1
2
3
4
5
6
7
8
import sklearn
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import pandas as pd 
 
iris = load_iris()
print(type(iris))
cs

 

 

다음처럼 Dictionary 형태로 되어 있어서, 아래쪽 스크롤하면 분류 결과도 있다. 

data는 numpy array로 되어 있다. 

 

1
2
3
4
iris_data = iris.data #x data set
iris_label = iris.target
print("iris target 명 :", iris.target_names)
print('iris target값 :', iris_label)
cs

 

 

 

1
2
3
4
5
6
7
x_train, x_test, y_train, y_test = train_test_split(iris_data, iris_label, test_size = 0.2, random_state = 1)
print("DecisionTreeClassifier 생성")
dc_tree = DecisionTreeClassifier(random_state=1)
dc_tree.fit(x_train, y_train) # 학습 수행
pred = dc_tree.predict(x_test)
print("예측값",pred)
print("실제값",y_test)
cs

 

결과 값 

 

DecisionTreeClassifier 생성
예측값 [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 2 0 2 1 0 0 1 2]
실제값 [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2]

 

한개 빼고는 다 맞췄는 것을 볼 수 있다. 

예측 정확도를 뽑아 보자 

정확도는 맞춘 개수 / 전체 개수를 하면 확인할 수 있다. 

전체 개수가 30개이기 때문에 29/30일 하면 된다. 

1
2
from sklearn.metrics import accuracy_score
print("예측 정확도 :", accuracy_score(y_test, pred))
cs

 

예측 정확도 일치 한다.