前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >如何在腾讯钛中训练基于bert预训练语言模型的文本分类模型

如何在腾讯钛中训练基于bert预训练语言模型的文本分类模型

原创
作者头像
用户1750490
修改于 2019-12-10 10:40:44
修改于 2019-12-10 10:40:44
1.5K00
代码可运行
举报
文章被收录于专栏:钛问题钛问题
运行总次数:0
代码可运行
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import codecs
import os

import keras
import numpy as np
import pandas as pd
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
from keras_radam import RAdam
max_len = 96
config_path = 'roberta/bert_config_large.json'
checkpoint_path = 'roberta/roberta_zh_large_model.ckpt'
dict_path = 'roberta/vocab.txt'

token_dict = {}

with codecs.open(dict_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)


class OurTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]')  # space类用未经训练的[unused1]表示
            else:
                R.append('[UNK]')  # 剩余的字符是[UNK]
        return R


tokenizer = OurTokenizer(token_dict)

neg = pd.read_csv('data/enhance_data_result.csv', header=None)

data = []

for d, label in zip(neg[1], neg[2]):
    if label in [2, 0, 1]:
        if isinstance(d, str):
            data.append((d, label))

# 按照9:1的比例划分训练集和验证集
random_order = list(range(len(data)))
np.random.shuffle(random_order)
train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]


def seq_padding(X, padding=0):
    L = [len(x) for x in X]
    ML = max(L)
    return np.array([
        np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
    ])


class data_generator:
    def __init__(self, data, batch_size=2):
        self.data = data
        self.batch_size = batch_size
        self.steps = len(self.data) // self.batch_size
        if len(self.data) % self.batch_size != 0:
            self.steps += 1

    def __len__(self):
        return self.steps

    def __iter__(self):
        while True:
            idxs = list(range(len(self.data)))
            np.random.shuffle(idxs)
            X1, X2, Y = [], [], []
            for i in idxs:
                d = self.data[i]
                text = d[0][:max_len]
                x1, x2 = tokenizer.encode(first=text)
                y = d[1]
                X1.append(x1)
                X2.append(x2)
                Y.append([y])
                if len(X1) == self.batch_size or i == idxs[-1]:
                    X1 = seq_padding(X1)
                    X2 = seq_padding(X2)
                    Y = seq_padding(Y)
                    yield [X1, X2], Y
                    [X1, X2, Y] = [], [], []


from keras.layers import *
from keras.models import Model

bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)

for l in bert_model.layers:
    l.trainable = True

x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))

x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x)
x = Dropout(0.8)(x)
p = Dense(3, activation='softmax')(x)

model = Model([x1_in, x2_in], p)
save = ModelCheckpoint(
    os.path.join('bert.h5'),
    monitor='val_acc',
    verbose=1,
    save_best_only=True,
    mode='auto'
)
early_stopping = EarlyStopping(
    monitor='val_acc',
    min_delta=0,
    patience=8,
    verbose=1,
    mode='auto'
)
callbacks = [save, early_stopping]
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=Adam(1e-5),  # 用足够小的学习率
    metrics=['accuracy']
)
model.summary()

train_D = data_generator(train_data)
valid_D = data_generator(valid_data)

