《電子技術(shù)應(yīng)用》
您所在的位置:首頁 > 可編程邏輯 > 業(yè)界動(dòng)態(tài) > 介紹機(jī)器學(xué)習(xí)分類算法——決策樹

介紹機(jī)器學(xué)習(xí)分類算法——決策樹

2018-07-27
關(guān)鍵詞: 決策樹 最大熵原則 分類算法

  今天,我們介紹機(jī)器學(xué)習(xí)里比較常用的一種分類算法,決策樹。決策樹是對(duì)人類認(rèn)知識(shí)別的一種模擬,給你一堆看似雜亂無章的數(shù)據(jù),如何用盡可能少的特征,對(duì)這些數(shù)據(jù)進(jìn)行有效的分類。

  決策樹借助了一種層級(jí)分類的概念,每一次都選擇一個(gè)區(qū)分性最好的特征進(jìn)行分類,對(duì)于可以直接給出標(biāo)簽 label 的數(shù)據(jù),可能最初選擇的幾個(gè)特征就能很好地進(jìn)行區(qū)分,有些數(shù)據(jù)可能需要更多的特征,所以決策樹的深度也就表示了你需要選擇的幾種特征。

  在進(jìn)行特征選擇的時(shí)候,常常需要借助信息論的概念,利用最大熵原則

  決策樹一般是用來對(duì)離散數(shù)據(jù)進(jìn)行分類的,對(duì)于連續(xù)數(shù)據(jù),可以事先對(duì)其離散化。

  在介紹決策樹之前,我們先簡(jiǎn)單的介紹一下信息熵,我們知道,熵的定義為:

