在生成摘要时,我们可能会遇到重复生成某些词或短语的问题。coverage机制就是为了解决这个问题而设计的,它通过记录已经关注过的源文本部分,来避免重复关注和生成。
首先看模型的整体架构设计
class PGN(nn.Module):
def __init__(self, v):
super(PGN, self).__init__()
self.v = v # 词表
self.DEVICE = d10_config.DEVICE
# 核心组件
self.attention = Attention(d10_config.hidden_size)
self.encoder = Encoder(len(v), d10_config.embed_size, d10_config.hidden_size)
self.decoder = Decoder(len(v), d10_config.embed_size, d10_config.hidden_size)
self.reduce_state = ReduceState()
这个架构包含了4个主要组件:编码器、注意力机制、解码器和状态降维组件。每个组件都有其特定的职责。
coverage机制的核心实现在注意力模块中:
class Attention(nn.Module):
def __init__(self, hidden_units):
# ...
# coverage向量的转换层
self.wc = nn.Linear(1, 2 * hidden_units, bias=False)
def forward(self, decoder_states, encoder_output, x_padding_masks, coverage_vector):
# ...
if d10_config.coverage:
coverage_features = self.wc(coverage_vector.unsqueeze(2))
attn_inputs = attn_inputs + coverage_features
这里的关键点是:
在PGN的前向传播中,coverage的处理:
def forward(self, x, x_len, y, len_oovs, batch, num_batches, teacher_forcing):
# 初始化coverage向量为全0
coverage_vector = torch.zeros(x.size(), dtype=torch.float32).to(self.DEVICE)
for t in range(y.shape[1] - 1):
# 计算注意力时传入coverage_vector
context_vector, attention_weights, next_coverage_vector = self.attention(
decoder_states, encoder_output, x_padding_masks, coverage_vector)
# 计算coverage loss
if d10_config.coverage:
ct_min = torch.min(attention_weights, coverage_vector)
cov_loss = torch.sum(ct_min, dim=1)
loss = loss + d10_config.LAMBDA * cov_loss
coverage_vector = next_coverage_vector
具体分析:
ct_min = torch.min(attention_weights, coverage_vector)
这一行在比较当前时间步的注意力权重()和历史累积的注意力()。attention_weightscoverage_vector
attention_weights
表示当前时间步模型对源文本各个位置的注意力分布coverage_vector
记录了历史上对每个位置的累积注意力torch.min()
取两者中的较小值,这样做的原因是要找出重复注意的部分cov_loss = torch.sum(ct_min, dim=1)
将最小值加和,得到 coverage loss。这个 loss 反映了重复注意的程度:
attention_weightscoverage_vector
loss = loss + d10_config.LAMBDA * cov_loss
这里使用一个权重系数 来平衡原始损失和 coverage loss:LAMBDA
LAMBDA
控制两种损失的相对重要性coverage_vector = next_coverage_vector
用新的 coverage vector 更新历史记录:
next_coverage_vector
包含了当前时间步的注意力信息让我用一个具体的例子来说明:
假设我们有一个源句子 "小明喜欢吃苹果":
coverage_vector
attention_weights
coverage_vector
coverage_vector
这个机制的巧妙之处在于:
完整代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Config:
"""Configuration class to store model parameters"""
def __init__(self):
self.vocab_size = 50000
self.embed_size = 256
self.hidden_size = 256
self.max_enc_steps = 400
self.max_dec_steps = 100
self.batch_size = 8
self.beam_size = 4
self.min_dec_steps = 35
self.coverage = True # Enable coverage mechanism
self.LAMBDA = 1.0 # Coverage loss weight
self.eps = 1e-12 # Small constant for numerical stability
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config()
class Encoder(nn.Module):
"""Bidirectional LSTM encoder"""
def __init__(self, vocab_size, embed_size, hidden_size, dropout=0.0):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(
embed_size,
hidden_size,
num_layers=1,
batch_first=True,
bidirectional=True,
dropout=dropout
)
def forward(self, x):
"""
Args:
x: Input tensor of shape [batch_size, seq_len]
Returns:
encoder_outputs: Output features from the encoder [batch_size, seq_len, hidden_size*2]
encoder_hidden: Final hidden states (h_n, c_n)
"""
embedded = self.embedding(x) # [batch_size, seq_len, embed_size]
encoder_outputs, encoder_hidden = self.lstm(embedded)
return encoder_outputs, encoder_hidden
class Attention(nn.Module):
"""Attention module with coverage mechanism"""
def __init__(self, hidden_size):
super(Attention, self).__init__()
# Transform encoder output features
self.W_h = nn.Linear(hidden_size * 2, hidden_size * 2, bias=False)
# Transform decoder state
self.W_s = nn.Linear(hidden_size * 2, hidden_size * 2)
# Transform coverage vector
self.W_c = nn.Linear(1, hidden_size * 2, bias=False)
# Convert attention inputs to scores
self.v = nn.Linear(hidden_size * 2, 1, bias=False)
def forward(self, decoder_state, encoder_outputs, encoder_mask, coverage):
"""
Args:
decoder_state: Current decoder hidden state [batch_size, hidden_size*2]
encoder_outputs: Encoder output features [batch_size, seq_len, hidden_size*2]
encoder_mask: Mask for padding [batch_size, seq_len]
coverage: Coverage vector [batch_size, seq_len]
Returns:
context_vector: Weighted sum of encoder outputs
attention_weights: Attention distribution
coverage: Updated coverage vector
"""
batch_size, seq_len, _ = encoder_outputs.size()
# Expand decoder state for attention calculation
decoder_state = decoder_state.unsqueeze(1).expand_as(encoder_outputs)
# Calculate attention features
encoder_features = self.W_h(encoder_outputs)
decoder_features = self.W_s(decoder_state)
# Add coverage features if enabled
if config.coverage:
coverage_features = self.W_c(coverage.unsqueeze(2))
attention_features = encoder_features + decoder_features + coverage_features
else:
attention_features = encoder_features + decoder_features
# Calculate attention scores
attention_scores = self.v(torch.tanh(attention_features)).squeeze(2)
# Mask padding tokens
attention_scores = attention_scores.masked_fill(
encoder_mask.bool(), float('-inf')
)
# Calculate attention weights
attention_weights = F.softmax(attention_scores, dim=1)
# Update coverage vector
if config.coverage:
coverage = coverage + attention_weights
# Calculate context vector
context_vector = torch.bmm(
attention_weights.unsqueeze(1),
encoder_outputs
).squeeze(1)
return context_vector, attention_weights, coverage
class Decoder(nn.Module):
"""LSTM decoder with attention and copy mechanism"""
def __init__(self, vocab_size, embed_size, hidden_size):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(
embed_size + hidden_size * 2, # Input: embedding + context vector
hidden_size,
num_layers=1,
batch_first=True
)
# Output projection
self.out = nn.Linear(hidden_size * 3, vocab_size)
# Pointer generator probability
self.p_gen_linear = nn.Linear(hidden_size * 4 + embed_size, 1)
def forward(self, input_token, last_hidden, last_context, encoder_outputs,
encoder_mask, coverage):
"""Single step decoding"""
# Word embedding
embedded = self.embedding(input_token) # [batch_size, 1, embed_size]
# Concatenate with last context vector
lstm_input = torch.cat([embedded, last_context.unsqueeze(1)], dim=2)
# LSTM step
lstm_output, hidden = self.lstm(lstm_input, last_hidden)
# Calculate attention
attention = Attention(config.hidden_size)
context_vector, attention_weights, coverage = attention(
lstm_output.squeeze(1),
encoder_outputs,
encoder_mask,
coverage
)
# Calculate vocabulary distribution
concat_output = torch.cat([
lstm_output.squeeze(1),
context_vector
], dim=1)
vocab_dist = F.softmax(self.out(concat_output), dim=1)
# Calculate p_gen
p_gen_input = torch.cat([
context_vector,
lstm_output.squeeze(1),
embedded.squeeze(1)
], dim=1)
p_gen = torch.sigmoid(self.p_gen_linear(p_gen_input))
return vocab_dist, hidden, context_vector, attention_weights, p_gen, coverage
class PGN(nn.Module):
"""Complete Pointer-Generator Network with Coverage"""
def __init__(self, vocab_size, embed_size, hidden_size):
super(PGN, self).__init__()
self.encoder = Encoder(vocab_size, embed_size, hidden_size)
self.decoder = Decoder(vocab_size, embed_size, hidden_size)
def forward(self, src, tgt, src_mask, max_output_length=None):
"""
Forward pass for training
Args:
src: Source text tensor [batch_size, src_len]
tgt: Target text tensor [batch_size, tgt_len]
src_mask: Source padding mask [batch_size, src_len]
max_output_length: Maximum output sequence length
"""
batch_size = src.size(0)
# Initialize coverage vector
coverage = torch.zeros_like(src_mask, dtype=torch.float)
# Encode source sequence
encoder_outputs, encoder_hidden = self.encoder(src)
# Initialize decoder input with SOS token
decoder_input = torch.ones(batch_size, 1).long() * config.SOS_idx
decoder_hidden = encoder_hidden
last_context = torch.zeros(batch_size, config.hidden_size * 2)
# Store outputs for loss calculation
outputs = []
attentions = []
coverages = []
# Teacher forcing: use target as next input
for t in range(max_output_length):
vocab_dist, decoder_hidden, context_vector, attention_weights, p_gen, coverage = \
self.decoder(
decoder_input,
decoder_hidden,
last_context,
encoder_outputs,
src_mask,
coverage
)
outputs.append(vocab_dist)
attentions.append(attention_weights)
coverages.append(coverage)
# Next input is current target token (teacher forcing)
decoder_input = tgt[:, t:t+1]
last_context = context_vector
# Stack all outputs
outputs = torch.stack(outputs, dim=1)
attentions = torch.stack(attentions, dim=1)
coverages = torch.stack(coverages, dim=1)
return outputs, attentions, coverages
def generate(self, src, src_mask, max_length=None):
"""Generate output sequence (inference)"""
if max_length is None:
max_length = config.max_dec_steps
batch_size = src.size(0)
# Initialize coverage vector
coverage = torch.zeros_like(src_mask, dtype=torch.float)
# Encode source sequence
encoder_outputs, encoder_hidden = self.encoder(src)
# Initialize decoder
decoder_input = torch.ones(batch_size, 1).long() * config.SOS_idx
decoder_hidden = encoder_hidden
last_context = torch.zeros(batch_size, config.hidden_size * 2)
generated_tokens = []
attention_weights_list = []
for _ in range(max_length):
vocab_dist, decoder_hidden, context_vector, attention_weights, p_gen, coverage = \
self.decoder(
decoder_input,
decoder_hidden,
last_context,
encoder_outputs,
src_mask,
coverage
)
# Select next token (greedy decoding)
decoder_input = vocab_dist.argmax(dim=1).unsqueeze(1)
last_context = context_vector
generated_tokens.append(decoder_input.squeeze())
attention_weights_list.append(attention_weights)
# Early stopping if EOS token is generated
if decoder_input.item() == config.EOS_idx:
break
return torch.stack(generated_tokens, dim=1), torch.stack(attention_weights_list, dim=1)
def train_step(model, optimizer, src_batch, tgt_batch, src_mask):
"""Single training step"""
optimizer.zero_grad()
# Forward pass
outputs, attentions, coverages = model(src_batch, tgt_batch, src_mask)
# Calculate loss
nll_loss = F.nll_loss(outputs.view(-1, config.vocab_size),
tgt_batch.view(-1),
ignore_index=config.PAD_idx)
# Calculate coverage loss
if config.coverage:
coverage_loss = torch.sum(torch.min(attentions, coverages)) / src_batch.size(0)
total_loss = nll_loss + config.LAMBDA * coverage_loss
else:
total_loss = nll_loss
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
return total_loss.item()
# Example usage:
def main():
# Create model instance
model = PGN(
vocab_size=config.vocab_size,
embed_size=config.embed_size,
hidden_size=config.hidden_size
).to(config.device)
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters())
# Training loop would go here
# ...
if __name__ == "__main__":
main()
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。