前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >TensorFlow 2.0 - tf.distribute 分布式训练

TensorFlow 2.0 - tf.distribute 分布式训练

作者头像
Michael阿明
发布2021-02-19 14:20:00
发布2021-02-19 14:20:00
40300
代码可运行
举报
运行总次数:0
代码可运行

文章目录

学习于:简单粗暴 TensorFlow 2

1. 单机多卡 MirroredStrategy

代码语言:javascript
代码运行次数:0
复制
# 分布式训练
import tensorflow as tf
import tensorflow_datasets as tfds

# 1 单机多卡 MirroredStrategy

strategy = tf.distribute.MirroredStrategy()
# 指定设备
strategy = tf.distribute.MirroredStrategy(devices=['/gpu:0'])
# ------------------------------------------------
num_epochs = 5
batch_size_per_replica = 64
learning_rate = 1e-4

# 定义策略
strategy = tf.distribute.MirroredStrategy()

print("设备数量:{}".format(strategy.num_replicas_in_sync))
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync


def resize(img, label):  # 处理图片
    img = tf.image.resize(img, [224, 224]) / 255.0
    return img, label


# 载入猫狗分类数据集
dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)

# 使用策略
with strategy.scope():
	# 模型构建代码放入 with 
    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )

model.fit(dataset, epochs=num_epochs)

2. 多机训练 MultiWorkerMirroredStrategy

  • 相比上面,多了以下配置
  • 'task': {'type': 'worker', 'index': 0} 每台机器 index 不一样
代码语言:javascript
代码运行次数:0
复制
num_workers = 2
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:20000", "localhost:20001"]
    },
    'task': {'type': 'worker', 'index': 0} # 每台机器的 index 不同
})

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
batch_size = batch_size_per_replica * num_workers

3. TPU 张量处理单元

可以在 Colab 上运行

代码语言:javascript
代码运行次数:0
复制
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/02/03 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 1. 单机多卡 MirroredStrategy
  • 2. 多机训练 MultiWorkerMirroredStrategy
  • 3. TPU 张量处理单元
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档