Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

原创
作者头像
用户1750490
修改于 2020-01-13 04:25:59
修改于 2020-01-13 04:25:59
1.3K00
代码可运行
举报
文章被收录于专栏:钛问题钛问题
运行总次数:0
代码可运行

声明本文代码方案来自苏剑林老师的bert4keras,代码来源链接

https://github.com/bojone/bert4keras

首先安装bert4keras pip install git+https://www.github.com/bojone/bert4keras.git 基于苏剑林老师的bert4keras进行小幅度改动

https://www.github.com/bojone/bert4keras.git

特别感谢腾讯钛提供的免费的32GB显存的机器。希望腾讯钛能一直给我提供机器。对应的我会给腾讯钛写好多好多的技术博客的呦。 下载 thucnews数据集 thucnews文件需要自己申请才可以下载的呦,非商业用途仅为了技术交流哦。 #! -*- coding: utf-8 -*- # albert做Seq2Seq任务,采用UNILM方案

苏剑林老师的原文如下。 # 介绍链接:https://kexue.fm/archives/6933

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from __future__ import print_function

import codecs
import glob
import json
import os

import numpy as np
from tqdm import tqdm

from bert4keras.backend import keras, K
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import DataGenerator
from bert4keras.snippets import parallel_apply, sequence_padding
from bert4keras.tokenizer import Tokenizer, load_vocab

seq2seq_config = 'seq2seq_config.json'
min_count = 64
max_len = 128
batch_size = 16
steps_per_epoch = 1000
epochs = 10000

# bert配置
config_path = 'albert_small_zh_google/albert_config_small_google.json'
checkpoint_path = 'albert_small_zh_google/albert_model.ckpt'
dict_path = 'albert_small_zh_google/vocab.txt'

# 训练样本。THUCNews数据集,每个样本保存为一个txt。
txts = glob.glob('thuctc/THUCNews/*/*.txt')


_token_dict = load_vocab(dict_path)  # 读取词典
_tokenizer = Tokenizer(_token_dict, do_lower_case=True)  # 建立临时分词器

if os.path.exists(seq2seq_config):

    tokens = json.load(open(seq2seq_config))

else:

    def _batch_texts():
        texts = []
        for txt in txts:
            text = codecs.open(txt, encoding='utf-8').read()
            texts.append(text)
            if len(texts) == 100:
                yield texts
                texts = []
        if texts:
            yield texts

    def _tokenize_and_count(texts):
        _tokens = {}
        for text in texts:
            for token in _tokenizer.tokenize(text):
                _tokens[token] = _tokens.get(token, 0) + 1
        return _tokens

    tokens = {}

    def _total_count(result):
        for k, v in result.items():
            tokens[k] = tokens.get(k, 0) + v

    # 10进程来完成词频统计
    parallel_apply(
        func=_tokenize_and_count,
        iterable=tqdm(_batch_texts(), desc=u'构建词汇表中'),
        workers=10,
        max_queue_size=100,
        callback=_total_count,
    )

    tokens = [(i, j) for i, j in tokens.items() if j >= min_count]
    tokens = sorted(tokens, key=lambda t: -t[1])
    tokens = [t[0] for t in tokens]
    json.dump(tokens,
              codecs.open(seq2seq_config, 'w', encoding='utf-8'),
              indent=4,
              ensure_ascii=False)

token_dict, keep_words = {}, []  # keep_words是在bert中保留的字表

for t in ['[PAD]', '[UNK]', '[CLS]', '[SEP]']:
    token_dict[t] = len(token_dict)
    keep_words.append(_token_dict[t])

for t in tokens:
    if t in _token_dict and t not in token_dict:
        token_dict[t] = len(token_dict)
        keep_words.append(_token_dict[t])

tokenizer = Tokenizer(token_dict, do_lower_case=True)  # 建立分词器


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        idxs = list(range(len(self.data)))
        if random:
            np.random.shuffle(idxs)
        batch_token_ids, batch_segment_ids = [], []
        for i in idxs:
            txt = self.data[i]
            text = codecs.open(txt, encoding='utf-8').read()
            text = text.split('\n')
            if len(text) > 1:
                title = text[0]
                content = '\n'.join(text[1:])
                token_ids, segment_ids = tokenizer.encode(content,
                                                          title,
                                                          max_length=max_len)
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
            if len(batch_token_ids) == self.batch_size or i == idxs[-1]:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                yield [batch_token_ids, batch_segment_ids], None
                batch_token_ids, batch_segment_ids = [], []


model = build_bert_model(
    config_path,
    checkpoint_path,
    application='seq2seq',
    model='albert',
    keep_words=keep_words,  # 只保留keep_words中的字,精简原字表
)

model.summary()

# 交叉熵作为loss,并mask掉输入部分的预测
y_in = model.input[0][:, 1:]  # 目标tokens
y_mask = model.input[1][:, 1:]
y = model.output[:, :-1]  # 预测tokens,预测与目标错开一位
cross_entropy = K.sparse_categorical_crossentropy(y_in, y)
cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)

