Transformer模型解析:自然语言处理的架构演进与实践

# Transformer模型解析:自然语言处理的架构演进与实践


## Transformer架构核心设计理念


Transformer模型摒弃了传统的循环神经网络和卷积神经网络架构,完全基于自注意力机制构建,实现了并行化训练和长距离依赖的有效捕捉。其设计思想围绕三个核心组件:多头自注意力机制、前馈神经网络和残差连接与层归一化。


### 注意力机制数学原理


自注意力机制的核心在于通过查询(Query)、键(Key)和值(Value)的交互计算:


```python

import torch

import torch.nn as nn

import torch.nn.functional as F

import math


class ScaledDotProductAttention(nn.Module):

    """缩放点积注意力"""

    def __init__(self, dropout=0.1):

        super(ScaledDotProductAttention, self).__init__()

        self.dropout = nn.Dropout(dropout)

    

    def forward(self, Q, K, V, mask=None):

        """

        Q: [batch_size, n_heads, seq_len, d_k]

        K: [batch_size, n_heads, seq_len, d_k]

        V: [batch_size, n_heads, seq_len, d_v]

        mask: [batch_size, seq_len, seq_len]

        """

        d_k = Q.size(-1)

        

        # 计算注意力分数

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        

        # 应用掩码(如需要)

        if mask is not None:

            scores = scores.masked_fill(mask == 0, -1e9)

        

        # 应用softmax

        attention_weights = F.softmax(scores, dim=-1)

        attention_weights = self.dropout(attention_weights)

        

        # 计算上下文向量

        context = torch.matmul(attention_weights, V)

        

        return context, attention_weights

```


## Transformer编码器实现


### 多头注意力机制


```python

class MultiHeadAttention(nn.Module):

    """多头注意力机制"""

    def __init__(self, d_model=512, n_heads=8, dropout=0.1):

        super(MultiHeadAttention, self).__init__()

        assert d_model % n_heads == 0

        

        self.d_model = d_model

        self.n_heads = n_heads

        self.d_k = d_model // n_heads

        self.d_v = d_model // n_heads

        

        # 线性变换层

        self.W_Q = nn.Linear(d_model, d_model)

        self.W_K = nn.Linear(d_model, d_model)

        self.W_V = nn.Linear(d_model, d_model)

        self.W_O = nn.Linear(d_model, d_model)

        

        self.attention = ScaledDotProductAttention(dropout)

        self.dropout = nn.Dropout(dropout)

        self.layer_norm = nn.LayerNorm(d_model)

    

    def forward(self, Q, K, V, mask=None):

        """

        Q, K, V: [batch_size, seq_len, d_model]

        mask: [batch_size, seq_len, seq_len]

        """

        batch_size, seq_len, _ = Q.size()

        

        # 残差连接

        residual = Q

        

        # 线性投影并分割多头

        Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        

        # 应用缩放点积注意力

        if mask is not None:

            mask = mask.unsqueeze(1)  # 扩展到多头维度

        

        context, attention_weights = self.attention(Q, K, V, mask=mask)

        

        # 合并多头输出

        context = context.transpose(1, 2).contiguous().view(

            batch_size, -1, self.d_model)

        

        # 输出投影

        output = self.W_O(context)

        output = self.dropout(output)

        

        # 残差连接和层归一化

        output = self.layer_norm(output + residual)

        

        return output, attention_weights

```


### 前馈神经网络


```python

class PositionwiseFeedForward(nn.Module):

    """位置感知前馈网络"""

    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):

        super(PositionwiseFeedForward, self).__init__()

        self.linear1 = nn.Linear(d_model, d_ff)

        self.linear2 = nn.Linear(d_ff, d_model)

        self.dropout = nn.Dropout(dropout)

        self.layer_norm = nn.LayerNorm(d_model)

        self.activation = nn.GELU()  # 替代原始的ReLU

    

    def forward(self, x):

        """

        x: [batch_size, seq_len, d_model]

        """

        residual = x

        

        # 前向传播

        x = self.linear1(x)

        x = self.activation(x)

        x = self.dropout(x)

        x = self.linear2(x)

        x = self.dropout(x)

        

        # 残差连接和层归一化

        output = self.layer_norm(x + residual)

        return output

```


