Machine Learning R

Tree Based Models in R

tutorial สอนสร้างและจูนโมเดล tree based ใน R พร้อมตัวอย่างโค้ด ใช้งานได้จริงสำหรับ decision tree และ random forest

วันนี้เราจะมาสอนสร้าง tree-based models ง่ายๆใน R

เรา assume ว่าเพื่อนๆรู้เรื่อง machine learning นิดหน่อย เช่น ทำไมต้อง train/ test split และทำไมต้องจูนค่า hyperparameter เป็นต้น ถ้าใครยังไม่รู้ว่า ML คืออะไร? ลองอ่านบทความแนะนำของ dataquest ได้ในลิ้งด้านล่าง

tutorial วันนี้แบ่งเป็น 5 พาร์ท ใช้เวลาอ่านและทำตามประมาณ 20 นาที

  1. Getting to Know R – Optional สำหรับเพื่อนๆที่ยังไม่เคยเขียน R มาก่อนเลย
  2. Prepare Data
  3. Decision Tree
  4. Random Forest
  5. Summary

Ready to fly? เปิด RStudio ขึ้นมาแล้วลอง copy โค้ดด้านล่างไปรันใน console ได้เลย 😛


Getting to Know R

5 นาที – อธิบายการเขียน R เบื้องต้น ถ้าใครเขียน R เป็นบ้างแล้ว สามารถ skip ไปที่หัวข้อ Prepare Data ได้เลย

  • line 3-6 ลองสร้าง vector ง่ายๆด้วยฟังชั่น c()
  • line 9-12 เราใช้ฟังชั่น data.frame() เพื่อสร้าง dataframe ซึ่งเป็นหัวใจสำคัญของการทำ data analysis และ machine learning ใน R
  • line 15-21 คือฟังชั่นที่เราใช้บ่อยๆกับ dataframe เช่น str, head, tail และ summary
# R crash course
# create vectors
customer_name <- c("David", "Mary", "John", "Jack", "Daniel")
customer_age <- c(30, 32, 28, 20, 31)
customer_gender <- c("M", "F", "M", "M", "M")
customer_purchased <- c(TRUE, FALSE, FALSE, TRUE, TRUE)
# create a dataframe
customer_dataframe <- data.frame(customer_name,
customer_age,
customer_gender,
customer_purchased)
# first look at dataframe
print(customer_dataframe)
# useful functions to work with dataframe
str(customer_dataframe)
head(customer_dataframe)
tail(customer_dataframe)
summary(customer_dataframe)
view raw .R hosted with ❤ by GitHub

จบพาร์ทแรก ตอนนี้เพื่อนๆสามารถเขียนโค้ด R ง่ายๆเพื่อสร้าง vector และ dataframe รวมถึงการใช้งานฟังชั่นเบื้องต้น session ต่อไป เราจะสอนโหลด dataset สำหรับ tutorial นี้


Prepare Data

5 นาที – โหลด Breast Cancer dataset เข้าสู่ RStudio จัดการกับ missing values และแบ่งข้อมูลเป็น 70% train และ 30% test

Breast Cancer เป็นข้อมูลสำหรับปัญหา binary classification มีทั้งหมด 11 columns 699 observations โดยตัวแปรที่เราต้องการ predict คือ Class {benign, malignant}

  • line 2-3 ดาวน์โหลดและติดตั้ง package mlbench ใน R ซึ่งเป็น package ที่รวบรวม datasets สำหรับงาน machine learning จากเว็บ UCI ML repository
  • line 6 ใช้ฟังชั่น data() เพื่อโหลด BreastCancer dataframe เข้าไปใน R
  • line 21-24 เป็น standard code ที่ใช้ split data เป็นสองส่วน – 70% train และ 30% test ถ้าใครอยากลองเปลี่ยน ratio สามารถเปลี่ยนได้ที่เลข 0.7 (เป็น 0.8 หรือ 0.6 ก็ได้)