model.fit_generator(
    train_D.__iter__(),
    steps_per_epoch=1000,
    epochs=5000,
    validation_data=valid_D.__iter__(),
    validation_steps=1000,
    callbacks=callbacks,

)

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Java常见日期格式及日期的计算工具类
import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; import java.util.GregorianCalendar; import java.util.List;
一头小山猪
2020/06/15
4.4K0
类查看方法
​ 可以采用(int)(Math.random()*n)来获取【0,n)之间的随机整数值
秋落雨微凉
2022/10/25
6880
java 转为Calendar_java Calendar和Date()的转化
public static void main(String args[]){
全栈程序员站长
2022/06/26
9100
Java基础-常用类(二)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
cwl_java
2019/11/12
3580
JAVA零基础小白学习教程之StringBuilder类和包装类.Arrays类.Math类
B友https://www.bilibili.com/video/BV1QG4y1J76q?p=39
张哥编程
2024/12/13
940
Java 小记 - 时间的处理与探究
时间的处理与日期的格式转换几乎是所有应用的基础职能之一,几乎所有的语言都会为其提供基础类库。作为曾经 .NET 的重度使用者,赖其优雅的语法,特别是可扩展方法这个神级特性的存在,我几乎没有特意关注过这些个基础类库,他们如同空气一般,你呼吸着,却不用感受其所在何处。煽情结束,入坑 Java 后甚烦其时间处理方式,在此做个总结与备忘。
捷义
2018/07/18
7140
java中经常使用的日期格式化(全)「建议收藏」
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/115600.html原文链接:https://javaforall.cn
全栈程序员站长
2022/07/10
2K0
Java 相关时间工具类
package com.cavytech.wear2.util; import android.text.TextUtils; import android.text.format.Time; import android.util.Log; import com.cavytech.wear2.entity.BandSleepStepBean; import com.cavytech.wear2.entity.GetSleepBean; import java.text.DateFormat;
先知先觉
2019/01/21
1.5K0
Java 日期时间处理
java.util.Date对象表示一个精确到毫秒的瞬间; 但由于Date从JDK1.0起就开始存在了,历史悠久,而且功能强大(既包含日期,也包含时间),所以他的大部分构造器/方法都已Deprecated,因此就不再推荐使用(如果贸然使用的话,可能会出现性能/安全方面的问题);下面我仅介绍它还剩下的为数不多的几个方法(这些方法的共同点是Date与毫秒值的转换):
哲洛不闹
2018/09/19
3.5K0
Java 日期时间处理
Java 时间类-Calendar、Date、LocalDate/LocalTime
1、Date 类 java.util.Date是一个“万能接口”,它包含日期、时间,还有毫秒数,如果你只想用java.util.Date存储日期,或者只存储时间,那么,只有你知道哪些部分的数据是有用
九灵
2018/03/09
2K0
Java 时间类-Calendar、Date、LocalDate/LocalTime
java获取当前日期和时间(各种方法对比)
System.currentTimeMillis()产生一个当前的毫秒,这个毫秒其实就是自1970年1月1日0时起的毫秒数,类型为long; Date:
ha_lydms
2023/08/09
3.1K0
java获取当前日期和时间(各种方法对比)
深入理解Java常用类-----时间日期
本文主要介绍了Java中的日期和时间操作,包括使用Date、Calendar、DateFormat、Time以及线程安全的Calendar等类。同时,还介绍了Joda-Time库,该库提供了更先进的日期和时间操作类。通过这些技术,我们可以方便地进行日期和时间的读取、计算、操作和格式化等操作。
Single
2018/01/04
1.3K0
深入理解Java常用类-----时间日期
java calendar 设置小时_Java Calendar类的时间操作[通俗易懂]
Java Calendar 类时间操作,这也许是创建日历和管理最简单的一个方案,示范代码很简单,演示了获取时间,日期时间的累加和累减,以及比较。
全栈程序员站长
2022/08/12
1.6K0
java calendar 设置小时_Java Calendar类的时间操作[通俗易懂]
Java日期格式化
文章目录 1. 日期格式化 1.1. 前言 1.2. Date 1.2.1. 构造方法 1.2.2. 常用的方法 1.2.3. 实例 1.3. SimpleDateFormat 1.3.1. 构造方法 1.3.2. 常用的方法 1.3.3. 常用的日期格式化的模板 1.3.4. 实例 1.4. Calendar 1.4.1. 创建对象 1.4.2. 常用方法 1.4.3. 实例 日期格式化 前言 更多文章请看本人博客https://chenjiabing666.github.io/ 版权所有,如需转
爱撒谎的男孩
2019/12/31
2.7K0
java获取当前年月日时间戳_现在的年月日怎么来的
两种方法,通过Date类或者通过Calendar类。Date类比较简单,但是要得到细致的字段的话Calendar类比较方便。
全栈程序员站长
2022/10/29
1.7K0
java获取当前年月日时间戳_现在的年月日怎么来的
Java日期计算常用方法《详细版》
🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文并茂🦖生动形象🐅简单易学!欢迎大家来踩踩~🌺 🌊 《IDEA开发秘籍专栏》 🐾 学会IDEA常用操作,工作效率翻倍~💐 🌊 《100天精通Golang(基础入门篇)》 🐅 学会Golang语言,畅玩云原生,走遍大小厂~
猫头虎
2024/04/07
2510
LocalDateTime、Date时间工具类
参考:Date、LocalTime、LocalDate、LocalDate-时间操作工具类_Hatsune_Miku_的博客-CSDN博客
CBeann
2023/12/25
2750
java 日期格式化工具类
六月的雨在Tencent
2024/03/28
1090
Java Date 和 Calendar 实例
当前日期:  2012-03-07 2012-03-07 12:30:11 2012-3-7 12:30:11.101 计算周:  -3 3/5/12 12:30 PM 3/10/12 12:30 PM 3/12/12 12:30 PM 3/3/12 12:30 PM 计算月:  2012-03-01 2012-03-31 2012-02-01 2012-02-29 2012-04-01 2012-04-30 计算年:  2012-01-01 2012-12-31 2011-01-01 2011-12-31 2013-01-01 2013-12-31 366 in 2012 计算季度:  2012-3-7 in [ 2012-1 : 2012-3 ] 31 in [ 2012-3-7 ] true 日期格式转换与计算:  Wed Jun 20 00:00:00 CST 2012 Wednesday 2012-06-02 -> 2012-06-12间隔天数:10
阳光岛主
2019/02/19
3.1K0
Java 时间格式化(java中如何格式化一个日期)
1、通过MessageFormat转化 String dateTime = MessageFormat.format(“{0,date,yyyy-MM-dd-HH-mm:ss:ms}” , new Object[] { new java.sql.Date(System.currentTimeMillis()) }); 说明: yyyy-MM-dd-HH-mm:ss:ms 年yyyy 月MM 日dd 时(大写为24进制,小写为12进制) 分mm 秒ss 微妙ms
全栈程序员站长
2022/08/01
6.5K0
Java 时间格式化(java中如何格式化一个日期)
相关推荐
Java常见日期格式及日期的计算工具类
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验