掩码生成 (Mask Generation)¶
掩码生成是指为图像生成具有语义意义的掩码的任务。这项任务与图像分割非常相似,但有许多不同之处。图像分割模型是在标注数据集上训练的,因此它们仅限于在训练过程中见过的类别;给定一张图像,它们会返回一组掩码及其对应的类别。
掩码生成模型则是在大量数据上训练的,并且以两种模式运行:
- 提示模式:在这种模式下,模型接收一张图像和一个提示,提示可以是图像中某个对象内的二维点位置(XY 坐标)或围绕对象的边界框。在提示模式下,模型只返回提示指向的对象的掩码。
- 分割一切模式:在分割一切模式下,给定一张图像,模型会生成图像中的每个掩码。为此,会在图像上生成一个点网格并进行推理。
掩码生成任务由 Segment Anything Model (SAM) 支持。这是一个强大的模型,包括基于视觉 Transformers 的图像编码器、提示编码器和双向 Transformers 掩码解码器。图像和提示被编码,解码器接收这些嵌入并生成有效的掩码。

SAM 是一个强大的基础模型,因为它覆盖了大量数据。它是在 SA-1B 数据集上训练的,该数据集包含 100 万张图像和 11 亿个掩码。
在本指南中,您将学习如何:
- 使用批量处理进行分割一切模式的推理,
- 进行点提示模式的推理,
- 进行框提示模式的推理。
首先,让我们安装 transformers:
pip install -q transformers
掩码生成管道 (Mask Generation Pipeline)¶
使用 mask-generation 管道是最简单的进行掩码生成模型推理的方法。
from transformers import pipeline
checkpoint = "facebook/sam-vit-base"
mask_generator = pipeline(model=checkpoint, task="mask-generation")
让我们看一下图像。
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

让我们进行分割一切。points-per-batch 选项启用在分割一切模式下的并行推理,这可以加快推理解析速度,但会消耗更多内存。此外,SAM 只能在点上进行批处理,而不能在图像上进行批处理。pred_iou_thresh 是置信度阈值,只有高于该阈值的掩码才会被返回。
masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
masks 的内容如下:
{
'masks': [
array([
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]
]),
array([
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...
),
'scores': tensor([
0.9972, 0.9917,
...,
])
}
我们可以这样可视化它们:
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
for i, mask in enumerate(masks["masks"]):
plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)
plt.axis('off')
plt.show()
from transformers import SamModel, SamProcessor
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
要进行点提示,将输入点传递给处理器,然后将处理器的输出传递给模型进行推理。为了处理模型的输出,需要将处理器的初始输出中的 original_sizes 和 reshaped_input_sizes 传递进来,因为处理器会调整图像大小,输出需要进行外推。
input_points = [[[2592, 1728]]] # 蜜蜂的位置
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
我们可以可视化 masks 输出中的三个掩码。
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('原始图像')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]
for i, mask in enumerate(mask_list, start=1):
overlayed_image = np.array(image).copy()
overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
axes[i].imshow(overlayed_image)
axes[i].set_title(f'掩码 {i}')
for ax in axes:
ax.axis('off')
plt.show()

框提示 (Box Prompting)¶
您也可以像点提示一样进行框提示。只需将输入框以 [x_min, y_min, x_max, y_max] 格式与图像一起传递给 processor。将处理器的输出直接传递给模型,然后再次处理输出。
# 围绕蜜蜂的边界框
box = [2350, 1600, 2850, 2100]
inputs = processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
您可以像下面这样可视化蜜蜂周围的边界框。
import matplotlib.patches as patches
fig, ax = plt.subplots()
ax.imshow(image)
rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()

您可以查看以下推理输出。
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)
ax.axis("off")
plt.show()