model.add_loss(cross_entropy)
model.compile(optimizer=Adam(1e-5))


def gen_sent(s, topk=2, title_max_len=32):
    """beam search解码
    每次只保留topk个最优候选结果;如果topk=1,那么就是贪心搜索
    """
    content_max_len = max_len - title_max_len
    token_ids, segment_ids = tokenizer.encode(s, max_length=content_max_len)
    target_ids = [[] for _ in range(topk)]  # 候选答案id
    target_scores = [0] * topk  # 候选答案分数
    for i in range(title_max_len):  # 强制要求输出不超过title_max_len字
        _target_ids = [token_ids + t for t in target_ids]
        _segment_ids = [segment_ids + [1] * len(t) for t in target_ids]
        _probas = model.predict([_target_ids, _segment_ids
                                 ])[:, -1, 3:]  # 直接忽略[PAD], [UNK], [CLS]
        _log_probas = np.log(_probas + 1e-6)  # 取对数,方便计算
        _topk_arg = _log_probas.argsort(axis=1)[:, -topk:]  # 每一项选出topk
        _candidate_ids, _candidate_scores = [], []
        for j, (ids, sco) in enumerate(zip(target_ids, target_scores)):
            # 预测第一个字的时候,输入的topk事实上都是同一个,
            # 所以只需要看第一个,不需要遍历后面的。
            if i == 0 and j > 0:
                continue
            for k in _topk_arg[j]:
                _candidate_ids.append(ids + [k + 3])
                _candidate_scores.append(sco + _log_probas[j][k])
        _topk_arg = np.argsort(_candidate_scores)[-topk:]  # 从中选出新的topk
        target_ids = [_candidate_ids[k] for k in _topk_arg]
        target_scores = [_candidate_scores[k] for k in _topk_arg]
        best_one = np.argmax(target_scores)
        if target_ids[best_one][-1] == 3:
            return tokenizer.decode(target_ids[best_one])
    # 如果title_max_len字都找不到结束符,直接返回
    return tokenizer.decode(target_ids[np.argmax(target_scores)])


def just_show():
    s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。'
    s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10' \
         u'余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5' \
         u'亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 ' \
         u'团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。 '
    for s in [s1, s2]:
        print(u'生成标题:', gen_sent(s))
    print()


class Evaluate(keras.callbacks.Callback):
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('best_model.weights')
        # 演示效果
        just_show()


if __name__ == '__main__':

    evaluator = Evaluate()
    train_generator = data_generator(txts, batch_size)

    model.fit_generator(train_generator.forfit(),
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        callbacks=[evaluator])

else:

    model.load_weights('best_model.weights')

文章首发于知乎,欢迎转载。

