caffe2 tutorial: CIFAR-10: Part 1

その買うを、もっとハッピーに。|ハピタス

CIFARとはCanadian Institute for Advanced Researchの略で、機械学習の世界では有名な画像データ・セットらしい。caffe2のチュートリアルは非常に少なく、基本的には本家サイトのチュートリアルしかないと言っても決して過言ではない。tensorflowとかだとネット上にtutorialがゴロゴロ転がっているのに、caffe2はほとんどない。アメリカの大学サイトを見てもkeras + tensorflowがdefacto standardになっている感じなので、時代の潮流はkeras+tensorflowなのかもしれない。

スポンサーリンク

必要なmoduleのimport

※よく調べたらこのサイトに必須インストールが書いてあった。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import os
import lmdb
import shutil
from imageio import imread
import caffe2.python.predictor.predictor_exporter as pe
from caffe2.proto import caffe2_pb2
from caffe2.python.predictor import mobile_exporter
from caffe2.python import (
    brew,
    core,
    model_helper,
    net_drawer,
    optimizer,
    visualize,
    workspace,
)

# If you would like to see some really detailed initializations,
# you can change --caffe2_log_level=0 to --caffe2_log_level=-1
core.GlobalInit(['caffe2', '--caffe2_log_level=0'])
print("Necessities imported!")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-1-16324720062f> in <module>()
      8 import numpy as np
      9 import os
---> 10 import lmdb
     11 import shutil
     12 from imageio import imread

ModuleNotFoundError: No module named 'lmdb'

condaの場合はpython-lmdbだとここに書いてある。

!conda install -y python-lmdb
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - python-lmdb


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    python-lmdb-0.94           |   py36h14c3975_0         125 KB

The following NEW packages will be INSTALLED:

    python-lmdb: 0.94-py36h14c3975_0


Downloading and Extracting Packages
python-lmdb-0.94     |  125 KB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import os
import lmdb
import shutil
from imageio import imread
import caffe2.python.predictor.predictor_exporter as pe
from caffe2.proto import caffe2_pb2
from caffe2.python.predictor import mobile_exporter
from caffe2.python import (
    brew,
    core,
    model_helper,
    net_drawer,
    optimizer,
    visualize,
    workspace,
)

# If you would like to see some really detailed initializations,
# you can change --caffe2_log_level=0 to --caffe2_log_level=-1
core.GlobalInit(['caffe2', '--caffe2_log_level=0'])
print("Necessities imported!")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-6-16324720062f> in <module>()
     10 import lmdb
     11 import shutil
---> 12 from imageio import imread
     13 import caffe2.python.predictor.predictor_exporter as pe
     14 from caffe2.proto import caffe2_pb2

ModuleNotFoundError: No module named 'imageio'
!conda install -y imageio
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - imageio


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    imageio-2.3.0              |           py36_0         3.3 MB

The following NEW packages will be INSTALLED:

    imageio: 2.3.0-py36_0


Downloading and Extracting Packages
imageio-2.3.0        |  3.3 MB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import os
import lmdb
import shutil
from imageio import imread
import caffe2.python.predictor.predictor_exporter as pe
from caffe2.proto import caffe2_pb2
from caffe2.python.predictor import mobile_exporter
from caffe2.python import (
    brew,
    core,
    model_helper,
    net_drawer,
    optimizer,
    visualize,
    workspace,
)

# If you would like to see some really detailed initializations,
# you can change --caffe2_log_level=0 to --caffe2_log_level=-1
core.GlobalInit(['caffe2', '--caffe2_log_level=0'])
print("Necessities imported!")
net_drawer will not run correctly. Please install the correct dependencies.
Necessities imported!

net_drawer will not run correctly.←このエラーを解消するにはpydotをインストールするといいらしい。

