PyTorchメモ: データの長さを揃える処理 pad_sequenceおよびDataLoaderとの組み合わせ方

記事内に広告が含まれています。

深層学習モデルを学習させる際、学習データのshape は全て揃っている必要があります。

しかし、RNNやLSTMなどの自然言語処理モデルの場合、学習データの長さが全て揃っていることは稀で、パディングなどの処理を施して人工的に揃える必要があります。

そのパディング処理を各バッチに対して実行してくれるのがpad_sequence です。本記事では、この関数をDataloaderとともに使う方法をまとめます。

用いるライブラリ

import torch 
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

pad_sequenceの使い方

以下のように長さが各々異なるデータに対し、一番長いデータ長に揃えてミニバッチにまとめてくれます。引数 batch_first=True を指定することで、 [バッチサイズ, データ長] の形式に整形して返していくれます。

#長さが異なるデータを用意(xは3, yは7,zは4の長さ)
x = torch.LongTensor([100,201,303])
y = torch.LongTensor([1,2,3,4,5,6,7])
z = torch.LongTensor([-1,-2,-3,-4])
print(pad_sequence([x,y,z], batch_first=True))

#tensor([[100, 201, 303,   0,   0,   0,   0],
#        [  1,   2,   3,   4,   5,   6,   7],
#        [ -1,  -2,  -3,  -4,   0,   0,   0]])

Dataloaderと組み合わせて使う方法

pad_sequence 関数をDataloaderと組み合わせることでミニバッチ内のデータの長さを統一することができます。

## DataLoaderと一緒に使う方法
class DataSet(Dataset):
    #__init__, __len__, __getitem__の3つの関数を
    #定義することが必要
    def __init__(self, X, Y):
        self.training_data = X
        self.label_data = Y
    
    def __len__(self):
        return len(self.training_data)
    
    def __getitem__(self,idx):
        return self.training_data[idx], self.label_data[idx]

#ミニバッチを取り出して長さを揃える関数
def My_collate_func(batch):
    xs, ys = [], []
    for x,y in batch:
        xs. append(torch.LongTensor(x))
        ys.append(torch.LongTensor(y))
    
    #データ長を揃える処理
    xs = pad_sequence(xs, batch_first=True)
    ys = pad_sequence(ys, batch_first=True, padding_value=-1.0)
    return xs, ys

以下のように、DataLoder関数の引数にMy_collate_funcを入力することでデータ長を揃えることができます。

#X: training dataのミニバッチ, Y:label dataのミニバッチ
dataset = DataSet(X,Y)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=My_collate_func)

Dataloaderと組み合わせて使う方法のテスト

簡単なテストのため、6つのtraining dataとlabel dataを用意します。X[i] のlabel dataがY[i]という形式です。(データは適当に作ったものです)

#X[i]のlabel dataがY[i]
X = [array([527, 493, 584]), 
   array([433, 607, 587, 725,  47]), 
   array([772, 938, 875,  51, 359]), 
   array([302, 454, 351]), 
   array([564, 238, 972]), 
   array([648, 434, 493, 506, 271])
   ]
Y = [array([6, 2, 8]), 
     array([4, 4, 5, 7, 3]), 
     array([6, 1, 3, 5, 8]), 
     array([3, 9, 2]), 
     array([2, 4, 1]), 
     array([9, 8, 7, 1, 6])
]

このX,Yを用いてDataloaderを作成します。バッチサイズはbatch_size=2とします。

dataset = DataSet(X,Y)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_func)

各バッチの中身を確認します

for i, (xs, ys) in enumerate(dataloader):
    print(f"{i}th batach")
    print(xs) #training dataのミニバッチ
    print(ys) #label dataのミニバッチ
    print("=====")

出力結果.

0th batach
tensor([[527, 493, 584,   0,   0],
        [433, 607, 587, 725,  47]])
tensor([[ 6,  2,  8, -1, -1],
        [ 4,  4,  5,  7,  3]])
=====
1th batach
tensor([[772, 938, 875,  51, 359],
        [302, 454, 351,   0,   0]])
tensor([[ 6,  1,  3,  5,  8],
        [ 3,  9,  2, -1, -1]])
=====
2th batach
tensor([[564, 238, 972,   0,   0],
        [648, 434, 493, 506, 271]])
tensor([[ 2,  4,  1, -1, -1],
        [ 9,  8,  7,  1,  6]])
=====

3つのミニバッチに分けられています。各ミニバッチ内のデータをみると、一番長いデータ長に揃えるよう0が埋められています。また、label dataのミニバッチの場合は-1で埋められていることが確認できます。

まとめと参考

pad_sequenceの使い方とDataloaderと組み合わせる方法をメモしました。

以下のサイト・文献を参考にしました。

torch.nn.utils.rnn.pad_sequence — PyTorch 2.3 documentation

新納 浩幸, 『PyTorch自然言語処理プログラミング word2vec/LSTM/seq2seq/BERTで日本語テキスト解析!』, インプレスブックス, 2021年3月

Python
スポンサーリンク

コメント