Google Cloud ML Engine 是一个强大的云服务,用于训练、部署和管理机器学习模型。TensorFlow 是一个流行的开源机器学习库。在 Google Cloud ML Engine 中使用 TensorFlow 时,input_fn()
是一个关键函数,它负责准备数据以供模型训练或预测。
在 input_fn()
中执行预处理和标记化(tokenization)是很常见的,因为这样可以确保数据在送入模型之前已经被适当地处理。以下是一个简单的例子,展示了如何在 input_fn()
中执行这些操作:
确保你已经安装了 TensorFlow 和其他必要的库。
pip install tensorflow
input_fn()
以下是一个简单的 input_fn()
示例,它执行文本数据的预处理和标记化:
import tensorflow as tf
import numpy as np
def input_fn(data_file, batch_size, num_epochs, shuffle):
"""Input function for training and evaluation.
Args:
data_file: File path to the CSV file containing the data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
shuffle: Boolean, whether to shuffle the data.
Returns:
A tuple (features, labels) where features is a dictionary of input features,
and labels is the target tensor.
"""
# Load and preprocess the data
def parse_csv(value):
columns = tf.io.decode_csv(value, record_defaults=[[0]] * 3)
features = {'text': columns[0]}
labels = columns[1:]
return features, labels
# Read the CSV file
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.map(parse_csv, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Tokenization and preprocessing
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=10000, oov_token='<OOV>')
def tokenize_text(features, labels):
text = features['text']
text = tf.strings.lower(text) # Convert to lowercase
text = tf.strings.regex_replace(text, '[%s]' % re.escape(string.punctuation), '') # Remove punctuation
sequences = tokenizer.texts_to_sequences([text.numpy()[0]])[0] # Tokenize
padded = tf.keras.preprocessing.sequence.pad_sequences([sequences], maxlen=100) # Pad sequences
features['text'] = padded
return features, labels
dataset = dataset.map(tokenize_text, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.padded_batch(batch_size, padded_shapes=({'text': [None]}, [None]))
dataset = dataset.repeat(num_epochs)
return dataset
tf.data.experimental.AUTOTUNE
来自动调整并行处理的线程数。要在 Google Cloud ML Engine 中使用此 input_fn()
,你需要将其集成到你的 TensorFlow 估计器中,并确保你的 model_fn()
正确处理输入特征。
领取专属 10元无门槛券
手把手带您无忧上云