### 位置编码


```python

class PositionalEncoding(nn.Module):

    """正弦位置编码"""

    def __init__(self, d_model, max_len=5000, dropout=0.1):

        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        

        # 创建位置编码矩阵

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 

                           (-math.log(10000.0) / d_model))

        

        # 计算正弦和余弦编码

        pe[:, 0::2] = torch.sin(position * div_term)

        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, d_model]

        

        self.register_buffer('pe', pe)

    

    def forward(self, x):

        """

        x: [batch_size, seq_len, d_model]

        """

        x = x + self.pe[:x.size(1), :]

        return self.dropout(x)


class LearnedPositionalEncoding(nn.Module):

    """可学习的位置编码"""

    def __init__(self, d_model, max_len=512, dropout=0.1):

        super(LearnedPositionalEncoding, self).__init__()

        self.position_embeddings = nn.Embedding(max_len, d_model)

        self.dropout = nn.Dropout(dropout)

        

    def forward(self, x):

        """

        x: [batch_size, seq_len, d_model]

        """

        seq_len = x.size(1)

        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(positions)

        x = x + position_embeddings

        return self.dropout(x)

```


### 编码器层


```python

class EncoderLayer(nn.Module):

    """Transformer编码器层"""

    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):

        super(EncoderLayer, self).__init__()

        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)

        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.dropout = nn.Dropout(dropout)

    

    def forward(self, x, mask=None):

        """

        x: [batch_size, seq_len, d_model]

        mask: [batch_size, seq_len, seq_len]

        """

        # 自注意力子层

        attention_output, attention_weights = self.self_attention(x, x, x, mask)

        

        # 前馈网络子层

        output = self.feed_forward(attention_output)

        

        return output, attention_weights

```


### 完整编码器


```python

class TransformerEncoder(nn.Module):

    """Transformer编码器"""

    def __init__(self, vocab_size, d_model=512, n_layers=6, 

                 n_heads=8, d_ff=2048, max_len=512, dropout=0.1):

        super(TransformerEncoder, self).__init__()

        self.d_model = d_model

        

        # 词嵌入层

        self.token_embedding = nn.Embedding(vocab_size, d_model)

        

        # 位置编码

        self.position_encoding = PositionalEncoding(d_model, max_len, dropout)

        

        # 编码器层

        self.layers = nn.ModuleList([

            EncoderLayer(d_model, n_heads, d_ff, dropout)

            for _ in range(n_layers)

        ])

        

        # 输出层归一化

        self.layer_norm = nn.LayerNorm(d_model)

    

    def forward(self, input_ids, attention_mask=None):

        """

        input_ids: [batch_size, seq_len]

        attention_mask: [batch_size, seq_len]

        """

        batch_size, seq_len = input_ids.size()

        

        # 词嵌入

        token_embeddings = self.token_embedding(input_ids) * math.sqrt(self.d_model)

        

        # 位置编码

        embeddings = self.position_encoding(token_embeddings)

        

        # 创建注意力掩码

        if attention_mask is not None:

            # 扩展掩码维度

            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)

            attention_mask = attention_mask.float()

            attention_mask = (1.0 - attention_mask) * -10000.0

        else:

            attention_mask = None

        

        # 通过编码器层

        hidden_states = embeddings

        all_attention_weights = []

        

        for layer in self.layers:

            hidden_states, attention_weights = layer(hidden_states, attention_mask)

            all_attention_weights.append(attention_weights)

        

        # 层归一化

        hidden_states = self.layer_norm(hidden_states)

        

        return hidden_states, all_attention_weights

```


## Transformer解码器实现


### 解码器层


```python

class DecoderLayer(nn.Module):

    """Transformer解码器层"""

    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):

        super(DecoderLayer, self).__init__()

        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)

        self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)

        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

    

    def forward(self, x, encoder_output, 

                self_attention_mask=None, cross_attention_mask=None):

        """

        x: [batch_size, tgt_seq_len, d_model]

        encoder_output: [batch_size, src_seq_len, d_model]

        """

        # 掩码自注意力

        self_attention_output, self_attn_weights = self.self_attention(

            x, x, x, self_attention_mask)

        

        # 交叉注意力

        cross_attention_output, cross_attn_weights = self.cross_attention(

            self_attention_output, encoder_output, encoder_output, 

            cross_attention_mask)

        

        # 前馈网络

        output = self.feed_forward(cross_attention_output)

        

        return output, self_attn_weights, cross_attn_weights

```


