DataLoaderの中身の確認: for文を使う方法とiterとnextを 使う方法

記事内に広告が含まれています。

DataloaderはPyTorchでデータをバッチ処理してくれる便利なクラスです。

このクラスから各バッチを取り出すためにfor文がよく使われますが、他の方法としてiterとnextを使う方法もあります。

#結論
Iter = iter(dataloader)
xdata, ydata = next(Iter) # dataloaderから生成したバッチ単位の教師データとラベルデータ

iterとnextを用いる方法のデモ

例として、scikit-learnのdigitデータ(手書きの数字画像データ)を例にdataloaderのバッチを確認してみる。

from sklearn import datasets
digits = datasets.load_digits() #データの取得
X  = digits.data
Y = digits.target

Dataloaderの用意

import torch
from torch.utils.data import Dataset
#DataSetの定義
class DigitDataSet(Dataset):
    #__init__, __len__, __getitem__の3つの関数を
    #定義することが必要
    def __init__(self, X, Y):
        self.training_data = X
        self.label_data = Y
    
    def __len__(self):
        return len(self.training_data)
    
    def __getitem__(self,idx):
        return self.training_data[idx], self.label_data[idx]

dataset = DigitDataSet(X, Y)
#バッチサイズ3で実験
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3) 

各バッチの確認.

Iter = iter(dataloader)
xdata, ydata = next(Iter) #教師データ、ラベルデータ
print(xdata.shape, type(xdata)) 
#torch.Size([3, 64]) <class 'torch.Tensor'>

print( ydata.shape, ydata)
#torch.Size([3]) tensor([0, 1, 2])

[バッチサイズ, ベクトルの次元] の形状になっており、きちんとバッチごとに分けられていることがわかります。

for文を使う方法

併せて、for文の方法も述べておきます。

#先頭から3つのバッチを確認
cnt = 0
for data in dataloader:
    print(data[0].shape, type(data[0]))
    print(data[1].shape, type(data[1]))
    print("=========")
    cnt += 1
    if cnt==3:
        break

出力結果

batch: 0
torch.Size([3, 64]) <class 'torch.Tensor'>
torch.Size([3]) <class 'torch.Tensor'>
=========
batch: 1
torch.Size([3, 64]) <class 'torch.Tensor'>
torch.Size([3]) <class 'torch.Tensor'>
=========
batch: 2
torch.Size([3, 64]) <class 'torch.Tensor'>
torch.Size([3]) <class 'torch.Tensor'>
=========

まとめ・参考

PyTorchのDataLoaderで各バッチの形状を確認する方法についてまとめました。各バッチ内のデータの形状などを確認したいときやデバッグなどに iter next を使う方法は便利なのではないかと思います。

本記事を書く際、以下のサイトを参考にしました。

PytorchのDataloaderとSamplerの使い方 - Qiita
Dataloaderとは datasetsからバッチごとに取り出すことを目的に使われます。 基本的にtorch.utils.data.DataLoaderを使います。 イメージとしてはdatasetsはデータすべてのリスト、Da...

また、以下の書籍の第三章にもiterとnextを使う方法が紹介されており、参考にしました。

新納 浩幸, 『PyTorch自然言語処理プログラミング word2vec/LSTM/seq2seq/BERTで日本語テキスト解析!』, インプレスブックス, 2021年3月

ちなみにこの本は自然言語処理の実装とPyTorchの使い方を同時に学習でき、私自身とても勉強になっています。

Python
スポンサーリンク

コメント