|
本帖最后由 guojiasheng 于 2016-10-6 19:59 编辑
GBDT
这边暂时就不对gbdt进行理论性说明了。这个在工业界用的比较多,大家可以多google看papaer
github:
https://github.com/guojiasheng/gbdt
补充说一下理论,不对请指出:
集成学习就是通过学习多个普通的基分类器,然后将这些基分类器进行某种加权组合,比如简单投票,概率投票等。
(1)bagging: 就是采用有放回的抽样,美名其约 bootstrap。每次抽出的训练集大小和原数据集一样,但是有重复的和遗漏的样本。
比如抽取了k个训练集,然后训练这k个,然后组合投票。
(2)RF:其实很简单,但是效果非常好,就是在bagging的基础上做了一丢丢改动。首先它是森林,所以基分类器就是树。改进的就是树的
节点属性每次是随机选取k个,然后从k个里面找个最好的属性出来,其它和bagging一样。
(3)boosting:提升方法,每一轮训练都提升分类器性能,简单明了。拿adaboost来看,每一次迭代都是提高错分样本权重,这样每一次迭代都会性能有所提高。
(4)gdbt:就是梯度下降树,归到boosting里面,每一次学习都是为了减少上一次学习的残差,为了消除这个残差或者减小这个残差,我们就可以在残差减小的梯度上建立一个新模型,以这个残差为loss function,然后梯度下降法,降低残差。、
这边我首先介绍使用python的的。
目录:
/Bioinformatics_Machine_Learning/Machine_Learning/classifier/GBDT
(1)基于scikit-learn机器学习库的。
/scikit-learn 可以参考网站:http://scikit-learn.org/stable/m ... _svmlight_file.html
这边我手写了一个方便大家使用,如有问题请多多指教:
1.交叉验证看效果:
- Usage: gbdt.py [-ls (loss: deviance,exponential),-lr(learning_rate 0.1),-ns(n_estimators 100),-md(max_depth 3),-sub(subsample 1),-cv (10)] dataset
复制代码
参数说明:
ls: loss默认是deviance,包括(deviance、exponential),deviance 其实就是logistic regression,exponential 指数损失函数,其实这时候就gbdt就变成了adaboost算法了。
lr: learning_rate,代表每个tree的学习速率,默认是0.1,具体可以自己试验看效果
ns : n_estimators,代表 boosting的次数,默认是100,次数越多效果肯定越好,而且也不用担心over-fitting的问题,因为gbdt很robust。其实这部boosting的次数,我们可以认为就是RF里面tree的个数。当然这里面其实不太一样。
md: max_depth, 代表每一棵树的深度, 默认是3.
sub: subsample,每次随机选取训练数据集的比率,默认是1.就是全部用来训练。如果sub小于1的话,就变成stochastic gradient boosting.
cv: cross Validation,交叉验证的大小,这个k-fold要设置.
p: 预测文件路径
dataset : 训练数据集,注意这边目前只输入libsvm格式的数据
只要dataset放最后面,其他参数顺序无所谓!效果不好,就ns设高点~
command:
(1) python gbdt.py -cv 10 heart_scale
(2) python gbdt.py -ns 100 -md 5 -cv 10 heart_scale
output:
(1)会输出交叉验证的Acc
(2)为了应对不平衡数据,还输出了Auc。
(3)输出混淆矩阵,虽然样式没那么好看,但是还是能ok的。
(4)输出feature的重要性打分,这个还是挺有用的,分值越高表示这个feature的在这个过程起到的重要性越大,(是不是感觉这样论文就有的分析了,经过分析哪些特征起到了重要作用...).这边我输出的是个数组,按顺序输出每一维特征的重要程度)
2.对文件进行predict:
当我们有测试文件 test 和 训练文化train的时候,我们用train训练model,用这个model来预测test的文件,输出文件为当前目录下的predict文件。
python gbdt.py -p testFile trainFile
输出的文件为predict,包括原始lable 预测lable 以及不同lable的概率打分
guojiasheng@iZ22bwiomoaZ:/Bioinformatics_Machine_Learning/Machine_Learning/classifier/GBDT/scikit-learn$ python gbdt.py -cv 10 heart_scale
10 cross validation result
ACC:0.8
AUC:0.875
confusion_matrix
[[124 26]
[ 28 92]]
The feature importances (the higher, the more important the feature)
[ 0.15695588 0.02693622 0.0746598 0.1304533 0.20291259 0.00553532
0.01781097 0.10505819 0.0114092 0.09309421 0.02433416 0.08534509
0.06549508]
|
|