前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【tensorflow2.0】使用TPU训练模型

【tensorflow2.0】使用TPU训练模型

作者头像
西西嘛呦
发布于 2020-08-26 02:55:16
发布于 2020-08-26 02:55:16
1.2K00
代码可运行
举报
运行总次数:0
代码可运行

如果想尝试使用Google Colab上的TPU来训练模型,也是非常方便,仅需添加6行代码。

在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 TPU

注:以下代码只能在Colab 上才能正确执行。

可通过以下colab链接测试效果《tf_TPU》:

https://colab.research.google.com/drive/1XCIhATyE1R7lq6uwFlYlRsUr5d9_-r1s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import * 

一,准备数据

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
 
MAX_WORDS = x_train.max()+1
CAT_NUM = y_train.max()+1
 
ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()
 
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()

二,定义模型

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.backend.clear_session()
def create_model():
 
    model = models.Sequential()
 
    model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    return(model)
 
def compile_model(model):
    model.compile(optimizer=optimizers.Nadam(),
                loss=losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)]) 
    return(model)

三,训练模型

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 增加以下6行代码
import os
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
    model = create_model()
    model.summary()
    model = compile_model(model)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
INFO:tensorflow:Initializing the TPU system: grpc://10.62.22.122:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.62.22.122:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 300, 7)            216874    
_________________________________________________________________
conv1d (Conv1D)              (None, 296, 64)           2304      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 146, 32)           6176      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 2336)              0         
_________________________________________________________________
dense (Dense)                (None, 46)                107502    
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
history = model.fit(ds_train,validation_data = ds_test,epochs = 10)

前面的都没问题,最后运行上面这句话时colab崩溃了,colab自动重启,不知道是什么原因,下面是原书中的结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Train for 281 steps, validate for 71 steps
Epoch 1/10
281/281 [==============================] - 12s 43ms/step - loss: 3.4466 - sparse_categorical_accuracy: 0.4332 - sparse_top_k_categorical_accuracy: 0.7180 - val_loss: 3.3179 - val_sparse_categorical_accuracy: 0.5352 - val_sparse_top_k_categorical_accuracy: 0.7195
Epoch 2/10
281/281 [==============================] - 6s 20ms/step - loss: 3.3251 - sparse_categorical_accuracy: 0.5405 - sparse_top_k_categorical_accuracy: 0.7302 - val_loss: 3.3082 - val_sparse_categorical_accuracy: 0.5463 - val_sparse_top_k_categorical_accuracy: 0.7235
Epoch 3/10
281/281 [==============================] - 6s 20ms/step - loss: 3.2961 - sparse_categorical_accuracy: 0.5729 - sparse_top_k_categorical_accuracy: 0.7280 - val_loss: 3.3026 - val_sparse_categorical_accuracy: 0.5499 - val_sparse_top_k_categorical_accuracy: 0.7217
Epoch 4/10
281/281 [==============================] - 5s 19ms/step - loss: 3.2751 - sparse_categorical_accuracy: 0.5924 - sparse_top_k_categorical_accuracy: 0.7276 - val_loss: 3.2957 - val_sparse_categorical_accuracy: 0.5543 - val_sparse_top_k_categorical_accuracy: 0.7217
Epoch 5/10
281/281 [==============================] - 5s 19ms/step - loss: 3.2655 - sparse_categorical_accuracy: 0.6008 - sparse_top_k_categorical_accuracy: 0.7290 - val_loss: 3.3022 - val_sparse_categorical_accuracy: 0.5490 - val_sparse_top_k_categorical_accuracy: 0.7231
Epoch 6/10
281/281 [==============================] - 5s 19ms/step - loss: 3.2616 - sparse_categorical_accuracy: 0.6041 - sparse_top_k_categorical_accuracy: 0.7295 - val_loss: 3.3015 - val_sparse_categorical_accuracy: 0.5503 - val_sparse_top_k_categorical_accuracy: 0.7235
Epoch 7/10
281/281 [==============================] - 6s 21ms/step - loss: 3.2595 - sparse_categorical_accuracy: 0.6059 - sparse_top_k_categorical_accuracy: 0.7322 - val_loss: 3.3064 - val_sparse_categorical_accuracy: 0.5454 - val_sparse_top_k_categorical_accuracy: 0.7266
Epoch 8/10
281/281 [==============================] - 6s 21ms/step - loss: 3.2591 - sparse_categorical_accuracy: 0.6063 - sparse_top_k_categorical_accuracy: 0.7327 - val_loss: 3.3025 - val_sparse_categorical_accuracy: 0.5481 - val_sparse_top_k_categorical_accuracy: 0.7231
Epoch 9/10
281/281 [==============================] - 5s 19ms/step - loss: 3.2588 - sparse_categorical_accuracy: 0.6062 - sparse_top_k_categorical_accuracy: 0.7332 - val_loss: 3.2992 - val_sparse_categorical_accuracy: 0.5521 - val_sparse_top_k_categorical_accuracy: 0.7257
Epoch 10/10
281/281 [==============================] - 5s 18ms/step - loss: 3.2577 - sparse_categorical_accuracy: 0.6073 - sparse_top_k_categorical_accuracy: 0.7363 - val_loss: 3.2981 - val_sparse_categorical_accuracy: 0.5516 - val_sparse_top_k_categorical_accuracy: 0.7306
CPU times: user 18.9 s, sys: 3.86 s, total: 22.7 s
Wall time: 1min 1s

