Machine Learning R

สร้างโมเดล Tree Based ง่ายๆด้วย R

tutorial สอนสร้างและจูนโมเดล tree based ใน R พร้อมตัวอย่างโค้ด ใช้งานได้จริงสำหรับ decision tree และ random forest

วันนี้เราจะมาสอนสร้าง tree-based models ง่ายๆใน R

เรา assume ว่าเพื่อนๆรู้เรื่อง machine learning นิดหน่อย เช่น ทำไมต้อง train/ test split และทำไมต้องจูนค่า hyperparameter เป็นต้น ถ้าใครยังไม่รู้ว่า ML คืออะไร? ลองอ่านบทความแนะนำของ dataquest ได้ในลิ้งด้านล่าง

tutorial วันนี้แบ่งเป็น 5 พาร์ท ใช้เวลาอ่านและทำตามประมาณ 20 นาที

  1. Getting to Know R – Optional สำหรับเพื่อนๆที่ยังไม่เคยเขียน R มาก่อนเลย
  2. Prepare Data
  3. Decision Tree
  4. Random Forest
  5. Summary

Ready to fly? เปิด RStudio ขึ้นมาแล้วลอง copy โค้ดด้านล่างไปรันใน console ได้เลย 😛


Getting to Know R

5 นาที – อธิบายการเขียน R เบื้องต้น ถ้าใครเขียน R เป็นบ้างแล้ว สามารถ skip ไปที่หัวข้อ Prepare Data ได้เลย

  • line 3-6 ลองสร้าง vector ง่ายๆด้วยฟังชั่น c()
  • line 9-12 เราใช้ฟังชั่น data.frame() เพื่อสร้าง dataframe ซึ่งเป็นหัวใจสำคัญของการทำ data analysis และ machine learning ใน R
  • line 15-21 คือฟังชั่นที่เราใช้บ่อยๆกับ dataframe เช่น str, head, tail และ summary

# R crash course
# create vectors
customer_name <- c("David", "Mary", "John", "Jack", "Daniel")
customer_age <- c(30, 32, 28, 20, 31)
customer_gender <- c("M", "F", "M", "M", "M")
customer_purchased <- c(TRUE, FALSE, FALSE, TRUE, TRUE)
# create a dataframe
customer_dataframe <- data.frame(customer_name,
customer_age,
customer_gender,
customer_purchased)
# first look at dataframe
print(customer_dataframe)
# useful functions to work with dataframe
str(customer_dataframe)
head(customer_dataframe)
tail(customer_dataframe)
summary(customer_dataframe)
view raw .R hosted with ❤ by GitHub

จบพาร์ทแรก ตอนนี้เพื่อนๆสามารถเขียนโค้ด R ง่ายๆเพื่อสร้าง vector และ dataframe รวมถึงการใช้งานฟังชั่นเบื้องต้น session ต่อไป เราจะสอนโหลด dataset สำหรับ tutorial นี้


Prepare Data

5 นาที – โหลด Breast Cancer dataset เข้าสู่ RStudio จัดการกับ missing values และแบ่งข้อมูลเป็น 70% train และ 30% test

Breast Cancer เป็นข้อมูลสำหรับปัญหา binary classification มีทั้งหมด 11 columns 699 observations โดยตัวแปรที่เราต้องการ predict คือ Class {benign, malignant}

  • line 2-3 ดาวน์โหลดและติดตั้ง package mlbench ใน R ซึ่งเป็น package ที่รวบรวม datasets สำหรับงาน machine learning จากเว็บ UCI ML repository
  • line 6 ใช้ฟังชั่น data() เพื่อโหลด BreastCancer dataframe เข้าไปใน R
  • line 21-24 เป็น standard code ที่ใช้ split data เป็นสองส่วน – 70% train และ 30% test ถ้าใครอยากลองเปลี่ยน ratio สามารถเปลี่ยนได้ที่เลข 0.7 (เป็น 0.8 หรือ 0.6 ก็ได้)

# install new package
install.packages("mlbench")
library(mlbench)
# load data into R
data(BreastCancer)
# review dataset
str(BreastCancer)
head(BreastCancer)
summary(BreastCancer)
# remove ID column
BreastCancer$Id <- NULL
# remove missing values
BreastCancer <- BreastCancer[complete.cases(BreastCancer),]
# prepare dataset
# split data into 70% train and 30% test sets
set.seed(123)
idx <- sample(nrow(BreastCancer), 0.7*nrow(BreastCancer))
train_df <- BreastCancer[idx, ]
test_df <- BreastCancer[-idx, ]
view raw .R hosted with ❤ by GitHub

