当您尝试训练基于MobileNetV2的图像分类器时,若loss不能收敛,可能是由于多种原因造成的
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import numpy as np
image = ... # Your image data
image = preprocess_input(image)
tf.keras.preprocessing.image.ImageDataGenerator
实现数据增强。from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
fit
或fit_generator
方法训练模型。如果您使用了数据增强,请确保在训练时应用它。batch_size = 32
epochs = 50
history = model.fit(
train_datagen.flow(x_train, y_train, batch_size=batch_size),
steps_per_epoch=len(x_train) // batch_size,
epochs=epochs,
validation_data=(x_val, y_val)
)
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
领取专属 10元无门槛券
手把手带您无忧上云