本帖最后由 lyqmath 于 2019-10-11 15:09 编辑
Alexnet是经典的CNN网络结构,广泛应用于图像分类相关领域,这里我们基于deepNetworkDesigner进行网络结构查看、分析、编辑,形成自定义的网络结构。
在R2019环境的command窗口输入如下命令,加载网络。
- >> net = alexnet;
- >> net
- net =
- SeriesNetwork - 属性:
- Layers: [25×1 nnet.cnn.layer.Layer]
- >>
复制代码
如果还未配置该工具箱,将提示Add-ons操作,根据相关步骤指引安装即可。
输入如下命令,打开网络编辑对话框。
- >> deepNetworkDesigner
- >>
复制代码
点击Import按钮,读取已加载的网络结构变量,具体如下所示。
可以发现,输入的图片维数为227*227*3,输出的维数为1000*1
我们这里采用最简单的网络编辑策略,将输出类别进行修改,将其设置为5*1的输出,也就是对应于“五类”的图片分类应用。
可以发现,我们这里是对全连接层、输出层进行的修改,这样能够确保输出的分类数与预设的分类数能对应起来。
最后,我们点击智能分析按钮,对网络进行自动检查。
这样,我们就得到了25层的自定义AlexNet,对应于设置的五类别分类应用,点击执行导出按钮,可以将其导出到变量、文件等形式,这里我们将其生成代码,便于自定义修改。
将自动生成的代码进行组合,编辑到m文件,具体如下所示。
- clc; clear all; close all;
- layers = [
- imageInputLayer([227 227 3],"Name","data")
- convolution2dLayer([11 11],96,"Name","conv1","BiasLearnRateFactor",2,"Stride",[4 4])
- reluLayer("Name","relu1")
- crossChannelNormalizationLayer(5,"Name","norm1","K",1)
- maxPooling2dLayer([3 3],"Name","pool1","Stride",[2 2])
- groupedConvolution2dLayer([5 5],128,2,"Name","conv2","BiasLearnRateFactor",2,"Padding",[2 2 2 2])
- reluLayer("Name","relu2")
- crossChannelNormalizationLayer(5,"Name","norm2","K",1)
- maxPooling2dLayer([3 3],"Name","pool2","Stride",[2 2])
- convolution2dLayer([3 3],384,"Name","conv3","BiasLearnRateFactor",2,"Padding",[1 1 1 1])
- reluLayer("Name","relu3")
- groupedConvolution2dLayer([3 3],192,2,"Name","conv4","BiasLearnRateFactor",2,"Padding",[1 1 1 1])
- reluLayer("Name","relu4")
- groupedConvolution2dLayer([3 3],128,2,"Name","conv5","BiasLearnRateFactor",2,"Padding",[1 1 1 1])
- reluLayer("Name","relu5")
- maxPooling2dLayer([3 3],"Name","pool5","Stride",[2 2])
- fullyConnectedLayer(4096,"Name","fc6","BiasLearnRateFactor",2)
- reluLayer("Name","relu6")
- dropoutLayer(0.5,"Name","drop6")
- fullyConnectedLayer(4096,"Name","fc7","BiasLearnRateFactor",2)
- reluLayer("Name","relu7")
- dropoutLayer(0.5,"Name","drop7")
- fullyConnectedLayer(5,"Name","fc_lyq","BiasLearnRateFactor",2)
- softmaxLayer("Name","prob")
- classificationLayer("Name","classoutput_lyq")];
- plot(layerGraph(layers));
复制代码
这就是本帖子要网络的结构示意图,下面的帖子我们将结合图片库样本进行训练,测试识别情况。
|