前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >带有coverage机制的PGN模型架构

带有coverage机制的PGN模型架构

原创
作者头像
@小森
发布2025-01-23 11:24:07
发布2025-01-23 11:24:07
450
举报

在生成摘要时,我们可能会遇到重复生成某些词或短语的问题。coverage机制就是为了解决这个问题而设计的,它通过记录已经关注过的源文本部分,来避免重复关注和生成。

首先看模型的整体架构设计

代码语言:txt
复制
class PGN(nn.Module):
    def __init__(self, v):
        super(PGN, self).__init__()
        self.v = v  # 词表
        self.DEVICE = d10_config.DEVICE
        
        # 核心组件
        self.attention = Attention(d10_config.hidden_size)
        self.encoder = Encoder(len(v), d10_config.embed_size, d10_config.hidden_size)
        self.decoder = Decoder(len(v), d10_config.embed_size, d10_config.hidden_size)
        self.reduce_state = ReduceState()

这个架构包含了4个主要组件:编码器、注意力机制、解码器和状态降维组件。每个组件都有其特定的职责。

coverage机制的核心实现在注意力模块中:

代码语言:txt
复制
class Attention(nn.Module):
    def __init__(self, hidden_units):
        # ...
        # coverage向量的转换层
        self.wc = nn.Linear(1, 2 * hidden_units, bias=False)

    def forward(self, decoder_states, encoder_output, x_padding_masks, coverage_vector):
        # ...
        if d10_config.coverage:
            coverage_features = self.wc(coverage_vector.unsqueeze(2))
            attn_inputs = attn_inputs + coverage_features

这里的关键点是:

  • coverage_vector记录了历史上对每个输入位置的累积注意力
  • 通过self.wc层将coverage信息转换到合适的维度
  • 将coverage特征加入到注意力计算中

在PGN的前向传播中,coverage的处理:

代码语言:txt
复制
def forward(self, x, x_len, y, len_oovs, batch, num_batches, teacher_forcing):
    # 初始化coverage向量为全0
    coverage_vector = torch.zeros(x.size(), dtype=torch.float32).to(self.DEVICE)
    
    for t in range(y.shape[1] - 1):
        # 计算注意力时传入coverage_vector
        context_vector, attention_weights, next_coverage_vector = self.attention(
            decoder_states, encoder_output, x_padding_masks, coverage_vector)
        
        # 计算coverage loss
        if d10_config.coverage:
            ct_min = torch.min(attention_weights, coverage_vector)
            cov_loss = torch.sum(ct_min, dim=1)
            loss = loss + d10_config.LAMBDA * cov_loss
            coverage_vector = next_coverage_vector
  • coverage_vector初始化为0,表示开始时没有关注任何位置
  • 每个解码步骤都会更新coverage_vector
  • coverage loss的计算使用了注意力权重和coverage向量的最小值
  • 通过LAMBDA超参数来平衡主损失和承保损失

具体分析:

代码语言:txt
复制
ct_min = torch.min(attention_weights, coverage_vector)

这一行在比较当前时间步的注意力权重()和历史累积的注意力()。attention_weightscoverage_vector

  • attention_weights表示当前时间步模型对源文本各个位置的注意力分布
  • coverage_vector记录了历史上对每个位置的累积注意力
  • torch.min()取两者中的较小值,这样做的原因是要找出重复注意的部分
代码语言:txt
复制
cov_loss = torch.sum(ct_min, dim=1)

将最小值加和,得到 coverage loss。这个 loss 反映了重复注意的程度:

  • 如果一个位置被重复注意,那么 和 都会有较大的值attention_weightscoverage_vector
  • 取最小值后的加和就反映了总体的重复注意程度
代码语言:txt
复制
loss = loss + d10_config.LAMBDA * cov_loss

这里使用一个权重系数 来平衡原始损失和 coverage loss:LAMBDA

  • 原始的 loss 主要关注生成质量
  • coverage loss 惩罚重复注意
  • LAMBDA控制两种损失的相对重要性
代码语言:txt
复制
coverage_vector = next_coverage_vector

用新的 coverage vector 更新历史记录:

  • next_coverage_vector包含了当前时间步的注意力信息
  • 这个更新确保了我们能够追踪累积的注意力分

让我用一个具体的例子来说明:

假设我们有一个源句子 "小明喜欢吃苹果":

  1. 一开始 全是 0,表示没有注意过任何位置coverage_vector
  2. 第一次生成时,模型可能主要注意 "小明", 在对应位置有较大值attention_weights
  3. 这些值会被记录到 中coverage_vector
  4. 如果后面模型又想去注意 “小明”,由于在该位置已经有值,就会产生较大的 coverage losscoverage_vector
  5. 这样就能抑制模型重复关注和生成同样的内容

这个机制的巧妙之处在于:

  • 它通过累积注意力来追踪已经使用过的信息
  • 使用最小值操作来准确捕捉重复注意的程度
  • 通过 loss 项来引导模型避免重复