1.jpg

  我們先構(gòu)造一些簡(jiǎn)單的數(shù)據(jù):

  from sklearn import datasets

  import numpy as np

  import matplotlib.pyplot as plt

  import math

  import operator

  def Create_data():

  dataset = [[1, 1, 'yes'],

  [1, 1, 'yes'],

  [1, 0, 'no'],

  [0, 1, 'no'],

  [0, 1, 'no'],

  [3, 0, 'maybe']]

  feat_name = ['no surfacing', 'flippers']

  return dataset, feat_name

  然后定義一個(gè)計(jì)算熵的函數(shù):

  def Cal_entrpy(dataset):

  n_sample = len(dataset)

  n_label = {}

  for featvec in dataset:

  current_label = featvec[-1]

  if current_label not in n_label.keys():

  n_label[current_label] = 0

  n_label[current_label] += 1

  shannonEnt = 0.0

  for key in n_label:

  prob = float(n_label[key]) / n_sample

  shannonEnt -= prob * math.log(prob, 2)

  return shannonEnt

  要注意的是,熵越大,說明數(shù)據(jù)的類別越分散,越呈現(xiàn)某種無序的狀態(tài)。

  下面再定義一個(gè)拆分?jǐn)?shù)據(jù)集的函數(shù):

  def Split_dataset(dataset, axis, value):

  retDataSet = []

  for featVec in dataset:

  if featVec[axis] == value:

  reducedFeatVec = featVec[:axis]

  reducedFeatVec.extend(featVec[axis+1 :])

  retDataSet.append(reducedFeatVec)

  return retDataSet

  結(jié)合前面的幾個(gè)函數(shù),我們可以構(gòu)造一個(gè)特征選擇的函數(shù):

  def Choose_feature(dataset):

  num_sample = len(dataset)

  num_feature = len(dataset[0]) - 1

  baseEntrpy = Cal_entrpy(dataset)

  best_Infogain = 0.0

  bestFeat = -1

  for i in range (num_feature):

  featlist = [example[i] for example in dataset]

  uniquValus = set(featlist)

  newEntrpy = 0.0

  for value in uniquValus:

  subData = Split_dataset(dataset, i, value)

  prob = len(subData) / float(num_sample)

  newEntrpy += prob * Cal_entrpy(subData)

  info_gain = baseEntrpy - newEntrpy

  if (info_gain > best_Infogain):

  best_Infogain = info_gain

  bestFeat = i

  return bestFeat

  然后再構(gòu)造一個(gè)投票及計(jì)票的函數(shù)

  def Major_cnt(classlist):

  class_num = {}

  for vote in classlist:

  if vote not in class_num.keys():

  class_num[vote] = 0

  class_num[vote] += 1

  Sort_K = sorted(class_num.iteritems(),

  key = operator.itemgetter(1), reverse=True)

  return Sort_K[0][0]

  有了這些,就可以構(gòu)造我們需要的決策樹了:

  def Create_tree(dataset, featName):

  classlist = [example[-1] for example in dataset]

  if classlist.count(classlist[0]) == len(classlist):

  return classlist[0]

  if len(dataset[0]) == 1:

  return Major_cnt(classlist)

  bestFeat = Choose_feature(dataset)

  bestFeatName = featName[bestFeat]

  myTree = {bestFeatName: {}}

  del(featName[bestFeat])

  featValues = [example[bestFeat] for example in dataset]

  uniqueVals = set(featValues)

  for value in uniqueVals:

  subLabels = featName[:]

  myTree[bestFeatName][value] = Create_tree(Split_dataset

  (dataset, bestFeat, value), subLabels)

  return myTree

  def Get_numleafs(myTree):

  numLeafs = 0

  firstStr = myTree.keys()[0]

  secondDict = myTree[firstStr]

  for key in secondDict.keys():

  if type(secondDict[key]).__name__ == 'dict' :

  numLeafs += Get_numleafs(secondDict[key])

  else:

  numLeafs += 1

  return numLeafs

  def Get_treedepth(myTree):

  max_depth = 0

  firstStr = myTree.keys()[0]

  secondDict = myTree[firstStr]

  for key in secondDict.keys():

  if type(secondDict[key]).__name__ == 'dict' :

  this_depth = 1 + Get_treedepth(secondDict[key])

  else:

  this_depth = 1

  if this_depth > max_depth:

  max_depth = this_depth

  return max_depth

  我們也可以把決策樹繪制出來:

  def Plot_node(nodeTxt, centerPt, parentPt, nodeType):

  Create_plot.ax1.annotate(nodeTxt, xy=parentPt,

  xycoords='axes fraction',

  xytext=centerPt, textcoords='axes fraction',

  va=center, ha=center, bbox=nodeType, arrowprops=arrow_args)

  def Plot_tree(myTree, parentPt, nodeTxt):

  numLeafs = Get_numleafs(myTree)

  Get_treedepth(myTree)

  firstStr = myTree.keys()[0]

  cntrPt = (Plot_tree.xOff + (1.0 + float(numLeafs))/2.0/Plot_tree.totalW,

  Plot_tree.yOff)

  Plot_midtext(cntrPt, parentPt, nodeTxt)

  Plot_node(firstStr, cntrPt, parentPt, decisionNode)

  secondDict = myTree[firstStr]

  Plot_tree.yOff = Plot_tree.yOff - 1.0/Plot_tree.totalD

  for key in secondDict.keys():

  if type(secondDict[key]).__name__=='dict':

  Plot_tree(secondDict[key],cntrPt,str(key))

  else:

  Plot_tree.xOff = Plot_tree.xOff + 1.0/Plot_tree.totalW

  Plot_node(secondDict[key], (Plot_tree.xOff, Plot_tree.yOff),

  cntrPt, leafNode)

  Plot_midtext((Plot_tree.xOff, Plot_tree.yOff), cntrPt, str(key))

  Plot_tree.yOff = Plot_tree.yOff + 1.0/Plot_tree.totalD

  def Create_plot (myTree):

  fig = plt.figure(1, facecolor = 'white')

  fig.clf()

  axprops = dict(xticks=[], yticks=[])

  Create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)

  Plot_tree.totalW = float(Get_numleafs(myTree))

  Plot_tree.totalD = float(Get_treedepth(myTree))

  Plot_tree.xOff = -0.5/Plot_tree.totalW; Plot_tree.yOff = 1.0;

  Plot_tree(myTree, (0.5,1.0), '')

  plt.show()

  def Plot_midtext(cntrPt, parentPt, txtString):

  xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]

  yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]

  Create_plot.ax1.text(xMid, yMid, txtString)

  def Classify(myTree, featLabels, testVec):

  firstStr = myTree.keys()[0]

  secondDict = myTree[firstStr]

  featIndex = featLabels.index(firstStr)

  for key in secondDict.keys():

  if testVec[featIndex] == key:

  if type(secondDict[key]).__name__ == 'dict' :

  classLabel = Classify(secondDict[key],featLabels,testVec)

  else:

  classLabel = secondDict[key]

  return classLabel

  最后,可以測(cè)試我們的構(gòu)造的決策樹分類器:

  decisionNode = dict(boxstyle=sawtooth, fc=0.8)

  leafNode = dict(boxstyle=round4, fc=0.8)

  arrow_args = dict(arrowstyle=-)

  myData, featName = Create_data()

  S_entrpy = Cal_entrpy(myData)

  new_data = Split_dataset(myData, 0, 1)

  best_feat = Choose_feature(myData)

  myTree = Create_tree(myData, featName[:])

  num_leafs = Get_numleafs(myTree)

  depth = Get_treedepth(myTree)

  Create_plot(myTree)

  predict_label = Classify(myTree, featName, [1, 0])

  print(the predict label is: , predict_label)

  print(the decision tree is: , myTree)

  print(the best feature index is: , best_feat)

  print(the new dataset: , new_data)

  print(the original dataset: , myData)

  print(the feature names are: , featName)

  print(the entrpy is:, S_entrpy)

  print(the number of leafs is: , num_leafs)

  print(the dpeth is: , depth)

  print(All is well.)

  構(gòu)造的決策樹最后如下所示:

2.jpg


本站內(nèi)容除特別聲明的原創(chuàng)文章之外,轉(zhuǎn)載內(nèi)容只為傳遞更多信息,并不代表本網(wǎng)站贊同其觀點(diǎn)。轉(zhuǎn)載的所有的文章、圖片、音/視頻文件等資料的版權(quán)歸版權(quán)所有權(quán)人所有。本站采用的非本站原創(chuàng)文章及圖片等內(nèi)容無法一一聯(lián)系確認(rèn)版權(quán)者。如涉及作品內(nèi)容、版權(quán)和其它問題,請(qǐng)及時(shí)通過電子郵件或電話通知我們,以便迅速采取適當(dāng)措施,避免給雙方造成不必要的經(jīng)濟(jì)損失。聯(lián)系電話:010-82306118;郵箱:aet@chinaaet.com。