假设,我们在TensorFlow流中拟合了一个模型
model.fit(
train_generator,
epochs=epochs,
verbose=1,
steps_per_epoch=steps_per_epoch,
validation_data=valid_generator,
validation_steps=val_steps_per_epoch).history
在下一步中,我们将生成预测。
Y_pred = model.predict_generator(valid_generator, np.ceil(valid_generator.samples / valid_generator.batch_size))
我想知道是否有可能保存预测并将其从磁盘加载,以便调试后续代码,而无需在每次重新启动后每次都对数据进行再培训和预测。
当然,保存和加载模型是可能的,但在预测方面仍然存在一些开销。
任何想法都会受到高度赞赏。提前感谢
发布于 2021-05-02 11:45:05
根据我从注释框中的理解,这里有一些可能的解决方案用于您的查询,请告诉我它是否适合您。
,我想知道是否有可能保存预测并将其从磁盘加载,以便调试后续代码,而无需在每次重新启动后每次都对模型和数据进行预测。
首先,我们建立一个模型,并首先对其进行培训。
import tensorflow as tf
# Model
input = tf.keras.Input(shape=(28, 28))
base_maps = tf.keras.layers.Flatten(input_shape=(28, 28))(input)
base_maps = tf.keras.layers.Dense(128, activation='relu')(base_maps)
base_maps = tf.keras.layers.Dense(units=10, activation='softmax', name='primary')(base_maps)
model = tf.keras.Model(inputs=[input], outputs=[base_maps])
# compile
model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy'],
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) )
# data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = tf.divide(x_train, 255)
y_train = tf.one_hot(y_train , depth=10)
# customized fit
model.fit(x_train, y_train, batch_size=512, epochs=3, verbose = 1)
接下来,我们使用这个经过训练的模型来预测未见数据(x_test
),并将预测保存到磁盘中,以便以后可以调试模型性能问题。
import numpy as np
import pandas as pd
y_pred = model.predict(x_test) # get prediction
y_pred = np.argmax(y_pred, axis=-1) # get class labels
# save ground truth and prediction to local disk as CSV file
oof = pd.DataFrame(dict(
gt = y_test,
pd = y_pred,
))
oof.to_csv('oof.csv', index=False)
oof.head(20)
# compute how many prediction are accurate or match
oof['check'] = np.where((oof['gt'] == oof['pd']), 'Match', 'No Match')
oof.check.value_counts()
Match 9492
No Match 508
Name: check, dtype: int64
像这样,我们可以做各种类型的分析,从模型预测和地面真相。但是,为了节省概率(而不是实际的标签),我们也可以这样做:reference。
y_pred = model.predict(x_test)
np.savetxt("y_pred.csv", y_pred , delimiter=",")
https://stackoverflow.com/questions/67354697
复制相似问题