Good to know – ในชีวิตจริงเรานิยม split data เป็นสามส่วน {train, validate, test} หรือใช้ k-fold cross validation ในการ train model


Decision Tree

5 นาที – train decision tree และการจูนค่า complexity parameter เพื่อให้ได้ accuracy ที่สูงขึ้นและลดการ overfit ของโมเดล

หน้าตาของ decision tree เวลาเรา visualize จะออกมาแบบรูปด้านล่าง

โดยตัวแปรแรกที่ถูกใช้ split data ที่ root node (ด้านบนสุดของ tree) คือ Cell.size = 1,2 ถ้าตอบ yes จะวิ่งไปทางซ้าย ตอบ no จะไปทางขวา จนลงมาถึง terminal node (ด้านล่างสุดของ tree) ที่กระบวนการ split data หยุดตรงนี้ มาลองอ่าน diagram กัน

  • ถ้า case นี้มี Cell.size = 1,2 โมเดลจะ predict ว่า benign (เซลล์ดี)
  • ถ้า case นี้มี Cell.size > 2 และ Cell.shape > 2 โมเดลจะ predict ว่า malignant (เซลล์ร้าย)
decision tree in R
visualize decision tree using rpart.plot() function

สำหรับ package หลักที่เราใช้ build และ visualize tree ใน R คือ rpart และ rpart.plot

  • line 8 คือการเขียนโค้ดเพื่อสร้าง decision tree model ด้วยฟังชั่น rpart()
    • Class ~ . คือ formula ใน R อ่านว่า “ตัวแปร Class เป็นฟังชั่นของตัวแปร x ทั้งหมดใน train dataset”
  • line 11 เราเรียกดูค่า complexity parameter (เรียกสั้นๆว่า cp) ที่ทำให้ค่า xerror ต่ำที่สุด
    • cp สูงเกินไป – tree ของเราจะ underfit ทำให้ได้ accuracy ต่ำ ↔ xerror สูง
    • cp ต่ำเกินไป – tree ของเราจะ overfit ไม่สามารถนำโมเดลไปใช้กับ new data ได้
    • เราต้องเลือกค่า cp ที่เหมาะสมเพื่อให้ได้ optimal final model
  • line 19 เราใช้ฟังชั่น prune() เพื่อสร้าง decision tree ด้วยค่า cp ที่ดีที่สุดของเรา ในตัวอย่างเราเลือกใช้ cp = 0.01 สำหรับ final model
  • line 22 visualize final model ของเราด้วยฟังชั่น rpart.plot()

# install.packages("rpart")
# install.packages("rpart.plot")
library(rpart)
library(rpart.plot)
# train a decision tree
set.seed(123)
dt_model <- rpart(Class ~ ., data = train_df, method = "class")
# find the best cp hyperparameter
dt_model$cptable
# CP nsplit rel error xerror xstd
# 1 0.78787879 0 1.0000000 1.0000000 0.06299647
# 2 0.05454545 1 0.2121212 0.2242424 0.03540970
# 3 0.01000000 2 0.1575758 0.1696970 0.03111626
# prune our model for higher accuracy
dt_model_final <- prune(dt_model, cp = 0.01000000)
# plot model
rpart.plot(dt_model_final)
# prediction
p <- predict(dt_model_final, newdata=test_df, type="class")
table(test_df$Class, p)
# benign malignant
# benign 124 7
# malignant 4 70
view raw .R hosted with ❤ by GitHub

พอเรา prune จนได้ final model แล้ว เราจะใช้มันทำนาย test dataset ด้วยฟังชั่น predict() และสร้าง confusion matrix เพื่อวัด accuracy ของโมเดลด้วยฟังชั่น table() ใน line 25-26

เราสามารถคำนวณ accuracy จาก confusion matrix ด้านบน ด้วยการหาผลรวมเส้นทแยงมุม หารด้วยจำนวน testing cases ทั้งหมดที่เราทดสอบ (124 + 70) / (124 + 7 + 4 + 70) = 0.9463415

decision tree ที่เราสร้างขึ้นมาใช้ cp = 0.10 และได้ test accuracy = 94.63%


Random Forest

5 นาที – train random forest ด้วยฟังชั่น randomForest()

cats

