写点什么

恒源云 (GpuShare)_ 医学图像分割:MT-UNet

作者:恒源云
  • 2022 年 3 月 09 日
  • 本文字数:6081 字

    阅读完需:约 20 分钟

恒源云(GpuShare)_医学图像分割:MT-UNet


我们社区有新的技术分享小伙伴啦🎉🎉🎉

热烈欢迎👏

作为一名合格的搬运工,我必须做点事情表达我的喜悦之情:搬运~搬运~立即搬运~


文章来源 | 恒源云社区


原文地址 | 新的混合Transformer模块(MTM)


原文作者 | 咚咚



摘要

Method



如图 1 所示。该网络基于编码器-解码器结构


  1. 为了降低计算成本,MTMs 只对空间大小较小的深层使用,

  2. 浅层仍然使用经典的卷积运算。这是因为浅层主要关注局部信息,包含更多高分辨率的细节。

MTM

如图 2 所示。MTM 主要由 LGG-SA 和 EA 组成。


LGG-SA 用于对不同粒度的短期和长期依赖进行建模,而 EA 用于挖掘样本间的相关性。


该模块是为了替代原来的 Transformer 编码器,以提高其在视觉任务上的性能和降低时间复杂度

LGG-SA(Local-Global Gaussian-Weighted Self-Attention)

传统的 SA 模块对所有 tokens 赋予相同的关注度,而 LGG -SA 则不同,利用 local-global 自注意力和高斯 mask 使其可以更专注于邻近区域。实验证明,该方法可以提高模型的性能,节省计算资源。该模块的详细设计如图 3 所示



local-global 自注意力


在计算机视觉中,邻近区域之间的相关性往往比遥远区域之间的相关性更重要,在计算注意图时,不需要为更远的区域花费相同的代价。


因此,提出 local-global 自注意力


  1. 上图 stage1 中的每个局部窗口中含有四个 token,local SA 计算每个窗口内的内在 affinities。

  2. 每个窗口中的 token 被 aggregate 聚合为一个全局 token ,表示窗口的主要信息。对于聚合函数,轻量级动态卷积(Lightweight Dynamic convolution, LDConv)的性能最好。

  3. 在得到下采样的整个特征图后,可以以更少的开销执行 global SA(上图 stage2)。



其中


其中,stage1 中的局部窗口自注意力代码如下:


