오늘은 iris data를 활용하여, 분류를 해보겠다.
iris data는 scikit-learn에서 가져올 수 있다.
1
2
3
4
5
6
7
8
|
import sklearn
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
iris = load_iris()
print(type(iris))
|
cs |
다음처럼 Dictionary 형태로 되어 있어서, 아래쪽 스크롤하면 분류 결과도 있다.
data는 numpy array로 되어 있다.
1
2
3
4
|
iris_data = iris.data #x data set
iris_label = iris.target
print("iris target 명 :", iris.target_names)
print('iris target값 :', iris_label)
|
cs |
1
2
3
4
5
6
7
|
x_train, x_test, y_train, y_test = train_test_split(iris_data, iris_label, test_size = 0.2, random_state = 1)
print("DecisionTreeClassifier 생성")
dc_tree = DecisionTreeClassifier(random_state=1)
dc_tree.fit(x_train, y_train) # 학습 수행
pred = dc_tree.predict(x_test)
print("예측값",pred)
print("실제값",y_test)
|
cs |
결과 값
DecisionTreeClassifier 생성
예측값 [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 2 0 2 1 0 0 1 2]
실제값 [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2]
한개 빼고는 다 맞췄는 것을 볼 수 있다.
예측 정확도를 뽑아 보자
정확도는 맞춘 개수 / 전체 개수를 하면 확인할 수 있다.
전체 개수가 30개이기 때문에 29/30일 하면 된다.
1
2
|
from sklearn.metrics import accuracy_score
print("예측 정확도 :", accuracy_score(y_test, pred))
|
cs |
예측 정확도 일치 한다.
'AI기초' 카테고리의 다른 글
[데이터 전처리]데이터 인코딩(레이블 /원핫 인코딩 ) (1) | 2024.01.07 |
---|