我们知道,用tensor数据很多一个原因就是它可以利用GPU进行并行运算,大大提高运算效率。tensor数据的广播特性可以看做是一种“索引”,从效率极低的循环中解脱出来。我们先来看一段程序。
a = torch.randint(0, 2, (8, 256, 256)) b = torch.randint(0, 15, (8, 256, 256)) c = torch.randint(0, 15, (8, 256, 256)) d = torch.randn(8, 2, 256, 256) def func(a, b, c, d): e = torch.zeros(8, 256, 256) # 遍历每个批次中的图像 for batch_idx in range(8): # 遍历图像e的每个像素点 for i in range(256): for j in range(256): # 获取a中的值确定通道 channel_idx = int(a[batch_idx, i, j].item()) # 获取b和c中的值确定在d中的位置 row_idx = int(b[batch_idx, i, j].item()) col_idx = int(c[batch_idx, i, j].item()) # 获取d中的值并赋给e中对应位置 e[batch_idx, i, j] = d[batch_idx, channel_idx, row_idx, col_idx] return e
e=func(a,b,c,d)
这段程序利用多层循环实现的数据索引,旨在构建一个新的注意力map,但是这样效率非常的低,因此,我们考虑用tensor数据的广播特性进行代替。
import torch def func(a, b, c, d): # 获取通道、行、列索引 channel_idx = a.long() row_idx = b.long() col_idx = c.long() # 通过索引获取d中的值,这里要求了括号中的各个tensor数据的shape要相同 e = d[torch.arange(8).unsqueeze(1).unsqueeze(1), channel_idx, row_idx, col_idx] return e # 示例用法 # 创建示例数据 a = torch.randint(0, 2, (8, 256, 256)) b = torch.randint(0, 15, (8, 25服务器托管网6, 256)) c = torch.randint(0, 15, (8, 256, 256)) d = torch.randn(8, 2, 256, 256) # 调用函数 e = func(a, b, c, d) print(e)
根据给定的代码和张量尺寸,我们可以逐步分析最终张量e的尺寸和形成过程:
1. 首先,我们通过torch.arange(8).unsqueeze(1).unsqueeze(1)生成了一个形状为[8, 1, 1]的张量。
2. 然后,我们使用这个张量、以及channel_idx、row_idx、col_idx作为索引,从d中提取元素。
根据广播规则,张量的维度会被扩展以匹配其他张量的维度,以便进行索引操作。
3. 提取的元素被放置在一个新的张量e中。
现在让我们逐步计算e的尺寸:
原始张量d服务器托管网的尺寸为[8, 2, 256, 256]。
生成的索引张量的尺寸为[8, 1, 1]。
其他索引张量channel_idx、row_idx、col_idx的尺寸为[8, 256, 256]。
因为我们使用了四个索引张量,每个张量的维度都会被广播扩展成匹配最大的张量维度,也就是[8, 2, 256, 256]。所以:
第一个索引张量被扩展为[8, 2, 256, 256]。
channel_idx、row_idx、col_idx各自被扩展为[8, 2, 256, 256]。
最终,从d中提取的元素将形成一个与索引张量相同的形状,即[8, 2, 256, 256]。
因此,最终张量e的尺寸将是[8, 256, 256],这是由于在索引操作时,第一个索引张量(形状为[8, 1, 1])的维度被消除了,而提取的元素会按照剩余索引张量的形状进行排列。
理解广播特性旨在理解维度的扩张和索引的概念,与高中所学索引排序有许多相似之处。
以此为记。
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
1.打一个Development包 2.打开cmd,CD到sdk的Platform-tools下 3.连接安卓设备 ①连接安卓手机,需要开启开发者模式, 不同的手机开启方式有所不同,比如华为的手机需要在:设置-关于手机-版本号,连续点击7次版本号 ②打开USB…