首頁 > 軟體

PyTorch常用函數torch.cat()中dim引數使用說明

2023-09-12 18:01:17

Part 1: 簡介

在PyTorch中,torch.cat()是一個被廣泛使用的函數。它可以讓我們在某個維度上把多個張量組合在一起。對於那些想要深入瞭解使用PyTorch進行資料分析和建模的開發者來說,理解torch.cat()函數的dim引數是非常重要的。

在PyTorch中,幾乎所有與神經網路有關的操作都涉及到張量(Tensor)操作。因此,在PyTorch中,將多個相同形狀的張量沿某個軸/維度連線起來的過程非常重要。這就是 torch.cat() 函數的作用。torch.cat() 的最基本用法如下:

torch.cat(tensors, dim=0, out=None) -> Tensor

其中tensors表示要拼接的張量列表,dim表示我們希望在哪個維度上連線,預設是0,即在第一維上連線。out是輸出張量,可不傳入,當傳入此引數時其大小必須能容納在cat操作後的輸出tensor中。

Part 2: dim引數的說明

dim引數指示拼接發生的軸或維度。在拼接多個張量時,我們必須指定在哪個維度上拼接它們。dim引數可以是正數、負數或None(預設為0),具體來說,dim引數可以有以下三種常見用法:

正數

最常見的方式是使用正整數來指定要連線的維度/軸的索引值。例如,在將兩個大小為 3x5x7 的張量沿第2個維度拼接在一起時,這些張量變成一個形狀為 3x10x7 的張量。

# 定義兩個大小都為[3, 5, 7]的隨機Tensor
tensor1 = torch.randn(3, 5, 7)
tensor2 = torch.randn(3, 5, 7)
# 在第二維度上(索引1)進行合併 
cat_tensor = torch.cat((tensor1, tensor2), dim=1)
print(cat_tensor.shape) # 輸出: torch.Size([3, 10, 7])

負數

我們也可以使用負整數來表示要連線的軸/維度。當dim引數被設定為負整數時,它代表距離張量最後一個軸的間隔數。例如,將一個大小為3x5x7 和一個大小為3x6x7的張量沿著最後一個維度進行拼接,即 concatenate 第三個維度:

# 定義兩個大小分別為 [3, 5, 7], [3, 6, 7] 的隨機Tensor
tensor1 = torch.randn(3, 5, 7)
tensor2 = torch.randn(3, 6, 7)
# 在最後一個維度上(-1表示)進行合併 
cat_tensor = torch.cat((tensor1, tensor2), dim=-1)
print(cat_tensor.shape) # 輸出: torch.Size([3, 5, 14])

None

如果 dim 引數的值為 None,則會將所有輸入張量沿著前面的維度全部展開。這通常會在神經網路模型中使用,例如線上性層之間堆疊各個特徵向量時。

# 定義兩個大小分別為 [3, 5, 7], [4, 6, 8] 的隨機Tensor
tensor1 = torch.randn(3, 5, 7)
tensor2 = torch.randn(4, 6, 8)
# 將每個張量reshape為1D向量 
resized_t1 = tensor1.view(-1)
resized_t2 = tensor2.view(-1)
# 按行連線兩個1D張量  
cat_tensor = torch.cat((resized_t1, resized_t2), dim=None)
print(cat_tensor.shape) # 輸出: torch.Size([315])

Part 3: 總結

torch.cat() 函數是PyTorch非常有用的函數之一,它可以在某個維度上將多個張量組合成一個大張量。理解dim引數的含義和使用方法對於深入學習PyTorch和構建神經網路非常重要。通過在 dim 引數上增加或減少索引來改變連線選定的張量的方式,我們可以讓torch.cat()函數在資料處理、模型設計和深度學習中發揮重要作用。

以上就是PyTorch常用函數torch.cat()中dim引數使用說明的詳細內容,更多關於PyTorch torch.cat() dim的資料請關注it145.com其它相關文章!


IT145.com E-mail:sddin#qq.com