共轭梯度法已经在前文中给出介绍:
python版本的“共轭梯度法”算法代码
=======================================
使用共轭梯度法时,如果系数矩阵为Hessian矩阵,那么我们可以使用Pearlmutter trick技术来减少计算过程中的内存消耗,加速计算。
使用Pearlmutter trick的共轭梯度解法源自论文:
Fast Exact Multiplication by the Hessian
论文地址:
https://www.bcl.hamilton.ie/~barak/papers/nc-hessian.pdf
由于原论文中内容较多,所以我们介绍Pearlmutter trick技术建议看的资料为其他网文blog:
https://justindomke.wordpress.com/2009/01/17/hessian-vector-products/
======================================
依照上面内容解释一下:
Hession 矩阵 H(x) 为 f(x) 的二阶导数矩阵,因此有:
我们可以将 g(x) 按照泰勒公式展开为一阶导形式:
v 为一个向量vector,根据上面的公式我们可以得到:
其中 γ 为标量系数,该系数极小,因此γv可以看做为Δx
因此我们可以得到下面形式的公式:
——————————————————–
因为 H(x) 必然为正定对称矩阵,因此我们对 H(x) * y = b 形式的求解式可以使用共轭梯度法,而共轭梯度法在计算过程中需要重复的计算 H(x)*p, 其中 p 为计算过程中的迭代向量,p是在迭代过程中不断变化的,而H(x) 是系数矩阵在迭代过程中是不变的。
Pearlmutter trick 这个技术给出的是近似解,这里知道这个技术即可,实际感觉好像用处也不多,有可能是自己理解的不深。关于H(x)*b与共轭梯度的结合代码这里就省略掉了。
import torch
# 计算目标为: H*b, H为函数s关于变量w的hessian矩阵
# 变量w
w = torch.randn(4, requires_grad=True)
# 关于变量w的函数s
data = torch.randn(1000*4).reshape((-1, 4))
label = torch.randn(1000)
s = torch.mean( torch.square(label - torch.matmul(data, w)) )
# 计算目标中的b
b = torch.randn(4)
# s对w的一阶导
first_grad = torch.autograd.grad(s, w, create_graph=True)[0]
# 使用标准公式计算 H*b
second_grad = []
for grad in first_grad:
second_grad.append(torch.autograd.grad(grad, w, retain_graph=True)[0][None, :])
H = torch.concatenate(second_grad, axis=0)
print("Hessian method:")
print(torch.matmul(H, b))
# 在目标函数s进行一阶导后点乘向量b, 然后再对w进行一次求导, 也就是 dot(first_grad, b)后再次求导
# 这样计算可以减少内存占用,因为该种计算方式不会在内存中对整个hessian矩阵进行展开
# 如果要重复计算 H*b,而b又每次迭代都变化的情况,此种方式的缺点是每次都需要再次求导,但是总计算量应该变化不大
tmp = torch.dot(first_grad, b)
print(torch.autograd.grad(tmp, w)[0])
# paper, Pearlmutter trick, 没太感觉出优势, 或许理解的不对
r = 0.0001
new_w = w+r*b
new_s = torch.mean( torch.square(label - torch.matmul(data, new_w)) )
new_first_grad = torch.autograd.grad(new_s, w)[0]
print( (new_first_grad - first_grad)/r )
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
机房租用,北京机房租用,IDC机房托管, http://www.fwqtg.net
相关推荐: 创建nodejs项目并接入mysql,完成用户相关的增删改查的详细操作
本文为博主原创,转载请注明出处: 1.使用npm进行初始化 在本地创建项目的文件夹名称,如 node_test,并在该文件夹下进行黑窗口执行初始化命令 2. 安装 expres包和myslq依赖包 npm i express@4.17.1 mysql2@…