diff --git a/README.md b/README.md index 6050d3a..45b46f4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -# Assorted TensorFlow tutorials +# Assorted tutorials -Disclaimer: This is a personal repo and not an official Google product. - -License Apache 2.0 +General disclaimer, this is my personal repo and not an official Google product. If you'd like to use this code, say, to build a mission critical component of your giant space laser, you should know there's no warranty, etc. diff --git a/decision_tree.ipynb b/decision_tree.ipynb new file mode 100644 index 0000000..07da773 --- /dev/null +++ b/decision_tree.ipynb @@ -0,0 +1,1040 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Code to accompany Machine Learning Recipes #8. We'll write a Decision Tree Classifier, in pure Python. Below each of the methods, I've written a little demo to help explain what it does." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# For Python 2 / 3 compatability\n", + "from __future__ import print_function" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Toy dataset.\n", + "# Format: each row is an example.\n", + "# The last column is the label.\n", + "# The first two columns are features.\n", + "# Feel free to play with it by adding more features & examples.\n", + "# Interesting note: I've written this so the 2nd and 5th examples\n", + "# have the same features, but different labels - so we can see how the\n", + "# tree handles this case.\n", + "training_data = [\n", + " ['Green', 3, 'Apple'],\n", + " ['Yellow', 3, 'Apple'],\n", + " ['Red', 1, 'Grape'],\n", + " ['Red', 1, 'Grape'],\n", + " ['Yellow', 3, 'Lemon'],\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Column labels.\n", + "# These are used only to print the tree.\n", + "header = [\"color\", \"diameter\", \"label\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def unique_vals(rows, col):\n", + " \"\"\"Find the unique values for a column in a dataset.\"\"\"\n", + " return set([row[col] for row in rows])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Green', 'Red', 'Yellow'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "unique_vals(training_data, 0)\n", + "# unique_vals(training_data, 1)\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def class_counts(rows):\n", + " \"\"\"Counts the number of each type of example in a dataset.\"\"\"\n", + " counts = {} # a dictionary of label -> count.\n", + " for row in rows:\n", + " # in our dataset format, the label is always the last column\n", + " label = row[-1]\n", + " if label not in counts:\n", + " counts[label] = 0\n", + " counts[label] += 1\n", + " return counts" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Apple': 2, 'Grape': 2, 'Lemon': 1}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "class_counts(training_data)\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def is_numeric(value):\n", + " \"\"\"Test if a value is numeric.\"\"\"\n", + " return isinstance(value, int) or isinstance(value, float)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "is_numeric(7)\n", + "# is_numeric(\"Red\")\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class Question:\n", + " \"\"\"A Question is used to partition a dataset.\n", + "\n", + " This class just records a 'column number' (e.g., 0 for Color) and a\n", + " 'column value' (e.g., Green). The 'match' method is used to compare\n", + " the feature value in an example to the feature value stored in the\n", + " question. See the demo below.\n", + " \"\"\"\n", + "\n", + " def __init__(self, column, value):\n", + " self.column = column\n", + " self.value = value\n", + "\n", + " def match(self, example):\n", + " # Compare the feature value in an example to the\n", + " # feature value in this question.\n", + " val = example[self.column]\n", + " if is_numeric(val):\n", + " return val >= self.value\n", + " else:\n", + " return val == self.value\n", + "\n", + " def __repr__(self):\n", + " # This is just a helper method to print\n", + " # the question in a readable format.\n", + " condition = \"==\"\n", + " if is_numeric(self.value):\n", + " condition = \">=\"\n", + " return \"Is %s %s %s?\" % (\n", + " header[self.column], condition, str(self.value))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Is diameter >= 3?" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Let's write a question for a numeric attribute\n", + "Question(1, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Is color == Green?" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# How about one for a categorical attribute\n", + "q = Question(0, 'Green')\n", + "q" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Let's pick an example from the training set...\n", + "example = training_data[0]\n", + "# ... and see if it matches the question\n", + "q.match(example) # this will be true, since the first example is Green.\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def partition(rows, question):\n", + " \"\"\"Partitions a dataset.\n", + "\n", + " For each row in the dataset, check if it matches the question. If\n", + " so, add it to 'true rows', otherwise, add it to 'false rows'.\n", + " \"\"\"\n", + " true_rows, false_rows = [], []\n", + " for row in rows:\n", + " if question.match(row):\n", + " true_rows.append(row)\n", + " else:\n", + " false_rows.append(row)\n", + " return true_rows, false_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Let's partition the training data based on whether rows are Red.\n", + "true_rows, false_rows = partition(training_data, Question(0, 'Red'))\n", + "# This will contain all the 'Red' rows.\n", + "true_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This will contain everything else.\n", + "false_rows\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def gini(rows):\n", + " \"\"\"Calculate the Gini Impurity for a list of rows.\n", + "\n", + " There are a few different ways to do this, I thought this one was\n", + " the most concise. See:\n", + " https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity\n", + " \"\"\"\n", + " counts = class_counts(rows)\n", + " impurity = 1\n", + " for lbl in counts:\n", + " prob_of_lbl = counts[lbl] / float(len(rows))\n", + " impurity -= prob_of_lbl**2\n", + " return impurity" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Let's look at some example to understand how Gini Impurity works.\n", + "#\n", + "# First, we'll look at a dataset with no mixing.\n", + "no_mixing = [['Apple'],\n", + " ['Apple']]\n", + "# this will return 0\n", + "gini(no_mixing)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now, we'll look at dataset with a 50:50 apples:oranges ratio\n", + "some_mixing = [['Apple'],\n", + " ['Orange']]\n", + "# this will return 0.5 - meaning, there's a 50% chance of misclassifying\n", + "# a random example we draw from the dataset.\n", + "gini(some_mixing)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7999999999999998" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now, we'll look at a dataset with many different labels\n", + "lots_of_mixing = [['Apple'],\n", + " ['Orange'],\n", + " ['Grape'],\n", + " ['Grapefruit'],\n", + " ['Blueberry']]\n", + "# This will return 0.8\n", + "gini(lots_of_mixing)\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def info_gain(left, right, current_uncertainty):\n", + " \"\"\"Information Gain.\n", + "\n", + " The uncertainty of the starting node, minus the weighted impurity of\n", + " two child nodes.\n", + " \"\"\"\n", + " p = float(len(left)) / (len(left) + len(right))\n", + " return current_uncertainty - p * gini(left) - (1 - p) * gini(right)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6399999999999999" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Calculate the uncertainy of our training data.\n", + "current_uncertainty = gini(training_data)\n", + "current_uncertainty" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.1399999999999999" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# How much information do we gain by partioning on 'Green'?\n", + "true_rows, false_rows = partition(training_data, Question(0, 'Green'))\n", + "info_gain(true_rows, false_rows, current_uncertainty)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.37333333333333324" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# What about if we partioned on 'Red' instead?\n", + "true_rows, false_rows = partition(training_data, Question(0,'Red'))\n", + "info_gain(true_rows, false_rows, current_uncertainty)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14).\n", + "# Why? Look at the different splits that result, and see which one\n", + "# looks more 'unmixed' to you.\n", + "true_rows, false_rows = partition(training_data, Question(0,'Red'))\n", + "\n", + "# Here, the true_rows contain only 'Grapes'.\n", + "true_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# And the false rows contain two types of fruit. Not too bad.\n", + "false_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Green', 3, 'Apple']]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# On the other hand, partitioning by Green doesn't help so much.\n", + "true_rows, false_rows = partition(training_data, Question(0,'Green'))\n", + "\n", + "# We've isolated one apple in the true rows.\n", + "true_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Yellow', 3, 'Apple'],\n", + " ['Red', 1, 'Grape'],\n", + " ['Red', 1, 'Grape'],\n", + " ['Yellow', 3, 'Lemon']]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# But, the false-rows are badly mixed up.\n", + "false_rows\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def find_best_split(rows):\n", + " \"\"\"Find the best question to ask by iterating over every feature / value\n", + " and calculating the information gain.\"\"\"\n", + " best_gain = 0 # keep track of the best information gain\n", + " best_question = None # keep train of the feature / value that produced it\n", + " current_uncertainty = gini(rows)\n", + " n_features = len(rows[0]) - 1 # number of columns\n", + "\n", + " for col in range(n_features): # for each feature\n", + "\n", + " values = set([row[col] for row in rows]) # unique values in the column\n", + "\n", + " for val in values: # for each value\n", + "\n", + " question = Question(col, val)\n", + "\n", + " # try splitting the dataset\n", + " true_rows, false_rows = partition(rows, question)\n", + "\n", + " # Skip this split if it doesn't divide the\n", + " # dataset.\n", + " if len(true_rows) == 0 or len(false_rows) == 0:\n", + " continue\n", + "\n", + " # Calculate the information gain from this split\n", + " gain = info_gain(true_rows, false_rows, current_uncertainty)\n", + "\n", + " # You actually can use '>' instead of '>=' here\n", + " # but I wanted the tree to look a certain way for our\n", + " # toy dataset.\n", + " if gain >= best_gain:\n", + " best_gain, best_question = gain, question\n", + "\n", + " return best_gain, best_question" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Is diameter >= 3?" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Find the best question to ask first for our toy dataset.\n", + "best_gain, best_question = find_best_split(training_data)\n", + "best_question\n", + "# FYI: is color == Red is just as good. See the note in the code above\n", + "# where I used '>='.\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class Leaf:\n", + " \"\"\"A Leaf node classifies data.\n", + "\n", + " This holds a dictionary of class (e.g., \"Apple\") -> number of times\n", + " it appears in the rows from the training data that reach this leaf.\n", + " \"\"\"\n", + "\n", + " def __init__(self, rows):\n", + " self.predictions = class_counts(rows)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "class Decision_Node:\n", + " \"\"\"A Decision Node asks a question.\n", + "\n", + " This holds a reference to the question, and to the two child nodes.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " question,\n", + " true_branch,\n", + " false_branch):\n", + " self.question = question\n", + " self.true_branch = true_branch\n", + " self.false_branch = false_branch" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def build_tree(rows):\n", + " \"\"\"Builds the tree.\n", + "\n", + " Rules of recursion: 1) Believe that it works. 2) Start by checking\n", + " for the base case (no further information gain). 3) Prepare for\n", + " giant stack traces.\n", + " \"\"\"\n", + "\n", + " # Try partitioing the dataset on each of the unique attribute,\n", + " # calculate the information gain,\n", + " # and return the question that produces the highest gain.\n", + " gain, question = find_best_split(rows)\n", + "\n", + " # Base case: no further info gain\n", + " # Since we can ask no further questions,\n", + " # we'll return a leaf.\n", + " if gain == 0:\n", + " return Leaf(rows)\n", + "\n", + " # If we reach here, we have found a useful feature / value\n", + " # to partition on.\n", + " true_rows, false_rows = partition(rows, question)\n", + "\n", + " # Recursively build the true branch.\n", + " true_branch = build_tree(true_rows)\n", + "\n", + " # Recursively build the false branch.\n", + " false_branch = build_tree(false_rows)\n", + "\n", + " # Return a Question node.\n", + " # This records the best feature / value to ask at this point,\n", + " # as well as the branches to follow\n", + " # dependingo on the answer.\n", + " return Decision_Node(question, true_branch, false_branch)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def print_tree(node, spacing=\"\"):\n", + " \"\"\"World's most elegant tree printing function.\"\"\"\n", + "\n", + " # Base case: we've reached a leaf\n", + " if isinstance(node, Leaf):\n", + " print (spacing + \"Predict\", node.predictions)\n", + " return\n", + "\n", + " # Print the question at this node\n", + " print (spacing + str(node.question))\n", + "\n", + " # Call this function recursively on the true branch\n", + " print (spacing + '--> True:')\n", + " print_tree(node.true_branch, spacing + \" \")\n", + "\n", + " # Call this function recursively on the false branch\n", + " print (spacing + '--> False:')\n", + " print_tree(node.false_branch, spacing + \" \")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "my_tree = build_tree(training_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is diameter >= 3?\n", + "--> True:\n", + " Is color == Yellow?\n", + " --> True:\n", + " Predict {'Lemon': 1, 'Apple': 1}\n", + " --> False:\n", + " Predict {'Apple': 1}\n", + "--> False:\n", + " Predict {'Grape': 2}\n" + ] + } + ], + "source": [ + "print_tree(my_tree)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def classify(row, node):\n", + " \"\"\"See the 'rules of recursion' above.\"\"\"\n", + "\n", + " # Base case: we've reached a leaf\n", + " if isinstance(node, Leaf):\n", + " return node.predictions\n", + "\n", + " # Decide whether to follow the true-branch or the false-branch.\n", + " # Compare the feature / value stored in the node,\n", + " # to the example we're considering.\n", + " if node.question.match(row):\n", + " return classify(row, node.true_branch)\n", + " else:\n", + " return classify(row, node.false_branch)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Apple': 1}" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# The tree predicts the 1st row of our\n", + "# training data is an apple with confidence 1.\n", + "classify(training_data[0], my_tree)\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def print_leaf(counts):\n", + " \"\"\"A nicer way to print the predictions at a leaf.\"\"\"\n", + " total = sum(counts.values()) * 1.0\n", + " probs = {}\n", + " for lbl in counts.keys():\n", + " probs[lbl] = str(int(counts[lbl] / total * 100)) + \"%\"\n", + " return probs" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Apple': '100%'}" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# Printing that a bit nicer\n", + "print_leaf(classify(training_data[0], my_tree))\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Apple': '50%', 'Lemon': '50%'}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#######\n", + "# Demo:\n", + "# On the second example, the confidence is lower\n", + "print_leaf(classify(training_data[1], my_tree))\n", + "#######" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Evaluate\n", + "testing_data = [\n", + " ['Green', 3, 'Apple'],\n", + " ['Yellow', 4, 'Apple'],\n", + " ['Red', 2, 'Grape'],\n", + " ['Red', 1, 'Grape'],\n", + " ['Yellow', 3, 'Lemon'],\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual: Apple. Predicted: {'Apple': '100%'}\n", + "Actual: Apple. Predicted: {'Lemon': '50%', 'Apple': '50%'}\n", + "Actual: Grape. Predicted: {'Grape': '100%'}\n", + "Actual: Grape. Predicted: {'Grape': '100%'}\n", + "Actual: Lemon. Predicted: {'Lemon': '50%', 'Apple': '50%'}\n" + ] + } + ], + "source": [ + "for row in testing_data:\n", + " print (\"Actual: %s. Predicted: %s\" %\n", + " (row[-1], print_leaf(classify(row, my_tree))))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/decision_tree.py b/decision_tree.py new file mode 100644 index 0000000..495ac0a --- /dev/null +++ b/decision_tree.py @@ -0,0 +1,416 @@ +"""Code to accompany Machine Learning Recipes #8. + +We'll write a Decision Tree Classifier, in pure Python. +""" + +# For Python 2 / 3 compatability +from __future__ import print_function + +# Toy dataset. +# Format: each row is an example. +# The last column is the label. +# The first two columns are features. +# Feel free to play with it by adding more features & examples. +# Interesting note: I've written this so the 2nd and 5th examples +# have the same features, but different labels - so we can see how the +# tree handles this case. +training_data = [ + ['Green', 3, 'Apple'], + ['Yellow', 3, 'Apple'], + ['Red', 1, 'Grape'], + ['Red', 1, 'Grape'], + ['Yellow', 3, 'Lemon'], +] + +# Column labels. +# These are used only to print the tree. +header = ["color", "diameter", "label"] + + +def unique_vals(rows, col): + """Find the unique values for a column in a dataset.""" + return set([row[col] for row in rows]) + +####### +# Demo: +# unique_vals(training_data, 0) +# unique_vals(training_data, 1) +####### + + +def class_counts(rows): + """Counts the number of each type of example in a dataset.""" + counts = {} # a dictionary of label -> count. + for row in rows: + # in our dataset format, the label is always the last column + label = row[-1] + if label not in counts: + counts[label] = 0 + counts[label] += 1 + return counts + +####### +# Demo: +# class_counts(training_data) +####### + + +def is_numeric(value): + """Test if a value is numeric.""" + return isinstance(value, int) or isinstance(value, float) + +####### +# Demo: +# is_numeric(7) +# is_numeric("Red") +####### + + +class Question: + """A Question is used to partition a dataset. + + This class just records a 'column number' (e.g., 0 for Color) and a + 'column value' (e.g., Green). The 'match' method is used to compare + the feature value in an example to the feature value stored in the + question. See the demo below. + """ + + def __init__(self, column, value): + self.column = column + self.value = value + + def match(self, example): + # Compare the feature value in an example to the + # feature value in this question. + val = example[self.column] + if is_numeric(val): + return val >= self.value + else: + return val == self.value + + def __repr__(self): + # This is just a helper method to print + # the question in a readable format. + condition = "==" + if is_numeric(self.value): + condition = ">=" + return "Is %s %s %s?" % ( + header[self.column], condition, str(self.value)) + +####### +# Demo: +# Let's write a question for a numeric attribute +# Question(1, 3) +# How about one for a categorical attribute +# q = Question(0, 'Green') +# Let's pick an example from the training set... +# example = training_data[0] +# ... and see if it matches the question +# q.match(example) +####### + + +def partition(rows, question): + """Partitions a dataset. + + For each row in the dataset, check if it matches the question. If + so, add it to 'true rows', otherwise, add it to 'false rows'. + """ + true_rows, false_rows = [], [] + for row in rows: + if question.match(row): + true_rows.append(row) + else: + false_rows.append(row) + return true_rows, false_rows + + +####### +# Demo: +# Let's partition the training data based on whether rows are Red. +# true_rows, false_rows = partition(training_data, Question(0, 'Red')) +# This will contain all the 'Red' rows. +# true_rows +# This will contain everything else. +# false_rows +####### + +def gini(rows): + """Calculate the Gini Impurity for a list of rows. + + There are a few different ways to do this, I thought this one was + the most concise. See: + https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity + """ + counts = class_counts(rows) + impurity = 1 + for lbl in counts: + prob_of_lbl = counts[lbl] / float(len(rows)) + impurity -= prob_of_lbl**2 + return impurity + + +####### +# Demo: +# Let's look at some example to understand how Gini Impurity works. +# +# First, we'll look at a dataset with no mixing. +# no_mixing = [['Apple'], +# ['Apple']] +# this will return 0 +# gini(no_mixing) +# +# Now, we'll look at dataset with a 50:50 apples:oranges ratio +# some_mixing = [['Apple'], +# ['Orange']] +# this will return 0.5 - meaning, there's a 50% chance of misclassifying +# a random example we draw from the dataset. +# gini(some_mixing) +# +# Now, we'll look at a dataset with many different labels +# lots_of_mixing = [['Apple'], +# ['Orange'], +# ['Grape'], +# ['Grapefruit'], +# ['Blueberry']] +# This will return 0.8 +# gini(lots_of_mixing) +####### + +def info_gain(left, right, current_uncertainty): + """Information Gain. + + The uncertainty of the starting node, minus the weighted impurity of + two child nodes. + """ + p = float(len(left)) / (len(left) + len(right)) + return current_uncertainty - p * gini(left) - (1 - p) * gini(right) + +####### +# Demo: +# Calculate the uncertainy of our training data. +# current_uncertainty = gini(training_data) +# +# How much information do we gain by partioning on 'Green'? +# true_rows, false_rows = partition(training_data, Question(0, 'Green')) +# info_gain(true_rows, false_rows, current_uncertainty) +# +# What about if we partioned on 'Red' instead? +# true_rows, false_rows = partition(training_data, Question(0,'Red')) +# info_gain(true_rows, false_rows, current_uncertainty) +# +# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14). +# Why? Look at the different splits that result, and see which one +# looks more 'unmixed' to you. +# true_rows, false_rows = partition(training_data, Question(0,'Red')) +# +# Here, the true_rows contain only 'Grapes'. +# true_rows +# +# And the false rows contain two types of fruit. Not too bad. +# false_rows +# +# On the other hand, partitioning by Green doesn't help so much. +# true_rows, false_rows = partition(training_data, Question(0,'Green')) +# +# We've isolated one apple in the true rows. +# true_rows +# +# But, the false-rows are badly mixed up. +# false_rows +####### + + +def find_best_split(rows): + """Find the best question to ask by iterating over every feature / value + and calculating the information gain.""" + best_gain = 0 # keep track of the best information gain + best_question = None # keep train of the feature / value that produced it + current_uncertainty = gini(rows) + n_features = len(rows[0]) - 1 # number of columns + + for col in range(n_features): # for each feature + + values = set([row[col] for row in rows]) # unique values in the column + + for val in values: # for each value + + question = Question(col, val) + + # try splitting the dataset + true_rows, false_rows = partition(rows, question) + + # Skip this split if it doesn't divide the + # dataset. + if len(true_rows) == 0 or len(false_rows) == 0: + continue + + # Calculate the information gain from this split + gain = info_gain(true_rows, false_rows, current_uncertainty) + + # You actually can use '>' instead of '>=' here + # but I wanted the tree to look a certain way for our + # toy dataset. + if gain >= best_gain: + best_gain, best_question = gain, question + + return best_gain, best_question + +####### +# Demo: +# Find the best question to ask first for our toy dataset. +# best_gain, best_question = find_best_split(training_data) +# FYI: is color == Red is just as good. See the note in the code above +# where I used '>='. +####### + +class Leaf: + """A Leaf node classifies data. + + This holds a dictionary of class (e.g., "Apple") -> number of times + it appears in the rows from the training data that reach this leaf. + """ + + def __init__(self, rows): + self.predictions = class_counts(rows) + + +class Decision_Node: + """A Decision Node asks a question. + + This holds a reference to the question, and to the two child nodes. + """ + + def __init__(self, + question, + true_branch, + false_branch): + self.question = question + self.true_branch = true_branch + self.false_branch = false_branch + + +def build_tree(rows): + """Builds the tree. + + Rules of recursion: 1) Believe that it works. 2) Start by checking + for the base case (no further information gain). 3) Prepare for + giant stack traces. + """ + + # Try partitioing the dataset on each of the unique attribute, + # calculate the information gain, + # and return the question that produces the highest gain. + gain, question = find_best_split(rows) + + # Base case: no further info gain + # Since we can ask no further questions, + # we'll return a leaf. + if gain == 0: + return Leaf(rows) + + # If we reach here, we have found a useful feature / value + # to partition on. + true_rows, false_rows = partition(rows, question) + + # Recursively build the true branch. + true_branch = build_tree(true_rows) + + # Recursively build the false branch. + false_branch = build_tree(false_rows) + + # Return a Question node. + # This records the best feature / value to ask at this point, + # as well as the branches to follow + # dependingo on the answer. + return Decision_Node(question, true_branch, false_branch) + + +def print_tree(node, spacing=""): + """World's most elegant tree printing function.""" + + # Base case: we've reached a leaf + if isinstance(node, Leaf): + print (spacing + "Predict", node.predictions) + return + + # Print the question at this node + print (spacing + str(node.question)) + + # Call this function recursively on the true branch + print (spacing + '--> True:') + print_tree(node.true_branch, spacing + " ") + + # Call this function recursively on the false branch + print (spacing + '--> False:') + print_tree(node.false_branch, spacing + " ") + + +def classify(row, node): + """See the 'rules of recursion' above.""" + + # Base case: we've reached a leaf + if isinstance(node, Leaf): + return node.predictions + + # Decide whether to follow the true-branch or the false-branch. + # Compare the feature / value stored in the node, + # to the example we're considering. + if node.question.match(row): + return classify(row, node.true_branch) + else: + return classify(row, node.false_branch) + + +####### +# Demo: +# The tree predicts the 1st row of our +# training data is an apple with confidence 1. +# my_tree = build_tree(training_data) +# classify(training_data[0], my_tree) +####### + +def print_leaf(counts): + """A nicer way to print the predictions at a leaf.""" + total = sum(counts.values()) * 1.0 + probs = {} + for lbl in counts.keys(): + probs[lbl] = str(int(counts[lbl] / total * 100)) + "%" + return probs + + +####### +# Demo: +# Printing that a bit nicer +# print_leaf(classify(training_data[0], my_tree)) +####### + +####### +# Demo: +# On the second example, the confidence is lower +# print_leaf(classify(training_data[1], my_tree)) +####### + +if __name__ == '__main__': + + my_tree = build_tree(training_data) + + print_tree(my_tree) + + # Evaluate + testing_data = [ + ['Green', 3, 'Apple'], + ['Yellow', 4, 'Apple'], + ['Red', 2, 'Grape'], + ['Red', 1, 'Grape'], + ['Yellow', 3, 'Lemon'], + ] + + for row in testing_data: + print ("Actual: %s. Predicted: %s" % + (row[-1], print_leaf(classify(row, my_tree)))) + +# Next steps +# - add support for missing (or unseen) attributes +# - prune the tree to prevent overfitting +# - add support for regression