### 完整解码器


```python

class TransformerDecoder(nn.Module):

    """Transformer解码器"""

    def __init__(self, vocab_size, d_model=512, n_layers=6, 

                 n_heads=8, d_ff=2048, max_len=512, dropout=0.1):

        super(TransformerDecoder, self).__init__()

        self.d_model = d_model

        

        # 词嵌入层

        self.token_embedding = nn.Embedding(vocab_size, d_model)

        

        # 位置编码

        self.position_encoding = PositionalEncoding(d_model, max_len, dropout)

        

        # 解码器层

        self.layers = nn.ModuleList([

            DecoderLayer(d_model, n_heads, d_ff, dropout)

            for _ in range(n_layers)

        ])

        

        # 输出层归一化

        self.layer_norm = nn.LayerNorm(d_model)

        

        # 输出投影

        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)

        

        # 权重共享

        self.output_projection.weight = self.token_embedding.weight

    

    def create_causal_mask(self, seq_len, device):

        """创建因果掩码(用于自回归生成)"""

        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)

        mask = mask.masked_fill(mask == 1, float('-inf'))

        return mask

    

    def forward(self, input_ids, encoder_output, encoder_mask=None):

        """

        input_ids: [batch_size, tgt_seq_len]

        encoder_output: [batch_size, src_seq_len, d_model]

        """

        batch_size, tgt_seq_len = input_ids.size()

        

        # 词嵌入

        token_embeddings = self.token_embedding(input_ids) * math.sqrt(self.d_model)

        

        # 位置编码

        embeddings = self.position_encoding(token_embeddings)

        

        # 创建因果掩码

        causal_mask = self.create_causal_mask(tgt_seq_len, input_ids.device)

        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

        

        # 通过解码器层

        hidden_states = embeddings

        all_self_attention_weights = []

        all_cross_attention_weights = []

        

        for layer in self.layers:

            hidden_states, self_attn_weights, cross_attn_weights = layer(

                hidden_states, encoder_output, 

                self_attention_mask=causal_mask,

                cross_attention_mask=encoder_mask

            )

            all_self_attention_weights.append(self_attn_weights)

            all_cross_attention_weights.append(cross_attn_weights)

        

        # 层归一化

        hidden_states = self.layer_norm(hidden_states)

        

        # 输出投影

        logits = self.output_projection(hidden_states)

        

        return logits, all_self_attention_weights, all_cross_attention_weights

<"4l.zhaiLimao.com"><"8s.yunruiwater.cn"><"2f.sxyicheng.cn">

```


## 完整Transformer模型


```python

class Transformer(nn.Module):

    """完整Transformer模型(编码器-解码器架构)"""

    def __init__(self, src_vocab_size, tgt_vocab_size, 

                 d_model=512, n_layers=6, n_heads=8, 

                 d_ff=2048, max_len=512, dropout=0.1):

        super(Transformer, self).__init__()

        

        # 编码器

        self.encoder = TransformerEncoder(

            src_vocab_size, d_model, n_layers, 

            n_heads, d_ff, max_len, dropout

        )

        

        # 解码器

        self.decoder = TransformerDecoder(

            tgt_vocab_size, d_model, n_layers,

            n_heads, d_ff, max_len, dropout

        )

        

        # 初始化参数

        self._init_parameters()

    

    def _init_parameters(self):

        """初始化模型参数"""

        for p in self.parameters():

            if p.dim() > 1:

                nn.init.xavier_uniform_(p)

    

    def encode(self, src_input, src_mask=None):

        """编码阶段"""

        return self.encoder(src_input, src_mask)

    

    def decode(self, tgt_input, encoder_output, encoder_mask=None):

        """解码阶段"""

        return self.decoder(tgt_input, encoder_output, encoder_mask)

    

    def forward(self, src_input, tgt_input, 

                src_mask=None, tgt_mask=None):

        """

        完整的前向传播

        src_input: [batch_size, src_seq_len]

        tgt_input: [batch_size, tgt_seq_len]

        """

        # 编码

        encoder_output, encoder_attention_weights = self.encode(src_input, src_mask)

        

        # 解码

        logits, decoder_self_attention_weights, decoder_cross_attention_weights = \

            self.decode(tgt_input, encoder_output, encoder_mask)

        

        return {

            'logits': logits,

            'encoder_output': encoder_output,

            'encoder_attention': encoder_attention_weights,

            'decoder_self_attention': decoder_self_attention_weights,

            'decoder_cross_attention': decoder_cross_attention_weights

        }

```