!conda install -y flask \
graphviz \
hypothesis \
matplotlib \
pydot \
pyyaml \
scikit-image \
setuptools \
tornado \
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - flask
    - graphviz
    - hypothesis
    - matplotlib
    - pydot
    - pyyaml
    - scikit-image
    - setuptools
    - tornado


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    locket-0.2.0               |   py36h787c0ad_1           8 KB
    distributed-1.21.8         |           py36_0         777 KB
    tblib-1.3.2                |   py36h34cf8b6_0          16 KB
    yaml-0.1.7                 |       had09818_2          85 KB
    scikit-image-0.13.1        |   py36h14c3975_1        23.0 MB
    libtool-2.4.6              |       h544aabb_3         517 KB
    partd-0.3.8                |   py36h36fd896_0          31 KB
    bokeh-0.12.16              |           py36_0         4.1 MB
    pyyaml-3.12                |   py36hafb9ca4_1         159 KB
    cloudpickle-0.5.3          |           py36_0          26 KB
    networkx-2.1               |           py36_0         1.8 MB
    packaging-17.1             |           py36_0          33 KB
    coverage-4.5.1             |   py36h14c3975_0         211 KB
    attrs-18.1.0               |           py36_0          43 KB
    psutil-5.4.5               |   py36h14c3975_0         303 KB
    zict-0.1.3                 |   py36h3a3bf81_0          18 KB
    hypothesis-3.57.0          |   py36h24bf2e0_0         260 KB
    flask-1.0.2                |           py36_1         119 KB
    pango-1.41.0               |       hd475d92_0         530 KB
    graphviz-2.40.1            |       h25d223c_0         6.9 MB
    dask-core-0.17.5           |           py36_0         1.0 MB
    heapdict-1.0.0             |           py36_2           7 KB
    click-6.7                  |   py36h5253387_0         104 KB
    pydot-1.2.4                |           py36_0          37 KB
    dask-0.17.5                |           py36_0           3 KB
    toolz-0.9.0                |           py36_0          91 KB
    cytoolz-0.9.0.1            |   py36h14c3975_0         419 KB
    msgpack-python-0.5.6       |   py36h6bb024c_0          96 KB
    itsdangerous-0.24          |   py36h93cc618_1          20 KB
    pywavelets-0.5.2           |   py36he602eb0_0         4.0 MB
    sortedcontainers-2.0.3     |           py36_0          42 KB
    ------------------------------------------------------------
                                           Total:        44.8 MB

The following NEW packages will be INSTALLED:

    attrs:            18.1.0-py36_0         
    bokeh:            0.12.16-py36_0        
    click:            6.7-py36h5253387_0    
    cloudpickle:      0.5.3-py36_0          
    coverage:         4.5.1-py36h14c3975_0  
    cytoolz:          0.9.0.1-py36h14c3975_0
    dask:             0.17.5-py36_0         
    dask-core:        0.17.5-py36_0         
    distributed:      1.21.8-py36_0         
    flask:            1.0.2-py36_1          
    graphviz:         2.40.1-h25d223c_0     
    heapdict:         1.0.0-py36_2          
    hypothesis:       3.57.0-py36h24bf2e0_0 
    itsdangerous:     0.24-py36h93cc618_1   
    libtool:          2.4.6-h544aabb_3      
    locket:           0.2.0-py36h787c0ad_1  
    msgpack-python:   0.5.6-py36h6bb024c_0  
    networkx:         2.1-py36_0            
    packaging:        17.1-py36_0           
    pango:            1.41.0-hd475d92_0     
    partd:            0.3.8-py36h36fd896_0  
    psutil:           5.4.5-py36h14c3975_0  
    pydot:            1.2.4-py36_0          
    pywavelets:       0.5.2-py36he602eb0_0  
    pyyaml:           3.12-py36hafb9ca4_1   
    scikit-image:     0.13.1-py36h14c3975_1 
    sortedcontainers: 2.0.3-py36_0          
    tblib:            1.3.2-py36h34cf8b6_0  
    toolz:            0.9.0-py36_0          
    yaml:             0.1.7-had09818_2      
    zict:             0.1.3-py36h3a3bf81_0  


