必威体育Betway必威体育官网
当前位置:首页 > IT技术

决策树——CART和模型树

时间:2019-08-10 20:45:27来源:IT技术作者:seo实验室小编阅读:75次「手机版」
 

模型树

CART树

理解:      如果CART树处理离散型数据,叫做分类决策树,那么,引入基尼指数作为寻找最好的数据划分的依据,基尼指数越小,说明数据的“纯度越高”,随机森林的代码里边就运用到了基尼指数。如果CART树处理连续型数据时,叫做回归决策树,那么,引入了平方误差,首先,它使用二元切分来处理数据,得到两个子集,计算误差,找到最小误差,确定最佳切分的特征编号和特征值,然后进行建树。

构建回归树,需要给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外,该函数还要确定什么时候停止划分,一旦停止划分会生成一个叶节点。这里引入reLeaf(),regErr()分别得到叶节点和总方差。叶节点的模型是目标变量的 均值,var()是均方差,所以需要乘以数据集的样本个数。

划分数据集时,如果找不到一个‘好’的二元切分,该函数返回None值并产生叶节点,叶节点的值也为None。

from numpy import *
#载入数据
def loaddataset(fileName) : 
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines() :
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))#将后面的数据集映射为浮点型
        dataMat.APPend(fltLine)
    return dataMat

#切分数据集为两个子集
# dataSet: 数据集合
# feature: 待切分的特征
# value: 该特征的某个值   
#nonzero():得到数组非零元素的位置(数组索引) 
def binSplitDataSet(dataSet, feature, value) :
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0, mat1


# 负责生成叶节点,当chooseBestSplit()函数确定不再对数据进行切分时,
# 将调用该regLeaf()函数来得到叶节点的模型,在回归树中,该模型其实就是目标变量的均值
def regLeaf(dataSet) :
    return mean(dataSet[:, -1])


# 误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数var
# 因为这里需要返回的是总方差,所以要用均方差乘以数据集中样本的个数  
def regErr(dataSet) :
    return var(dataSet[:, -1]) * shape(dataSet)[0]

# dataSet: 数据集合
# leafType: 给出建立叶节点的函数
# errType: 误差计算函数
# ops: 包含树构建所需其他参数的元组
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)) :
    # 将数据集分成两个部分,若满足停止条件,chooseBestSplit将返回None和某类模型的值
    # 若构建的是回归树,该模型是个常数。若是模型树,其模型是一个线性方程。
    # 若不满足停止条件,chooseBestSplit()将创建一个新的Python字典,并将数据集分成两份,
    # 在这两份数据集上将分别继续递归调用createTree()函数
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None : return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

# 回归树的切分函数,构建回归树的核心函数。目的:找出数据的最佳二元切分方式。如果找不到
# 一个“好”的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点,叶节点的
# 值也将返回None。
# 如果找到一个“好”的切分方式,则返回特征编号和切分特征值。
# 最佳切分就是使得切分后能达到最低误差的切分。
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)) :
    # tolS是容许的误差下降值
    # tolN是切分的最小样本数
    tolS = ops[0]; tolN = ops[1]
    # 如果剩余特征值的数目为1,那么就不再切分而返回
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1 :
        return None, leafType(dataSet)
    # 当前数据集的大小
    m,n = shape(dataSet)
    # 当前数据集的误差
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1) :
#        for splitVal in set(dataSet[:, featIndex]) :
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN) : continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS :
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果切分数据集后效果提升不够大,那么就不应该进行切分操作而直接创建叶节点
    if (S - bestS) < tolS :
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 检查切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN) :
        return None, leafType(dataSet)
    # 如果前面的这些终止条件都不满足,那么就返回切分特征和特征值。
    return bestIndex, bestValue

通过降低决策树的复杂度来避免过拟合的过程叫剪枝,预剪枝和后剪枝的单个效果可能是不好的,一般来说,我们可以同时采用这两种剪枝方法。

模型树:

理解:模型树和回归树的区别就是回归树的叶节点是一个常数值,而模型树的叶节点是分段线性函数,分段线性模型就是我们对数据集的一部分数据以某个线性模型建模,而另一份数据以另一个线性模型建模。

#模型树
# 主要功能:将数据格式化成目标变量Y和自变量X。X、Y用于执行简单的线性规划。
def linearSolve(dataSet) :
    m,n = shape(dataSet) 
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
    xTx = X.T*X
    # 矩阵的逆不存在时会造成程序异常
    if linalg.det(xTx) == 0.0 :
        raise NameERROR('This matrix is singular, cannot do inverse, \n try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

# 与regLeaf()类似,当数据不需要切分时,它负责生成叶节点的模型。
def modelLeaf(dataSet) :
    ws, X, Y = linearSolve(dataSet)
    return ws

# 在给定的数据集上计算误差。与regErr()类似,会被chooseBestSplit()调用来找到最佳切分。
def modelErr(dataSet) :
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y-yHat, 2))