## 训练与优化技巧


```python

class TransformerTrainer:

    """Transformer训练器"""

    def __init__(self, model, learning_rate=5e-4, warmup_steps=4000, label_smoothing=0.1):

        self.model = model

        self.device = next(model.parameters()).device

        

        # 损失函数(带标签平滑)

        self.criterion = LabelSmoothingLoss(

            model.decoder.output_projection.out_features, 

            label_smoothing=label_smoothing

        )

        

        # 优化器

        self.optimizer = torch.optim.Adam(

            model.parameters(), 

            lr=learning_rate,

            betas=(0.9, 0.98),

            eps=1e-9

        )

        

        # 学习率调度器(带warmup)

        self.scheduler = self.get_cosine_schedule_with_warmup(

            self.optimizer, warmup_steps

        )

    

    def get_cosine_schedule_with_warmup(self, optimizer, warmup_steps):

        """带warmup的余弦学习率调度"""

        def lr_lambda(current_step):

            if current_step < warmup_steps:

                return float(current_step) / float(max(1, warmup_steps))

            

            progress = float(current_step - warmup_steps) / float(max(1, 100000 - warmup_steps))

            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    

    def train_step(self, src_batch, tgt_batch, src_mask=None, tgt_mask=None):

        """单个训练步骤"""

        self.model.train()

        

        # 前向传播

        outputs = self.model(src_batch, tgt_batch[:, :-1], src_mask, tgt_mask)

        logits = outputs['logits']

        

        # 计算损失

        loss = self.criterion(

            logits.contiguous().view(-1, logits.size(-1)),

            tgt_batch[:, 1:].contiguous().view(-1)

        )

        

        # 反向传播

        self.optimizer.zero_grad()

        loss.backward()

        

        # 梯度裁剪

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

        

        # 参数更新

        self.optimizer.step()

        self.scheduler.step()

        

        return loss.item()

    

    def generate(self, src_input, src_mask=None, max_len=50, 

                beam_size=4, temperature=1.0):

        """束搜索生成"""

        self.model.eval()

        

        with torch.no_grad():

            # 编码输入

            encoder_output, _ = self.model.encode(src_input, src_mask)

            

            # 束搜索

            sequences = self.beam_search(

                encoder_output, src_mask, 

                max_len, beam_size, temperature

            )

        

        return sequences

    

    def beam_search(self, encoder_output, encoder_mask, 

                   max_len, beam_size, temperature):

        """束搜索实现"""

        batch_size = encoder_output.size(0)

        

        # 初始化束

        beam_scores = torch.zeros(batch_size, beam_size, device=self.device)

        beam_sequences = torch.ones(batch_size, beam_size, 1, 

                                   dtype=torch.long, device=self.device)

        

        # 开始符ID(假设为1)

        start_token = torch.tensor([[1]], device=self.device).expand(batch_size, 1)

        

        for step in range(max_len):

            # 扩展束

            expanded_encoder_output = encoder_output.unsqueeze(1).expand(

                -1, beam_size, -1, -1).contiguous().view(

                batch_size * beam_size, *encoder_output.shape[1:])

            

            expanded_encoder_mask = encoder_mask.unsqueeze(1).expand(

                -1, beam_size, -1, -1).contiguous().view(

                batch_size * beam_size, *encoder_mask.shape[1:]) if encoder_mask is not None else None

            

            # 前向传播

            logits, _, _ = self.model.decode(

                beam_sequences.view(batch_size * beam_size, -1),

                expanded_encoder_output,

                expanded_encoder_mask

            )

            

            # 获取最后一步的logits

            next_token_logits = logits[:, -1, :] / temperature

            

            # 计算分数

            vocab_size = next_token_logits.size(-1)

            next_token_scores = F.log_softmax(next_token_logits, dim=-1)

            

            # 结合累积分数

            next_token_scores = next_token_scores + beam_scores.view(-1).unsqueeze(1)

            

            # 重新组织为[batch, beam * vocab]

            next_token_scores = next_token_scores.view(batch_size, beam_size * vocab_size)

            

            # 选择top-k

            topk_scores, topk_indices = torch.topk(

                next_token_scores, beam_size, dim=-1)

            

            # 确定beam和token索引

            beam_indices = topk_indices // vocab_size

            token_indices = topk_indices % vocab_size

            

            # 更新序列和分数

            new_sequences = []

            for i in range(batch_size):

                new_sequence = []

                for j in range(beam_size):

                    beam_idx = beam_indices[i, j].item()

                    token_idx = token_indices[i, j].item()

                    

                    # 获取旧序列

                    old_sequence = beam_sequences[i, beam_idx]

                    

                    # 添加新token

                    new_sequence_i = torch.cat([

                        old_sequence, 

                        torch.tensor([[token_idx]], device=self.device)

                    ], dim=-1)

                    new_sequence.append(new_sequence_i)

                

                new_sequences.append(torch.stack(new_sequence))

            

            beam_sequences = torch.stack(new_sequences)

            beam_scores = topk_scores

            

            # 检查是否全部生成结束符(假设结束符为2)

            if torch.all(beam_sequences[:, :, -1] == 2):

                break

        

        # 返回最佳序列

        best_sequences = beam_sequences[:, 0]

        return best_sequences


class LabelSmoothingLoss(nn.Module):

    """标签平滑损失"""

    def __init__(self, vocab_size, label_smoothing=0.1, ignore_index=0):

        super(LabelSmoothingLoss, self).__init__()

        self.vocab_size = vocab_size

        self.label_smoothing = label_smoothing

        self.ignore_index = ignore_index

        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    

    def forward(self, logits, targets):

        """

        logits: [batch_size * seq_len, vocab_size]

        targets: [batch_size * seq_len]

        """

        # 创建平滑标签

        confidence = 1.0 - self.label_smoothing

        smoothing_value = self.label_smoothing / (self.vocab_size - 2)

        

        # 创建one-hot标签

        smoothed_targets = torch.full_like(logits, smoothing_value)

        smoothed_targets.scatter_(1, targets.unsqueeze(1), confidence)

        

        # 处理忽略索引

        if self.ignore_index is not None:

            mask = targets == self.ignore_index

            smoothed_targets[mask] = 0

        

        # 计算KL散度损失

        log_probs = F.log_softmax(logits, dim=-1)

        loss = self.kl_div(log_probs, smoothed_targets)

        

        return loss

```


