[PyTorch+Numpy] torch.tensor と numpy.ndarray の交互の変換メモ

Python
記事内に広告が含まれています。
スポンサーリンク

torchで計算する際の torch.tensornumpy.ndarray 同士を変換する際の方法をメモします。

結論は以下の通りです。

  • tensor → np.ndarray であれば (tensor型配列).numpy()
  • np.ndarray → tensor であれば torch.from_numpy(numpy配列)

スポンサーリンク

tensor → np.ndarray の場合

import torch #1.11.0+cu113
import numpy as np #1.21.6

デモ

tensor = torch.tensor([0,1,2])
tensor_to_arr = tensor.numpy()
print(type(tensor_to_arr))
#<class 'numpy.ndarray'>

np.ndarray → tensor の場合

デモ

arr = np.array([1,2,3])
print(type(arr)) #<class 'numpy.ndarray'>

arr_to_tensor = torch.from_numpy(arr)
print(type(arr_to_tensor)) #<class 'torch.Tensor'>

ちなみに、torch.tensor() の引数である require_gradTrue にしておくと、numpy.ndarray に変換することはできない。

tensor2 = torch.tensor([1,2,3], dtype=torch.float64,  requires_grad=True)
print(type(tensor2)) #<class 'torch.Tensor'>

arr_to_tensor2 = tensor2.numpy()

#---------------------------------------------------------------------------
#RuntimeError                              Traceback (most recent call last)
#<ipython-input-5-2f66952bd392> in <module>()
#      2 print(type(tensor2))
#      3 
#----> 4 arr_to_tensor2 = tensor2.numpy()
#
#RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

この場合、numpy.ndarray に変換するには、detach() メソッドを利用する必要がある。

#detach()を使うとnumpy.ndarrayに変換できる

tensor2 = torch.tensor([1,2,3], dtype=torch.float64,  requires_grad=True)
arr = tensor2.detach().numpy()
print(arr.dtype) #float64

参考

torch.Tensor.numpy — PyTorch 2.0 documentation
torch.from_numpy — PyTorch 2.0 documentation

Python
スポンサーリンク
アウトプット雑記

コメント