完整代码:

代码语言:txt
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class Config:
    """Configuration class to store model parameters"""
    def __init__(self):
        self.vocab_size = 50000
        self.embed_size = 256
        self.hidden_size = 256
        self.max_enc_steps = 400
        self.max_dec_steps = 100
        self.batch_size = 8
        self.beam_size = 4
        self.min_dec_steps = 35
        self.coverage = True  # Enable coverage mechanism
        self.LAMBDA = 1.0    # Coverage loss weight
        self.eps = 1e-12     # Small constant for numerical stability
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()

class Encoder(nn.Module):
    """Bidirectional LSTM encoder"""
    def __init__(self, vocab_size, embed_size, hidden_size, dropout=0.0):
        super(Encoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(
            embed_size, 
            hidden_size, 
            num_layers=1,
            batch_first=True,
            bidirectional=True,
            dropout=dropout
        )

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len]
        Returns:
            encoder_outputs: Output features from the encoder [batch_size, seq_len, hidden_size*2]
            encoder_hidden: Final hidden states (h_n, c_n)
        """
        embedded = self.embedding(x)  # [batch_size, seq_len, embed_size]
        encoder_outputs, encoder_hidden = self.lstm(embedded)
        return encoder_outputs, encoder_hidden

class Attention(nn.Module):
    """Attention module with coverage mechanism"""
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        # Transform encoder output features
        self.W_h = nn.Linear(hidden_size * 2, hidden_size * 2, bias=False)
        # Transform decoder state
        self.W_s = nn.Linear(hidden_size * 2, hidden_size * 2)
        # Transform coverage vector
        self.W_c = nn.Linear(1, hidden_size * 2, bias=False)
        # Convert attention inputs to scores
        self.v = nn.Linear(hidden_size * 2, 1, bias=False)

    def forward(self, decoder_state, encoder_outputs, encoder_mask, coverage):
        """
        Args:
            decoder_state: Current decoder hidden state [batch_size, hidden_size*2]
            encoder_outputs: Encoder output features [batch_size, seq_len, hidden_size*2]
            encoder_mask: Mask for padding [batch_size, seq_len]
            coverage: Coverage vector [batch_size, seq_len]
        Returns:
            context_vector: Weighted sum of encoder outputs
            attention_weights: Attention distribution
            coverage: Updated coverage vector
        """
        batch_size, seq_len, _ = encoder_outputs.size()

        # Expand decoder state for attention calculation
        decoder_state = decoder_state.unsqueeze(1).expand_as(encoder_outputs)

        # Calculate attention features
        encoder_features = self.W_h(encoder_outputs)
        decoder_features = self.W_s(decoder_state)
        
        # Add coverage features if enabled
        if config.coverage:
            coverage_features = self.W_c(coverage.unsqueeze(2))
            attention_features = encoder_features + decoder_features + coverage_features
        else:
            attention_features = encoder_features + decoder_features

        # Calculate attention scores
        attention_scores = self.v(torch.tanh(attention_features)).squeeze(2)

        # Mask padding tokens
        attention_scores = attention_scores.masked_fill(
            encoder_mask.bool(), float('-inf')
        )

        # Calculate attention weights
        attention_weights = F.softmax(attention_scores, dim=1)

        # Update coverage vector
        if config.coverage:
            coverage = coverage + attention_weights

        # Calculate context vector
        context_vector = torch.bmm(
            attention_weights.unsqueeze(1), 
            encoder_outputs
        ).squeeze(1)

        return context_vector, attention_weights, coverage

class Decoder(nn.Module):
    """LSTM decoder with attention and copy mechanism"""
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(Decoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(
            embed_size + hidden_size * 2,  # Input: embedding + context vector
            hidden_size,
            num_layers=1,
            batch_first=True
        )
        
        # Output projection
        self.out = nn.Linear(hidden_size * 3, vocab_size)
        
        # Pointer generator probability
        self.p_gen_linear = nn.Linear(hidden_size * 4 + embed_size, 1)

    def forward(self, input_token, last_hidden, last_context, encoder_outputs, 
                encoder_mask, coverage):
        """Single step decoding"""
        # Word embedding
        embedded = self.embedding(input_token)  # [batch_size, 1, embed_size]
        
        # Concatenate with last context vector
        lstm_input = torch.cat([embedded, last_context.unsqueeze(1)], dim=2)
        
        # LSTM step
        lstm_output, hidden = self.lstm(lstm_input, last_hidden)
        
        # Calculate attention
        attention = Attention(config.hidden_size)
        context_vector, attention_weights, coverage = attention(
            lstm_output.squeeze(1),
            encoder_outputs,
            encoder_mask,
            coverage
        )
        
        # Calculate vocabulary distribution
        concat_output = torch.cat([
            lstm_output.squeeze(1),
            context_vector
        ], dim=1)
        
        vocab_dist = F.softmax(self.out(concat_output), dim=1)
        
        # Calculate p_gen
        p_gen_input = torch.cat([
            context_vector,
            lstm_output.squeeze(1),
            embedded.squeeze(1)
        ], dim=1)
        
        p_gen = torch.sigmoid(self.p_gen_linear(p_gen_input))
        
        return vocab_dist, hidden, context_vector, attention_weights, p_gen, coverage

class PGN(nn.Module):
    """Complete Pointer-Generator Network with Coverage"""
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(PGN, self).__init__()
        
        self.encoder = Encoder(vocab_size, embed_size, hidden_size)
        self.decoder = Decoder(vocab_size, embed_size, hidden_size)
        
    def forward(self, src, tgt, src_mask, max_output_length=None):
        """
        Forward pass for training
        Args:
            src: Source text tensor [batch_size, src_len]
            tgt: Target text tensor [batch_size, tgt_len]
            src_mask: Source padding mask [batch_size, src_len]
            max_output_length: Maximum output sequence length
        """
        batch_size = src.size(0)
        
        # Initialize coverage vector
        coverage = torch.zeros_like(src_mask, dtype=torch.float)
        
        # Encode source sequence
        encoder_outputs, encoder_hidden = self.encoder(src)
        
        # Initialize decoder input with SOS token
        decoder_input = torch.ones(batch_size, 1).long() * config.SOS_idx
        decoder_hidden = encoder_hidden
        last_context = torch.zeros(batch_size, config.hidden_size * 2)
        
        # Store outputs for loss calculation
        outputs = []
        attentions = []
        coverages = []
        
        # Teacher forcing: use target as next input
        for t in range(max_output_length):
            vocab_dist, decoder_hidden, context_vector, attention_weights, p_gen, coverage = \
                self.decoder(
                    decoder_input,
                    decoder_hidden,
                    last_context,
                    encoder_outputs,
                    src_mask,
                    coverage
                )
                
            outputs.append(vocab_dist)
            attentions.append(attention_weights)
            coverages.append(coverage)
            
            # Next input is current target token (teacher forcing)
            decoder_input = tgt[:, t:t+1]
            last_context = context_vector
            
        # Stack all outputs
        outputs = torch.stack(outputs, dim=1)
        attentions = torch.stack(attentions, dim=1)
        coverages = torch.stack(coverages, dim=1)
        
        return outputs, attentions, coverages
        
    def generate(self, src, src_mask, max_length=None):
        """Generate output sequence (inference)"""
        if max_length is None:
            max_length = config.max_dec_steps
            
        batch_size = src.size(0)
        
        # Initialize coverage vector
        coverage = torch.zeros_like(src_mask, dtype=torch.float)
        
        # Encode source sequence
        encoder_outputs, encoder_hidden = self.encoder(src)
        
        # Initialize decoder
        decoder_input = torch.ones(batch_size, 1).long() * config.SOS_idx
        decoder_hidden = encoder_hidden
        last_context = torch.zeros(batch_size, config.hidden_size * 2)
        
        generated_tokens = []
        attention_weights_list = []
        
        for _ in range(max_length):
            vocab_dist, decoder_hidden, context_vector, attention_weights, p_gen, coverage = \
                self.decoder(
                    decoder_input,
                    decoder_hidden,
                    last_context,
                    encoder_outputs,
                    src_mask,
                    coverage
                )
                
            # Select next token (greedy decoding)
            decoder_input = vocab_dist.argmax(dim=1).unsqueeze(1)
            last_context = context_vector
            
            generated_tokens.append(decoder_input.squeeze())
            attention_weights_list.append(attention_weights)
            
            # Early stopping if EOS token is generated
            if decoder_input.item() == config.EOS_idx:
                break
                
        return torch.stack(generated_tokens, dim=1), torch.stack(attention_weights_list, dim=1)

def train_step(model, optimizer, src_batch, tgt_batch, src_mask):
    """Single training step"""
    optimizer.zero_grad()
    
    # Forward pass
    outputs, attentions, coverages = model(src_batch, tgt_batch, src_mask)
    
    # Calculate loss
    nll_loss = F.nll_loss(outputs.view(-1, config.vocab_size), 
                         tgt_batch.view(-1),
                         ignore_index=config.PAD_idx)
    
    # Calculate coverage loss
    if config.coverage:
        coverage_loss = torch.sum(torch.min(attentions, coverages)) / src_batch.size(0)
        total_loss = nll_loss + config.LAMBDA * coverage_loss
    else:
        total_loss = nll_loss
    
    # Backward pass
    total_loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
    
    optimizer.step()
    
    return total_loss.item()

# Example usage:
def main():
    # Create model instance
    model = PGN(
        vocab_size=config.vocab_size,
        embed_size=config.embed_size,
        hidden_size=config.hidden_size
    ).to(config.device)
    
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters())
    
    # Training loop would go here
    # ...

if __name__ == "__main__":
    main()

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档