caffe2 tutorial: CIFAR-10: Part 1


CIFARとはCanadian Institute for Advanced Researchの略で、機械学習の世界では有名な画像データ・セットらしい。caffe2のチュートリアルは非常に少なく、基本的には本家サイトのチュートリアルしかないと言っても決して過言ではない。tensorflowとかだとネット上にtutorialがゴロゴロ転がっているのに、caffe2はほとんどない。アメリカの大学サイトを見てもkeras + tensorflowがdefacto standardになっている感じなので、時代の潮流はkeras+tensorflowなのかもしれない。




from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import os
import lmdb
import shutil
from imageio import imread
import caffe2.python.predictor.predictor_exporter as pe
from caffe2.proto import caffe2_pb2
from caffe2.python.predictor import mobile_exporter
from caffe2.python import (

# If you would like to see some really detailed initializations,
# you can change --caffe2_log_level=0 to --caffe2_log_level=-1
core.GlobalInit(['caffe2', '--caffe2_log_level=0'])
print("Necessities imported!")
net_drawer will not run correctly. Please install the correct dependencies.
Necessities imported!

net_drawer will not run correctly.←このエラーを解消するにはpydotをインストールするといいらしい。

CIFAR-10 .pngデータのダウンロード


import requests
import tarfile

# Set paths and variables
# data_folder is where the data is downloaded and unpacked
data_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10')
# root_folder is where checkpoint files and .pb model definition files will be outputted
root_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_files', 'tutorial_cifar10')

url = ""   # url to data
filename = url.split("/")[-1]                       # download file name
download_path = os.path.join(data_folder, filename) # path to extract data to

# Create data_folder if not already there
if not os.path.isdir(data_folder):

# If data does not already exist, download and extract
if not os.path.exists(download_path.strip('.tgz')):
    # Download data
    r = requests.get(url, stream=True)
    print("Downloading... {} to {}".format(url, download_path))
    open(download_path, 'wb').write(r.content)
    print("Finished downloading...")

    # Unpack images from tgz file
    print('Extracting images from tarball...')
    tar =, 'r')
    for item in tar:
        tar.extract(item, data_folder)
    print("Completed download and extraction!")
    print("Image directory already exists. Moving on...")
import requests
import tarfile

# Set paths and variables
# data_folder is where the data is downloaded and unpacked
data_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10')
# root_folder is where checkpoint files and .pb model definition files will be outputted
root_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_files', 'tutorial_cifar10')

url = ""   # url to data
filename = url.split("/")[-1]                       # download file name
download_path = os.path.join(data_folder, filename) # path to extract data to

# Create data_folder if not already there
if not os.path.isdir(data_folder):

# If data does not already exist, download and extract
if not os.path.exists(download_path.strip('.tgz')):
    # Download data
    r = requests.get(url, stream=True)
    print("Downloading... {} to {}".format(url, download_path))
    open(download_path, 'wb').write(r.content)
    print("Finished downloading...")

    # Unpack images from tgz file
    print('Extracting images from tarball...')
    tar =, 'r')
    for item in tar:
        tar.extract(item, data_folder)
    print("Completed download and extraction!")
    print("Image directory already exists. Moving on...")
Downloading... to /root/caffe2_notebooks/tutorial_data/cifar10/cifar.tgz
Finished downloading...
Extracting images from tarball...
Completed download and extraction!


import glob

# Grab 5 image paths from training set to display
sample_imgs = glob.glob(os.path.join(data_folder, "cifar", "train") + '/*.png')[:5]

# Plot images
f, ax = plt.subplots(1, 5, figsize=(10,10))
for i in range(5):


training(トレーニング), validation(バリデーション), testing(テスティング)用LMDBが必要。ラベルファイルは3つの工程で使われる画像データを仕分けるために必要。CIFAR-10は10ラベル(車や鳥や航空機等)にリスト化されている。画像はトレーニング用に5万、テスト用に1万のディクショナリーに分けられている。まず、LMDBを書く前にラベルを作成する。最初にパス環境設定とLMDBで使えるようにstring labelをinteger labelにマップ化するためのclassディクショナリを作成する。

