博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Task6.PyTorch理解更多神经网络优化方法
阅读量:5072 次
发布时间:2019-06-12

本文共 3520 字,大约阅读时间需要 11 分钟。

1.了解不同优化器

2.书写优化器代码

3.Momentum
4.二维优化,随机梯度下降法进行优化实现
5.Ada自适应梯度调节法
6.RMSProp
7.Adam
8.PyTorch种优化器选择

梯度下降法:

1.标准梯度下降法:GD

每个样本都下降一次,参考当前位置的最陡方向迈进容易得到局部最优,且训练速度慢

2.批量下降法:BGD

不再是一次输入样本调整一次,而是一批量数据后进行调整,模型参数的调整更新与全部输入样本的代价函数的和有关,即下山前掌握附近地势,选择最优方向。

3.随机梯度下降法SGD

在一批数据里随机选取一个样本。如盲人下山,并与用走一次计算一次梯度,总能到山底。但引入的噪声可能使得权值更新放下错误。,没法单独克服局部最优解。

动量优化法

标准动量优化momentum
当前权值的改变会受到上一次权值改变得影响。类似小球下滚得时候带上惯性,加快滚动速度。

NAG牛顿加速梯度

NAG牛顿加速梯度 施加当前速度后 ,往标准动量中添加一个校正因子。momentun小球盲目跟从梯度,但nag小球指走到坡底时速度慢下来,知道下一位置大致在哪,来更新当前位置参数。

Ada自适应梯度调节法: Adagrad:该算法的特点是自动调整学习率,适用于稀疏数据。梯度下降法在每一步对每一个参数使用相同的学习率,这种一刀切的做法不能有效的利用每一个数据集自身的特点。 Adadelta(Adagrad的改进算法):Adagrad的一个问题在于随着训练的进行,学习率快速单调衰减。Adadelta则使用梯度平方的移动平均来取代全部历史平方和。

RMSProp:RMSprop也是一种学习率调整的算法。Adagrad会累加之前所有的梯度平方,而RMSprop仅仅是计算对应的平均值,因此可缓解Adagrad算法学习率下降较快的问题。

Adam:如果把Adadelta里面梯度的平方和看成是梯度的二阶矩,那么梯度本身的求和就是一阶矩。Adam算法在Adadelta的二次矩基础之上又引入了一阶矩。而一阶矩,其实就类似于动量法里面的动量。

1 import torch 2 import torch.utils.data as Data 3 import torch.nn.functional as F 4 import matplotlib.pyplot as plt 5  6 LR = 0.01 7 BATCH_SIZE = 32 8 EPOCH = 12 9 10 x = torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)11 y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))12 13 plt.scatter(x.numpy(),y.numpy())14 plt.show()15 16 torch_dataset = Data.TensorDataset(x,y)17 loader = Data.DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2)18 19 torch_dataset = Data.TensorDataset(x,y)20 loader = Data.DataLoader(21     dataset=torch_dataset,22     batch_size=BATCH_SIZE,23     shuffle=True,24     num_workers=2,25 )26 27 class Net(torch.nn.Module):28     def __init__(self):29         super(Net,self).__init__()30         31         self.hidden = torch.nn.Linear(1,20)32         self.predict = torch.nn.Linear(20,1)33         34     def forward(self,x):35         x = F.relu(self.hidden(x))36         x = self.predict(x)37         return x38     39 net_SGD         = Net()40 net_Momentum    = Net()41 net_RMSprop     = Net()42 net_Adam        = Net()43 nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]44 45 # different optimizers46 opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)47 opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)48 opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)49 opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))50 optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]51 52 loss_func = torch.nn.MSELoss()53 losses_his = [[], [], [], []]   # record loss54 55 # training56 for epoch in range(EPOCH):57     print('Epoch: ', epoch)58     for step, (b_x, b_y) in enumerate(loader):          # for each training step59         for net, opt, l_his in zip(nets, optimizers, losses_his):60             output = net(b_x)              # get output for every net61             loss = loss_func(output, b_y)  # compute loss for every net62             opt.zero_grad()                # clear gradients for next train63             loss.backward()                # backpropagation, compute gradients64             opt.step()                     # apply gradients65             l_his.append(loss.data.numpy())     # loss recoder66 67 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']68 for i, l_his in enumerate(losses_his):69     plt.plot(l_his, label=labels[i])70 plt.legend(loc='best')71 plt.xlabel('Steps')72 plt.ylabel('Loss')73 plt.ylim((0, 0.2))74 plt.show()

 

参考:https://blog.csdn.net/qingxuanmingye/article/details/90514018

转载于:https://www.cnblogs.com/NPC-assange/p/11373495.html

你可能感兴趣的文章
算法训练 Torry的困惑(基本型)
查看>>
SSM框架构建多模块之业务拆分实践
查看>>
lombok问题
查看>>
算法图解之散列表
查看>>
Redis 命令
查看>>
Random获得的随机数怎么样减少重复率
查看>>
C++模板之Vector与STL初探
查看>>
生成器模式(建造者模式)
查看>>
ros中删除某个包之后用apt安装的包找不到
查看>>
分享几个可用的rtsp, http測试url
查看>>
Hadoop - YARN 启动流程
查看>>
(八十六)使用系统自带的分享框架Social.framework
查看>>
gitlab wiki 500
查看>>
sql 执行顺序
查看>>
C和C++实务精选丛书
查看>>
强制 类型转换
查看>>
PWN_3 ORW
查看>>
Android快速开发不可或缺的11个工具类
查看>>
【原】docker部署单节点consul
查看>>
样式化复选框(Styling Checkbox)
查看>>