搜索
您的当前位置:首页正文

caffe for windows的python接口学习(1):分析caffenet.py

来源:赴品旅游

首先我尝试在spyder中写caffe的配置文件,结果出现很多”看不懂“的错误。提示说在examples/pycaffe/caffenet.py中提供了一种直接在Python中编写网络的方法。

所以我决定先尝试理解一下这个实例文件。

首先是import:

from __future__ import print_function
from caffe import layers as L, params as P, to_proto
from caffe.proto import caffe_pb2

然后是一些辅助函数:其中L和P有其自带的一些函数调用。

def conv_relu(bottom, ks, nout, stride=1, pad=0, group=1):
    conv = L.Convolution(bottom, kernel_size=ks, stride=stride, num_output=nout, pad=pad, group=group)
    return conv, L.ReLU(conv, in_place=True)

def fc_relu(bottom, nout):
    fc = L.InnerProduct(bottom, num_output=nout)
    return fc, L.ReLU(fc, in_place=True)


def max_pool(bottom, ks, stride=1):
    return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)


def caffenet(lmdb, batch_size=256, include_acc=False):
    data, label = L.Data(source=lmdb, backend=P.Data.LMDB, batch_size=batch_size, ntop=2, transform_param=dict(crop_size=227, mean_value=[104, 117, 123], mirror=True))


然后是定义这个网络本身:

    conv1, relu1 = conv_relu(data, 11, 96, stride=4)
    pool1 = max_pool(relu1, 3, stride=2)
    norm1 = L.LRN(pool1, local_size=5, alpha=1e-4, beta=0.75)
    conv2, relu2 = conv_relu(norm1, 5, 256, pad=2, group=2)
    pool2 = max_pool(relu2, 3, stride=2)
    norm2 = L.LRN(pool2, local_size=5, alpha=1e-4, beta=0.75)
    conv3, relu3 = conv_relu(norm2, 3, 384, pad=1)
    conv4, relu4 = conv_relu(relu3, 3, 384, pad=1, group=2)
    conv5, relu5 = conv_relu(relu4, 3, 256, pad=1, group=2)
    pool5 = max_pool(relu5, 3, stride=2)
    fc6, relu6 = fc_relu(pool5, 4096)
    drop6 = L.Dropout(relu6, in_place=True)
    fc7, relu7 = fc_relu(drop6, 4096)
    drop7 = L.Dropout(relu7, in_place=True)
    fc8 = L.InnerProduct(drop7, num_output=1000)
    loss = L.SoftmaxWithLoss(fc8, label)

    if include_acc:
        acc = L.Accuracy(fc8, label)
        return to_proto(loss, acc)
    else:
        return to_proto(loss)

定义一个生成网络的函数:

def make_net():
    with open('train.prototxt', 'w') as f:
        print(caffenet('/path/to/caffe-train-lmdb'), file=f)


    with open('test.prototxt', 'w') as f:
        print(caffenet('/path/to/caffe-val-lmdb', batch_size=50, include_acc=True), file=f)


if __name__ == '__main__':
    make_net()

在path路径中是创建了一个自己的数据文件,放入前面转换过的两份caffe-train-lmdb和caffe-val-lmdb。运行过后开始找不到文件,提示一些警告。

比较绝望的时候,在保存.py文件的位置找到了生成的两个配置文件。


因篇幅问题不能全部显示,请点此查看更多更全内容

Top