Machine Learning R

สร้างโมเดล Tree Based ง่ายๆด้วย R

tutorial สอนสร้างและจูนโมเดล tree based ใน R พร้อมตัวอย่างโค้ด ใช้งานได้จริงสำหรับ decision tree และ random forest

วันนี้เราจะมาสอนสร้าง tree-based models ง่ายๆใน R

เรา assume ว่าเพื่อนๆรู้เรื่อง machine learning นิดหน่อย เช่น ทำไมต้อง train/ test split และทำไมต้องจูนค่า hyperparameter เป็นต้น ถ้าใครยังไม่รู้ว่า ML คืออะไร? ลองอ่านบทความแนะนำของ dataquest ได้ในลิ้งด้านล่าง

tutorial วันนี้แบ่งเป็น 5 พาร์ท ใช้เวลาอ่านและทำตามประมาณ 20 นาที

  1. Getting to Know R – Optional สำหรับเพื่อนๆที่ยังไม่เคยเขียน R มาก่อนเลย
  2. Prepare Data
  3. Decision Tree
  4. Random Forest
  5. Summary

Ready to fly? เปิด RStudio ขึ้นมาแล้วลอง copy โค้ดด้านล่างไปรันใน console ได้เลย 😛


Getting to Know R

5 นาที – อธิบายการเขียน R เบื้องต้น ถ้าใครเขียน R เป็นบ้างแล้ว สามารถ skip ไปที่หัวข้อ Prepare Data ได้เลย

  • line 3-6 ลองสร้าง vector ง่ายๆด้วยฟังชั่น c()
  • line 9-12 เราใช้ฟังชั่น data.frame() เพื่อสร้าง dataframe ซึ่งเป็นหัวใจสำคัญของการทำ data analysis และ machine learning ใน R
  • line 15-21 คือฟังชั่นที่เราใช้บ่อยๆกับ dataframe เช่น str, head, tail และ summary

จบพาร์ทแรก ตอนนี้เพื่อนๆสามารถเขียนโค้ด R ง่ายๆเพื่อสร้าง vector และ dataframe รวมถึงการใช้งานฟังชั่นเบื้องต้น session ต่อไป เราจะสอนโหลด dataset สำหรับ tutorial นี้


Prepare Data

5 นาที – โหลด Breast Cancer dataset เข้าสู่ RStudio จัดการกับ missing values และแบ่งข้อมูลเป็น 70% train และ 30% test

Breast Cancer เป็นข้อมูลสำหรับปัญหา binary classification มีทั้งหมด 11 columns 699 observations โดยตัวแปรที่เราต้องการ predict คือ Class {benign, malignant}

  • line 2-3 ดาวน์โหลดและติดตั้ง package mlbench ใน R ซึ่งเป็น package ที่รวบรวม datasets สำหรับงาน machine learning จากเว็บ UCI ML repository
  • line 6 ใช้ฟังชั่น data() เพื่อโหลด BreastCancer dataframe เข้าไปใน R
  • line 21-24 เป็น standard code ที่ใช้ split data เป็นสองส่วน – 70% train และ 30% test ถ้าใครอยากลองเปลี่ยน ratio สามารถเปลี่ยนได้ที่เลข 0.7 (เป็น 0.8 หรือ 0.6 ก็ได้)

Good to know – ในชีวิตจริงเรานิยม split data เป็นสามส่วน {train, validate, test} หรือใช้ k-fold cross validation ในการ train model


Decision Tree

5 นาที – train decision tree และการจูนค่า complexity parameter เพื่อให้ได้ accuracy ที่สูงขึ้นและลดการ overfit ของโมเดล

หน้าตาของ decision tree เวลาเรา visualize จะออกมาแบบรูปด้านล่าง

โดยตัวแปรแรกที่ถูกใช้ split data ที่ root node (ด้านบนสุดของ tree) คือ Cell.size = 1,2 ถ้าตอบ yes จะวิ่งไปทางซ้าย ตอบ no จะไปทางขวา จนลงมาถึง terminal node (ด้านล่างสุดของ tree) ที่กระบวนการ split data หยุดตรงนี้ มาลองอ่าน diagram กัน

  • ถ้า case นี้มี Cell.size = 1,2 โมเดลจะ predict ว่า benign (เซลล์ดี)
  • ถ้า case นี้มี Cell.size > 2 และ Cell.shape > 2 โมเดลจะ predict ว่า malignant (เซลล์ร้าย)
decision tree in R
visualize decision tree using rpart.plot() function

