【極めればこのテンソル操作 】tensor.unsqueeze(0)と array[None] の違い
今日は、 unsqueeze(0) の解説しつつ、私たちがよく直面する「あるある」な問題についてもちょこっと話してみたいと思います。
「value.unsqueeze(0)」と「value[None]」 の見分けついていますか?
はい、前者は主に PyTorch、後者は NumPyでの操作の違いです。
でもどちらも、ぱっとみは、先頭に新しく次元を追加する操作なので、コードをちらっとみただけではわからないことがありますよね。
なぜかというと、ディープラーニング系のプログラミングでは PyTorchのテンソルと、NumPyの配列操作がかなり入り混じるからです。
そう、今日の話題はPyTorchとNumPyのコードが入り乱れて、どっちの配列(テンソル)を扱っているのわけワカメになる問題です。
ちなみに、話題のテーマをブラさないように PyTorchでは 先頭に新しい次元を追加するときに unsqueeze(0) だけでなく [None] も使えてしまいますが、いったん[None]は NumPy で主に使用する操作という前提で説明させてくださいませ。^^;
これに対する当社なりの処方箋は、また別投稿をしたいとおもいますが、両者が無邪気に入り混じらないように、PyTorchとNumPyのコードをなるべく分離するようにしています。例えば「同一関数、メソッド内はPyTorchかNumPyに寄せる」、や、「GPU投入寸前までPyTorchテンソル化をガマンしてNumPyでがんばる」など、(涙ぐましい?)現場の工夫をしています^^
NumPy系の変数名には「なんちゃら_numpy」「なんちゃら_tensor」のようにするなど、あまりにも紛らわしいときには、行っていますが、型宣言のゆるいPythonコーディングの慣例上、同一変数名なのにNumPyからPyTorchにいつのまにか変わっていた、なんていう外部コードも大量にあり、なかなか難しいですね。
PyTorchとNumPyが入り乱れる世界 ~機械学習プロジェクトを進めていると、こんな経験ありませんか?
- データの前処理はNumPyで行っていたのに、モデルに入力するときにはPyTorchのテンソルに変換しなければならない。
- モデルから出力されたPyTorchのテンソルを、可視化のためにNumPy配列に戻す。
- そして気づいたら、コード内でNumPyとPyTorchの関数が混在している...
これって、まるでプログラミング言語のバベルの塔ですよね。
今回は、PyTorchの.unsqueeze(0)
メソッドとNumPyの[None]
インデックスの違いについて詳しく見ていきましょう。一見似ているこれらの操作ですが、実は重要な違いがあります。
1. 基本的な違い
まず、最も基本的な違いは、冒頭でふれたとおり、
.unsqueeze(0)
: PyTorchのテンソルに使用されるメソッドです。[None]
: NumPy配列やPythonのリストに使用されるインデックス操作です。
(コラムに書きましたが、実はPyTorchでも使えちゃいますが、頭に次元追加する操作は PyTorchでは unsqueeze(0)、おしりに次元追加する操作はunsqueeze(-1)でやるのが可読性や操作意図のわかりやすからオススメです)
2. 動作の詳細
.unsqueeze(0)
PyTorchの.unsqueeze(0)
メソッドは、テンソルの0次元目(先頭)に新しい次元を追加します。これは、バッチ処理のためにデータを準備する際によく使用されます。1件だけのデータを学習モデルに突っ込みたいときも、「バッチ次元」を求められることが常なので unsqueeze(0) は頻発するコードだとおもいます。
import torch
x = torch.tensor([1, 2, 3])
print(x.shape) # torch.Size([3])
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape) # torch.Size([1, 3])
[None]
NumPyの[None]
インデックスは、配列に新しい軸を追加します。これも実質的に次元を1つ増やすことになります。
例:
import numpy as np
y = np.array([1, 2, 3])
print(y.shape) # (3,)
y_expanded = y[None]
print(y_expanded.shape) # (1, 3)
3. 柔軟性の違い
.unsqueeze(n)
メソッドは、引数n
を変えることで任意の位置に次元を追加できる柔軟性があります。
例:
import torch
z = torch.tensor([[1, 2], [3, 4]])
print(z.shape) # torch.Size([2, 2])
z_unsqueezed_0 = z.unsqueeze(0)
print(z_unsqueezed_0.shape) # torch.Size([1, 2, 2])
z_unsqueezed_1 = z.unsqueeze(1)
print(z_unsqueezed_1.shape) # torch.Size([2, 1, 2])
一方、[None]
は常に新しい軸を先頭(axis 0)に追加します。ただし、NumPyにはnp.expand_dims()
関数があり、これを使用すると任意の位置に次元を追加できます。
import numpy as np
w = np.array([[1, 2], [3, 4]])
print(w.shape) # (2, 2)
w_expanded_0 = np.expand_dims(w, axis=0)
print(w_expanded_0.shape) # (1, 2, 2)
w_expanded_1 = np.expand_dims(w, axis=1)
print(w_expanded_1.shape) # (2, 1, 2)
4. パフォーマンスの考慮
一般的に、.unsqueeze()
と[None]
(またはnp.expand_dims()
)の間にパフォーマンスの大きな差はありません。しかし、大規模なデータセットや複雑なモデルを扱う場合、わずかな違いが積み重なって影響を与える可能性があります。
PyTorchを使用している場合は.unsqueeze()
を、NumPyを使用している場合は[None]
やnp.expand_dims()
を使用するのが自然で効率的です。
まとめ ~.unsqueeze(0)
と[None]
の実践的理解~
今回は、.unsqueeze(0)
と[None]
の用法について詳しく解説しました。
問題の本質は、PyTorchとNumPyの混在にありますが、コードを書く上では、どちらの「世界」にいるのかを常に意識することが大切ですね。
コードを読む際には、.unsqueeze(0)
が登場したら「ここからPyTorchでの次元追加だな」と考え、[None]
を見たら「まだNumPyの領域にいるな」と理解するとよいでしょう。
使用シーンの違いも重要なポイントです。.unsqueeze(0)
は多くの場合、1件データのモデル投入の直前に「緊急的な」次元追加として用いられます。そのため、モデル投入直前でよく目にすることになります。一方、[None]
による次元追加は、通常モデル投入よりもずっと前の段階、つまりまだNumPy操作のフェーズで行われることが多いです。その後、モデル投入直前でPyTorchテンソルへの変換とGPUへの送り込みが行われるというパターンもよく見かけます。
これらの操作を見かけたら、まずは「バッチ次元追加かな?」と推測してみるのが良いでしょう。バッチ処理のニーズで使われることが多いためです。ただし、必ずしもバッチ次元の追加だけでなく、例えば画像処理ではチャンネル次元の追加に使われることもあるので、コンテキストをよく確認することが大切です。
結論として、.unsqueeze(0)
と[None]
の違いを理解し、適切に使い分けることで、より明確で効率的なコードを書くことができます。また、これらの操作を見かけたときは「バッチ次元の追加かもしれない」と考えつつ、常にコンテキストを確認する習慣をつけることで、コードの意図をより深く理解できるようになるでしょう。