|
本帖最后由 xingpengwei 于 2017-7-20 10:15 编辑
为了方便使用对libsvm包进行了小改,过程如下:
之前见guojiasheng学长改了libsvm,输出F1等指标,我check了一下:
- svm_cross_validation(&prob,¶m,nr_fold,target);
- //F1 准确率和召回率的调和均值 :这里只能支持二分类,并且正类为的lable应该转为1 负类为-1
- int tp = 0 ,fp = 0, fn = 0 , tn = 0;
- for(i=0;i<prob.l;i++){
- if(prob.y[i] == 1 && target[i] == 1) tp++; //正类判定为正类
- if(prob.y[i] == 1 && target[i] < 1) fn++; //正类判定为负类
- if(prob.y[i] < 1 && target[i] == 1) fp++; //负类判定为正类
- if(prob.y[i] < 1 && target[i] < 1) tn++; //负类判定为负类
- }
- //precision tp/(tp+fp)
- double precision = 1.0*tp/(tp+fp);
- //recall tp/(tp+fn)
- double recall = 1.0*tp/(tp+fn);
- //F1
- double F1 = (2.0*tp)/(2.0*tp+fp+fn);
- printf("Cross Validation precision = %g%%\n",100.0*precision);
- printf("Cross Validation recall = %g%%\n",100.0*recall);
- printf("Cross Validation F1 = %g%%\n",100.0*F1);
复制代码
值得注意的是根据学长的代码输入的二分类label应该为+1和-1,如果你用weka的arff转的libsvm文件的话,label一般为0和1(0正1负)这样根据代码的计算tp、fn、fp、tn的顺序是会正好相反的,所以使用前请保证label为1和-1表示正负类,这点新生值得注意。
后来我又改了tool里的easy.py,另见了一个文件myeasy_F1.py,主要的功能是使用F1作为优化结果的标准,因为之前有发现效果会变好一点。
还改了输出,增加了一些有需求的输出,具体叫以下输出例子
- ./grid_F1.py -v 10 -svmtrain "../svm-train" "mix_npps.libsvm"
- Best c=512.0, g=0.001953125 CV F1=80.2937
- ../svm-train -c 512.0 -g 0.001953125 -v 10 "mix_npps.libsvm" "mix_npps.libsvm.model"
- Training...
- precision,recall,F1,Accuracy=81.1085,79.495,80.2937,80.4897
- 2614
- Sn,Sp,Acc,MCC = 0.794950267789,0.814843152257,0.804896710023,0.609914111683
- Output model: mix_npps.libsvm.model
复制代码
输出了你运行的命令,做好的参数,最好的F1,precision,recall,F1,Accuracy,和总样例个数,Sn,Sp,Acc,MCC和模型
后来,因为想在matlab里直接使用libsvm的十折验证,不至于每次都在转到libsvm包文件夹里去使用,并把混淆矩阵输了出来,做了如下优化:
我修改了matlab文件夹下的svmtrain的c文件:
在int * do_cross_validation()函数中添加了tp,fn,fp,tn的计算(参考了guojiasheng学长的思路):
- int tp = 0 ,fp = 0, fn = 0 , tn = 0;
- for(i=0;i<prob.l;i++){
- if(prob.y[i] == 1 && target[i] == 1) tp++; //正类判定为正类
- if(prob.y[i] == 1 && target[i] < 1) fn++; //正类判定为负类
- if(prob.y[i] < 1 && target[i] == 1) fp++; //负类判定为正类
- if(prob.y[i] < 1 && target[i] < 1) tn++; //负类判定为负类
- }
-
- printf("%d %d %d %d\n",tp,fn,fp,tn);
- int confu[4] = {tp,fn, fp, tn};
复制代码
又修改了对外输出的接口:
- if(cross_validation)
- {
- int *ptr;
- plhs[0] = mxCreateNumericMatrix(1, 4, mxINT32_CLASS, mxREAL);
- ptr = (int *)mxGetData(plhs[0]);
-
- // double *ptr;
- // plhs[0] = mxCreateDoubleMatrix(1, 4, mxREAL);
- // ptr = mxGetPr(plhs[0]);
- ptr[0] = do_cross_validation()[0];
- ptr[1] = do_cross_validation()[1];
- ptr[2] = do_cross_validation()[2];
- ptr[3] = do_cross_validation()[3];
- }
复制代码
使用方法:
- confu=svmtrain(label,feat,'-v 10');
- confu=double(confu);
- TP=confu(1);
- FN=confu(2);
- FP=confu(3);
- TN=confu(4);
-
- Sn=TP/(TP+FN);
- Sp=TN/(TN+FP);
- Acc=(TP+TN)/(TP+FN+FP+TN);
- MCC=(TP*TN-FN*FP)/sqrt((TP+FP)*(TN+FN)*(TP+FN)*(TN+FP));
复制代码
(和原来输入参数一样,只是输出由只输出一个准确率,变成输出混淆矩阵,后边的代码为如何利用混淆矩阵计算各项指标)
注意:使用matlab2015b编译C,(就别装VS了{:254:} )最好用4.9的mingw编译器。 在2015版本及以上可在add-ones上添加mingw,这样安装最方便。或者直接下载minggw.mlpkginstall,双击matlab自动关联打开,安装相应编译器。如果用matlab使用mingw编译c++文件的话,make.m文件是需要把make.m文件下的CFLAGS 替换成COMPFLAGS。
编译好的文件的话,附件下载。 |
|