concept ของ random forest คือการสร้าง decision tree หลายๆต้น โดยค่า default ของ number of trees – ntree ในฟังชั่น randomForest() จะอยู่ที่ 500 ต้น แล้วค่อยเอาผล predictions ของทั้ง 500 ต้นมาโหวตกันว่า new case นั้นจะเป็น benign หรือ malignant

เช่น จากทั้งหมด 500 ต้น – 400 ต้นทำนายว่า benign และอีก 100 ต้นทำนายว่า malignant เราจะยึดผลโหวตส่วนใหญ่ final prediction เท่ากับ benign

random forest เป็นโมเดลประเภท Ensemble Learning ที่เกิดจากการสร้างและรวมหลายๆโมเดลเข้าด้วยกันเพื่อให้ได้ model performance ที่ดีขึ้น and it works !!

# install.packages("randomForest")
library(randomForest)
# build random forest model
set.seed(123)
rf_model <- randomForest(Class ~ ., data = train_df)
# print model
print(rf_model)
# predict test data and compute accuracy
p <- predict(rf_model, newdata = test_df)
table(test_df$Class, p)
# p
# benign malignant
# benign 127 4
# malignant 1 73
view raw .R hosted with ❤ by GitHub

วิธีการคำนวณ accuracy จาก confusion matrix จะเหมือนกับของ decision tree เลย (127 + 73) / (127 + 4 + 1 + 73) = 0.9756098

random forest ที่เราสร้างขึ้นมาใช้ ntree = 500 และได้ test accuracy = 97.56% สูงกว่า decision tree ประมาณ 3% แต่แลกมากับเวลาในการ train ที่นานขึ้น (จะเห็นความแตกต่างเรื่องเวลาชัดมาก ถ้า dataset เราใหญ่)


Summary

ใน tutorial นี้ เราเรียน concept ของ tree-based models เบื้องต้นใน R ซึ่งสองตัวที่เราใช้กันเยอะมากในหลายๆ applications คือ decision tree และ random forest

  • decision tree – train เร็ว accuracy ปานกลาง ค่อนข้าง overfit ถ้าเราไม่ prune มันก่อน แต่ข้อดีคืออธิบายง่ายมาก
  • random forest – train ช้า accuracy สูง แต่อธิบายยากเพราะกระบวนการสร้าง trees 500 ต้นเกิดขึ้นแบบ random อย่างที่ชื่อ algorithm implies

decision tree ปกติจะเป็นโมเดลแรกๆที่เราลองสร้าง (baseline model) แล้วค่อยพยายาม improve performance ด้วยการจูน hyperparameter หรือเปลี่ยนไปใช้ tree-based แบบอื่นๆอย่าง random forest (bagging algorithm) หรือ xgboost (boosting algorithm)

อยากเรียน algorithm อะไรอีก? คอมเม้นบอกเราใต้บล๊อกวันนี้ได้เลย 😎

10 comments

  1. rpart.plot

    เกิดข้อผิดพลาดขึ้น แจ้งว่าในไลบรารี ไม่มีแพ็คเกจที่เรียกว่า “rpart.plot” ครับ
    ผมต้องทำอย่างไร? หรือว่าไม่จำเป็นต้องใช้ครับ ตามตัวอย่าง

  2. ในขั้นตอนสุดท้าย อยากให้ plot tree ของ random forest เพิ่มด้วยอ่ะครับ เพื่อการแปลผลครับ ไม่ทราบว่า เจ้าของblog พอมี code หรือป่าวครับ ถ้ามี ทำให้ดูหน่อยได้ไหมครับ ขอบคุณครับ

    1. เราไม่ plot tree ของ random forest ครับ เพราะว่ามันมีเป็น 100+ trees เลย ปกติเราดูพวก variable importance, lift chart, precision-recall พวกนี้แทนครับ

  3. หากมีการใช้ 10-fold CV มาสร้าง decision tree เราจะสามารถกำหนด final decision tree ได้อย่างไรครับ เพราะโครงสร้างของ decision tree ใน แต่ละ fold ก็จะแตกต่างกัน

      1. ซึ่งคือ การปรับค่า cp ใช่ไหมค่ะ

      2. parameters มีหลายตัวเลยครับของ decision tree เช่น ความลึกของต้นไม้, จำนวน n ในแต่ละ node ฯลฯ ลองดูชื่อ parameters ต่างๆได้ที่นี่ครับ

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.