สำหรับ package หลักที่เราใช้ build และ visualize tree ใน R คือ rpart และ rpart.plot

  • line 8 คือการเขียนโค้ดเพื่อสร้าง decision tree model ด้วยฟังชั่น rpart()
    • Class ~ . คือ formula ใน R อ่านว่า “ตัวแปร Class เป็นฟังชั่นของตัวแปร x ทั้งหมดใน train dataset”
  • line 11 เราเรียกดูค่า complexity parameter (เรียกสั้นๆว่า cp) ที่ทำให้ค่า xerror ต่ำที่สุด
    • cp สูงเกินไป – tree ของเราจะ underfit ทำให้ได้ accuracy ต่ำ ↔ xerror สูง
    • cp ต่ำเกินไป – tree ของเราจะ overfit ไม่สามารถนำโมเดลไปใช้กับ new data ได้
    • เราต้องเลือกค่า cp ที่เหมาะสมเพื่อให้ได้ optimal final model
  • line 19 เราใช้ฟังชั่น prune() เพื่อสร้าง decision tree ด้วยค่า cp ที่ดีที่สุดของเรา ในตัวอย่างเราเลือกใช้ cp = 0.01 สำหรับ final model
  • line 22 visualize final model ของเราด้วยฟังชั่น rpart.plot()

พอเรา prune จนได้ final model แล้ว เราจะใช้มันทำนาย test dataset ด้วยฟังชั่น predict() และสร้าง confusion matrix เพื่อวัด accuracy ของโมเดลด้วยฟังชั่น table() ใน line 25-26

เราสามารถคำนวณ accuracy จาก confusion matrix ด้านบน ด้วยการหาผลรวมเส้นทแยงมุม หารด้วยจำนวน testing cases ทั้งหมดที่เราทดสอบ (124 + 70) / (124 + 7 + 4 + 70) = 0.9463415

decision tree ที่เราสร้างขึ้นมาใช้ cp = 0.10 และได้ test accuracy = 94.63%


Random Forest

5 นาที – train random forest ด้วยฟังชั่น randomForest()

cats

concept ของ random forest คือการสร้าง decision tree หลายๆต้น โดยค่า default ของ number of trees – ntree ในฟังชั่น randomForest() จะอยู่ที่ 500 ต้น แล้วค่อยเอาผล predictions ของทั้ง 500 ต้นมาโหวตกันว่า new case นั้นจะเป็น benign หรือ malignant

เช่น จากทั้งหมด 500 ต้น – 400 ต้นทำนายว่า benign และอีก 100 ต้นทำนายว่า malignant เราจะยึดผลโหวตส่วนใหญ่ final prediction เท่ากับ benign

random forest เป็นโมเดลประเภท Ensemble Learning ที่เกิดจากการสร้างและรวมหลายๆโมเดลเข้าด้วยกันเพื่อให้ได้ model performance ที่ดีขึ้น and it works !!

วิธีการคำนวณ accuracy จาก confusion matrix จะเหมือนกับของ decision tree เลย (127 + 73) / (127 + 4 + 1 + 73) = 0.9756098

random forest ที่เราสร้างขึ้นมาใช้ ntree = 500 และได้ test accuracy = 97.56% สูงกว่า decision tree ประมาณ 3% แต่แลกมากับเวลาในการ train ที่นานขึ้น (จะเห็นความแตกต่างเรื่องเวลาชัดมาก ถ้า dataset เราใหญ่)


Summary

ใน tutorial นี้ เราเรียน concept ของ tree-based models เบื้องต้นใน R ซึ่งสองตัวที่เราใช้กันเยอะมากในหลายๆ applications คือ decision tree และ random forest

  • decision tree – train เร็ว accuracy ปานกลาง ค่อนข้าง overfit ถ้าเราไม่ prune มันก่อน แต่ข้อดีคืออธิบายง่ายมาก
  • random forest – train ช้า accuracy สูง แต่อธิบายยากเพราะกระบวนการสร้าง trees 500 ต้นเกิดขึ้นแบบ random อย่างที่ชื่อ algorithm implies

decision tree ปกติจะเป็นโมเดลแรกๆที่เราลองสร้าง (baseline model) แล้วค่อยพยายาม improve performance ด้วยการจูน hyperparameter หรือเปลี่ยนไปใช้ tree-based แบบอื่นๆอย่าง random forest (bagging algorithm) หรือ xgboost (boosting algorithm)

อยากเรียน algorithm อะไรอีก? คอมเม้นบอกเราใต้บล๊อกวันนี้ได้เลย 😎

5 comments

  1. rpart.plot

    เกิดข้อผิดพลาดขึ้น แจ้งว่าในไลบรารี ไม่มีแพ็คเกจที่เรียกว่า “rpart.plot” ครับ
    ผมต้องทำอย่างไร? หรือว่าไม่จำเป็นต้องใช้ครับ ตามตัวอย่าง

  2. ในขั้นตอนสุดท้าย อยากให้ plot tree ของ random forest เพิ่มด้วยอ่ะครับ เพื่อการแปลผลครับ ไม่ทราบว่า เจ้าของblog พอมี code หรือป่าวครับ ถ้ามี ทำให้ดูหน่อยได้ไหมครับ ขอบคุณครับ

    1. เราไม่ plot tree ของ random forest ครับ เพราะว่ามันมีเป็น 100+ trees เลย ปกติเราดูพวก variable importance, lift chart, precision-recall พวกนี้แทนครับ

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.