前言
本文提出了用於低分辨率圖像分割的MaskAttn - UNet框架,並將其核心的掩碼注意力機制集成到YOLOv11中。傳統U - Net類模型難以捕捉全局關聯,Transformer類模型計算量大,而掩碼注意力機制通過可學習的掩碼,讓模型選擇性關注重要區域,融合了卷積的局部效率和注意力的全局視野。其工作流程包括特徵適配、掩碼生成、定向注意力計算和特徵融合。我們將掩碼注意力機制代碼集成到YOLOv11中。
介紹
摘要
低分辨率圖像分割在機器人技術、增強現實和大規模場景理解等實際應用中至關重要。在這些場景中,由於計算資源限制,高分辨率數據往往難以獲取。為解決這一挑戰,我們提出了一種新穎的分割框架MaskAttn-UNet,它通過掩碼注意力機制對傳統UNet架構進行了優化。該模型能夠選擇性地突出重要區域,同時抑制無關背景,從而在雜亂複雜場景中提升分割精度。與傳統UNet變體不同,MaskAttn-UNet有效平衡了局部特徵提取與全局上下文感知能力,使其特別適用於低分辨率輸入場景。我們在三個基準數據集上對該方法進行了評估,所有輸入圖像均調整為128×128分辨率,結果表明其在語義分割、實例分割和全景分割任務中均展現出具有競爭力的性能。實驗結果顯示,MaskAttn-UNet的精度可與當前最先進方法媲美,且計算成本遠低於基於Transformer的模型,為資源受限場景下的低分辨率圖像分割提供了高效且可擴展的解決方案。
基本原理
掩碼注意力模塊是MaskAttn-UNet模型的核心創新組件,核心目標是在低分辨率圖像分割場景中,高效平衡“局部細節捕捉”與“全局關聯建模”,同時避免傳統注意力機制的算力浪費,其原理可從核心設計邏輯、工作流程、關鍵特性三方面展開:
一、核心設計邏輯
該模塊的核心思路是“選擇性關注”——不像傳統自注意力機制那樣對圖像中所有像素進行無差別全局計算,也不像純卷積那樣侷限於局部區域,而是通過一個“可學習的掩碼”(類似智能篩選器),讓模型自動聚焦於對分割任務有用的區域(如物體輪廓、關鍵結構、前景目標),同時抑制無意義的背景噪音或冗餘信息。
其設計初衷是解決兩大痛點:
- 傳統U-Net類模型:依賴卷積的局部性,難以捕捉圖像中遠距離物體的關聯(如重疊物體、分散目標的整體特徵),導致複雜場景分割模糊;
- Transformer類模型:全局自注意力計算量大(像素間兩兩匹配),內存和算力消耗極高,不適合低分辨率、資源受限的實際場景。
因此,掩碼注意力模塊本質是“卷積的局部效率”與“注意力的全局視野”的融合——用掩碼篩選關鍵區域,只在有用區域內進行注意力計算,實現“精準且高效”的特徵提取。
二、完整工作流程
模塊的工作過程可拆解為4個關鍵步驟,全程圍繞“篩選-計算-融合-優化”展開:
- 特徵格式適配:先接收來自編碼器或解碼器的特徵圖(包含圖像的局部細節和初步語義信息),並調整其格式,使其適配後續注意力計算的需求;
- 掩碼生成與篩選:自動學習一個二進制掩碼(可理解為一張“關注地圖”),地圖上的“高亮區域”對應圖像中需要重點關注的部分(如物體邊緣、前景目標),“暗區”對應無關背景。這個掩碼是動態學習的,會根據不同圖像、不同場景自適應調整,而非固定規則;
- 定向注意力計算:採用多頭注意力機制(共4個注意力頭,相當於從多個角度捕捉特徵),但僅在掩碼篩選後的“高亮區域”內計算像素間的關聯。比如,對於低分辨率圖像中的小物體,掩碼會聚焦於該物體的像素範圍,讓這些像素間相互傳遞信息,從而強化物體的整體特徵,同時忽略背景像素的無效關聯;
- 特徵融合與優化:將注意力計算後的特徵,與原始輸入的特徵通過“殘差連接”融合(保留初始的局部細節),再經過兩層前饋網絡進一步優化特徵質量,最終輸出“既包含局部精準細節,又融入全局關鍵關聯”的增強特徵。
- 魯棒性強:掩碼能有效抑制背景噪音,在複雜場景(如 clutter 雜亂環境、重疊物體、光線變化)中,仍能精準區分前景目標與背景,提升分割的穩定性。
核心代碼
class Mask2FormerAttention(nn.Module):
def __init__(self, channels, size):
super(Mask2FormerAttention, self).__init__()
self.channels = channels
self.size = size
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.mask = None
self.norm = nn.LayerNorm([channels])
def forward(self, x):
batch_size, channels, height, width = x.size()
if channels != self.channels:
raise ValueError("Input channel size does not match initialized channel size.")
x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
Q = self.query(x)
K = self.key(x)
V = self.value(x)
scores = torch.matmul(Q, K.transpose(-2, -1))
scores = scores / (self.channels ** 0.5)
if self.mask is None or self.mask.size(-1) != height * width:
binary_mask = torch.randint(0, 2, (batch_size, height, width), device=x.device)
binary_mask = binary_mask.view(batch_size, -1)
processed_mask = torch.where(binary_mask > 0.5, torch.tensor(0.0, device=x.device), torch.tensor(-float('inf'), device=x.device))
self.mask = processed_mask.unsqueeze(1).expand(-1, height * width, -1)
scores = scores + self.mask
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
attention_output = attention_output + x
attention_output = self.norm(attention_output)
return attention_output.view(batch_size, channels, height, width)