Downloading and Extracting Packages
locket-0.2.0         |    8 KB | ####################################### | 100% 
distributed-1.21.8   |  777 KB | ####################################### | 100% 
tblib-1.3.2          |   16 KB | ####################################### | 100% 
yaml-0.1.7           |   85 KB | ####################################### | 100% 
scikit-image-0.13.1  | 23.0 MB | ####################################### | 100% 
libtool-2.4.6        |  517 KB | ####################################### | 100% 
partd-0.3.8          |   31 KB | ####################################### | 100% 
bokeh-0.12.16        |  4.1 MB | ####################################### | 100% 
pyyaml-3.12          |  159 KB | ####################################### | 100% 
cloudpickle-0.5.3    |   26 KB | ####################################### | 100% 
networkx-2.1         |  1.8 MB | ####################################### | 100% 
packaging-17.1       |   33 KB | ####################################### | 100% 
coverage-4.5.1       |  211 KB | ####################################### | 100% 
attrs-18.1.0         |   43 KB | ####################################### | 100% 
psutil-5.4.5         |  303 KB | ####################################### | 100% 
zict-0.1.3           |   18 KB | ####################################### | 100% 
hypothesis-3.57.0    |  260 KB | ####################################### | 100% 
flask-1.0.2          |  119 KB | ####################################### | 100% 
pango-1.41.0         |  530 KB | ####################################### | 100% 
graphviz-2.40.1      |  6.9 MB | ####################################### | 100% 
dask-core-0.17.5     |  1.0 MB | ####################################### | 100% 
heapdict-1.0.0       |    7 KB | ####################################### | 100% 
click-6.7            |  104 KB | ####################################### | 100% 
pydot-1.2.4          |   37 KB | ####################################### | 100% 
dask-0.17.5          |    3 KB | ####################################### | 100% 
toolz-0.9.0          |   91 KB | ####################################### | 100% 
cytoolz-0.9.0.1      |  419 KB | ####################################### | 100% 
msgpack-python-0.5.6 |   96 KB | ####################################### | 100% 
itsdangerous-0.24    |   20 KB | ####################################### | 100% 
pywavelets-0.5.2     |  4.0 MB | ####################################### | 100% 
sortedcontainers-2.0 |   42 KB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
!conda install -y -c conda-forge python-nvd3
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - python-nvd3


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    unidecode-1.0.22           |           py36_0         228 KB  conda-forge
    certifi-2018.4.16          |           py36_0         142 KB  conda-forge
    python-slugify-1.2.5       |           py36_0           9 KB  conda-forge
    python-nvd3-0.15.0         |           py36_0          33 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         411 KB

The following NEW packages will be INSTALLED:

    python-nvd3:     0.15.0-py36_0     conda-forge
    python-slugify:  1.2.5-py36_0      conda-forge
    unidecode:       1.0.22-py36_0     conda-forge

The following packages will be UPDATED:

    ca-certificates: 2018.03.07-0                  --> 2018.4.16-0      conda-forge
    certifi:         2018.4.16-py36_0              --> 2018.4.16-py36_0 conda-forge
    openssl:         1.0.2o-h20670df_0             --> 1.0.2o-0         conda-forge


Downloading and Extracting Packages
unidecode-1.0.22     |  228 KB | ####################################### | 100% 
certifi-2018.4.16    |  142 KB | ####################################### | 100% 
python-slugify-1.2.5 |    9 KB | ####################################### | 100% 
python-nvd3-0.15.0   |   33 KB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
スポンサーリンク

CIFAR-10 .pngデータのダウンロード

ダウンロードするデータは1G近いのでデータ展開まで時間がかかる。

import requests
import tarfile

# Set paths and variables
# data_folder is where the data is downloaded and unpacked
data_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10')
# root_folder is where checkpoint files and .pb model definition files will be outputted
root_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_files', 'tutorial_cifar10')

url = "http://pjreddie.com/media/files/cifar.tgz"   # url to data
filename = url.split("/")[-1]                       # download file name
download_path = os.path.join(data_folder, filename) # path to extract data to

# Create data_folder if not already there
if not os.path.isdir(data_folder):
    os.makedirs(data_folder)

