PyTorchの個人的まとめ

pytorchの使い方

torchvision.modelsという定義モデルの使い方.

次のモデルが定義されている

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
<pre class="”line-numbers”"><code class="”language-python”">
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
</code><code></code></pre>

という書き方。

事前学習モデルの使用の仕方は以下のように

pretrained=Trueにする.

データの読み込み

Class torchvision.datasets.ImageFolder

(roottransform=Nonetarget_transform=Noneloader=<function default_loader>)

<pre class="”line-numbers”"><code class="”language-python”">
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
</code><code></code></pre>

データがこのようになってる時に使う

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

モデルをセーブするとき

torch.save(model.state_dict(), PATH)

PATHはファイル名を含んだパス.

例えば,

torch.save(model.state_dict(), “~/Desktop/model/first.model”)とか

モデルをロードするとき

model = models.inception_v3(num_classes=10) # クラス数は任意に(データに合わせて)

決めないとエラー出る.

 param = torch.load(“~/Desktop/model/first.model”)
 model.load_state_dict(param)
 model.eval()

torchvision.transforms.functional.normalize(tensor, mean, std)

はテンソル画像を平均と標準偏差で正規化するメソッドです。

torchvision.transforms.functional.resize(imgsizeinterpolation=2)

PIL 画像を与えられたサイズでリサイズします。

コメントを残す

メールアドレスが公開されることはありません。

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください