论文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]
本文记录了我在学习 VAE 过程中的一些公式推导和思考。如果你希望从头开始学习 VAE,建议先看一下苏剑林的博客(本文末尾有链接)。
VAE 的整体框架
VAE 认为,随机变量 (boldsymbol{x} sim p(boldsymbol{x})) 由两个随机过程得到:
- 根据先验分布 (p(boldsymbol{z})) 生成隐变量 (boldsymbol{z})。
- 根据条件分布 (p(boldsymbol{x} | boldsymbol{z})) 由 (boldsymbol{z}) 得到 (boldsymbol{x})。
于是 (p(boldsymbol{x}, boldsymbol{z}) = p(boldsymbol{z})p(boldsymbol{x} | boldsymbol{z})) 就是我们所需要的生成模型。
一种朴素的想法是:先用随机数生成器生成隐变量 (boldsymbol{z}),然后用 (p(boldsymbol{x} | boldsymbol{z})) 从 (boldsymbol{z}) 中生成出(或者说重构出) (boldsymbol{x}),通过最小化重构损失来训练模型。这个想法的问题在于:我们无法找到生成的样本与原始样本之间的对应关系,重构损失算不了,无法训练。
VAE 的做法是引入后验分布 (p(boldsymbol{z} | boldsymbol{x})),训练过程变为:
- 采样一批原始样本 (boldsymbol{x})。
- 用 (p(boldsymbol{z} | boldsymbol{x})) 获得每个样本 (boldsymbol{x}) 对应的隐变量 (boldsymbol{z})。
- 用 (p(boldsymbol{x} | boldsymbol{z})) 从隐变量 (boldsymbol{z}) 中重构出 (boldsymbol{x}),通过最小化重构损失来训练模型。
从这个角度来看,(p(boldsymbol{z} | boldsymbol{x})) 相当于编码器,(p(boldsymbol{x} | boldsymbol{z})) 相当于解码器,训练结束后只需要保留解码器 (p(boldsymbol{x} | boldsymbol{z})) 即可。
除了重构损失以外,VAE 还有一项 KL 散度损失,希望近似的后验分布 (q(boldsymbol{z} | boldsymbol{x})) 尽量接近先验分布 (p(boldsymbol{z})),即最小化二者的 KL 散度。
变分下界的推导
现有 (N) 个由分布 (P(boldsymbol{x}; boldsymbol{theta})) 生成的样本 (boldsymbol{x}^{(1)}, ldots, boldsymbol{x}^{(N)}),我们可以使用极大似然估计从这些样本中估计出分布的参数 (boldsymbol{theta}),即
boldsymbol{theta}
& = operatorname*{argmax}_{boldsymbol{theta}} p(boldsymbol{x}^{(1)}; boldsymbol{theta}) cdots p(boldsymbol{x}^{(N)}; boldsymbol{theta})
& = operatorname*{argmax}_{boldsymbol{theta}} ln(p(boldsymbol{x}^{(1)}; boldsymbol{theta}) cdots p(boldsymbol{x}^{(N)}; boldsymbol{theta}))
& = operatorname*{argmax}_{boldsymbol{theta}} sum_{i=1}^n ln p(boldsymbol{x}^{(i)}; boldsymbol{theta}).
end{aligned}
]
后验分布 (p(boldsymbol{z} | boldsymbol{x}) = frac{p(boldsymbol{z})p(boldsymbol{x} | boldsymbol{z})}{p(boldsymbol{x})} = frac{p(boldsymbol{z})p(boldsymbol{x} | boldsymbol{z})}{int_{boldsymbol{z}} p(boldsymbol{x}, boldsymbol{z}) mathrm{d}boldsymbol{z}}) 是 intractable 的,因为分母处的边缘分布 (p(boldsymbol{x})) 积不出来。具体来说,联合分布 (p(boldsymbol{x}, boldsymbol{z}) = p(boldsymbol{z})p(boldsymbol{x} | boldsymbol{z})) 的表达式非常复杂,(int_{boldsymbol{z}} p(boldsymbol{x}, boldsymbol{z}) mathrm{d}boldsymbol{z}) 这个积分找不到解析解。
需要使用变分推断解决后验分布无法计算的问题。我们使用一个形式已知的分布 (q(boldsymbol{z}|boldsymbol{x}^{(i)}; boldsymbol{phi})) 来近似后验分布 (p(boldsymbol{z}|boldsymbol{x}^{(i)}; boldsymbol{theta})),于是有
log p(boldsymbol{x}^{(i)})
& = int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[log q(boldsymbol{z}|boldsymbol{x}^{(i)}) – log p(boldsymbol{z}|boldsymbol{x}^{(i)})] mathrm{d}boldsymbol{z} + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{z}|boldsymbol{x}^{(i)})] mathrm{d}boldsymbol{z} + log p(boldsymbol{x}^{(i)}) cdot 1
& = int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})logfrac{q(boldsymbol{z}|boldsymbol{x}^{(i)})}{p(boldsymbol{z}|boldsymbol{x}^{(i)})} mathrm{d}boldsymbol{z} + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{z}|boldsymbol{x}^{(i)})] mathrm{d}boldsymbol{z} + log p(boldsymbol{x}^{(i)}) cdot int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})mathrm{d}boldsymbol{z}
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{z}|boldsymbol{x}^{(i)})] mathrm{d}boldsymbol{z} + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})log p(boldsymbol{x}^{(i)}) mathrm{d}boldsymbol{z}
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{x}^{(i)})] mathrm{d}boldsymbol{z}
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log (p(boldsymbol{z}|boldsymbol{x}^{(i)})p(boldsymbol{x}^{(i)}))] mathrm{d}boldsymbol{z}
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{x}^{(i)}, boldsymbol{z})] mathrm{d}boldsymbol{z}
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + mathbb{E}_{boldsymbol{z} sim q(boldsymbol{z}|boldsymbol{x}^{(i)})}[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{x}^{(i)}, boldsymbol{z})]
& = mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z}|boldsymbol{x}^{(i)})] + L(boldsymbol{theta}, boldsymbol{phi}; boldsymbol{x}^{(i)})
& geq L(boldsymbol{theta}, boldsymbol{phi}; boldsymbol{x}^{(i)}).
end{aligned}
]
利用 KL 散度大于等于 0 这一特性,我们得到了对数似然 (log p(boldsymbol{x}^{(i)})) 的一个下界 (L(boldsymbol{theta}, boldsymbol{phi}; boldsymbol{x}^{(i)})),于是可以将最大化对数似然改为最大化这个下界。
这个下界可以进一步写成
L(boldsymbol{theta}, boldsymbol{phi}; boldsymbol{x}^{(i)})
& = int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{x}^{(i)}, boldsymbol{z})] mathrm{d}boldsymbol{z}
& = int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log (p(boldsymbol{z})p(boldsymbol{x}^{(i)}|boldsymbol{z}))] mathrm{d}boldsymbol{z}
& = int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[-log q(boldsymbol{z}|boldsymbol{x}^{(i)}) + log p(boldsymbol{z}) + log p(boldsymbol{x}^{(i)}|boldsymbol{z})] mathrm{d}boldsymbol{z}
& = -int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})[log q(boldsymbol{z}|boldsymbol{x}^{(i)}) – log p(boldsymbol{z})] mathrm{d}boldsymbol{z} + int_{boldsymbol{z}} q(boldsymbol{z}|boldsymbol{x}^{(i)})log p(boldsymbol{x}^{(i)}|boldsymbol{z})] mathrm{d}boldsymbol{z}
& = -mathrm{KL}[q(boldsymbol{z}|boldsymbol{x}^{(i)}), p(boldsymbol{z})] + mathbb{E}_{boldsymbol{z} sim q(boldsymbol{z}|boldsymbol{x}^{(i)})}[log p(boldsymbol{x}^{(i)}|boldsymbol{z})].
end{aligned}
]
其中的第一项是 KL 散度损失,第二项是重构损失。
KL 散度损失
使用标准正态分布作为先验分布,即 (p(boldsymbol{z}) = N(boldsymbol{z}; boldsymbol{0}, boldsymbol{I}))。
使用一个由 MLP 的输出来参数化的正态分布作为近似后验分布,即 (q(boldsymbol{z}|boldsymbol{x}^{(i)}; boldsymbol{phi}) = N(boldsymbol{z}; boldsymbol{mu}(boldsymbol{x}^{(i)}; boldsymbol{phi}), boldsymbol{sigma}^2(boldsymbol{x}^{(i)}; boldsymbol{phi})boldsymbol{I}))。
选择正态分布的好处在于 KL 散度的这个积分可以写出解析解,训练时直接按照公式计算即可,无需通过采样的方式来算积分。
由于我们选择的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可:
mathrm{KL}[N(z; mu, sigma^2), N(z; 0, 1)]
& = int_z N(z; mu, sigma^2)logfrac{N(z; mu, sigma^2)}{N(z; 0, 1)} mathrm{d}z
& = int_z N(z; mu, sigma^2) logfrac{frac{1}{sqrt{2pi}sigma}expleft(-frac{(z – mu)^2}{2sigma^2}right)}{frac{1}{sqrt{2pi}}expleft(-frac{z^2}{2}right)} mathrm{d}z
& = int_z N(z; mu, sigma^2) logleft(frac{1}{sqrt{sigma^2}}expleft(frac{1}{2}left(-frac{(z – mu^2)^2}{sigma^2} + z^2right)right)right) mathrm{d}z
& = frac{1}{2}int_z N(z; mu, sigma^2) left(-logsigma^2 – frac{(z – mu)^2}{sigma^2} + z^2right)mathrm{d}z
& = frac{1}{2}left(-logsigma^2int_z N(z; mu, sigma^2) mathrm{d}z – frac{1}{sigma^2}int_z N(z; mu, sigma^2)(z – mu)^2mathrm{d}z + int_z N(z; mu, sigma^2)z^2mathrm{d}zright)
& = frac{1}{2}left(-logsigma^2 cdot 1 – frac{1}{sigma^2} cdot sigma^2 + mu^2 + sigma^2right)
& = frac{1}{2}(-logsigma^2 – 1 + mu^2 + sigma^2).
end{aligned}
]
解释一下倒数第三行的三个积分:
- (int_z N(z; mu, sigma^2) mathrm{d}z) 是概率密度函数的积分,也就是 1。
- (int_z N(z; mu, sigma^2)(z – mu)^2mathrm{d}z) 是方差的定义,也就是 (sigma^2)。
- (int_z N(z; mu, sigma^2)z^2mathrm{d}z) 是正态分布的二阶矩,结果为 (mu^2 + sigma^2)。
重构损失
伯努利分布模型
当 (boldsymbol{x}) 是二值向量时,可以用伯努利分布(两点分布)来建模 (p(boldsymbol{x}|boldsymbol{z})),即认为向量 (boldsymbol{x}) 的每个维度都服从对应的相互独立的伯努利分布。使用一个 MLP 来计算各维度所对应的伯努利分布的参数,第 (i) 维伯努利分布的参数为 (y_i = boldsymbol{y}(boldsymbol{z})_i),于是有
]
]
其中 (D) 表示向量 (boldsymbol{x}) 的维度。可见此时最大化 (log p(boldsymbol{x}|boldsymbol{z})) 等价于最小化交叉熵损失。
正态分布模型
当 (boldsymbol{x}) 是实值向量时,可以用正态分布来建模 (p(boldsymbol{x}|boldsymbol{z}))。使用一个 MLP 来计算正态分布的参数,于是有
p(boldsymbol{x}|boldsymbol{z})
& = N(boldsymbol{x}; boldsymbol{mu}, boldsymbol{sigma}^2boldsymbol{I})
& = prod_{i=1}^D N(x_i; mu_i, sigma_i^2)
& = left(prod_{i=1}^Dfrac{1}{sqrt{2pi}sigma_i}right)expleft(sum_{i=1}^D-frac{(x_i – mu_i)^2}{2sigma_i^2}right),
end{aligned}
]
]
很多时候我们会假设 (sigma_i^2) 是一个常数,于是 MLP 只需要输出均值参数 (boldsymbol{mu}) 即可。此时有
]
可见此时最大化 (log p(boldsymbol{x}|boldsymbol{z})) 等价于最小化 MSE 损失。
重参数化技巧
需要使用重参数化技巧解决采样 (z) 时不可导的问题。解决的思路是先从无参数分布中采样一个 (varepsilon),再通过变换得到 (z)。
从 (N(mu, sigma^2)) 中采样一个 (z),相当于先从 (N(0, 1)) 中采样一个 (varepsilon),然后令 (z = mu + varepsiloncdotsigma)。
相关知识
技巧,通过取对数把乘除变成加减:
]
随机变量的函数的期望:
]
利用此公式可以将积分改写成期望的形式,这样就可以用采样的方式计算积分了(蒙特卡罗积分法)。
条件概率密度的定义:
]
此处的 (p) 并不是概率而是概率密度函数,但是这个公式在形式上跟条件概率公式是一样的。
参考资料
苏剑林的 VAE 系列博客:
- 变分自编码器(一):原来是这么一回事 – 科学空间
- 变分自编码器(二):从贝叶斯观点出发 – 科学空间
- 变分自编码器(三):这样做为什么能成? – 科学空间
15 分钟了解变分推理:
- 【15分钟】了解变分推理 – 哔哩哔哩
- 【15分钟】了解变分自编码器 – 哔哩哔哩
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
机房租用,北京机房租用,IDC机房托管, http://www.fwqtg.net
目前信息通信行业是社会最为热门的认证之一,很多人为了进入这一行都会通过考证书来给自己背书,华为云作为现在最有热度的云厂商之一,其旗下的认证是相当具有含金量的,下面小编就热门的云计算和数通介绍一下,有需要的可以在认证大使上详细了解。 华为云云计算概述 培训与认证…