前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Numpy简述神经网络模型权重搜索原理-Pytorch引文

Numpy简述神经网络模型权重搜索原理-Pytorch引文

作者头像
一个有趣的灵魂W
发布2023-10-06 17:05:06
1650
发布2023-10-06 17:05:06
举报

Tensorflow的bug太多了,我只能转投Pytorch的怀抱

01

最近Tensorflow(下称TF)已死的言论不知道大家是否接收到:

放弃支持Windows GPU、bug多,TensorFlow被吐槽:2.0后慢慢死去 https://zhuanlan.zhihu.com/p/656241342

主要是谷歌放弃了在Windows上对TF的支持。对普通开发者而言,顶层信息其实并没有太大的波澜,随波逐流就是。

但是,如果我们嗅到一丝丝警觉而不管不顾的话,早晚要被抛弃!

所以,Pytorch(下称torch)还不得不信手拈来。同时,让我们顺带复习一下基本的求导、前馈、权重、Loss等词汇在深度学习里是怎么运作的吧:

正文开始:

学习torch之前,容我们思考一下,深度学习模型的学习思维和逻辑过程。假如,面对我们的是一个线性模型:Y=wX。那我们最关键的是学习(训练、调整)权重w的值。

02

以下代码能让我们直观的感受到w的粗略学习过程:

代码语言:javascript
复制
import numpy as np
w_list=[]
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1 # Random value
def forward(x):
    return x*w
def loss(x, y):
    y_pred = forward(x)
    return (y_pred-y)*(y_pred-y)
for w in np.arange(0.0,4.1,0.1):
    print("w=", w)
    l_sum=0
    for x_val, y_val in zip (x_data, y_data):
        y_pred_val = forward(x_val)
        l = loss(x_val, y_val)
        l_sum+=l
        print("\t", x_val, y_val, y_pred_val, l)
        
    print("MSE=", l_sum/3)
    w_list.append(w)
    mse_list.append(l_sum/3)



import matplotlib.pyplot as plt
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

上述内容就是一个w的变化过程。从原始数据中我们可以简单判断出,w应该等于2。权重不断的在改变中经过了2,但并没有停止的意思。因此我们的模型并不能给我们最终结果为2。

03

由此,我们需要优化:

优化的过程需要涉及到求导,导数为0的时候就是我们线性函数的最优解(暂时)。

代码语言:javascript
复制
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

w_list = []
mse_list=[]
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0

# Function for forward pass to predict y
def forward(x):
    return x*w
def loss(x,y):
    y_pred = forward(x)
    return (y_pred-y)**2
def gradient(x,y):
    return 2*x*(x*w-y)

# Training loop

print('Predict (before training)', 4, forward(4))

# Training loop

for epoch in range(100):
    l_sum=0
    for x_val, y_val in zip(x_data, y_data):
        grad = gradient(x_val, y_val)
        w = w-0.01*grad
        print('\tgrad: ', x_val, y_val, grad)
        l=loss(x_val, y_val)
        l_sum+=l
        
    print('Progress: ', epoch, 'w=', w, 'loss=', l)
    w_list.append(w)
    mse_list.append(l_sum/3)
    
    
print('Predict (After training)', '4 hours', forward(4))    

plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

那现在,我们就能顺利得到循环n次后的最优解w=2。

这就是这个学习过程的基本思路,但它其实并不需要涉及到torch,这是因为我们目前还没涉及到自动微分的过程。

04

torch其实就是集成了许多核心运算形式,方便我们调用。这点TF其实也是一样的。只不过在使用过程中,许多开发者发现TF版本兼容性较差,动不动就因为版本原因产生bug。解决bug的成本太高了,所以许多人才转投torch等其他开源框架。

下期,我们将重点描述torch的基本入门操作。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-10-03 14:02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 一个有趣的灵魂W 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档