# If data does not already exist, download and extract
if not os.path.exists(download_path.strip('.tgz')):
    # Download data
    r = requests.get(url, stream=True)
    print("Downloading... {} to {}".format(url, download_path))
    open(download_path, 'wb').write(r.content)
    print("Finished downloading...")

    # Unpack images from tgz file
    print('Extracting images from tarball...')
    tar = tarfile.open(download_path, 'r')
    for item in tar:
        tar.extract(item, data_folder)
    print("Completed download and extraction!")
    
else:
    print("Image directory already exists. Moving on...")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-2-11b5893ecfde> in <module>()
----> 1 import requests
      2 import tarfile
      3 
      4 # Set paths and variables
      5 # data_folder is where the data is downloaded and unpacked

ModuleNotFoundError: No module named 'requests'
!conda install -y requests
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - requests


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cryptography-2.2.2         |   py36h14c3975_0         599 KB
    pysocks-1.6.8              |           py36_0          22 KB
    blas-1.0                   |              mkl           6 KB
    pyopenssl-18.0.0           |           py36_0          82 KB
    asn1crypto-0.24.0          |           py36_0         155 KB
    numpy-base-1.14.3          |   py36hdbf6ddf_2         4.1 MB
    ------------------------------------------------------------
                                           Total:         4.9 MB

The following NEW packages will be INSTALLED:

    asn1crypto:   0.24.0-py36_0        
    chardet:      3.0.4-py36h0f667ec_1 
    cryptography: 2.2.2-py36h14c3975_0 
    idna:         2.6-py36h82fb2a8_1   
    pyopenssl:    18.0.0-py36_0        
    pysocks:      1.6.8-py36_0         
    requests:     2.18.4-py36he2e5f8d_1
    urllib3:      1.22-py36hbe7ace6_0  

The following packages will be UPDATED:

    numpy-base:   1.14.3-py36h2b20989_2 --> 1.14.3-py36hdbf6ddf_2

The following packages will be DOWNGRADED:

    blas:         1.0-openblas          --> 1.0-mkl              


Downloading and Extracting Packages
cryptography-2.2.2   |  599 KB | ####################################### | 100% 
pysocks-1.6.8        |   22 KB | ####################################### | 100% 
blas-1.0             |    6 KB | ####################################### | 100% 
pyopenssl-18.0.0     |   82 KB | ####################################### | 100% 
asn1crypto-0.24.0    |  155 KB | ####################################### | 100% 
numpy-base-1.14.3    |  4.1 MB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
import requests
import tarfile

# Set paths and variables
# data_folder is where the data is downloaded and unpacked
data_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10')
# root_folder is where checkpoint files and .pb model definition files will be outputted
root_folder = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_files', 'tutorial_cifar10')

url = "http://pjreddie.com/media/files/cifar.tgz"   # url to data
filename = url.split("/")[-1]                       # download file name
download_path = os.path.join(data_folder, filename) # path to extract data to

# Create data_folder if not already there
if not os.path.isdir(data_folder):
    os.makedirs(data_folder)

# If data does not already exist, download and extract
if not os.path.exists(download_path.strip('.tgz')):
    # Download data
    r = requests.get(url, stream=True)
    print("Downloading... {} to {}".format(url, download_path))
    open(download_path, 'wb').write(r.content)
    print("Finished downloading...")

    # Unpack images from tgz file
    print('Extracting images from tarball...')
    tar = tarfile.open(download_path, 'r')
    for item in tar:
        tar.extract(item, data_folder)
    print("Completed download and extraction!")
    
else:
    print("Image directory already exists. Moving on...")
Downloading... http://pjreddie.com/media/files/cifar.tgz to /root/caffe2_notebooks/tutorial_data/cifar10/cifar.tgz
Finished downloading...
Extracting images from tarball...
Completed download and extraction!

今後の参考のためにダウンロードしたデータをチラ見してみる。

import glob

# Grab 5 image paths from training set to display
sample_imgs = glob.glob(os.path.join(data_folder, "cifar", "train") + '/*.png')[:5]