# install new package
install.packages("mlbench")
library(mlbench)
# load data into R
data(BreastCancer)
# review dataset
str(BreastCancer)
head(BreastCancer)
summary(BreastCancer)
# remove ID column
BreastCancer$Id <- NULL
# remove missing values
BreastCancer <- BreastCancer[complete.cases(BreastCancer),]
# prepare dataset
# split data into 70% train and 30% test sets
set.seed(123)
idx <- sample(nrow(BreastCancer), 0.7*nrow(BreastCancer))
train_df <- BreastCancer[idx, ]
test_df <- BreastCancer[-idx, ]
view raw .R hosted with ❤ by GitHub

Good to know – ในชีวิตจริงเรานิยม split data เป็นสามส่วน {train, validate, test} หรือใช้ k-fold cross validation ในการ train model


Decision Tree

5 นาที – train decision tree และการจูนค่า complexity parameter เพื่อให้ได้ accuracy ที่สูงขึ้นและลดการ overfit ของโมเดล

หน้าตาของ decision tree เวลาเรา visualize จะออกมาแบบรูปด้านล่าง

โดยตัวแปรแรกที่ถูกใช้ split data ที่ root node (ด้านบนสุดของ tree) คือ Cell.size = 1,2 ถ้าตอบ yes จะวิ่งไปทางซ้าย ตอบ no จะไปทางขวา จนลงมาถึง terminal node (ด้านล่างสุดของ tree) ที่กระบวนการ split data หยุดตรงนี้ มาลองอ่าน diagram กัน

  • ถ้า case นี้มี Cell.size = 1,2 โมเดลจะ predict ว่า benign (เซลล์ดี)
  • ถ้า case นี้มี Cell.size > 2 และ Cell.shape > 2 โมเดลจะ predict ว่า malignant (เซลล์ร้าย)
decision tree in R
visualize decision tree using rpart.plot() function

สำหรับ package หลักที่เราใช้ build และ visualize tree ใน R คือ rpart และ rpart.plot

  • line 8 คือการเขียนโค้ดเพื่อสร้าง decision tree model ด้วยฟังชั่น rpart()
    • Class ~ . คือ formula ใน R อ่านว่า “ตัวแปร Class เป็นฟังชั่นของตัวแปร x ทั้งหมดใน train dataset”
  • line 11 เราเรียกดูค่า complexity parameter (เรียกสั้นๆว่า cp) ที่ทำให้ค่า xerror ต่ำที่สุด
    • cp สูงเกินไป – tree ของเราจะ underfit ทำให้ได้ accuracy ต่ำ ↔ xerror สูง
    • cp ต่ำเกินไป – tree ของเราจะ overfit ไม่สามารถนำโมเดลไปใช้กับ new data ได้
    • เราต้องเลือกค่า cp ที่เหมาะสมเพื่อให้ได้ optimal final model
  • line 19 เราใช้ฟังชั่น prune() เพื่อสร้าง decision tree ด้วยค่า cp ที่ดีที่สุดของเรา ในตัวอย่างเราเลือกใช้ cp = 0.01 สำหรับ final model
  • line 22 visualize final model ของเราด้วยฟังชั่น rpart.plot()
# install.packages("rpart")
# install.packages("rpart.plot")
library(rpart)
library(rpart.plot)
# train a decision tree
set.seed(123)
dt_model <- rpart(Class ~ ., data = train_df, method = "class")
# find the best cp hyperparameter
dt_model$cptable
# CP nsplit rel error xerror xstd
# 1 0.78787879 0 1.0000000 1.0000000 0.06299647
# 2 0.05454545 1 0.2121212 0.2242424 0.03540970
# 3 0.01000000 2 0.1575758 0.1696970 0.03111626
# prune our model for higher accuracy
dt_model_final <- prune(dt_model, cp = 0.01000000)
# plot model
rpart.plot(dt_model_final)
# prediction
p <- predict(dt_model_final, newdata=test_df, type="class")
table(test_df$Class, p)
# benign malignant
# benign 124 7
# malignant 4 70
view raw .R hosted with ❤ by GitHub

