决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏。比如猜一个动物:

问:这个动物是陆生动物吗?

答:是的。

问:这个动物有鳃吗?

答:没有。

这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正)。所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量。

AllElectronics顾客数据库标记类的训练元组
RIDageincomestudentcredit_ratingClass: buys_computer
1youthhighnofairno
2youthhighnoexcellentno
3middle_agedhighnofairyes
4seniormediumnofairyes
5seniorlowyesfairyes
6seniorlowyesexcellentno
7middle_agedlowyesexcellentyes
8youthmediumnofairno
9youthlowyesfairyes
10seniormediumyesfairyes
11youthmediumyesexcellentyes
12middle_agedmediumnoexcellentyes
13middle_agedhighyesfairyes
14seniormediumnoexcellentno

AllElectronics顾客数据库标记类的训练元组为例。我们想要以这些样本为训练集,训练我们的决策树模型,以此来挖掘出顾客是否会购买电脑的决策模式。

在决策树ID3算法中,计算信息度的公式如下:

$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$

计算信息增益的公式如下:

$$Gain(A) = Info(D) - Info_A(D)$$

按照公式,在要进行分类的类别变量中,有5个“no”9个“yes”,因此期望信息为:

$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$

首先计算特征age的期望信息:

$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$

因此,如果按照age进行划分,则获得的信息增益为:

$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$

依次计算以incomestudentcredit_rating来分裂的信息增益,由此选择能够带来最大信息增益的变量,在当

前结点选择以以该变量的取值进行分裂。递归地进行执行即可生成决策树。更加详细的内容可以参考:

https://en.wikipedia.org/wiki/Decision_tree

C#代码的实现如下:

  1 using System;
  2 using System.Collections.Generic;
  3 using System.Linq;
  4 namespace MachineLearning.DecisionTree
  5 {
  6     public class DecisionTreeID3<T> where T : IEquatable<T>
  7     {
  8         T[,] Data;
  9         string[] Names;
 10         int Category;
 11         T[] CategoryLabels;
 12         DecisionTreeNode<T> Root;
 13         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
 14         {
 15             Data = data;
 16             Names = names;
 17             Category = data.GetLength(1) - 1;//类别变量需要放在最后一列
 18             CategoryLabels = categoryLabels;
 19         }
 20         public void Learn()
 21         {
 22             int nRows = Data.GetLength(0);
 23             int nCols = Data.GetLength(1);
 24             int[] rows = new int[nRows];
 25             int[] cols = new int[nCols];
 26             for (int i = 0; i < nRows; i++) rows[i] = i;
 27             for (int i = 0; i < nCols; i++) cols[i] = i;
 28             Root = new DecisionTreeNode<T>(-1, default(T));
 29             Learn(rows, cols, Root);
 30             DisplayNode(Root);
 31         }
 32         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
 33         {
 34             if (Node.Label != -1)
 35                 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
 36             foreach (var item in Node.Children)
 37                 DisplayNode(item, depth + 1);
 38         }
 39         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root)
 40         {
 41             var categoryValues = GetAttribute(Data, Category, pnRows);
 42             var categoryCount = categoryValues.Distinct().Count();
 43             if (categoryCount == 1)
 44             {
 45                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
 46                 Root.Children.Add(node);
 47             }
 48             else
 49             {
 50                 if (pnRows.Length == 0) return;
 51                 else if (pnCols.Length == 1)
 52                 {
 53                     //投票~
 54                     //多数票表决制
 55                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
 56                     var node = new DecisionTreeNode<T>(Category, Vote.First());
 57                     Root.Children.Add(node);
 58                 }
 59                 else
 60                 {
 61                     var maxCol = MaxEntropy(pnRows, pnCols);
 62                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
 63                     string currentPrefix = Names[maxCol];
 64                     foreach (var attr in attributes)
 65                     {
 66                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
 67                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
 68                         var node = new DecisionTreeNode<T>(maxCol, attr);
 69                         Root.Children.Add(node);
 70                         Learn(rows, cols, node);//递归生成决策树
 71                     }
 72                 }
 73             }
 74         }
 75         public double AttributeInfo(int attrCol, int[] pnRows)
 76         {
 77             var tuples = AttributeCount(attrCol, pnRows);
 78             var sum = (double)pnRows.Length;
 79             double Entropy = 0.0;
 80             foreach (var tuple in tuples)
 81             {
 82                 int[] count = new int[CategoryLabels.Length];
 83                 foreach (var irow in pnRows)
 84                     if (Data[irow, attrCol].Equals(tuple.Item1))
 85                     {
 86                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
 87                         count[index]++;//目前仅支持类别变量在最后一列
 88                     }
 89                 double k = 0.0;
 90                 for (int i = 0; i < count.Length; i++)
 91                 {
 92                     double frequency = count[i] / (double)tuple.Item2;
 93                     double t = -frequency * Log2(frequency);
 94                     k += t;
 95                 }
 96                 double freq = tuple.Item2 / sum;
 97                 Entropy += freq * k;
 98             }
 99             return Entropy;
100         }
101         public double CategoryInfo(int[] pnRows)
102         {
103             var tuples = AttributeCount(Category, pnRows);
104             var sum = (double)pnRows.Length;
105             double Entropy = 0.0;
106             foreach (var tuple in tuples)
107             {
108                 double frequency = tuple.Item2 / sum;
109                 double t = -frequency * Log2(frequency);
110                 Entropy += t;
111             }
112             return Entropy;
113         }
114         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
115         {
116             foreach (var irow in pnRows)
117                 yield return data[irow, col];
118         }
119         private static double Log2(double x)
120         {
121             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
122         }
123         public int MaxEntropy(int[] pnRows, int[] pnCols)
124         {
125             double cateEntropy = CategoryInfo(pnRows);
126             int maxAttr = 0;
127             double max = double.MinValue;
128             foreach (var icol in pnCols)
129                 if (icol != Category)
130                 {
131                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
132                     if (max < Gain)
133                     {
134                         max = Gain;
135                         maxAttr = icol;
136                     }
137                 }
138             return maxAttr;
139         }
140         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
141         {
142             var tuples = from n in GetAttribute(Data, col, pnRows)
143                          group n by n into i
144                          select Tuple.Create(i.First(), i.Count());
145             return tuples;
146         }
147     }
148 }

