首頁 > 軟體

詳解利用Pytorch實現ResNet網路之評估訓練模型

2023-11-26 14:00:42

正文

每個 batch 前清空梯度,否則會將不同 batch 的梯度累加在一塊,導致模型引數錯誤。

然後我們將輸入和目標張量都移動到所需的裝置上,並將模型的梯度設定為零。我們呼叫model(inputs)來計算模型的輸出,並使用損失函數(在此處為交叉熵)來計算輸出和目標之間的誤差。然後我們通過呼叫loss.backward()來計算梯度,最後呼叫optimizer.step()來更新模型的引數。

在訓練過程中,我們還計算了準確率和平均損失。我們將這些值返回並使用它們來跟蹤訓練進度。

評估模型

我們還需要一個測試函數,用於評估模型在測試資料集上的效能。

以下是該函數的程式碼:

def test(model, criterion, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100 * correct / total
    avg_loss = test_loss / len(test_loader)
    return acc, avg_loss

在測試函數中,我們定義了一個with torch.no_grad()區塊。這是因為我們希望在測試集上進行前向傳遞時不計算梯度,從而加快模型的執行速度並節約記憶體。

輸入和目標也要移動到所需的裝置上。我們計算模型的輸出,並使用損失函數(在此處為交叉熵)來計算輸出和目標之間的誤差。我們通過累加損失,然後計算準確率和平均損失來評估模型的效能。

訓練 ResNet50 模型

接下來,我們需要訓練 ResNet50 模型。將資料載入器傳遞到訓練迴圈,以及一些其他引數,例如訓練週期數和學習率。

以下是完整的訓練程式碼:

num_epochs = 10
learning_rate = 0.001
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, num_epochs + 1):
    train_acc, train_loss = train(model, optimizer, criterion, train_loader, device)
    test_acc, test_loss = test(model, criterion, test_loader, device)
    print(f"Epoch {epoch}  Train Accuracy: {train_acc:.2f}%  Train Loss: {train_loss:.5f}  Test Accuracy: {test_acc:.2f}%  Test Loss: {test_loss:.5f}")
    # 儲存模型
    if epoch == num_epochs or epoch % 5 == 0:
        torch.save(model.state_dict(), f"resnet-epoch-{epoch}.ckpt")

在上面的程式碼中,我們首先定義了num_epochslearning_rate。我們使用了兩個資料載入器,一個用於訓練集,另一個用於測試集。然後我們移動模型到所需的裝置,並定義了損失函數和優化器。

在迴圈中,我們一次訓練模型,並在 train 和 test 資料集上計算準確率和平均損失。然後將這些值列印出來,並可選地每五次週期儲存模型引數。

您可以嘗試使用 ResNet50 模型對自己的影象資料進行訓練,並通過增加學習率、增加訓練週期等方式進一步提高模型精度。也可以調整 ResNet 的架構並進行效能比較,例如使用 ResNet101 和 ResNet152 等更深的網路。

以上就是詳解利用Pytorch實現ResNet網路的詳細內容,更多關於Pytorch ResNet網路的資料請關注it145.com其它相關文章!


IT145.com E-mail:sddin#qq.com