class WinAttention(nn.Module):    def __init__(self, configs, dim):        super(WinAttention, self).__init__()        self.window_size = configs["win_size"]        self.attention = Attention(dim, configs)
def forward(self, x): b, n, c = x.shape h, w = int(np.sqrt(n)), int(np.sqrt(n)) x = x.permute(0, 2, 1).contiguous().view(b, c, h, w) if h % self.window_size != 0: right_size = h + self.window_size - h % self.window_size new_x = torch.zeros((b, c, right_size, right_size)) new_x[:, :, 0:x.shape[2], 0:x.shape[3]] = x[:] new_x[:, :, x.shape[2]:, x.shape[3]:] = x[:, :, (x.shape[2] - right_size):, (x.shape[3] - right_size):] x = new_x b, c, h, w = x.shape x = x.view(b, c, h // self.window_size, self.window_size, w // self.window_size, self.window_size) x = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(b, h // self.window_size, w // self.window_size, self.window_size * self.window_size, c).cuda() x = self.attention(x) # (b, p, p, win, c) 对局部窗口内的tokens进行自注意力计算 return x
复制代码


聚合函数代码如下


class DlightConv(nn.Module):    def __init__(self, dim, configs):        super(DlightConv, self).__init__()        self.linear = nn.Linear(dim, configs["win_size"] * configs["win_size"])        self.softmax = nn.Softmax(dim=-1)
def forward(self, x): # (b, p, p, win, c) h = x avg_x = torch.mean(x, dim=-2) # (b, p, p, c) x_prob = self.softmax(self.linear(avg_x)) # (b, p, p, win)
x = torch.mul(h, x_prob.unsqueeze(-1)) # (b, p, p, win, c) x = torch.sum(x, dim=-2) # (b, p, p, c) return x
复制代码


Gaussian-Weighted Axial Attention


与使用原始 SA 的 LSA 不同,提出了高斯加权轴向注意(GWAA)的方法。GWAA 通过一个可学习的高斯矩阵增强了相邻区域的感知全权重,同时由于具有轴向注意力而降低了时间复杂度。


  1. 上图中 stage2 中特征图的第三行第三列特征进行 linear projection 得到

  2. 将该特征点所在行和列的所有特征分别进行 linear projection 得到

  3. 将该特征点与所有的 K 和 V 的欧式距离定义为


最终的高斯加权轴向注意力输出结果为



并简化为



轴向注意力代码如下:


class Attention(nn.Module):    def __init__(self, dim, configs, axial=False):        super(Attention, self).__init__()        self.axial = axial        self.dim = dim        self.num_head = configs["head"]        self.attention_head_size = int(self.dim / configs["head"])        self.all_head_size = self.num_head * self.attention_head_size
self.query_layer = nn.Linear(self.dim, self.all_head_size) self.key_layer = nn.Linear(self.dim, self.all_head_size) self.value_layer = nn.Linear(self.dim, self.all_head_size)
self.out = nn.Linear(self.dim, self.dim) self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_head, self.attention_head_size) x = x.view(*new_x_shape) return x
def forward(self, x): # first row and col attention if self.axial: # x: (b, p, p, c) # row attention (single head attention) b, h, w, c = x.shape mixed_query_layer = self.query_layer(x) mixed_key_layer = self.key_layer(x) mixed_value_layer = self.value_layer(x)
query_layer_x = mixed_query_layer.view(b * h, w, -1) key_layer_x = mixed_key_layer.view(b * h, w, -1).transpose(-1, -2) # (b*h, -1, w) attention_scores_x = torch.matmul(query_layer_x, key_layer_x) # (b*h, w, w) attention_scores_x = attention_scores_x.view(b, -1, w, w) # (b, h, w, w)
# col attention (single head attention) query_layer_y = mixed_query_layer.permute(0, 2, 1, 3).contiguous().view( b * w, h, -1) key_layer_y = mixed_key_layer.permute( 0, 2, 1, 3).contiguous().view(b * w, h, -1).transpose(-1, -2) # (b*w, -1, h) attention_scores_y = torch.matmul(query_layer_y, key_layer_y) # (b*w, h, h) attention_scores_y = attention_scores_y.view(b, -1, h, h) # (b, w, h, h)
return attention_scores_x, attention_scores_y, mixed_value_layer
else: mixed_query_layer = self.query_layer(x) mixed_key_layer = self.key_layer(x) mixed_value_layer = self.value_layer(x)
query_layer = self.transpose_for_scores(mixed_query_layer).permute( 0, 1, 2, 4, 3, 5).contiguous() # (b, p, p, head, n, c) key_layer = self.transpose_for_scores(mixed_key_layer).permute( 0, 1, 2, 4, 3, 5).contiguous() value_layer = self.transpose_for_scores(mixed_value_layer).permute( 0, 1, 2, 4, 3, 5).contiguous()
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt( self.attention_head_size) atten_probs = self.softmax(attention_scores)
context_layer = torch.matmul( atten_probs, value_layer) # (b, p, p, head, win, h) context_layer = context_layer.permute(0, 1, 2, 4, 3, 5).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer)
return attention_output
复制代码


高斯加权代码如下:


class GaussianTrans(nn.Module):    def __init__(self):        super(GaussianTrans, self).__init__()        self.bias = nn.Parameter(-torch.abs(torch.randn(1)))        self.shift = nn.Parameter(torch.abs(torch.randn(1)))        self.softmax = nn.Softmax(dim=-1)
def forward(self, x): x, atten_x_full, atten_y_full, value_full = x #x(b, h, w, c) atten_x_full(b, h, w, w) atten_y_full(b, w, h, h) value_full(b, h, w, c) new_value_full = torch.zeros_like(value_full)
for r in range(x.shape[1]): # row for c in range(x.shape[2]): # col atten_x = atten_x_full[:, r, c, :] # (b, w) atten_y = atten_y_full[:, c, r, :] # (b, h)
dis_x = torch.tensor([(h - c)**2 for h in range(x.shape[2]) ]).cuda() # (b, w) dis_y = torch.tensor([(w - r)**2 for w in range(x.shape[1]) ]).cuda() # (b, h)
dis_x = -(self.shift * dis_x + self.bias).cuda() dis_y = -(self.shift * dis_y + self.bias).cuda()
atten_x = self.softmax(dis_x + atten_x) atten_y = self.softmax(dis_y + atten_y)
new_value_full[:, r, c, :] = torch.sum( atten_x.unsqueeze(dim=-1) * value_full[:, r, :, :] + atten_y.unsqueeze(dim=-1) * value_full[:, :, c, :], dim=-2) return new_value_full
复制代码