## BERT变体:掩码语言模型


```python

class BERTConfig:

    """BERT配置类"""

    def __init__(self):

        self.vocab_size = 30522

        self.hidden_size = 768

        self.num_hidden_layers = 12

        self.num_attention_heads = 12

        self.intermediate_size = 3072

        self.hidden_dropout_prob = 0.1

        self.attention_probs_dropout_prob = 0.1

        self.max_position_embeddings = 512

        self.type_vocab_size = 2

        self.initializer_range = 0.02


class BERTEmbeddings(nn.Module):

    """BERT嵌入层"""

    def __init__(self, config):

        super(BERTEmbeddings, self).__init__()

        self.word_embeddings = nn.Embedding(

            config.vocab_size, config.hidden_size)

        self.position_embeddings = nn.Embedding(

            config.max_position_embeddings, config.hidden_size)

        self.token_type_embeddings = nn.Embedding(

            config.type_vocab_size, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    

    def forward(self, input_ids, token_type_ids=None, position_ids=None):

        seq_length = input_ids.size(1)

        

        if position_ids is None:

            position_ids = torch.arange(

                seq_length, dtype=torch.long, device=input_ids.device)

            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        

        if token_type_ids is None:

            token_type_ids = torch.zeros_like(input_ids)

        

        words_embeddings = self.word_embeddings(input_ids)

        position_embeddings = self.position_embeddings(position_ids)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        

        embeddings = words_embeddings + position_embeddings + token_type_embeddings

        embeddings = self.LayerNorm(embeddings)

        embeddings = self.dropout(embeddings)

        return embeddings


class BERTLayer(nn.Module):

    """BERT层(基于Transformer编码器层)"""

    def __init__(self, config):

        super(BERTLayer, self).__init__()

        self.attention = MultiHeadAttention(

            d_model=config.hidden_size,

            n_heads=config.num_attention_heads,

            dropout=config.attention_probs_dropout_prob

        )

        self.intermediate = nn.Sequential(

            nn.Linear(config.hidden_size, config.intermediate_size),

            nn.GELU(),

            nn.Linear(config.intermediate_size, config.hidden_size),

            nn.Dropout(config.hidden_dropout_prob)

        )

        self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)

    

    def forward(self, hidden_states, attention_mask):

        # 自注意力

        attention_output, attention_weights = self.attention(

            hidden_states, hidden_states, hidden_states, attention_mask)

        

        # 前馈网络

        intermediate_output = self.intermediate(attention_output)

        

        # 残差连接和层归一化

        layer_output = self.output_layer_norm(intermediate_output + attention_output)

        

        return layer_output, attention_weights


class BERTModel(nn.Module):

    """BERT模型"""

    def __init__(self, config):

        super(BERTModel, self).__init__()

        self.embeddings = BERTEmbeddings(config)

<"5p.jsnjz.cn"><"9q.csxthr.com"><"3k.zhaiLimao.com">

        self.encoder = nn.ModuleList([

            BERTLayer(config) for _ in range(config.num_hidden_layers)

        ])

        self.pooler = nn.Linear(config.hidden_size, config.hidden_size)

        self.tanh = nn.Tanh()

        

        # 初始化权重

        self.apply(self._init_weights)

    

    def _init_weights(self, module):

        if isinstance(module, nn.Linear):

            module.weight.data.normal_(mean=0.0, std=0.02)

            if module.bias is not None:

                module.bias.data.zero_()

        elif isinstance(module, nn.Embedding):

            module.weight.data.normal_(mean=0.0, std=0.02)

            if module.padding_idx is not None:

                module.weight.data[module.padding_idx].zero_()

        elif isinstance(module, nn.LayerNorm):

            module.bias.data.zero_()

            module.weight.data.fill_(1.0)

    

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):

        # 嵌入层

        embedding_output = self.embeddings(input_ids, token_type_ids)

        

        # 扩展注意力掩码

        if attention_mask is not None:

            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            extended_attention_mask = extended_attention_mask.to(

                dtype=next(self.parameters()).dtype)

            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        else:

            extended_attention_mask = None

        

        # 编码器层

        hidden_states = embedding_output

        all_attention_weights = []

        

        for layer in self.encoder:

            hidden_states, attention_weights = layer(

                hidden_states, extended_attention_mask)

            all_attention_weights.append(attention_weights)

        

        # 池化层(取[CLS]标记的输出)

        pooled_output = self.tanh(self.pooler(hidden_states[:, 0]))

        

        return {

            'sequence_output': hidden_states,

            'pooled_output': pooled_output,

            'attention_weights': all_attention_weights

        }

```


