手書き数字のデータを扱う!Pythonでmnistを使う方法【初心者向け】

初心者向けにPythonでmnistを使う方法について解説しています。これは機械学習の入門として使われるデータセットのひとつで、手書き数字の画像データを集めたものです。導入の方法と基本の使い方についてサンプルプログラムを見ながら学びましょう。

TechAcademyマガジンはオンラインのプログラミングスクールTechAcademy [テックアカデミー]が運営する教育×テクノロジーのWebメディアです。初心者でもすぐ勉強できる記事が2,000以上あります。

Pythonでmnistを使う方法について解説します。

Pythonについてそもそもよく分からないという方は、Pythonとは何なのか解説した記事をまずご覧ください。

 

なお本記事は、TechAcademyのPythonオンライン講座の内容をもとにしています。

 

田島悠介

今回は、Pythonに関する内容だね!

大石ゆかり

どういう内容でしょうか?

田島悠介

mnistの使い方について詳しく説明していくね!

大石ゆかり

お願いします!

 

mnistとは

mnistとは、手書き数字の画像のデータのセットです。機械学習やディープラーニングを学ぶ際のデータセットとして良く用いられます。画像は全部で7万枚あり、トレーニング用データ6万枚とテスト用データ1万枚で構成されています。

データは、画像データとラベルで構成されています。ラベルとは画像データが表す数字です。

1つ1つの画像はグレースケールで、大きさが縦28ピクセル・横28ピクセルです。各ピクセルには0〜255の値が格納されています。

ちなみにmnistとは Mixed National Institute of Standards and Technology database の略です。

 

mnistの使い方

mnistを使うには、以下の方法があります。

THE MNIST DATABASE of handwritten digits からダウンロードする

こちらが本家です。Yann LeCun さんのサイトからダウンロードできます。

http://yann.lecun.com/exdb/mnist/

 

scikit-learn を使い mldata.org からダウンロードする

mldata.orgは機械学習用データを集めたサイトです。以下のように記述することで、 mnist をダウンロードできます。初回ダウンロードには時間がかかりますが、次回以降はダウンロード済のデータを読み込んで利用できます。

ただし、 mldata.org は、しばしばサーバがダウンしており、ダウンロードできない場合があります。なお、scikit-learnには、 load_digits というメソッドで手書き数字のデータセットを取得できます。これは mnist を加工して作成した、 縦8ピクセル・横8ピクセル、1800枚の小さなデータセットです。 mnist とは大きさも枚数も異なりますので注意してください。

from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original', data_home=".")

各種機械学習のライブラリを使う

最もおすすめの方法です。 TensorFlow や Keras などの機械学習のライブラリには、あらかじめ mnist をダウンロードするメソッドが用意されています。

 

[PR] Pythonで挫折しない学習方法を動画で公開中

実際に書いてみよう

今回のサンプルプログラムでは、機械学習ライブラリの Keras を使い、 mnist のダウンロードと表示を行います。なお事前に必要なライブラリのインストールが必要です。

pip install keras
pip install matplotlib

サンプルプログラムは以下となります。

# 必要なライブラリのインポート
from keras import backend as K
from keras.datasets import mnist
import matplotlib.pyplot as plt

# mnist データをダウンロード
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 画像データとラベルの要素数を表示
print("画像データの要素数", train_images.shape)
print("ラベルデータの要素数", train_labels.shape)

# ラベルと画像データを表示
for i in range(0,10):
 print("ラベル", train_labels[i])
 plt.imshow(train_images[i].reshape(28, 28), cmap='Greys')
 plt.show()

実行結果は以下のようになります。

 

この記事を監修してくれた方

太田和樹(おおたかずき)
ITベンチャー企業のPM兼エンジニア

普段は主に、Web系アプリケーション開発のプロジェクトマネージャーとプログラミング講師を行っている。守備範囲はフロントエンド、モバイル、サーバサイド、データサイエンティストと幅広い。その幅広い知見を生かして、複数の領域を組み合わせた新しい提案をするのが得意。

開発実績:画像認識技術を活用した駐車場混雑状況把握(実証実験)、音声認識を活用したヘルプデスク支援システム、Pepperを遠隔操作するアプリの開発、大規模基幹系システムの開発・導入マネジメント

地方在住。仕事のほとんどをリモートオフィスで行う。通勤で消耗する代わりに趣味のDIYや家庭菜園、家族との時間を楽しんでいる。

 

大石ゆかり

内容分かりやすくて良かったです!

田島悠介

ゆかりちゃんも分からないことがあったら質問してね!

大石ゆかり

分かりました。ありがとうございます!

オンラインのプログラミングスクールTechAcademyではPythonを使って機械学習の基礎を学ぶPythonオンライン講座を開催しています。

初心者向けの書籍を使って人工知能(AI)や機械学習について学ぶことができます。

現役エンジニアがパーソナルメンターとして受講生に1人ずつつき、マンツーマンのメンタリングで学習をサポートし、最短4週間で習得することが可能です。

また、現役エンジニアから学べる無料のプログラミング体験会も実施しているので、ぜひ参加してみてください。