R 예제 코드 - KNN / k-NN / k-Nearest Neighber / k-최근접 이웃
k-NN을 iris 데이터에 적용해서 Species를 분류하는 R 코드를 만들어 보자.
▼ k-NN 알고리즘에 대한 이론적인 설명이 궁금하다면? ▼
2017/03/14 - [Analysis/ALGORITHM] - KNN / k-NN / k-Nearest Neighber / k-최근접 이웃 알고리즘
1. 데이터 준비하기
iris 데이터를 Sepal.Length / Sepal.Width / Species 3가지 변수만 있는 데이터로 단순한 데이터 data 로 바꿔보자.
data <- iris[, c("Sepal.Length", "Sepal.Width", "Species")]
2. 데이터 나누기 (train set, validation set, test set)
data 를 train / valid / test 로 나누자.
# 재현성을 위한 seed 설정
set.seed(123)
# idx 설정
idx <- sample(x = c("train", "valid", "test"),
size = nrow(data),
replace = TRUE,
prob = c(3, 1, 1))
# idx에 따라 데이터 나누기
train <- data[idx == "train", ]
valid <- data[idx == "valid", ]
test <- data[idx == "test", ]
3. 데이터 분포 확인하기
산점도를 그려서 train, valid, test 데이터의 분포를 확인해보자.
# 색상 투명도 설정을 위한 패키지 설치 및 라이브러리 불러오기
install.packages("scales")
library(scales)
# 이제 alpha 함수를 사용할 수 있다!
# train 산점도 그리기
plot(formula = Sepal.Length ~ Sepal.Width,
data = train,
col = alpha(c("purple", "blue", "green"), 0.7)[train$Species],
main = "train - Classification Species")
# valid 표시하기
points(formula = Sepal.Length ~ Sepal.Width,
data = valid,
pch = 17,
cex = 1.2,
col = "red")
# test 표시하기
points(formula = Sepal.Length ~ Sepal.Width,
data = test,
pch = 15,
cex = 1.2,
col = "orange")
# 범례 그리기
legend("topright",
c(levels(data$Species), "valid", "test"),
pch = c(1, 1, 1, 17, 15),
col = c(alpha(c("purple", "blue", "green"), 0.7), "red", "orange"),
cex = 0.9)
4. x와 y로 나누기
train 을 train_x, train_y로,
valid를 valid_x, valid_y로,
test를 test_x, test_y로 나누자.
# x는 3번째 열을 제외한다는 의미로 -3
train_x <- train[, -3]
valid_x <- valid[, -3]
test_x <- test[, -3]
# y는 3번째 열만 필터링한다는 의미로 3
train_y <- train[, 3]
valid_y <- valid[, 3]
test_y <- test[, 3]
5. knn 적용하기
먼저 k가 1일 때를 해본다. 다수결이 동점인 경우 (ties break) 랜덤성이 있으므로 재현성을 위해 seed를 설정했다.
# knn 함수를 사용하기 위해 class 패키지를 설치하고 라이브러리 불러오기
install.packages("class")
library(class)
# k = 1 일 때
set.seed(1234)
knn_1 <- knn(train = train_x,
test = valid_x,
cl = train_y,
k = 1)
# train 산점도 그리기
plot(formula = Sepal.Length ~ Sepal.Width,
data = train,
col = alpha(c("purple", "blue", "green"), 0.7)[train$Species],
main = "KNN (k = 1)")
# knn valid 결과 표시하기
points(formula = Sepal.Length ~ Sepal.Width,
data = valid,
pch = 17,
cex = 1.2,
col = alpha(c("purple", "blue", "green"), 0.7)[knn_1])
# 범례 그리기
legend("topright",
c(paste("train", levels(train$Species)), paste("valid", levels(valid$Species))),
pch = c(rep(1, 3), rep(17, 3)),
col = c(rep(alpha(c("purple", "blue", "green"), 0.7), 2)),
cex = 0.9)
# 분류 정확도 계산하기
accuracy_1 <- sum(knn_1 == valid_y) / length(valid_y) ; accuracy_1
아래와 같이 산점도가 그려지는데, setosa는 잘 분류되는 반면 versicolor와 virginica는 분류 오류가 있음을 알 수 있다.
이번에는 k가 21일 때를 해본다.
# k = 21 일 때
set.seed(1234)
knn_21 <- knn(train = train_x,
test = valid_x,
cl = train_y,
k = 21)
plot(formula = Sepal.Length ~ Sepal.Width,
data = train,
col = alpha(c("purple", "blue", "green"), 0.7)[train$Species],
main = "KNN (k = 21)")
# knn valid 결과 표시하기
points(formula = Sepal.Length ~ Sepal.Width,
data = valid,
pch = 17,
cex = 1.2,
col = alpha(c("purple", "blue", "green"), 0.7)[knn_21])
# 범례 그리기
legend("topright",
c(paste("train", levels(train$Species)), paste("valid", levels(valid$Species))),
pch = c(rep(1, 3), rep(17, 3)),
col = c(rep(alpha(c("purple", "blue", "green"), 0.7), 2)),
cex = 0.9)
# 분류 정확도 계산하기
accuracy_21 <- sum(knn_21 == valid_y) / length(valid_y) ; accuracy_21
6. k가 1부터 train 행 수까지 변화할 때 분류 정확도 구하기
반복문 for 를 이용하여 k가 1부터 89 (train 행 수) 까지 변화할 때 분류 정확도가 몇 % 되는지 그래프를 그려보고 최적의 k를 확인해 보자.
# 분류 정확도 사전 할당
accuracy_k <- NULL
# kk가 1부터 train 행 수까지 증가할 때 (반복문)
for(kk in c(1:nrow(train_x))){
# k가 kk일 때 knn 적용하기
set.seed(1234)
knn_k <- knn(train = train_x,
test = valid_x,
cl = train_y,
k = kk)
# 분류 정확도 계산하기
accuracy_k <- c(accuracy_k, sum(knn_k == valid_y) / length(valid_y))
}
# k에 따른 분류 정확도 데이터 생성
valid_k <- data.frame(k = c(1:nrow(train_x)), accuracy = accuracy_k)
# k에 따른 분류 정확도 그래프 그리기
plot(formula = accuracy ~ k,
data = valid_k,
type = "o",
pch = 20,
main = "validation - optimal k")
# 그래프에 k 라벨링 하기
with(valid_k, text(accuracy ~ k, labels = rownames(valid_k), pos = 1, cex = 0.7))
# 분류 정확도가 가장 높으면서 가장 작은 k는?
min(valid_k[valid_k$accuracy %in% max(accuracy_k), "k"])
아래 그래프를 보면 알 수 있듯이 k = 21일 때 정확도가 가장 높으면서 k가 가장 작기 때문에 우리는 21-NN 모델을 선택하게 된다.
그럼 이제 21-NN 모델이 얼마나 분류가 잘 되는지 test 데이터를 이용해서 표현해보자.
7. 21-NN에 test 데이터 적용하기
k = 21인 knn 모델에 test 데이터를 넣어 분류 결과를 분석해보자.
# 21-NN에 test 데이터 적용하기
set.seed(1234)
knn_21_test <- knn(train = train_x,
test = test_x,
cl = train_y,
k = 21)
# Confusion Matrix 틀 만들기
result <- matrix(NA, nrow = 3, ncol = 3)
rownames(result) <- paste0("real_", levels(train_y))
colnames(result) <- paste0("clsf_", levels(train_y))
# Confusion Matrix 값 입력하기
result[1, 1] <- sum(ifelse(test_y == "setosa" & knn_21_test == "setosa", 1, 0))
result[2, 1] <- sum(ifelse(test_y == "versicolor" & knn_21_test == "setosa", 1, 0))
result[3, 1] <- sum(ifelse(test_y == "virginica" & knn_21_test == "setosa", 1, 0))
result[1, 2] <- sum(ifelse(test_y == "setosa" & knn_21_test == "versicolor", 1, 0))
result[2, 2] <- sum(ifelse(test_y == "versicolor" & knn_21_test == "versicolor", 1, 0))
result[3, 2] <- sum(ifelse(test_y == "virginica" & knn_21_test == "versicolor", 1, 0))
result[1, 3] <- sum(ifelse(test_y == "setosa" & knn_21_test == "virginica", 1, 0))
result[2, 3] <- sum(ifelse(test_y == "versicolor" & knn_21_test == "virginica", 1, 0))
result[3, 3] <- sum(ifelse(test_y == "virginica" & knn_21_test == "virginica", 1, 0))
# Confusion Matrix 출력하기
result
# 최종 정확도 계산하기
sum(knn_21_test == test_y) / sum(result)
최종적으로 아래와 같이 Confusion Matrix와 최종 정확도가 출력된다.
> result
clsf_setosa clsf_versicolor clsf_virginica
real_setosa 9 0 0
real_versicolor 1 8 4
real_virginica 0 2 8
> sum(knn_21_test == test_y) / length(test_y)
[1] 0.78125
해석해보자면, 실제 setosa인데 제대로 경우가 9건이고 다른 Species로 분류된 경우는 0건이다.
그리고 실제 versicolor인데 setosa로 분류된 경우가 1건, 제대로 분류된 경우가 8건, virginica로 분류된 경우가 4건이다.
또 실제 virginica인데 setosa로 분류된 경우는 0건, versicolor로 분류된 경우가 2건, 제대로 분류된 경우가 8건이다.
최종적으로 이 데이터에 대한 21-NN 모델의 분류 정확도는 78.1% 라고 결론 지을 수 있다.
'회사생활 > R' 카테고리의 다른 글
[R 예제 코드] Logistic Regression / 로지스틱 회귀분석 (19) | 2017.03.22 |
---|---|
한눈에 정리하는 ggplot2를 이용한 R 시각화 기초 1 (2) | 2017.03.17 |
Train vs. Validation vs. Test Data (0) | 2017.03.14 |
R 시각화 - 산점도 (Basic Scatter Plot) (0) | 2017.03.13 |
R Studio 옵션 설정하기 (Global Options) (0) | 2017.03.10 |