决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏。比如猜一个动物:
问:这个动物是陆生动物吗?
答:是的。
问:这个动物有鳃吗?
答:没有。
这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正)。所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量。
AllElectronics顾客数据库标记类的训练元组 | |||||
RID | age | income | student | credit_rating | Class: buys_computer |
1 | youth | high | no | fair | no |
2 | youth | high | no | excellent | no |
3 | middle_aged | high | no | fair | yes |
4 | senior | medium | no | fair | yes |
5 | senior | low | yes | fair | yes |
6 | senior | low | yes | excellent | no |
7 | middle_aged | low | yes | excellent | yes |
8 | youth | medium | no | fair | no |
9 | youth | low | yes | fair | yes |
10 | senior | medium | yes | fair | yes |
11 | youth | medium | yes | excellent | yes |
12 | middle_aged | medium | no | excellent | yes |
13 | middle_aged | high | yes | fair | yes |
14 | senior | medium | no | excellent | no |
以AllElectronics顾客数据库标记类的训练元组为例。我们想要以这些样本为训练集,训练我们的决策树模型,以此来挖掘出顾客是否会购买电脑的决策模式。
在决策树ID3算法中,计算信息度的公式如下:
$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$
计算信息增益的公式如下:
$$Gain(A) = Info(D) - Info_A(D)$$
按照公式,在要进行分类的类别变量中,有5个“no”和9个“yes”,因此期望信息为:
$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$
首先计算特征age的期望信息:
$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$
因此,如果按照age进行划分,则获得的信息增益为:
$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$
依次计算以income、student和credit_rating来分裂的信息增益,由此选择能够带来最大信息增益的变量,在当
前结点选择以以该变量的取值进行分裂。递归地进行执行即可生成决策树。更加详细的内容可以参考:
https://en.wikipedia.org/wiki/Decision_tree
C#代码的实现如下:
1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 namespace MachineLearning.DecisionTree 5 { 6 public class DecisionTreeID3<T> where T : IEquatable<T> 7 { 8 T[,] Data; 9 string[] Names; 10 int Category; 11 T[] CategoryLabels; 12 DecisionTreeNode<T> Root; 13 public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels) 14 { 15 Data = data; 16 Names = names; 17 Category = data.GetLength(1) - 1;//类别变量需要放在最后一列 18 CategoryLabels = categoryLabels; 19 } 20 public void Learn() 21 { 22 int nRows = Data.GetLength(0); 23 int nCols = Data.GetLength(1); 24 int[] rows = new int[nRows]; 25 int[] cols = new int[nCols]; 26 for (int i = 0; i < nRows; i++) rows[i] = i; 27 for (int i = 0; i < nCols; i++) cols[i] = i; 28 Root = new DecisionTreeNode<T>(-1, default(T)); 29 Learn(rows, cols, Root); 30 DisplayNode(Root); 31 } 32 public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0) 33 { 34 if (Node.Label != -1) 35 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value); 36 foreach (var item in Node.Children) 37 DisplayNode(item, depth + 1); 38 } 39 private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root) 40 { 41 var categoryValues = GetAttribute(Data, Category, pnRows); 42 var categoryCount = categoryValues.Distinct().Count(); 43 if (categoryCount == 1) 44 { 45 var node = new DecisionTreeNode<T>(Category, categoryValues.First()); 46 Root.Children.Add(node); 47 } 48 else 49 { 50 if (pnRows.Length == 0) return; 51 else if (pnCols.Length == 1) 52 { 53 //投票~ 54 //多数票表决制 55 var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First(); 56 var node = new DecisionTreeNode<T>(Category, Vote.First()); 57 Root.Children.Add(node); 58 } 59 else 60 { 61 var maxCol = MaxEntropy(pnRows, pnCols); 62 var attributes = GetAttribute(Data, maxCol, pnRows).Distinct(); 63 string currentPrefix = Names[maxCol]; 64 foreach (var attr in attributes) 65 { 66 int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray(); 67 int[] cols = pnCols.Where(i => i != maxCol).ToArray(); 68 var node = new DecisionTreeNode<T>(maxCol, attr); 69 Root.Children.Add(node); 70 Learn(rows, cols, node);//递归生成决策树 71 } 72 } 73 } 74 } 75 public double AttributeInfo(int attrCol, int[] pnRows) 76 { 77 var tuples = AttributeCount(attrCol, pnRows); 78 var sum = (double)pnRows.Length; 79 double Entropy = 0.0; 80 foreach (var tuple in tuples) 81 { 82 int[] count = new int[CategoryLabels.Length]; 83 foreach (var irow in pnRows) 84 if (Data[irow, attrCol].Equals(tuple.Item1)) 85 { 86 int index = Array.IndexOf(CategoryLabels, Data[irow, Category]); 87 count[index]++;//目前仅支持类别变量在最后一列 88 } 89 double k = 0.0; 90 for (int i = 0; i < count.Length; i++) 91 { 92 double frequency = count[i] / (double)tuple.Item2; 93 double t = -frequency * Log2(frequency); 94 k += t; 95 } 96 double freq = tuple.Item2 / sum; 97 Entropy += freq * k; 98 } 99 return Entropy; 100 } 101 public double CategoryInfo(int[] pnRows) 102 { 103 var tuples = AttributeCount(Category, pnRows); 104 var sum = (double)pnRows.Length; 105 double Entropy = 0.0; 106 foreach (var tuple in tuples) 107 { 108 double frequency = tuple.Item2 / sum; 109 double t = -frequency * Log2(frequency); 110 Entropy += t; 111 } 112 return Entropy; 113 } 114 private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows) 115 { 116 foreach (var irow in pnRows) 117 yield return data[irow, col]; 118 } 119 private static double Log2(double x) 120 { 121 return x == 0.0 ? 0.0 : Math.Log(x, 2.0); 122 } 123 public int MaxEntropy(int[] pnRows, int[] pnCols) 124 { 125 double cateEntropy = CategoryInfo(pnRows); 126 int maxAttr = 0; 127 double max = double.MinValue; 128 foreach (var icol in pnCols) 129 if (icol != Category) 130 { 131 double Gain = cateEntropy - AttributeInfo(icol, pnRows); 132 if (max < Gain) 133 { 134 max = Gain; 135 maxAttr = icol; 136 } 137 } 138 return maxAttr; 139 } 140 public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows) 141 { 142 var tuples = from n in GetAttribute(Data, col, pnRows) 143 group n by n into i 144 select Tuple.Create(i.First(), i.Count()); 145 return tuples; 146 } 147 } 148 }
决策树结点的构造:
1 using System.Collections.Generic; 2 3 namespace MachineLearning.DecisionTree 4 { 5 public sealed class DecisionTreeNode<T> 6 { 7 public int Label { get; set; } 8 public T Value { get; set; } 9 public List<DecisionTreeNode<T>> Children { get; set; } 10 public DecisionTreeNode(int label, T value) 11 { 12 Label = label; 13 Value = value; 14 Children = new List<DecisionTreeNode<T>>(); 15 } 16 } 17 }
调用方法如下:
1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 using System.Text; 5 using System.Threading.Tasks; 6 using MachineLearning.DecisionTree; 7 namespace MachineLearning 8 { 9 class Program 10 { 11 static void Main(string[] args) 12 { 13 var da = new string[,] 14 { 15 {"youth","high","no","fair","no"}, 16 {"youth","high","no","excellent","no"}, 17 {"middle_aged","high","no","fair","yes"}, 18 {"senior","medium","no","fair","yes"}, 19 {"senior","low","yes","fair","yes"}, 20 {"senior","low","yes","excellent","no"}, 21 {"middle_aged","low","yes","excellent","yes"}, 22 {"youth","medium","no","fair","no"}, 23 {"youth","low","yes","fair","yes"}, 24 {"senior","medium","yes","fair","yes"}, 25 {"youth","medium","yes","excellent","yes"}, 26 {"middle_aged","medium","no","excellent","yes"}, 27 {"middle_aged","high","yes","fair","yes"}, 28 {"senior","medium","no","excellent","no"} 29 }; 30 var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" }; 31 var tree = new DecisionTreeID3<string>(da, names, new string[] { "yes", "no" }); 32 tree.Learn(); 33 Console.ReadKey(); 34 } 35 } 36 }
运行结果:
注:作者本人也在学习中,能力有限,如有错漏还请不吝指正。转载请注明作者。
所有评论(0)