local-global 自注意力完整代码如下:


class CSAttention(nn.Module):    def __init__(self, dim, configs):        super(CSAttention, self).__init__()        self.win_atten = WinAttention(configs, dim)        self.dlightconv = DlightConv(dim, configs)        self.global_atten = Attention(dim, configs, axial=True)        self.gaussiantrans = GaussianTrans()        #self.conv = nn.Conv2d(dim, dim, 3, padding=1)        #self.maxpool = nn.MaxPool2d(2)        self.up = nn.UpsamplingBilinear2d(scale_factor=4)        self.queeze = nn.Conv2d(2 * dim, dim, 1)
def forward(self, x): ''' :param x: size(b, n, c) :return: ''' origin_size = x.shape _, origin_h, origin_w, _ = origin_size[0], int(np.sqrt( origin_size[1])), int(np.sqrt(origin_size[1])), origin_size[2] x = self.win_atten(x) # (b, p, p, win, c) b, p, p, win, c = x.shape h = x.view(b, p, p, int(np.sqrt(win)), int(np.sqrt(win)), c).permute(0, 1, 3, 2, 4, 5).contiguous() h = h.view(b, p * int(np.sqrt(win)), p * int(np.sqrt(win)), c).permute(0, 3, 1, 2).contiguous() # (b, c, h, w)
x = self.dlightconv(x) # (b, p, p, c) atten_x, atten_y, mixed_value = self.global_atten( x) # (b, h, w, w) (b, w, h, h) (b, h, w, c)这里的h w就是p gaussian_input = (x, atten_x, atten_y, mixed_value) x = self.gaussiantrans(gaussian_input) # (b, h, w, c) x = x.permute(0, 3, 1, 2).contiguous() # (b, c, h, w)
x = self.up(x) x = self.queeze(torch.cat((x, h), dim=1)).permute(0, 2, 3, 1).contiguous() x = x[:, :origin_h, :origin_w, :].contiguous() x = x.view(b, -1, c)
return x
复制代码

EA

外部注意(External Attention, EA),是用于解决 SA 无法利用不同输入数据样本之间关系的问题。


与使用每个样本自己的线性变换来计算注意分数的自我注意不同,在 EA 中,所有的数据样本共享两个记忆单元 MK MV(如图 2 所示),描述了整个数据集的最重要信息。


EA 代码如下:


class MEAttention(nn.Module):    def __init__(self, dim, configs):        super(MEAttention, self).__init__()        self.num_heads = configs["head"]        self.coef = 4        self.query_liner = nn.Linear(dim, dim * self.coef)        self.num_heads = self.coef * self.num_heads        self.k = 256 // self.coef        self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)        self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
self.proj = nn.Linear(dim * self.coef, dim)
def forward(self, x): B, N, C = x.shape x = self.query_liner(x) # (b, n, 4c) x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # (b, h, n, 4c/h)
attn = self.linear_0(x) # (b, h, n, 256/4)
attn = attn.softmax(dim=-2) # (b, h, 256/4) attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True)) # (b, h, 256/4)
x = self.linear_1(attn).permute(0, 2, 1, 3).reshape(B, N, -1)
x = self.proj(x)
return x
复制代码

EXPERIMENTS





用户头像

恒源云

关注

专注人工智能云GPU服务器训练平台 2020.12.25 加入

还未添加个人简介

评论

发布
暂无评论
恒源云(GpuShare)_医学图像分割:MT-UNet_深度学习_恒源云_InfoQ写作平台