# 为了和modeTreeEval()保持一致,保留两个输入参数
def regTreeEval(model, inDat) :
    return float(model)

# 对输入数据进行格式化处理,在原数据矩阵上增加第0列,元素的值都是1
def modelTreeEval(model, inDat) :
    n = shape(inDat)[1]
    X = mat(ones((1, n+1)))
    X[:, 1:n+1] = inDat
    return float(X*model)

def isTree(obj):  
    return (type(obj).__name__=='dict')
# 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
# modeEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。
# 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上
# 调用modelEval()函数,该函数的默认值为regTreeEval()
def treeForeCast(tree, inData, modelEval=regTreeEval) :
    if not isTree(tree) : return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal'] :
        if isTree(tree['left']) :
            return treeForeCast(tree['left'], inData, modelEval)
        else : 
            return modelEval(tree['left'], inData)
    else :
        if isTree(tree['right']) :
            return treeForeCast(tree['right'], inData, modelEval)
        else :
            return modelEval(tree['right'], inData)

# 多次调用treeForeCast()函数,以向量形式返回预测值,在整个测试集进行预测非常有用
def createForeCast(tree, testData, modelEval=regTreeEval) :
    m = len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m) :
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

使用Tkinter工具构建图形用户界面

from numpy import *
from tkinter import *
import regTrees as regTrees

import matplotlib
matplotlib.use('TkAgg')  #设置后端TkAgg
#将TkAgg和matplotlib链接起来
from matplotlib.backends.backend_tkagg import figureCanvasTkAgg
from matplotlib.figure import Figure

#
def reDraw(tolS, tolN) :
    reDraw.f.clf()   #清空之前的图像
    reDraw.a = reDraw.f.add_subplot(111)   #重新添加新图
    if chkBtnVar.get() :  #检查选框model tree是否被选中
        if tolN < 2 : tolN = 2
        myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, 
                                     regTrees.modelErr, (tolS, tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)
    else :
        myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat)
    # reDraw.rawDat[:,0].A,需要将矩阵转换成数组
    reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5) # 绘制真实值
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) # 绘制预测值
    reDraw.canvas.show()

#
def getInputs() :#获取输入
    try : tolN = int(tolNentry.get())  #期望输入是整数
    except :  #清楚错误用默认值替换
        tolN = 10
        print ("enter integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try : tolS = float(tolSentry.get())
    except :   #期望输入是浮点数
        tolS = 1.0
        print ("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

# 
def drawNewTree() :
    # 取得输入框的值
    tolN, tolS = getInputs()   # 从输入文本框中获取参数
    # 利用tolN,tolS,调用reDraw生成漂亮的图
    reDraw(tolS, tolN)  #绘制图
    
#布局GUI
root = Tk()
# 创建画布
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)

Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
# 点击“ReDraw”按钮后,调用drawNewTree()函数
Button(root, text='ReDraw', command=drawNewTree).grid(row=1, column=2, rowspan=3)

chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text='Model Tree', variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

reDraw.rawDat = mat(regTrees.loadDataSet('ex00.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)

reDraw(1.0, 10)

root.mainloop()

相关阅读

机器学习_线性回归模型

1.线性回归 1.1模型 1.1.1目标函数(损失函数、正则) a.无正则:最小二乘线性回归(OLS) b.L2正则:岭回归(Ridge Regression) c.L1正则:Lass

Matlab数模笔记(10)--线性规划、非线性规划 与 0/1规划

@1、背景举例:Lingo软件:非线性规划 lingo :lingo code:solve:0/1 规划 线性:举例:分析:模型:lingo code:结果:

电商产品“信任模型:引入社区互动元素

文章作者基于自己的思考和理解,为大家分享下网上购物(电商)产品最核心的设计原则-信任模型。最近唯品会与京东、腾讯成为战略合作伙

详解Cassandra数据模型中的primary key

Primary key的基本使用方法 Primary key的基本使用方法同关系型数据库中的primary key基本相同,既用来作为某一行数据的主键。我

客户流失预测模型,如何进行效果评估

一、一个重要指标:提升度用来评估客户流失预测模型预测效果好坏的一个重要指标,就是提升度。所谓提升度,简单来说,使用模型预测客户流

分享到:

栏目导航

推荐阅读

热门阅读