决策树结点的构造:

 1 using System.Collections.Generic;
 2 
 3 namespace MachineLearning.DecisionTree
 4 {
 5     public sealed class DecisionTreeNode<T>
 6     {
 7         public int Label { get; set; }
 8         public T Value { get; set; }
 9         public List<DecisionTreeNode<T>> Children { get; set; }
10         public DecisionTreeNode(int label, T value)
11         {
12             Label = label;
13             Value = value;
14             Children = new List<DecisionTreeNode<T>>();
15         }
16     }
17 }

 

调用方法如下:

 1 using System;
 2 using System.Collections.Generic;
 3 using System.Linq;
 4 using System.Text;
 5 using System.Threading.Tasks;
 6 using MachineLearning.DecisionTree;
 7 namespace MachineLearning
 8 {
 9     class Program
10     {
11         static void Main(string[] args)
12         {
13             var da = new string[,]
14             {
15                 {"youth","high","no","fair","no"},
16                 {"youth","high","no","excellent","no"},
17                 {"middle_aged","high","no","fair","yes"},
18                 {"senior","medium","no","fair","yes"},
19                 {"senior","low","yes","fair","yes"},
20                 {"senior","low","yes","excellent","no"},
21                 {"middle_aged","low","yes","excellent","yes"},
22                 {"youth","medium","no","fair","no"},
23                 {"youth","low","yes","fair","yes"},
24                 {"senior","medium","yes","fair","yes"},
25                 {"youth","medium","yes","excellent","yes"},
26                 {"middle_aged","medium","no","excellent","yes"},
27                 {"middle_aged","high","yes","fair","yes"},
28                 {"senior","medium","no","excellent","no"}
29             };
30             var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
31             var tree = new DecisionTreeID3<string>(da, names, new string[] { "yes", "no" });
32             tree.Learn();
33             Console.ReadKey();
34         }
35     }
36 }

 

运行结果:


 

注:作者本人也在学习中,能力有限,如有错漏还请不吝指正。转载请注明作者。

转载于:https://www.cnblogs.com/HeYanjie/p/5787361.html

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