# Plot images
f, ax = plt.subplots(1, 5, figsize=(10,10))
plt.tight_layout()
for i in range(5):
    ax[i].set_title(sample_imgs[i].split("_")[-1].split(".")[0])
    ax[i].axis('off')
    ax[i].imshow(imread(sample_imgs[i]).astype(np.uint8)) 
スポンサーリンク

ラベルファイル作成してLMDBsを書く

training(トレーニング), validation(バリデーション), testing(テスティング)用LMDBが必要。ラベルファイルは3つの工程で使われる画像データを仕分けるために必要。CIFAR-10は10ラベル(車や鳥や航空機等)にリスト化されている。画像はトレーニング用に5万、テスト用に1万のディクショナリーに分けられている。まず、LMDBを書く前にラベルを作成する。最初にパス環境設定とLMDBで使えるようにstring labelをinteger labelにマップ化するためのclassディクショナリを作成する。

# Paths to train and test directories
training_dir_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'train')
testing_dir_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'test')

# Paths to label files
training_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'training_dictionary.txt')
validation_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'validation_dictionary.txt')
testing_labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'testing_dictionary.txt')

# Paths to LMDBs
training_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'training_lmdb')
validation_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'validation_lmdb')
testing_lmdb_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'testing_lmdb')

# Path to labels.txt
labels_path = os.path.join(os.path.expanduser('~'), 'caffe2_notebooks', 'tutorial_data', 'cifar10', 'cifar', 'labels.txt')

# Open label file handler
labels_handler = open(labels_path, "r")

# Create classes dictionary to map string labels to integer labels
classes = {}
i = 0
lines = labels_handler.readlines()
for line in sorted(lines):
    line = line.rstrip()
    classes[line] = i
    i += 1
labels_handler.close()

print("classes:", classes)
classes: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

string→integerマップ用クラスディクショナリが作成できたので、トレーニング、バリデーション、テスト用ラベルファイルを作成できる。3つの工程で使用する画像データは以下のように分けられる。

  • training: 44,000 images (73%)
  • validation: 6,000 images (10%)
  • testing: 10,000 images (17%)
    トレーニングデータから6000をバリデーションに振り分けて、テスト前にトレーニング精度の検証をする。トレーニング用データとバリエーションに振り分けるデータを均等化するためにデータをシャッフルする。
from random import shuffle

# Open file handlers
training_labels_handler = open(training_labels_path, "w")
validation_labels_handler = open(validation_labels_path, "w")
testing_labels_handler = open(testing_labels_path, "w")

# Create training, validation, and testing label files
i = 0
validation_count = 6000
imgs = glob.glob(training_dir_path + '/*.png')  # read all training images into array
shuffle(imgs)  # shuffle array
for img in imgs:
    # Write first 6,000 image paths, followed by their integer label, to the validation label files
    if i < validation_count:
        validation_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
    # Write the remaining to the training label files
    else:
        training_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
    i += 1
print("Finished writing training and validation label files")

# Write our testing label files using the testing images
for img in glob.glob(testing_dir_path + '/*.png'):
    testing_labels_handler.write(img + ' ' + str(classes[img.split('_')[-1].split('.')[0]]) + '\n')
print("Finished writing testing label files")

# Close file handlers
training_labels_handler.close()
validation_labels_handler.close()
testing_labels_handler.close()
Finished writing training and validation label files
Finished writing testing label files
スポンサーリンク

LMDBに画像データをフィードする

次に、作成されたラベルファイルを使ってlmdbを作成する。画像データをLMDBにフィードする前に、ここでやったように、画像データをRGBからBGRへ、HWCからCHWに変換する必要がある。caffe2の場合、さらに、バッチで使われる画像数を表すNも必要になってくるので、NC(BGR)HW形式へ変換しなければならない。

