首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何将一个热编码存储为一个对象?

如何将一个热编码存储为一个对象?
EN

Stack Overflow用户
提问于 2019-09-23 16:33:58
回答 1查看 442关注 0票数 0

首先,介绍一下我的模型的架构背景。

对我的keras模型的输入相当简单:

  • 范畴变量A
  • 范畴变量B
  • 数值输入C,在0,1范围内。

该模型具有一个单一的输出:

  • A 0,1上的数字

在培训模型时,我的输入数据是来自使用pd.read_sql()的SQL数据库的数据。我用以下函数对分类变量A和B(分别在dataframe col1col2original_data中)进行了一次热编码:

代码语言:javascript
运行
复制
from keras import utils as np_utils

def preprocess_categorical_features(self):
        col1 = np_utils.to_categorical(np.copy(self.original_data.CURRENT_RTIF.values))
        col2 = np_utils.to_categorical(np.copy(self.original_data.NEXT_RTIF.values))
        cat_input_data = np.append(col1,col2,axis=1)
        return cat_input_data

稍后,当我需要从这个模型进行预测时,输入数据来自RabbitMQ以字典形式提供的实时提要。这个RabbitMQ数据必须由它自己的(不同的) reprocess_categorical_features()函数处理。

这就引出了我的问题:无论是对数据库的数据进行预处理,还是rabbitMQ提要,我如何确保一次热编码完全相同?

应用于数据库数据的A的一次热编码:

代码语言:javascript
运行
复制
|---------------------|------------------|
|          A          | One-Hot-Encoding |
|---------------------|------------------|
|       "coconut"     |      <0,1,0,0>   |
|---------------------|------------------|
|       "apple"       |      <1,0,0,0>   |
|---------------------|------------------|
|       "quince"      |      <0,0,0,1>   |
|---------------------|------------------|
|       "plum"        |      <0,1,0,0>   |
|---------------------|------------------|

应用于RabbitMQ数据的A的一次热编码(它们必须相同):

代码语言:javascript
运行
复制
|---------------------|------------------|
|          A          | One-Hot-Encoding |
|---------------------|------------------|
|       "coconut"     |      <0,1,0,0>   |
|---------------------|------------------|
|       "apple"       |      <1,0,0,0>   |
|---------------------|------------------|
|       "quince"      |      <0,0,0,1>   |
|---------------------|------------------|
|       "plum"        |      <0,1,0,0>   |
|---------------------|------------------|

是否有办法将编码保存为数据、numpy ndarray或字典,以便将编码从预处理训练数据的函数传递到预处理输入数据的函数?对于OHE,我愿意使用Keras以外的其他库,但我想知道是否有一种方法可以使用我目前正在使用的keras‘范畴化函数来完成这一任务。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-24 17:14:59

我没有依赖keras的utils.to_categorical方法,而是决定使用sklearn.preprocessing.OneHotEncoder。这允许我在处理培训数据时声明一个单一热编码器对象self.encoder

代码语言:javascript
运行
复制
class TrainingData:
    def preprocess_categorical_features(self):
        # declare OneHotEncoder object to save for later
        self.encoder = OneHotEncoder(sparse=False)

        # fit encoder to data
        self.encoder.fit(self.original_data.CURRENT_RTIF.values.reshape(-1,1))

        # perform one-hot-encoding on columns 1 and 2 of the training data
        col1 = self.encoder.transform(self.original_data.CURRENT_RTIF.values.reshape(-1,1))
        col2 = self.encoder.transform(self.original_data.NEXT_RTIF.values.reshape(-1,1))

        # return on-hot-encoded data as a numpy ndarray
        cat_input_data = np.append(col1,col2,axis=1)
        return cat_input_data

稍后,我可以重用该编码器(通过将它作为参数传递,training_data_ohe_encoder)到处理最终作出预测所需的输入数据的方法。

代码语言:javascript
运行
复制
class LiveData:
    def preprocess_categorical_features(self, training_data_ohe_encoder):
        # notice the training_data_ohe_encoder parameter; this is the 
        # encoder attribute from the Training Data Class.

        # one-hot-encode the live data using the training_data_ohe_encoder encoder
        col1 = training_data_ohe_encoder.transform(np.copy(self.preprocessed_data.CURRENT_RTIF.values).reshape(-1, 1))
        col2 = training_data_ohe_encoder.transform(np.copy(self.preprocessed_data.NEXT_RTIF.values).reshape(-1, 1))

        # return on-hot-encoded data as a numpy ndarray
        cat_input_data = np.append(col1,col2,axis=1)
        return cat_input_data
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58066673

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档