## 模型压缩与优化


```python

class TransformerCompression:

    """Transformer模型压缩技术"""

    

    @staticmethod

    def prune_attention_heads(model, head_importance, prune_ratio=0.3):

        """剪枝注意力头"""

        num_heads_to_prune = int(model.config.num_attention_heads * prune_ratio)

        

        for layer_idx, layer in enumerate(model.encoder.layer):

            # 根据重要性排序注意力头

            sorted_heads = torch.argsort(head_importance[layer_idx])

            heads_to_prune = sorted_heads[:num_heads_to_prune]

            

            # 执行剪枝

            layer.attention.prune_heads(heads_to_prune)

        

        return model

    

    @staticmethod

    def quantize_model(model, bits=8):

        """量化模型参数"""

        quantized_model = torch.quantization.quantize_dynamic(

            model,

            {nn.Linear},

            dtype=torch.qint8

        )

        return quantized_model

    

    @staticmethod

    def knowledge_distillation(teacher_model, student_model, 

                              temperature=3.0, alpha=0.5):

        """知识蒸馏"""

        def distillation_loss(student_logits, teacher_logits, labels):

            # 计算蒸馏损失

            soft_targets = F.softmax(teacher_logits / temperature, dim=-1)

            soft_prob = F.log_softmax(student_logits / temperature, dim=-1)

            

            distillation_loss = F.kl_div(

                soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)

            

            # 计算学生损失

            student_loss = F.cross_entropy(student_logits, labels)

            

            # 结合损失

            combined_loss = alpha * distillation_loss + (1 - alpha) * student_loss

            return combined_loss

        

        return distillation_loss

```


