在PyTorch中,torch.utils.checkpoint
模块提供了实现梯度检查点(也称为checkpointing)的功能。这个技术主要用于训练时内存优化,它允许我们以计算时间为代价,减少训练深度网络时的内存占用。
原理
梯度检查点技术的基本原理是,在前向传播的过程中,并不保存所有的中间激活值。相反,它只保存一部分关键的激活值。在反向传播时,根据保留的激活值重新计算丢弃的中间激活值。因此内存的使用量会下降,但计算量会增加,因为需要重新计算一些前向传播的部分。
用法
torch.utils.checkpoint
中主要的函数是 checkpoint。checkpoint 函数可以用来封装模型的一部分或者一个复杂的运算,这部分会使用梯度检查点。它的一般用法是:
import torch
from torch.utils.checkpoint import checkpoint
# 定义一个前向传播函数
def custom_forward(*inputs):
# 定义你的前向传播逻辑
# 例如: x, y = inputs; result = x + y
服务器托管网 ...
return result
# 在训练的前向传播过程中使用梯度检查点
model_output = checkpoint(custom_forward, *model_inputs)
在每次调用 custom_forward 函数时,它都会返回正常的前向传播结果。不过,checkpoint 函数会确保仅保留必须的激活值(即 custom_forward 的输出)。其他激活值不会保存在内存中,需要在反向传播时重新计算。
下面是一个具体的示例,演示了如何在一个简单的模型中使用 checkpoint 函数:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class SomeModel(nn.Module):
def __init__(self):
super(SomeModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 50, 5)
def forward(self, x):
# 使用checkpoint来减少第二层卷积的内存使用量
x = self.conv1(x)
x = checkpoint(self.conv2, x)
return x
model = SomeModel()
input = torch.randn(1, 1, 28, 28)
output = model(input)
loss = output.sum()
loss.b服务器托管网ackward()
在上面的例子中,conv2的前向计算是通过 checkpoint 封装的,这意味着在 conv1 的输出和 conv2 的输出之间的激活值不会被完全存储。在反向传播时,这些丢失的激活值会通过再次前向传递 conv2 来重新计算。
使用梯度检查点技术可以在训练大型模型时减少显存的占用,但由于在反向传播时额外的重新计算,它会增加一些计算成本。
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
在 C# 中,& 和 && 都是逻辑与运算符,用于判断两个条件是否同时为真。 它们之间的区别如下: &: 会对两个条件进行求值,无论第一服务器托管网个条件的结果是 true 还是 false ,都会对第二个条件进行求值。 如果两…