AA-CrossViT—— 模型搭建

本文模型采用基于轴向分块和交叉注意力融合策略的双分支视觉 Transformer 作为锂离子电池早期寿命预测模型。具体而言,分别基于图形特征的电压轴向和周期轴向进行分块,而后利用卷积层将各分块编码到向量空间,随后添加反映分块信息和位置的可学习参数矩阵分类 token xclsx_{cls} 和位置编码矩阵 XpeX_{pe} 。Transformer 编码层用以提取特征和捕捉分块间的依赖关系,随后提取双流分支的分类 token 和分块 token 输出作为交叉注意力机制的输入,以融合双轴向分支的信息。

导入相关依赖库。

Related Module
import torch
import torch.nn as nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import repeat, rearrange
from torchsummary import summary
  • Rearrange/Reducerearrange/reduce 的区别:

前者是网络层,后者是数据处理函数。

# Input Size

本文输入为锂离子电池的图形特征输入,尺寸为3×100×1003\times 100\times 100

Input
x = torch.randn(3,100,100)   # 单样本尺寸为 3*100*100
print(x.shape)   # (3,100,100)
x_batch = torch.randn(16,3,100,100)  # 单个 batch 尺寸为 16*3*100*100
print(x_batch.shape)

# Patch Embedding

视觉 Transformer 模型的第一步需要将图片划分为多个分块(Patches),并且将其映射到向量。具体而言,处于效率考虑,先利用卷积层将每个分块映射到向量空间维度dkd_k,卷积核的大小与步长均为 patch 的尺寸,卷积核个数等于编码维度dkd_k,而后利用 Rearrange 函数改变维度顺序。

in_channels, patch_size = 3, (1,100)
projection = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=(1,1)),
            Rearrange("b e h w -> b (h w) e"),
        )

Rearrange/Reduce 参数解析:输入参数格式为字符串

Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, input_size:int=100, in_channels:int=3, patch_size:tuple=(1,100), emb_size:int=100):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=(1,1)),
            Rearrange("b e h w -> b (h w) e"),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_size)/emb_size)
        self.positions = nn.Parameter(torch.randn((input_size**2 // (patch_size[0]*patch_size[1]) + 1,emb_size)) / emb_size)
    def forward(self,x:Tensor):
        b, c, h, w = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token,"() n e -> b n e",b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.positions    # 这里区别于 class token 直接用广播机制相加,原因是位置编码应该在 batch 内的不同样本也保持一致
        # print(x.shape)
        return x

# Class Token

# Positional Encoding

# Transformer Encoder

#

更新于 阅读次数