机器学习和生物信息学实验室联盟

 找回密码
 注册

QQ登录

只需一步,快速开始

搜索
查看: 6573|回复: 0
打印 上一主题 下一主题

【新生指南】Matlab StackAutoEncoder和Deep Belief Networks的使用调参和批处理

[复制链接]
跳转到指定楼层
楼主
发表于 2017-6-24 13:56:49 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
本帖最后由 xingpengwei 于 2017-6-24 13:54 编辑

        本帖只针对新生尽快熟悉工作做参考,我会写得啰嗦些,大神绕道{:243:} 。

        生物信息主要处理一维字符序列,如果想转化成矩阵处理,可使用onehot编码,快速理解就是01编码每个字符类别,如果序列只有AUGC四种字符,那么每个字符用0001,0010,0100,1000编码代替。这样序列就变成了Nx4的矩阵。
        StackAutoEncoder,即Stack自动编码机,一种最简单的无监督学习深度网络,首先我们来理解一下无监督学习:
        (引自http://blog.csdn.net/zouxy09/article/details/8775524
       
        如图,输入的样本是有标签的,即(input, target),我们根据当前输出和target(label)之间的差去改变前面各层的参数,直到收敛。
        那么无标签怎么得到误差呢?
       
        如上图,我们将input输入一个encoder编码器,就会得到一个code,这个code也就是输入的一个表示,那么我们怎么知道这个code表示的就是input呢?我们加一个decoder解码器,这时候decoder就会输出一个信息,那么如果输出的这个信息和一开始的输入信号input是很像的(理想情况下就是一样的),那很明显,我们就有理由相信这个code是靠谱的。所以,我们就通过调整encoder和decoder的参数,使得重构误差最小,这时候我们就得到了输入input信号的第一个表示了,也就是编码code了。因为是无标签数据,所以误差的来源就是直接重构后与原输入相比得到。
        Stack自动编码机的话,就是堆叠上述层,上一层的输入作为下一层的输入。
        看到这里,我们知道了StackAutoEncoder就是(encoder->decoder)-(encoder->decoder)-(encoder->decoder)...

        当然还有一些变形和其他知识,这里主要讲应用,有不看原理心忐忑的伙伴有时间后可以参考以下教程。
        斯坦福的UFLDL教程算是根正苗红的深度学习基础教程,或者是Andrew Ng的视频
        那么,我们来上代码:
        matlab自带有StackAutoEncoder的例子,首先我们来实现onehot编码:
  1. [tits, seqs]=fastaread('p1_n1_H.txt');
  2. num=length(tits);
  3. len=length(char(seqs(1)));

  4. Y=zeros(2,num);
  5. X=cell(1,num);
  6. %onehot
  7. for i=1:num
  8.    
  9.     str=char(tits(i));
  10.     temp=strsplit(str,'|');
  11.     label=char(temp(2));
  12.     if label=='1'
  13.         Y(:,i)=[1,0];
  14.     else
  15.         Y(:,i)=[0,1];
  16.     end
  17.    
  18.    
  19.     img=zeros(len,4);
  20.     for j=1:len
  21.         seq=char(seqs(1,i));
  22.         switch seq(j)
  23.             case 'A'
  24.                 img(j,:)=[0,0,0,1];
  25.             case 'C'
  26.                 img(j,:)=[0,0,1,0];
  27.             case 'G'
  28.                 img(j,:)=[0,1,0,0];
  29.             case 'U'
  30.                 img(j,:)=[1,0,0,0];
  31.             case 'T'
  32.                 img(j,:)=[1,0,0,0];
  33.         end
  34.     end
  35.      X{i}=img;
复制代码

        以上代码基本就是01编码代替序列字符。
        主体StackAutoEncoder代码参考https://cn.mathworks.com/help/nn ... classification.html
        你只需要替换相关输入,改图片长宽等。

        我们主要讲Deep Belief Networks深度信念网络:
        深度信念网络也是stack层,只是stack的是RBM,Restricted Boltzmann Machine (RBM)限制波尔兹曼机。
        理论参考http://blog.csdn.net/zouxy09/article/details/8781396 讲得挺好。
        这里我们使用matlab第三方的工具包:DeeBNet (Deep Belief Networks) toolbox in MATLAB and Octave
        工具包里边有很多例子脚本,我们主要使用test_getFeatureMNIST.m其他也行:
        首先我们要做的依然是把我们的数据结构化成这个脚本接受的形式:
  1. [tits, seqs]=fastaread('+SCM6A/p1_n1_H.txt');
  2. num=length(tits);
  3. len=length(char(seqs(1)));

  4. Y=zeros(num,1);
  5. X=zeros(num,len*4);
  6. %onehot
  7. for i=1:num
  8.    
  9.     str=char(tits(i));
  10.     temp=strsplit(str,'|');
  11.     label=char(temp(2));
  12.     if label=='1'
  13.         Y(i)=1;
  14.     else
  15.         Y(i)=2;
  16.     end
  17.    
  18.    
  19.     img=zeros(len,4);
  20.     for j=1:len
  21.         seq=char(seqs(1,i));
  22.         switch seq(j)
  23.             case 'A'
  24.                 img(j,:)=[0,0,0,1];
  25.             case 'C'
  26.                 img(j,:)=[0,0,1,0];
  27.             case 'G'
  28.                 img(j,:)=[0,1,0,0];
  29.             case 'U'
  30.                 img(j,:)=[1,0,0,0];
  31.             case 'T'
  32.                 img(j,:)=[1,0,0,0];
  33.         end
  34.     end
  35.      X(i,:)=img(:);
  36.    
  37. end

  38. inputs = X;
  39. targets = Y;
  40. %custom your split_point by multiplier
  41. split_point=round(size(inputs,1)*0.7);
  42. ranindex=randperm(size(inputs,1));

  43. data=inputs(ranindex(1:split_point),:);
  44. labels=targets(ranindex(1:split_point),:);

  45. testdata=inputs(ranindex(split_point+1:end),:);
  46. testlabels=targets(ranindex(split_point+1:end),:);

  47. save('SCM6A.mat','data','labels','testdata','testlabels');
复制代码

        后边的代码是生成一个随机种子然后分割训练测试集。

        下边是主体脚本:
        如果不包装网格调参的设定的话,代码是这样的:
  1. dbn=DBN();
  2. dbn.dbnType='autoEncoder';
  3. % RBM1
  4. rbmParams=RbmParameters(1000,ValueType.binary);
  5. rbmParams.maxEpoch=50;
  6. rbmParams.samplingMethodType=SamplingClasses.SamplingMethodType.CD;
  7. rbmParams.performanceMethod='reconstruction';
  8. dbn.addRBM(rbmParams);
  9. % RBM2
  10. rbmParams=RbmParameters(500,ValueType.binary);
  11. rbmParams.maxEpoch=50;
  12. rbmParams.samplingMethodType=SamplingClasses.SamplingMethodType.CD;
  13. rbmParams.performanceMethod='reconstruction';
  14. dbn.addRBM(rbmParams);
  15. % RBM3
  16. rbmParams=RbmParameters(250,ValueType.binary);
  17. rbmParams.maxEpoch=50;
  18. rbmParams.samplingMethodType=SamplingClasses.SamplingMethodType.CD;
  19. rbmParams.performanceMethod='reconstruction';
  20. dbn.addRBM(rbmParams);
  21. RBM4
  22. rbmParams=RbmParameters(200,ValueType.gaussian);
  23. rbmParams.maxEpoch=50;
  24. rbmParams.samplingMethodType=SamplingClasses.SamplingMethodType.CD;
  25. rbmParams.performanceMethod='reconstruction';
  26. dbn.addRBM(rbmParams);

  27. dbn.train(data);
  28. dbn.backpropagation(data,'yes');
复制代码

        Stack了四层RBM,每一层设置了一些参数,设置好后addRBM,然后训练,然后fine turn。
       
        为了能网格调参,我们需要把这部分的参数设置成变量的形式,然后包裹在几层for循环中(简单粗暴法),我们先设定一堆参数,如需修改参数范围,修改这部分就可以       
  1. RbmNode1=[250,100];
  2. RbmNode2=[500,250,100];
  3. RbmNode3=[1000,500,250,100];
  4. GridSearch_nodes={RbmNode1,RbmNode2,RbmNode3};
  5. GridSearch_maxEpoch=[200];
  6. GridSearch_learningRate=[0.1,0.01,0.001];
  7. GridSearch_BPorNot=[1,0];
复制代码

        然后把单层包裹着for循环中:
       
  1. fid = fopen('out.txt','w');
  2. for k=1:length(GridSearch_maxEpoch);
  3.     for l=1:4%GridSearch_SamplingMethodType
  4.         for m=1:length(GridSearch_learningRate)
  5.             for n=1:length(GridSearch_BPorNot)
  6.                
  7.                 for i=1:length(GridSearch_nodes)
  8.                     for times=1:3
  9.                         dbn=DBN();
  10.                         dbn.dbnType='autoEncoder';
  11.                         for j=1:length( GridSearch_nodes{i})
  12.                            
  13.                             rbmParams=RbmParameters(GridSearch_nodes{i}(j),ValueType.binary);
  14.                             rbmParams.maxEpoch=GridSearch_maxEpoch(k);
  15.                             rbmParams.samplingMethodType=l;
  16.                             rbmParams.performanceMethod='reconstruction';
  17.                             rbmParams.learningRate=GridSearch_learningRate(m);
  18.                             dbn.addRBM(rbmParams);
  19.                         end
  20.                         dbn.train(data);
  21.                         if GridSearch_BPorNot(n)==1
  22.                             dbn.backpropagation(data,'yes');
  23.                         end
  24.                         
  25.                         load('SCM6A_nosh.mat');
  26.                         feat=dbn.getFeature(inputs);
  27.                         label=targets-1;
  28.                         label(label==1)=-1;
  29.                         label(label==0)=1;
  30.                         
  31.                         confu=svmtrain(label,feat,'-v 10');
  32.                         confu=double(confu);
  33.                         TP=confu(1);
  34.                         FN=confu(2);
  35.                         FP=confu(3);
  36.                         TN=confu(4);
  37.                         
  38.                         Sn=TP/(TP+FN);
  39.                         Sp=TN/(TN+FP);
  40.                         Acc=(TP+TN)/(TP+FN+FP+TN);
  41.                         MCC=(TP*TN-FN*FP)/sqrt((TP+FP)*(TN+FN)*(TP+FN)*(TN+FP));
  42.                         
  43.                         fprintf(fid, '%s %s %s %s %s %s %s %s %s %s %s %s %s %s\n','times','GridSearch_nodes','maxEpoch','SamplingMethodType','learningRate','BPorNot','Sn','Sp','Acc','MCC','TP','FN','FP','TN');
  44.                         fprintf(fid, '%d %s %d %d %d %d %.4f %.4f %.4f %.4f %d %d %d %d\n',times,mat2str(GridSearch_nodes{i}(:)),GridSearch_maxEpoch(k),l,GridSearch_learningRate(m),GridSearch_BPorNot(n),Sn,Sp,Acc,MCC,TP,FN,FP,TN);
  45.                         
  46.                         
  47.                         
  48.                         feat_label=[feat,targets-1];
  49.                         csvwrite('feat_temp.csv',feat_label);
  50.                         
  51.                         S = fileread('feat_temp.csv');
  52.                         feat_dim=100;
  53.                         attribute='';
  54.                         for index=1:feat_dim
  55.                             attribute=[attribute,'@attribute Fea',int2str(index),' numeric',char(10)];
  56.                         end
  57.                         header=['@relation m6a_dbn',char(10),attribute,'@attribute class {0,1}',char(10),'@data',char(10)];
  58.                         
  59.                         S = [header,  S];
  60.                         name=sprintf('%s-%d-%d-%d-%d-%.4f-%.4f-%.4f-%.4f-%d-%d-%d-%d-%d',mat2str(GridSearch_nodes{i}(:)),GridSearch_maxEpoch(k),l,GridSearch_learningRate(m),GridSearch_BPorNot(n),Sn,Sp,Acc,MCC,TP,FN,FP,TN,times);
  61.                         FID = fopen([name,'.arff'], 'w');
  62.                         if FID == -1, error('Cannot open file'); end
  63.                         fwrite(FID, S, 'char');
  64.                         fclose(FID);
  65.                     end
  66.                 end
  67.             end
  68.         end
  69.     end
  70. end
  71. fclose(fid);
复制代码

        当然,我们也利用这个训练批量地
        使用修改编译后的libsvm十折交叉验证:
  1. confu=svmtrain(label,feat,'-v 10');
复制代码

        这一行直接可以输出混淆矩阵,默认的libsvm包是十折的话只输出一个准确度,这是蛋疼,
        具体参考关于libsvm包优化的那一篇

        计算了准确率等论文可以直接使用的指标:
  1. confu=double(confu);
  2.                         TP=confu(1);
  3.                         FN=confu(2);
  4.                         FP=confu(3);
  5.                         TN=confu(4);
  6.                         
  7.                         Sn=TP/(TP+FN);
  8.                         Sp=TN/(TN+FP);
  9.                         Acc=(TP+TN)/(TP+FN+FP+TN);
  10.                         MCC=(TP*TN-FN*FP)/sqrt((TP+FP)*(TN+FN)*(TP+FN)*(TN+FP));
复制代码

        生成保存了arff文件以[/code]便后续使用:
               
  1. feat_label=[feat,targets-1];
  2.                         csvwrite('feat_temp.csv',feat_label);
  3.                         
  4.                         S = fileread('feat_temp.csv');
  5.                         feat_dim=100;
  6.                         attribute='';
  7.                         for index=1:feat_dim
  8.                             attribute=[attribute,'@attribute Fea',int2str(index),' numeric',char(10)];
  9.                         end
  10.                         header=['@relation m6a_dbn',char(10),attribute,'@attribute class {0,1}',char(10),'@data',char(10)];
  11.                         
  12.                         S = [header,  S];
  13.                         name=sprintf('%s-%d-%d-%d-%d-%.4f-%.4f-%.4f-%.4f-%d-%d-%d-%d-%d',mat2str(GridSearch_nodes{i}(:)),GridSearch_maxEpoch(k),l,GridSearch_learningRate(m),GridSearch_BPorNot(n),Sn,Sp,Acc,MCC,TP,FN,FP,TN,times);
  14.                         FID = fopen([name,'.arff'], 'w');
  15.                         if FID == -1, error('Cannot open file'); end
  16.                         fwrite(FID, S, 'char');
  17.                         fclose(FID);
复制代码

        先就酱。
分享到:  QQ好友和群QQ好友和群 QQ空间QQ空间 腾讯微博腾讯微博 腾讯朋友腾讯朋友
收藏收藏 转播转播 分享分享
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 注册

本版积分规则

机器学习和生物信息学实验室联盟  

GMT+8, 2024-5-17 09:56 , Processed in 0.068338 second(s), 22 queries .

Powered by Discuz! X3.2

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表