机器学习和生物信息学实验室联盟
标题:
关于小改libsvm的功能,方便一键网格、输出各项指标、matlab接口修改等
[打印本页]
作者:
xingpengwei
时间:
2017-6-24 14:23
标题:
关于小改libsvm的功能,方便一键网格、输出各项指标、matlab接口修改等
本帖最后由 xingpengwei 于 2017-7-20 10:15 编辑
为了方便使用对libsvm包进行了小改,过程如下:
之前见guojiasheng学长改了libsvm,输出F1等指标,我check了一下:
svm_cross_validation(&prob,¶m,nr_fold,target);
//F1 准确率和召回率的调和均值 :这里只能支持二分类,并且正类为的lable应该转为1 负类为-1
int tp = 0 ,fp = 0, fn = 0 , tn = 0;
for(i=0;i<prob.l;i++){
if(prob.y[i] == 1 && target[i] == 1) tp++; //正类判定为正类
if(prob.y[i] == 1 && target[i] < 1) fn++; //正类判定为负类
if(prob.y[i] < 1 && target[i] == 1) fp++; //负类判定为正类
if(prob.y[i] < 1 && target[i] < 1) tn++; //负类判定为负类
}
//precision tp/(tp+fp)
double precision = 1.0*tp/(tp+fp);
//recall tp/(tp+fn)
double recall = 1.0*tp/(tp+fn);
//F1
double F1 = (2.0*tp)/(2.0*tp+fp+fn);
printf("Cross Validation precision = %g%%\n",100.0*precision);
printf("Cross Validation recall = %g%%\n",100.0*recall);
printf("Cross Validation F1 = %g%%\n",100.0*F1);
复制代码
值得注意的是根据学长的代码输入的二分类label应该为+1和-1,如果你用weka的arff转的libsvm文件的话,label一般为0和1(0正1负)这样根据代码的计算tp、fn、fp、tn的顺序是会正好相反的,所以使用前请保证label为1和-1表示正负类,这点新生值得注意。
后来我又改了tool里的easy.py,另见了一个文件myeasy_F1.py,主要的功能是使用F1作为优化结果的标准,因为之前有发现效果会变好一点。
还改了输出,增加了一些有需求的输出,具体叫以下输出例子
./grid_F1.py -v 10 -svmtrain "../svm-train" "mix_npps.libsvm"
Best c=512.0, g=0.001953125 CV F1=80.2937
../svm-train -c 512.0 -g 0.001953125 -v 10 "mix_npps.libsvm" "mix_npps.libsvm.model"
Training...
precision,recall,F1,Accuracy=81.1085,79.495,80.2937,80.4897
2614
Sn,Sp,Acc,MCC = 0.794950267789,0.814843152257,0.804896710023,0.609914111683
Output model: mix_npps.libsvm.model
复制代码
输出了你运行的命令,做好的参数,最好的F1,precision,recall,F1,Accuracy,和总样例个数,Sn,Sp,Acc,MCC和模型
后来,因为想在matlab里直接使用libsvm的十折验证,不至于每次都在转到libsvm包文件夹里去使用,并把混淆矩阵输了出来,做了如下优化:
我修改了matlab文件夹下的svmtrain的c文件:
在int * do_cross_validation()函数中添加了tp,fn,fp,tn的计算(参考了guojiasheng学长的思路):
int tp = 0 ,fp = 0, fn = 0 , tn = 0;
for(i=0;i<prob.l;i++){
if(prob.y[i] == 1 && target[i] == 1) tp++; //正类判定为正类
if(prob.y[i] == 1 && target[i] < 1) fn++; //正类判定为负类
if(prob.y[i] < 1 && target[i] == 1) fp++; //负类判定为正类
if(prob.y[i] < 1 && target[i] < 1) tn++; //负类判定为负类
}
printf("%d %d %d %d\n",tp,fn,fp,tn);
int confu[4] = {tp,fn, fp, tn};
复制代码
又修改了对外输出的接口:
if(cross_validation)
{
int *ptr;
plhs[0] = mxCreateNumericMatrix(1, 4, mxINT32_CLASS, mxREAL);
ptr = (int *)mxGetData(plhs[0]);
// double *ptr;
// plhs[0] = mxCreateDoubleMatrix(1, 4, mxREAL);
// ptr = mxGetPr(plhs[0]);
ptr[0] = do_cross_validation()[0];
ptr[1] = do_cross_validation()[1];
ptr[2] = do_cross_validation()[2];
ptr[3] = do_cross_validation()[3];
}
复制代码
使用方法:
confu=svmtrain(label,feat,'-v 10');
confu=double(confu);
TP=confu(1);
FN=confu(2);
FP=confu(3);
TN=confu(4);
Sn=TP/(TP+FN);
Sp=TN/(TN+FP);
Acc=(TP+TN)/(TP+FN+FP+TN);
MCC=(TP*TN-FN*FP)/sqrt((TP+FP)*(TN+FN)*(TP+FN)*(TN+FP));
复制代码
(和原来输入参数一样,只是输出由只输出一个准确率,变成输出混淆矩阵,后边的代码为如何利用混淆矩阵计算各项指标)
注意:使用matlab2015b编译C,(就别装VS了{:254:} )最好用4.9的mingw编译器。 在2015版本及以上可在add-ones上添加mingw,这样安装最方便。或者直接下载
minggw.mlpkginstall
,双击matlab自动关联打开,安装相应编译器。如果用matlab使用mingw编译c++文件的话,make.m文件是需要把make.m文件下的CFLAGS 替换成COMPFLAGS。
编译好的文件的话,附件下载。
作者:
xingpengwei
时间:
2017-6-24 15:26
文件大于16m,无法上传附件,回头找个地方上传吧
作者:
zouquan
时间:
2017-6-29 10:11
赞啊,放到99服务器上,写好试用说明,最好傻瓜式的
欢迎光临 机器学习和生物信息学实验室联盟 (http://123.57.240.48/)
Powered by Discuz! X3.2