参考:

开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-04-13 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
API架构风格的深度解析与选择策略:SOAP、REST、GraphQL与RPC
API作为系统间通信的桥梁,其设计风格也在持续发展和完善。SOAP、REST、GraphQL和RPC作为四种主流的API架构风格,各自具有鲜明的特点和适用场景。
公众号:码到三十五
2024/12/09
2080
API架构风格的深度解析与选择策略:SOAP、REST、GraphQL与RPC
API 架构风格抉择:SOAP、REST、GraphQL 和 RPC 的特性、优势与局限
两个独立的应用程序需要一个中介来相互通信。因此,开发人员通常会构建桥梁——应用程序编程接口 (API) ——以允许一个系统访问另一个系统的信息或功能。
架构精进之路
2025/04/16
1470
API 架构风格抉择:SOAP、REST、GraphQL 和 RPC 的特性、优势与局限
安息吧 REST API,GraphQL 长存
即使与 REST API 打交道这么多年,当我第一次了解到 GraphQL 和它试图解决的问题时,我还是禁不住把本文的标题发在了 Twitter 上。
疯狂的技术宅
2019/03/27
2.8K0
安息吧 REST API,GraphQL 长存
GraphQL与传统API对比介绍教程
在现代应用程序开发中,API(应用程序接口)扮演着至关重要的角色。随着技术的发展,API的实现方式也在不断进化。本文将介绍两种常见的API实现方式:传统API(主要是REST)和GraphQL,并对它们进行对比分析。
IT蜗壳-Tango
2024/06/22
2760
与我一起学习微服务架构设计模式8—外部API模式
Web应用在防火墙内部运行,它们通过高带宽、低延迟的局域网访问服务。其他客户端在防火墙之外运行,通过较低带宽、较高延迟的互联网或移动网路访问。
java达人
2019/12/06
1.4K0
新一代数据查询语言GraphQL来啦!
1. GraphQL来啦! 当Facebook构建移动应用的时候,它需要的是一个强大的数据获取API: 足够强大,满足Facebook自身复杂业务的需求; 足够简单,对开发者和使用者来说很容易上手与使
IMWeb前端团队
2017/12/29
3K1
新一代数据查询语言GraphQL来啦!
理解 GraphQL:现代 API 查询语言的详解与实践
GraphQL 是一种用于 API 的查询语言,以及一个用于执行查询的服务器端运行时。它允许客户端精确地请求所需的数据,避免冗余和不足。GraphQL 由 Facebook 于 2012 年开发,并在 2015 年开源。
编程小妖女
2024/12/28
1560
理解 GraphQL:现代 API 查询语言的详解与实践
4种主流的API架构风格对比
本文讨论了四种主要的 API 架构风格,比较它们的优缺点,并重点介绍每种情况下最适合的 API 架构风格。
深度学习与Python
2021/01/21
2.4K0
标准化技术下的软件开发
聊到集成测试、单元测试等测试分类,我想大多数人都有类似困惑或讨论,集成测试和 E2E 测试到底有啥区别。甚至还有一些系统测试、配置项测试等概念,不但让我们这种非 QA 专业的人弄不清楚,在和我们的 QA 同学讨论时也很难得到清晰的结论。
ThoughtWorks
2019/09/29
9640
标准化技术下的软件开发
深入解析 RESTful API:从设计到实践的完整指南
在当今的互联网世界中,不同系统之间的数据交互和通信是构建现代应用的核心需求。无论是移动应用、Web 平台,还是微服务架构,RESTful API 都扮演着桥梁的角色。它以其简洁性、灵活性和可扩展性,成为开发者构建分布式系统的首选方案。本文将从基础概念到实际应用,一步步拆解 RESTful API 的设计与实现,助你掌握这一关键技术。
DevKevin
2025/02/16
2880
GraphQL与OpenAPI:数据治理的优缺点
一位财富 50 强公司的 CTO 评估了 OpenAPI 和 GraphQL API 标准的优缺点,以及它们与数据治理的相关性。
云云众生s
2024/08/13
1850
API接口架构REST vs GraphQL
无论是创建网站,还是移动应用程序,我们都需要通过 API 来传递数据,通过 API 我们可以获取到数据库中的数据,可以操作数据库,可以处理一些业务逻辑。现在最流行的 API 架构是 REST。但是,GraphQL 正在逐渐追赶着它。
程序那些事儿
2023/03/07
1.7K0
API接口架构REST vs GraphQL
常见形式 Web API 的简单分类总结
请求--响应类的API的典型做法是,通过基于HTTP的Web服务器暴露一个/套接口。API定义一些端点,客户端发送数据的请求到这些端点,Web服务器处理这些请求,然后返回响应。响应的格式通常是JSON或XML。
solenovex
2018/10/15
3.2K0
【RESTful】RESTful API 接口设计规范 | 示例
参考官方文档:https://tools.ietf.org/html/rfc2616
前端修罗场
2023/10/07
1.8K0
【RESTful】RESTful API 接口设计规范 | 示例
RESTful架构API风格与相关规范 极客开发者
以上是对RESTful架构的概述,在本文中,我将使用自己的理解完整的表述RESTful的规范,以及如何设计符合RESTful规范的API。实际上,在对计算机技术的理解中,一百个人可能会有一百种理解方式,尽管见仁见智,但我们的目的都是把技术当作工具,去实现我们的程序功能。如果在本文中的描述有所错误,或您有所不解,欢迎留言评论!
极客开发者
2022/01/18
4220
人人都是 API 设计者:我对 RESTful API、GraphQL、RPC API 的思考
有一段时间没怎么写文章了,今天提笔写一篇自己对 API 设计的思考。首先,为什么写这个话题呢?其一,我阅读了《阿里研究员谷朴:API 设计最佳实践的思考》一文后受益良多,前两天并转载了这篇文章也引发了广大读者的兴趣,我觉得我应该把我自己的思考整理成文与大家一起分享与碰撞。其二,我觉得我针对这个话题,可以半个小时之内搞定,争取在 1 点前关灯睡觉,哈哈。
用户2781897
2019/05/17
1.1K0
架构师该如何为应用选择合适的API
架构师的主要活动是做出正确的技术决策。选择合适的API是一项重要的技术决策。那么今天就看看API的选择问题。
yuanyi928
2020/06/17
1.7K0
最流行六种的 API 架构风格(附 Node.js DEMO)
本篇将介绍六种最流行的 API 架构风格,分别是 SOAP、RESTful、GraphQL、gRPC、WebSocket 和 Webhook。对于每种 API 架构风格,我们将深入探讨其优点、缺点以及适用场景,并提供相应的 DEMO 以帮助读者更好地理解每种 API 架构的实现方法和运作原理。
Cellinlab
2023/06/01
2.2K0
最流行六种的 API 架构风格(附 Node.js DEMO)
GraphQL 初体验,Node.js 构建 GraphQL API 指南
过去几年中,GraphQL 已经成为一种非常流行的 API 规范,该规范专注于使客户端(无论是客户端、前端还是第三方)的数据获取更加容易。
coder_koala
2021/01/08
8.4K1
【REST架构】OData、JsonAPI、GraphQL 有什么区别?
我在职业生涯中使用过很多 OData,现在我来自不同团队的同事中很少有人建议我们迁移到 JsonAPI 和 GraphQL,因为它与 Microsoft 无关。我对这两种查询语言都没有太多经验。据我所知,OData 是 Salesforce、IBM、Microsoft 使用的标准,并且非常成熟。为什么要切换到 JsonAPI 和/或 GraphQL?有真正的好处吗?JsonAPI 和 GraphQL 是新标准吗?根据受欢迎程度更改公共 api 实现似乎没有用,尤其是在没有太大好处的情况下。
架构师研究会
2022/05/29
1.7K0
推荐阅读
相关推荐
API架构风格的深度解析与选择策略:SOAP、REST、GraphQL与RPC
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档