當前位置:成語大全網 - 新華字典 - Pytorch模型保存與加載,並在加載的模型基礎上繼續訓練

Pytorch模型保存與加載,並在加載的模型基礎上繼續訓練

pytorch保存模型非常簡單,主要有兩種方法:

壹般地,采用壹條語句即可保存參數:

其中model指定義的模型 實例變量 ,如 model=vgg16( ), path是保存參數的路徑,如 path='./model.pth' , path='./model.tar', path='./model.pkl', 保存參數的文件壹定要有後綴擴展名。

特別地,如果還想保存某壹次訓練采用的優化器、epochs等信息,可將這些信息組合起來構成壹個字典,然後將字典保存起來:

針對上述第壹種情況,也只需要壹句即可加載模型:

針對上述第二種以字典形式保存的方法,加載方式如下:

需要註意的是,只保存參數的方法在加載的時候要事先定義好跟原模型壹致的模型,並在該模型的實例對象(假設名為model)上進行加載,即在使用上述加載語句前已經有定義了壹個和原模型壹樣的Net, 並且進行了實例化 model=Net( ) 。

另外,如果每壹個epoch或每n個epoch都要保存壹次參數,可設置不同的path,如 path='./model' + str(epoch) +'.pth',這樣,不同epoch的參數就能保存在不同的文件中,選擇保存識別率最大的模型參數也壹樣,只需在保存模型語句前加個if判斷語句即可。

下面給出壹個具體的例子程序,該程序只保存最新的參數:

在訓練模型的時候可能會因為壹些問題導致程序中斷,或者常常需要觀察訓練情況的變化來更改學習率等參數,這時候就需要加載中斷前保存的模型,並在此基礎上繼續訓練,這時候只需要對上例中的 main() 函數做相應的修改即可,修改後的 main() 函數如下:

以上方法,如果想在命令行進行操作執行,都只需加入argpase模塊參數即可,相關方法可參考我的 博客

用法可參照上例。

這篇博客是壹個快速上手指南,想深入了解PyTorch保存和加載模型中的相關函數和方法,請移步我的這篇博客: PyTorch模型保存深入理解