2019年1月1日 星期二

深度學習 - 認識MNIST數據集 "Hello World"

深度學習 - 認識MNIST數據 "Hello World"

在深度學習一開始,一定會接觸到深度學習的“Hello World"
那就是MNIST 數據National Institute of Standards and Technology(美國國家標準語與技術研究院) 

1. 先來看看數據集的樣子

from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

先導入keras函式庫將mnist數據集的訓練資料(train_images)與測試資料(test_images)讀取出來。
深度學習的目的,就是要將機器透過訓練資料訓練過後,再由測試資料驗證訓練成果。

這邊先將數據印出來看看長什麼樣子
import matplotlib.pyplot as plt
for i in range(5) :
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.show()

註:
imshow 畫灰度圖
cmap=plt.cm.binary 將灰度圖可視化,只顯示黑白圖

透過 matplotlib 將數據顯示出來看看,這邊先印5張數據出來看看







再來看看他們對應的答案train_labels


for i in range(5):
    print ('%2s' % train_labels[i] ,end="")
    








再看一下數據庫形狀
>>> train_images.shape       #.shape可以看出資料的形狀,在訓練過程,資料要符合網絡  
(60000, 28, 28)                     的形狀
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
>>> test_images.shape 
(10000, 28, 28)
>>> len(test_labels) 
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8) 

到此,已經大概知道數據庫的長相。
其實就是60000張圖的訓練集和10000張圖的測試集,
其中每張圖是由28*28像素,其像素值介於0~255之間(uint8)
可以print出第一張圖來看看。
print(train_images[0])

大概就長得將這樣,上面透過matplotlib 畫出來就是一個手寫數字。

沒有留言:

張貼留言