本文对论文进行复现:Event Causality Extraction with Event Argument Correlations 事件因果识别(ECI)旨在检测两个给定文本事件之间是否存在因果关系,这对于理解事件因果关系至关重要。然而,ECI任务忽略了关键的事件结构和因果关系组件信息,导致在下游应用中存在困难。因此,论文提出了一种新颖的任务,名为事件因果提取(ECE),旨在从纯文本中提取因果事件对及其结构化事件信息。ECE任务更具挑战性,因为每个事件可能包含多个事件参数,需要考虑事件之间的细粒度关联来确定因果事件对。因此,论文提出了一种采用双网格标注方案的方法,以捕捉ECE的事件内部和事件间参数之间的关联。此外,他们设计了一种事件类型增强的模型架构,以实现双网格标注方案。实验证明了该方法的有效性,并进行了广泛的分析,指出了ECE的若干未来研究方向。 本次代码复现在原有代码的基础上已经下载好bert模型参数,直接训练就好
本文所涉及的所有资源的获取方式:这里
事件因果提取(ECE)旨在从纯文本中推导出因果事件对。在这里,一个因果事件对包含一个因果组件和一个结果组件,每个组件表示具有特定事件类型及其事件参数和事件角色的事件。给定一段文本,事件因果提取系统需要预测出其中所有的因果事件对,如图1所示.
该层派生了句子中单词的上下文表示和事件类型。为了方便后续的事件参数预测,我们打算进行事件类型感知的编码,将文本表示与事件类型信息相结合。具体来说,我们将事件类型连接到句子前面,并使用BERT进行编码,这是因为它具有深度自注意力架构。输入序列的组织形式如下所示
每个网格中的每个条目分别模拟了一个标记与一个事件类型之间的关系,用于事件参数推导。对于连接第j个事件类型ej和句子中第i个标记的条目,其表示gji可以通过融合函数得到,通过整合ti和ej的语义来获得。直观地说,可以通过各种语义融合方式实现,包括连接或加法。考虑到相同的事件参数跨度在不同的事件类型中可能扮演不同的角色,因此事件参数的决策应该取决于事件类型。因此,应该表明事件类型和标记之间的条件依赖关系。因此,我们采用了条件层归一化(CLN)来实现。CLN主要基于层归一化,但是它根据先前的条件动态计算增益和偏差,而不是直接将它们作为可学习参数部署在神经网络中。给定事件类型表示ej作为条件和标记表示hi,通过CLN实现融合函数如下:
采用两个语义融合函数,c 和r,分别为因果和效果网格表派生条目表示。每个语义融合函数都由一层CLN实现,因此条目表示为:
2.2.3 训练和推断 由于每个表中的可以同时分配多个标记,我们对条目表示进行多标签分类。具体来说,一个全连接网络预测了每个标记的概率:
在中国知识图谱和语义计算会议2021(CCKS2021)发布的语料库上进行实验。该语料库来自公共新闻和报道,包含7000个句子,平均长度为104个标记。它标注了15,816个事件,其中包含7908个因果事件对,涵盖了39种事件类型和3种事件角色,即产品、地区和行业。为了适应这个语料库的ECE任务,根据因果事件类型将其分成训练/验证/测试集。具体来说,CCKS2021被划分为训练/验证/测试集,比例为8:1:1。将拆分的数据集命名为ECE-CCKS。
step1:安装环境依赖
torch 1.7.1+cu110
transformers 4.5.1
step2:创建名为"log"的目录,并切换到名为"src"的目录中
mkdir log
cd ./src/
step3:训练 python train.py --task_name ece_task --training 1 --debug 0 --hidden_size 768
step4:推理 python train.py --task_name ece_task --training 0 --debug 0 --hidden_size 768 --model_name model_name
# start
class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
self.hidden_size = args.hidden_size
self.bert_embedding = BertModel.from_pretrained(args.bert_path)
self.tokenizer = BertTokenizer.from_pretrained(args.bert_path)
# Grid Representation and Classification for the Cause Table
self.gridmodel = GridModel(hidden_size=self.hidden_size, type_num=int(len(tt_map) // 2), dropout=args.dropout)
# Grid Representation and Classification for the Effect Table
self.gridmodel2 = GridModel(hidden_size=self.hidden_size, type_num=int(len(tt_map) // 2), dropout=args.dropout)
self.dp = nn.Dropout( args.dropout )
self.thresh = args.thresh
def encoding(self, input_ids, segment_ids, input_masks):
out = self.bert_embedding(input_ids=input_ids, attention_mask=input_masks, token_type_ids=segment_ids)
input_embs = out.last_hidden_state
input_embs = self.dp( input_embs )
return input_embs
def obtain_embs(self, input_embs_, label_indexs ):
etype_embs = torch.index_select(input_embs_, dim = 1, index = label_indexs[0])
input_embs = input_embs_[:, -max_seq_len: , :]
return input_embs, etype_embs
def run(self, input_ids, segment_ids, input_maks, label_indexs):
batch_size = input_ids.shape[0]
input_embs_ = self.encoding(input_ids, segment_ids, input_maks)
input_embs, etype_embs = self.obtain_embs(input_embs_, label_indexs)
# input_embs: [batch, seq, dim]
# etype_embs: [batch, type_num, dim]
tt_outputs_1 = self.gridmodel(input_embs, etype_embs)
tt_outputs_2 = self.gridmodel2(input_embs, etype_embs)
# Concat the output of two tables to derive the final loss
tt_outputs = torch.sigmoid(torch.cat([tt_outputs_1, tt_outputs_2], dim = -1))
return tt_outputs
def forward(self, input_ids, segment_ids, input_maks, label_indexs):
tt_outputs = self.run(input_ids, segment_ids, input_maks, label_indexs)
return tt_outputs
def inference(self, text_ids, input_ids, segment_ids, input_maks, label_indexs):
result = {'text_id': text_ids[0], 'result': []}
batch_size = input_ids.shape[0]
tt_outputs_ = self.run(input_ids, segment_ids, input_maks, label_indexs)
tt_outputs = tt_outputs_.squeeze(0).detach().cpu().numpy() # [type_num, seq, type_nums]
# Decode ent
input_ids = input_ids.squeeze(0)[-max_seq_len:]
sent_len = torch.sum(input_maks.squeeze(0)[-max_seq_len:]).item()
heads, tails, iids = np.where(tt_outputs > self.thresh)
ent_dict, event_ent_dict = {}, {}
reason_dict, result_dict = {}, {}
for (etype_id, token_id, iid) in list(zip(heads, tails, iids)):
# etype_id: index along the column, namely event types
# token_id: index along the row, namely the position of token
# iid: index of the predefined tags
etype = etype_id2type[etype_id]
tag_type = tt_id2type[iid]
tag, ent_type, ent_pos = tag_type.split('-')
### Step1: Argument Span Decoding ###
# In the Cause Table
if (tt_map['Rea2Rea-product-H'] <= iid <= tt_map['Rea2Rea-industry-T']) or (tt_map['Rea2Res-product-H'] <= iid <= tt_map['Rea2Res-industry-T']):
if etype not in reason_dict:
reason_dict[etype] = { 'reason':{'product': {'H': [], 'T': []}, 'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}},
'result':{'product': {'H': [], 'T': []}, 'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}}
if tag == 'Rea2Rea':
reason_dict[etype]['reason'][ent_type][ent_pos].append(token_id)
elif tag == 'Rea2Res':
reason_dict[etype]['result'][ent_type][ent_pos].append(token_id)
# In the Effect Table
elif (tt_map['Res2Res-product-H'] <= iid <= tt_map['Res2Res-industry-T']) or (tt_map['Res2Rea-product-H'] <= iid <= tt_map['Res2Rea-industry-T']):
if etype not in result_dict:
result_dict[etype] = {'reason':{'product': {'H': [], 'T': []}, 'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}},
'result':{'product': {'H': [], 'T': []}, 'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}}
if tag == 'Res2Rea':
result_dict[etype]['reason'][ent_type][ent_pos].append(token_id)
elif tag == 'Res2Res':
# pdb.set_trace()
result_dict[etype]['result'][ent_type][ent_pos].append(token_id)
reason_ent_dict, result_ent_dict = {}, {}
# In the cause table
for etype in reason_dict:
reason_ent_dict[etype] = {'reason': [], 'result': []}
for tag in reason_dict[etype]: # tag: reason(Intra) / result(Inter)
for key in reason_dict[etype][tag]: # key: product / region / industry
for ent_hid in reason_dict[etype][tag][key]['H']:
ent_tid_list = [ii for ii in reason_dict[etype][tag][key]['T'] if ii >= ent_hid]
if len(ent_tid_list) > 0:
ent_tid = min(ent_tid_list)
if max(ent_hid, ent_tid) < sent_len:
ent_text = "".join( self.tokenizer.convert_ids_to_tokens( input_ids[ent_hid: ent_tid + 1] ) )
reason_ent_dict[etype][tag].append((ent_text, key))
# In the effect table
for etype in result_dict:
result_ent_dict[etype] = {'reason': [], 'result': []}
for tag in result_dict[etype]: # tag: reason(Inter) / result(Intra)
for key in result_dict[etype][tag]: # key: product / region / industry
for ent_hid in result_dict[etype][tag][key]['H']:
ent_tid_list = [ii for ii in result_dict[etype][tag][key]['T'] if ii >= ent_hid]
if len(ent_tid_list) > 0:
ent_tid = min(ent_tid_list)
if max(ent_hid, ent_tid) < sent_len:
ent_text = "".join( self.tokenizer.convert_ids_to_tokens( input_ids[ent_hid: ent_tid + 1] ) )
result_ent_dict[etype][tag].append((ent_text, key))
### Step3: Decode event pair ###
for reason_type in reason_ent_dict:
for result_type in result_ent_dict:
reason_args = [item for item in reason_ent_dict[reason_type]['reason'] if item in result_ent_dict[result_type]['reason']]
result_args = [item for item in result_ent_dict[result_type]['result'] if item in reason_ent_dict[reason_type]['result']]
if max( len(reason_args), len(result_args)) != 0:
rr_pair = {'reason_type': reason_type, 'result_type': result_type,
'reason_product': set(), 'reason_region': set(), 'reason_industry': set(),
'result_product': set(), 'result_region': set(), 'result_industry': set()}
for item in reason_args:
ent_text, ent_type = item[-2], 'reason_' + item[-1]
rr_pair[ent_type].add(ent_text)
for item in result_args:
ent_text, ent_type = item[-2], 'result_' + item[-1]
rr_pair[ent_type].add(ent_text)
for key in ['reason_product', 'reason_region', 'reason_industry', 'result_product', 'result_region', 'result_industry']:
rr_pair[key] = ",".join(list(rr_pair[key])) if len(rr_pair[key]) != 0 else ""
result['result'].append(rr_pair)
return result
编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!