必威体育Betway必威体育官网
当前位置:首页 > IT技术

requires_grad和volatile

时间:2019-07-25 15:11:17来源:IT技术作者:seo实验室小编阅读:61次「手机版」
 

requires

每个Tensor都有两个标志:requires_grad和volatile。它们都允许从梯度计算中精细地排除子图,并可以提高效率。

requires_grad

如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。

x = Variable(torch.randn(5, 5))

y = Variable(torch.randn(5, 5))

z = Variable(torch.randn(5, 5), requires_grad=True)

a = x + y

print(a.requires_grad)

b = a + z

print(b.requires_grad)

Result:

False

True

这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。

model = torchvision.models.resnet18(pretrained=True)

for param in model.parameters():

param.requires_grad = False

#Replace the last fully-connected layer

#Parameters of newly constructed modules have requires_grad=True by default

model.fc = nn.Linear(512, 100)

#Optimize only the classifier

optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

volatile

纯粹的inference模式下推荐使用volatile,当你确定你甚至不会调用.backward()时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile也决定了require_grad is False。

volatile不同于require_grad的传递。如果一个操作甚至只有一个volatile的输入,它的输出也将是volatile。Volatility比“不需要梯度”更容易传递——只需要一个volatile的输入即可得到一个volatile的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile的输入就够了,这将保证不会保存中间状态。

regular_input = Variable(torch.randn(5, 5))

volatile_input = Variable(torch.randn(5, 5), volatile=True)

model = torchvision.models.resnet18(pretrained=True)

print(model(regular_input).requires_grad)

output: True

print(model(volatile_input).requires_grad)

output: False

print(model(volatile_input).volatile)

output: True

print(model(volatile_input).creator is None)

output: True

引用:https://pytorch-cn.readthedocs.io/zh/latest/notes/autograd/

相关阅读

分享到:

栏目导航

推荐阅读

热门阅读