## 可视化与分析工具


```python

class TransformerVisualizer:

    """Transformer可视化工具"""

    

    @staticmethod

    def visualize_attention(attention_weights, tokens=None, 

                          layer_idx=0, head_idx=0):

        """可视化注意力权重"""

        import matplotlib.pyplot as plt

        import seaborn as sns

        

        # 获取特定层和头的注意力权重

        attention = attention_weights[layer_idx][head_idx].cpu().numpy()

        

        plt.figure(figsize=(10, 8))

        sns.heatmap(attention, cmap='viridis', 

                   xticklabels=tokens, yticklabels=tokens)

        plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')

        plt.xlabel('Key Tokens')

        plt.ylabel('Query Tokens')

        plt.tight_layout()

        plt.show()

    

    @staticmethod

    def visualize_embedding_projection(embeddings, labels=None, method='pca'):

        """可视化词嵌入"""

        from sklearn.decomposition import PCA

        from sklearn.manifold import TSNE

        

        embeddings_np = embeddings.cpu().numpy()

        

        if method == 'pca':

            reducer = PCA(n_components=2)

        elif method == 'tsne':

            reducer = TSNE(n_components=2, perplexity=30)

        

        reduced_embeddings = reducer.fit_transform(embeddings_np)

        

        plt.figure(figsize=(12, 10))

        scatter = plt.scatter(reduced_embeddings[:, 0], 

                             reduced_embeddings[:, 1], 

                             c=labels, cmap='tab20', alpha=0.6)

        

        if labels is not None:

            plt.colorbar(scatter)

        

        plt.title(f'Embedding Visualization - {method.upper()}')

        plt.xlabel('Component 1')

        plt.ylabel('Component 2')

        plt.grid(alpha=0.3)

        plt.show()

    

    @staticmethod

    def plot_training_curves(losses, accuracies):

        """绘制训练曲线"""

        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        

        # 损失曲线

        axes[0].plot(losses['train'], label='Training Loss', linewidth=2)

        axes[0].plot(losses['val'], label='Validation Loss', linewidth=2)

        axes[0].set_xlabel('Epoch')

        axes[0].set_ylabel('Loss')

        axes[0].set_title('Training and Validation Loss')

        axes[0].legend()

        axes[0].grid(alpha=0.3)

        

        # 准确率曲线

        axes[1].plot(accuracies['train'], label='Training Accuracy', linewidth=2)

        axes[1].plot(accuracies['val'], label='Validation Accuracy', linewidth=2)

        axes[1].set_xlabel('Epoch')

        axes[1].set_ylabel('Accuracy')

        axes[1].set_title('Training and Validation Accuracy')

        axes[1].legend()

        axes[1].grid(alpha=0.3)

        

        plt.tight_layout()

        plt.show()

```


Transformer架构通过其独特的自注意力机制,彻底改变了自然语言处理领域。从原始的编码器-解码器架构到BERT、GPT等变体,Transformer展现出强大的表达能力和可扩展性。随着模型规模的增长和训练策略的优化,基于Transformer的模型在各种语言任务上不断突破性能边界。


请使用浏览器的分享功能分享到微信等