代码语言:javascript
代码运行次数:0
运行
复制

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
albert做Seq2Seq任务 采用UNILM方案
#! -*- coding: utf-8 -*- # albert做Seq2Seq任务,采用UNILM方案 # 介绍链接:https://kexue.fm/archives/6933 from __future__ import print_function import codecs import glob import json import os import numpy as np from tqdm import tqdm from bert4keras.backend import ke
用户1750490
2020/01/01
1.1K0
利用bert系列预训练模型在非结构化数据抽取数据
本文代码来源苏剑林老师bert4keras example中的例子。 https://github.com/bojone/bert4keras 中文数据中有一个数据是从非结构化文本中找到演艺圈相关实
用户1750490
2020/01/03
2.2K0
使用TensorFlow 2.0的简单BERT
这篇文章展示了使用TensorFlow 2.0的BERT [1]嵌入的简单用法。由于TensorFlow 2.0最近已发布,该模块旨在使用基于高级Keras API的简单易用的模型。在一本很长的NoteBook中描述了BERT的先前用法,该NoteBook实现了电影评论预测。在这篇文章中,将看到一个使用Keras和最新的TensorFlow和TensorFlow Hub模块的简单BERT嵌入生成器。所有代码都可以在Google Colab上找到。
代码医生工作室
2019/11/12
8.5K0
“瘦身成功”的ALBERT,能取代BERT吗?
模型的创新点集中在了预训练过程,采用Masked LM和Next Sentence Prediction两种方法,分别捕捉词语和句子级别的表示。
量子位
2020/03/31
9640
“瘦身成功”的ALBERT,能取代BERT吗?
Bert+seq2seq 周公解梦,看AI如何解析你的梦境?
作者:saiwaiyanyu 链接:https://juejin.im/post/5dd9e07b51882572f00c4523
Ai学习的老章
2019/12/05
7300
用深度学习做命名实体识别(五)-模型使用
注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。
AI粉嫩特工队
2019/09/23
9070
用深度学习做命名实体识别(五)-模型使用
用深度学习做命名实体识别(五)-模型使用
注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。
AI粉嫩特工队
2019/09/29
1.3K0
用深度学习做命名实体识别(五)-模型使用
山东算法赛网格事件智能分类topline
基于网格事件数据,对网格中的事件内容进行提取分析,对事件的类别进行划分,具体为根据提供的事件描述,对事件所属政务类型进行划分。
致Great
2022/01/06
5550
山东算法赛网格事件智能分类topline
TensorFlow2学习:RNN生成古诗词
https://blog.csdn.net/aaronjny/article/details/103806954
AI科技大本营
2020/03/24
1.7K0
TensorFlow2学习:RNN生成古诗词
BERT的PyTorch实现
本文主要介绍一下如何使用 PyTorch 复现BERT。请先花上 10 分钟阅读我的这篇文章 BERT详解(附带ELMo、GPT介绍),再来看本文,方能达到醍醐灌顶,事半功倍的效果
mathor
2020/07/27
9200
手把手教你搭建Bert文本分类模型,快点看过来吧!
企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。
致Great
2021/07/14
9030
手把手教你搭建Bert文本分类模型,快点看过来吧!
广告行业中那些趣事系列:详解BERT中分类器源码
摘要:BERT是近几年NLP领域中具有里程碑意义的存在。因为效果好和应用范围广所以被广泛应用于科学研究和工程项目中。广告系列中前几篇文章有从理论的方面讲过BERT的原理,也有从实战的方面讲过使用BERT构建分类模型。本篇从源码的角度从整体到局部分析BERT模型中分类器部分的源码。
guichen1013
2020/12/08
5010
广告行业中那些趣事系列:详解BERT中分类器源码
基于bert命名实体识别(一)数据处理
要使用官方的tensorflow版本的bert微调进行自己的命名实体识别,需要处理数据成bert相应的格式,主要是在run_classifier.py中,比如说:
西西嘛呦
2020/11/24
1.1K0
基于bert命名实体识别(一)数据处理
BERT详解
BERT(Bidirectional Encoder Representations from Transformers) 是一个语言表示模型(language representation model)。它的主要模型结构是trasnformer的encoder堆叠而成,它其实是一个2阶段的框架,分别是pretraining,以及在各个具体任务上进行finetuning。
Don.huang
2020/09/22
4.8K1
BERT+PET方式模型训练
@小森
2024/06/08
1520
BERT+PET方式模型训练
你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗?
大语言模型目前一发不可收拾,在使用的时候经常会看到transformers库的踪影,其中xxxCausalLM和xxxForConditionalGeneration会经常出现在我们的视野中,接下来我们就来聊聊transformers库中的一些基本任务。
西西嘛呦
2023/04/27
1.5K0
BERT源码分析PART II
BERT的使用可以分为两个步骤:pre-training和fine-tuning。pre-training的话可以很好地适用于自己特定的任务,但是训练成本很高(four days on 4 to 16 Cloud TPUs),对于大对数从业者而言不太好实现从零开始(from scratch)。不过Google已经发布了各种预训练好的模型可供选择,只需要进行对特定任务的Fine-tuning即可。
AINLP
2019/07/23
9360
BERT源码分析PART II
实践演练Pytorch Bert模型转ONNX模型及预测
在之前的文章 《GPU服务器初体验:从零搭建Pytorch GPU开发环境》 中,我通过Github上一个给新闻标题做分类的Bert项目,演示了Pytorch模型训练与预测的过程。我其实也不是机器学习的专业人士,对于模型的结构、训练细节所知有限,但作为后台开发而非算法工程师,我更关注的是模型部署的过程。
果冻虾仁
2022/11/12
3.2K0
实践演练Pytorch Bert模型转ONNX模型及预测
使用Python实现深度学习模型:序列到序列模型(Seq2Seq)
序列到序列(Seq2Seq)模型是一种深度学习模型,广泛应用于机器翻译、文本生成和对话系统等自然语言处理任务。它的核心思想是将一个序列(如一句话)映射到另一个序列。本文将详细介绍 Seq2Seq 模型的原理,并使用 Python 和 TensorFlow/Keras 实现一个简单的 Seq2Seq 模型。
Echo_Wish
2024/06/06
4590
事件因果提取论文复现
本文对论文进行复现:Event Causality Extraction with Event Argument Correlations 事件因果识别(ECI)旨在检测两个给定文本事件之间是否存在因果关系,这对于理解事件因果关系至关重要。然而,ECI任务忽略了关键的事件结构和因果关系组件信息,导致在下游应用中存在困难。因此,论文提出了一种新颖的任务,名为事件因果提取(ECE),旨在从纯文本中提取因果事件对及其结构化事件信息。ECE任务更具挑战性,因为每个事件可能包含多个事件参数,需要考虑事件之间的细粒度关联来确定因果事件对。因此,论文提出了一种采用双网格标注方案的方法,以捕捉ECE的事件内部和事件间参数之间的关联。此外,他们设计了一种事件类型增强的模型架构,以实现双网格标注方案。实验证明了该方法的有效性,并进行了广泛的分析,指出了ECE的若干未来研究方向。 本次代码复现在原有代码的基础上已经下载好bert模型参数,直接训练就好
Srlua
2024/11/27
1900
事件因果提取论文复现
相关推荐
albert做Seq2Seq任务 采用UNILM方案
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档