# Paths to train and test directories
training_dir_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'train')
testing_dir_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'test')

# Paths to label files
training_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'training_dictionary.txt')
validation_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'validation_dictionary.txt')
testing_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'testing_dictionary.txt')

# Paths to LMDBs
training_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'training_lmdb')
validation_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'validation_lmdb')
testing_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'testing_lmdb')

# Path to labels.txt
labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'labels.txt')

# Open label file handler
labels_handler = open(labels_path, "r")

# Create classes dictionary to map string labels to integer labels
classes = {}
i = 0
lines = labels_handler.readlines()
for line in sorted(lines):
    line = line.rstrip()
    classes[line] = i
    i += 1

print("classes:", classes)
classes: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


  • training: 44,000 images (73%)
  • validation: 6,000 images (10%)
  • testing: 10,000 images (17%)
from random import shuffle

# Open file handlers
training_labels_handler = open(training_labels_path, "w")
validation_labels_handler = open(validation_labels_path, "w")
testing_labels_handler = open(testing_labels_path, "w")

# Create training, validation, and testing label files
i = 0
validation_count = 6000
imgs = glob.glob(training_dir_path + '/*.png')  # read all training images into array
shuffle(imgs)  # shuffle array
for img in imgs:
    # Write first 6,000 image paths, followed by their integer label, to the validation label files
    if i < validation_count:
        validation_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
    # Write the remaining to the training label files
        training_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
    i += 1
print("Finished writing training and validation label files")

# Write our testing label files using the testing images
for img in glob.glob(testing_dir_path + '/*.png'):
    testing_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
print("Finished writing testing label files")

# Close file handlers
Finished writing training and validation label files
Finished writing testing label files



def write_lmdb(labels_file_path, lmdb_path):
    labels_handler = open(labels_file_path, "r")
    # Write to lmdb
    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40
    env =, map_size=LMDB_MAP_SIZE)

    with env.begin(write=True) as txn:
        count = 0
        for line in labels_handler.readlines():
            line = line.rstrip()
            im_path = line.split()[0]
            im_label = int(line.split()[1])
            # read in image (as RGB)
            img_data = imread(im_path).astype(np.float32)
            # convert to BGR
            img_data = img_data[:, :, (2, 1, 0)]
            # HWC -> CHW (N gets added in AddInput function)
            img_data = np.transpose(img_data, (2,0,1))
            # Create TensorProtos
            tensor_protos = caffe2_pb2.TensorProtos()
            img_tensor = tensor_protos.protos.add()
            img_tensor.data_type = 1
            flatten_img = img_data.reshape(
            label_tensor = tensor_protos.protos.add()
            label_tensor.data_type = 2
            if ((count % 1000 == 0)):
                print("Inserted {} rows".format(count))
            count = count + 1

    print("Inserted {} rows".format(count))
    print("\nLMDB saved at " + lmdb_path + "\n\n")
# Call function to write our LMDBs
if not os.path.exists(training_lmdb_path):
    print("Writing training LMDB")
    write_lmdb(training_labels_path, training_lmdb_path)
    print(training_lmdb_path, "already exists!")
if not os.path.exists(validation_lmdb_path):
    print("Writing validation LMDB")
    write_lmdb(validation_labels_path, validation_lmdb_path)
    print(validation_lmdb_path, "already exists!")
if not os.path.exists(testing_lmdb_path):
    print("Writing testing LMDB")
    write_lmdb(testing_labels_path, testing_lmdb_path)
    print(testing_lmdb_path, "already exists!")
Writing training LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/training_lmdb

Writing validation LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/validation_lmdb

Writing testing LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/testing_lmdb