def write_lmdb(labels_file_path, lmdb_path):
    labels_handler = open(labels_file_path, "r")
    # Write to lmdb
    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40
    print("LMDB_MAP_SIZE", LMDB_MAP_SIZE)
    env = lmdb.open(lmdb_path, map_size=LMDB_MAP_SIZE)

    with env.begin(write=True) as txn:
        count = 0
        for line in labels_handler.readlines():
            line = line.rstrip()
            im_path = line.split()[0]
            im_label = int(line.split()[1])
            
            # read in image (as RGB)
            img_data = imread(im_path).astype(np.float32)
            
            # convert to BGR
            img_data = img_data[:, :, (2, 1, 0)]
            
            # HWC -> CHW (N gets added in AddInput function)
            img_data = np.transpose(img_data, (2,0,1))
            
            # Create TensorProtos
            tensor_protos = caffe2_pb2.TensorProtos()
            img_tensor = tensor_protos.protos.add()
            img_tensor.dims.extend(img_data.shape)
            img_tensor.data_type = 1
            flatten_img = img_data.reshape(np.prod(img_data.shape))
            img_tensor.float_data.extend(flatten_img)
            label_tensor = tensor_protos.protos.add()
            label_tensor.data_type = 2
            label_tensor.int32_data.append(im_label)
            txn.put(
                '{}'.format(count).encode('ascii'),
                tensor_protos.SerializeToString()
            )
            if ((count % 1000 == 0)):
                print("Inserted {} rows".format(count))
            count = count + 1

    print("Inserted {} rows".format(count))
    print("\nLMDB saved at " + lmdb_path + "\n\n")
    labels_handler.close()
    
# Call function to write our LMDBs
if not os.path.exists(training_lmdb_path):
    print("Writing training LMDB")
    write_lmdb(training_labels_path, training_lmdb_path)
else:
    print(training_lmdb_path, "already exists!")
if not os.path.exists(validation_lmdb_path):
    print("Writing validation LMDB")
    write_lmdb(validation_labels_path, validation_lmdb_path)
else:
    print(validation_lmdb_path, "already exists!")
if not os.path.exists(testing_lmdb_path):
    print("Writing testing LMDB")
    write_lmdb(testing_labels_path, testing_lmdb_path)
else:
    print(testing_lmdb_path, "already exists!")
Writing training LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
Inserted 1000 rows
Inserted 2000 rows
Inserted 3000 rows
Inserted 4000 rows
Inserted 5000 rows
Inserted 6000 rows
Inserted 7000 rows
Inserted 8000 rows
Inserted 9000 rows
Inserted 10000 rows
Inserted 11000 rows
Inserted 12000 rows
Inserted 13000 rows
Inserted 14000 rows
Inserted 15000 rows
Inserted 16000 rows
Inserted 17000 rows
Inserted 18000 rows
Inserted 19000 rows
Inserted 20000 rows
Inserted 21000 rows
Inserted 22000 rows
Inserted 23000 rows
Inserted 24000 rows
Inserted 25000 rows
Inserted 26000 rows
Inserted 27000 rows
Inserted 28000 rows
Inserted 29000 rows
Inserted 30000 rows
Inserted 31000 rows
Inserted 32000 rows
Inserted 33000 rows
Inserted 34000 rows
Inserted 35000 rows
Inserted 36000 rows
Inserted 37000 rows
Inserted 38000 rows
Inserted 39000 rows
Inserted 40000 rows
Inserted 41000 rows
Inserted 42000 rows
Inserted 43000 rows
Inserted 44000 rows

LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/training_lmdb


Writing validation LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
Inserted 1000 rows
Inserted 2000 rows
Inserted 3000 rows
Inserted 4000 rows
Inserted 5000 rows
Inserted 6000 rows

LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/validation_lmdb


Writing testing LMDB
>>> Write database...
LMDB_MAP_SIZE 1099511627776
Inserted 0 rows
Inserted 1000 rows
Inserted 2000 rows
Inserted 3000 rows
Inserted 4000 rows
Inserted 5000 rows
Inserted 6000 rows
Inserted 7000 rows
Inserted 8000 rows
Inserted 9000 rows
Inserted 10000 rows

LMDB saved at /root/caffe2_notebooks/tutorial_data/cifar10/testing_lmdb


参考サイトhttps://github.com/

スポンサーリンク
スポンサーリンク