AA-CrossViT—— 模型搭建
本文模型采用基于轴向分块和交叉注意力融合策略的双分支视觉 Transformer 作为锂离子电池早期寿命预测模型。具体而言,分别基于图形特征的电压轴向和周期轴向进行分块,而后利用卷积层将各分块编码到向量空间,随后添加反映分块信息和位置的可学习参数矩阵分类 token 和位置编码矩阵 。Transformer 编码层用以提取特征和捕捉分块间的依赖关系,随后提取双流分支的分类 token 和分块 token 输出作为交叉注意力机制的输入,以融合双轴向分支的信息。
# Related Module
导入相关依赖库。
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from einops.layers.torch import Rearrange, Reduce | |
from einops import repeat, rearrange | |
from torchsummary import summary |
# Input Size
本文输入为锂离子电池的图形特征输入,尺寸为。
x = torch.randn(3,100,100) # 单样本尺寸为 3*100*100 | |
print(x.shape) # (3,100,100) | |
x_batch = torch.randn(16,3,100,100) # 单个 batch 尺寸为 16*3*100*100 | |
print(x_batch.shape) |
# Patch Embedding
视觉 Transformer 模型的第一步需要将图片划分为多个分块(Patches),并且将其映射到向量。具体而言,处于效率考虑,先利用卷积层将每个分块映射到向量空间维度,卷积核的大小与步长均为 patch 的尺寸,卷积核个数等于编码维度,而后利用 Rearrange 函数改变维度顺序。
in_channels, patch_size = 3, (1,100) | |
projection = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=(1,1)), | |
Rearrange("b e h w -> b (h w) e"), | |
) |
class PatchEmbedding(nn.Module): | |
def __init__(self, input_size:int=100, in_channels:int=3, patch_size:tuple=(1,100), emb_size:int=100): | |
super().__init__() | |
self.projection = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=(1,1)), | |
Rearrange("b e h w -> b (h w) e"), | |
) | |
self.cls_token = nn.Parameter(torch.randn(1,1,emb_size)/emb_size) | |
self.positions = nn.Parameter(torch.randn((input_size**2 // (patch_size[0]*patch_size[1]) + 1,emb_size)) / emb_size) | |
def forward(self,x:Tensor): | |
b, c, h, w = x.shape | |
x = self.projection(x) | |
cls_tokens = repeat(self.cls_token,"() n e -> b n e",b=b) | |
x = torch.cat([cls_tokens, x], dim=1) | |
x += self.positions # 这里区别于 class token 直接用广播机制相加,原因是位置编码应该在 batch 内的不同样本也保持一致 | |
# print(x.shape) | |
return x |