前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)

[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)

作者头像
Michael阿明
发布2021-02-19 14:57:15
发布2021-02-19 14:57:15
74700
代码可运行
举报
运行总次数:0
代码可运行

文章目录

练习地址:https://www.kaggle.com/c/ds100fa19

相关博文

[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(spacy)

[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(BERT)

1. 读入数据

  • 读取数据,test集没有标签
代码语言:javascript
代码运行次数:0
复制
import pandas as pd
import numpy as np
train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")
train.head()
  • 数据有无效的单元
代码语言:javascript
代码运行次数:0
复制
print(np.sum(np.array(train.isnull()==True), axis=0))
print(np.sum(np.array(test.isnull()==True), axis=0))

存在 Na 单元格

代码语言:javascript
代码运行次数:0
复制
[0 6 0 0]
[0 1 0]
  • fillna 填充处理
代码语言:javascript
代码运行次数:0
复制
train = train.fillna(" ")
test = test.fillna(" ")
print(np.sum(np.array(train.isnull()==True), axis=0))
print(np.sum(np.array(test.isnull()==True), axis=0))

填充完成,显示 sum = 0

代码语言:javascript
代码运行次数:0
复制
[0 0 0 0]
[0 0 0]
  • y 标签 只有 0 不是垃圾邮件, 1 是垃圾邮件
代码语言:javascript
代码运行次数:0
复制
print(train['spam'].unique())
[0 1]

2. 文本处理

  • 邮件内容和主题合并为一个特征
代码语言:javascript
代码运行次数:0
复制
X_train = train['subject'] + ' ' + train['email']
y_train = train['spam']
X_test = test['subject'] + ' ' + test['email']
  • 文本转成 tokens ids 序列
代码语言:javascript
代码运行次数:0
复制
from keras.preprocessing.text import Tokenizer
max_words = 300
tokenizer = Tokenizer(num_words=max_words, lower=True, split=' ')
# 只给频率最高的300个词分配 id,其他的忽略
tokenizer.fit_on_texts(list(X_train)+list(X_test)) # tokenizer 训练
X_train_tokens = tokenizer.texts_to_sequences(X_train)
X_test_tokens = tokenizer.texts_to_sequences(X_test)
  • pad ids 序列,使之长度一样
代码语言:javascript
代码运行次数:0
复制
# 样本 tokens 的长度不一样,pad
maxlen = 100
from keras.preprocessing import sequence
X_train_tokens_pad = sequence.pad_sequences(X_train_tokens, maxlen=maxlen,padding='post')
X_test_tokens_pad = sequence.pad_sequences(X_test_tokens, maxlen=maxlen,padding='post')

3. 建模

代码语言:javascript
代码运行次数:0
复制
embeddings_dim = 30 # 词嵌入向量维度
from keras.models import Model, Sequential
from keras.layers import Embedding, LSTM, GRU, SimpleRNN, Dense
model = Sequential()
model.add(Embedding(input_dim=max_words, # Size of the vocabulary
                    output_dim=embeddings_dim, # 词嵌入的维度
                    input_length=maxlen))
model.add(GRU(units=64)) # 可以改为 SimpleRNN , LSTM
model.add(Dense(units=1, activation='sigmoid'))
model.summary()

模型结构:

代码语言:javascript
代码运行次数:0
复制
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, 100, 30)           9000      
_________________________________________________________________
gru (GRU)                    (None, 64)                18432     
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 65        
=================================================================
Total params: 27,497
Trainable params: 27,497
Non-trainable params: 0
_________________________________________________________________

4. 训练

代码语言:javascript
代码运行次数:0
复制
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy']) # 配置模型
history = model.fit(X_train_tokens_pad, y_train,
                    batch_size=128, epochs=10, validation_split=0.2)
model.save("email_cat_lstm.h5") # 保存训练好的模型
  • 绘制训练曲线
代码语言:javascript
代码运行次数:0
复制
from matplotlib import pyplot as plt
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.show()

5. 测试

代码语言:javascript
代码运行次数:0
复制
pred_prob = model.predict(X_test_tokens_pad).squeeze()
pred_class = np.asarray(pred_prob > 0.5).astype(np.int32)
id = test['id']
output = pd.DataFrame({'id':id, 'Class': pred_class})
output.to_csv("submission_gru.csv",  index=False)
  • 3种RNN模型对比:
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/12/12 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 1. 读入数据
  • 2. 文本处理
  • 3. 建模
  • 4. 训练
  • 5. 测试
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档