如何修改 Transformer 模型¶
🤗 Transformers 库提供了一系列预训练模型和工具,用于自然语言处理、计算机视觉等多领域的应用。虽然这些模型覆盖了广泛的应用场景,但你可能会遇到一些不支持的用例。自定义模型可以解锁新的可能性,例如添加新层、修改架构或优化注意力机制。本指南将向你展示如何修改现有的Transformer模型以满足特定需求。好消息是,你不必离开Transformers框架就能进行这些修改。你可以直接在Transformers中修改模型,并仍然利用Trainer API、PreTrainedModel和高效的微调工具如PEFT等功能。
在本指南中,我们将逐步介绍如何自定义现有的Transformer模型,以满足你的需求,同时不失去生态系统的优势。
你将学习如何:
- 通过修改注意力机制来调整模型架构。
- 将低秩适应(LoRA)等技术应用于特定模型组件。
我们鼓励你贡献自己的“修改”方法,并与社区分享。
示例:修改Segment Anything Model (SAM)的注意力机制¶
Segment Anything Model (SAM) 是一个最先进的图像分割模型。在默认实现中,SAM在注意力机制中使用了一个组合的查询-键-值(qkv)投影。然而,你可能希望仅对注意力机制的特定组件进行微调,例如查询(q)和值(v)投影,以减少可训练参数的数量和计算资源的需求。
动机¶
通过将组合的qkv投影拆分为单独的q、k和v投影,你可以将LoRA(低秩适应)等技术应用于q和v投影。这种方法可以让你:
- 减少可训练参数的数量,降低计算开销。
- 通过专注于特定组件,潜在地提高性能。
- 在注意力机制中实验不同的适应策略。
实现¶
第一步:创建自定义注意力类¶
接下来,派生原始的SamVisionAttention类并修改它,使其具有单独的q、k和v投影。
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
del self.qkv
# 创建单独的 q, k, v 投影
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# 将 q, k, v 从组合投影中分离
q, k, v = state_dict[key].chunk(3, dim=0)
# 用单独的 q, k, v 投影替换
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# 标记旧的 qkv 键以便删除
keys_to_delete.append(key)
# 删除旧的 qkv 键
for key in keys_to_delete:
del state_dict[key]
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width, self.num_attention_heads, -1)).permute(0, 2, 1, 3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width, self.num_attention_heads, -1)).permute(0, 2, 1, 3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width, self.num_attention_heads, -1)).permute(0, 2, 1, 3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
解释:
- 单独投影: 移除了组合的
qkv投影,创建了单独的q、k和v线性层。 - 权重加载钩子:
_split_qkv_load_hook方法在加载模型时将预训练的qkv权重拆分为单独的q、k和v权重,确保与任何预训练模型的兼容性。 - 前向传播: 分别计算查询、键和值,注意力机制按常规方式进行。
第二步:替换原始的注意力类¶
将原始的SamVisionAttention类替换为自定义类,以便模型使用修改后的注意力机制。
from transformers import SamModel
from transformers.models.sam import modeling_sam
# 替换 modeling_sam 模块中的注意力类
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
# 加载预训练的 SAM 模型
model = SamModel.from_pretrained("facebook/sam-vit-base")
解释:
- 类替换: 通过将自定义类赋值给
modeling_sam.SamVisionAttention,模型中的所有SamVisionAttention实例都将使用修改后的版本。因此,当你调用SamModel时,它将使用新定义的SamVisionAttentionSplit。 - 模型加载: 使用
from_pretrained加载模型,并集成自定义的注意力机制。
第三步:将LoRA应用于特定投影¶
现在有了单独的q、k和v投影,你可以将LoRA应用于特定组件,例如q和v投影。
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"], # 将 LoRA 应用于 q 和 v 投影
lora_dropout=0.1,
task_type="mask-generation"
)
# 将 LoRA 应用于模型
model = get_peft_model(model, config)
解释:
- LoRA配置:
LoraConfig指定了秩r、缩放因子lora_alpha、目标模块("q"和"v")、dropout和任务类型。 - 应用LoRA:
get_peft_model函数将LoRA应用于模型中指定的模块。 - 参数减少: 通过专注于
q和v,减少了可训练参数的数量,从而加快训练速度并降低内存使用量。
第四步:验证可训练参数的数量¶
验证可训练参数的数量,看看你的修改带来了什么影响。
model.print_trainable_parameters()
预期输出:
贡献你自己的“修改”方法¶
修改预训练模型可以为研究和应用开辟新的途径。通过理解和调整像SAM这样的模型的内部机制,你可以根据需要定制它们,优化性能并尝试新的想法。
如果你已经为Transformers模型开发了自己的“修改”方法,并希望分享它们,请考虑为本文档贡献代码。
- 发起Pull Request: 直接在仓库中分享你的代码更改和改进。
- 编写文档: 提供清晰的解释和示例,说明你的修改。
- 参与社区: 通过打开问题与开发者和研究人员讨论你的想法并获得反馈。