พอเรา prune จนได้ final model แล้ว เราจะใช้มันทำนาย test dataset ด้วยฟังชั่น predict() และสร้าง confusion matrix เพื่อวัด accuracy ของโมเดลด้วยฟังชั่น table() ใน line 25-26

เราสามารถคำนวณ accuracy จาก confusion matrix ด้านบน ด้วยการหาผลรวมเส้นทแยงมุม หารด้วยจำนวน testing cases ทั้งหมดที่เราทดสอบ (124 + 70) / (124 + 7 + 4 + 70) = 0.9463415

decision tree ที่เราสร้างขึ้นมาใช้ cp = 0.10 และได้ test accuracy = 94.63%


Random Forest

5 นาที – train random forest ด้วยฟังชั่น randomForest()

cats

concept ของ random forest คือการสร้าง decision tree หลายๆต้น โดยค่า default ของ number of trees – ntree ในฟังชั่น randomForest() จะอยู่ที่ 500 ต้น แล้วค่อยเอาผล predictions ของทั้ง 500 ต้นมาโหวตกันว่า new case นั้นจะเป็น benign หรือ malignant

เช่น จากทั้งหมด 500 ต้น – 400 ต้นทำนายว่า benign และอีก 100 ต้นทำนายว่า malignant เราจะยึดผลโหวตส่วนใหญ่ final prediction เท่ากับ benign

random forest เป็นโมเดลประเภท Ensemble Learning ที่เกิดจากการสร้างและรวมหลายๆโมเดลเข้าด้วยกันเพื่อให้ได้ model performance ที่ดีขึ้น and it works !!

# install.packages("randomForest")
library(randomForest)
# build random forest model
set.seed(123)
rf_model <- randomForest(Class ~ ., data = train_df)
# print model
print(rf_model)
# predict test data and compute accuracy
p <- predict(rf_model, newdata = test_df)
table(test_df$Class, p)
# p
# benign malignant
# benign 127 4
# malignant 1 73
view raw .R hosted with ❤ by GitHub

วิธีการคำนวณ accuracy จาก confusion matrix จะเหมือนกับของ decision tree เลย (127 + 73) / (127 + 4 + 1 + 73) = 0.9756098

random forest ที่เราสร้างขึ้นมาใช้ ntree = 500 และได้ test accuracy = 97.56% สูงกว่า decision tree ประมาณ 3% แต่แลกมากับเวลาในการ train ที่นานขึ้น (จะเห็นความแตกต่างเรื่องเวลาชัดมาก ถ้า dataset เราใหญ่)


Summary

ใน tutorial นี้ เราเรียน concept ของ tree-based models เบื้องต้นใน R ซึ่งสองตัวที่เราใช้กันเยอะมากในหลายๆ applications คือ decision tree และ random forest

  • decision tree – train เร็ว accuracy ปานกลาง ค่อนข้าง overfit ถ้าเราไม่ prune มันก่อน แต่ข้อดีคืออธิบายง่ายมาก
  • random forest – train ช้า accuracy สูง แต่อธิบายยากเพราะกระบวนการสร้าง trees 500 ต้นเกิดขึ้นแบบ random อย่างที่ชื่อ algorithm implies

decision tree ปกติจะเป็นโมเดลแรกๆที่เราลองสร้าง (baseline model) แล้วค่อยพยายาม improve performance ด้วยการจูน hyperparameter หรือเปลี่ยนไปใช้ tree-based แบบอื่นๆอย่าง random forest (bagging algorithm) หรือ xgboost (boosting algorithm)

อยากเรียน algorithm อะไรอีก? คอมเม้นบอกเราใต้บล๊อกวันนี้ได้เลย 😎

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.