机器学习优化器和SGD和SGDM实验对比(编程实现SGD和SGDM)

news/2025/2/12 5:22:48/

机器学习优化器和SGD和SGDM实验对比

博主最近在学习优化器,于是呢,就做了一个SGD和SGDM的实验对比,可谓是不做不知道,一做吓一跳,这两个算法最终对结果的影响还是挺大的,在实验中SGDM明星要比SGD效果好太多了,SGD很容易陷入局部最优,而且非常容易发生梯度爆炸的情况,而SGDM做的实验当中还未出现这些情况。
在这次实验中博主发现了很多很多的特点对于SGDM和SGDM,下面博主列出这次实验的收获。
(1)SGDM相比SGD当拥有同样的学习率,SGDM更不容易发生梯度爆炸,SGD对于学习率的要求很高,大了,就会梯度爆炸,小了迭代特别慢。
(2)在本次此实验中,我们可以发现,小批量梯度下降比单个样本进行梯度下降区别极为大,单个样本做梯度下降时,特别容易发生梯度爆炸,模型不易收敛。
(3)SGDM相比SGD,loss下降曲线更加平稳,也不易陷入局部最优,但是他的训练较慢,可以说是非常慢了。
(4)超参数的设置对于这两个模型的影响都是很大的,要小心处理。
(5)数据集对于模型迭代也有很大影响,注意要对数据集进行适当的处理。
(6)随着训练轮次的增多,SGDM相比SGD更有可能取得更好的效果。

下面让我们看一看代码:

#coding=gbkimport torch
from torch.autograd import Variable
from torch.utils import data
import matplotlib.pyplot as pltX =torch.randn(100,4)
w=torch.tensor([1,2,3,4])Y =torch.matmul(X, w.type(dtype=torch.float))  + torch.normal(0, 0.1, (100, ))+6.5
print(Y)
Y=Y.reshape((-1, 1))#将X,Y转成200 batch大小,1维度的数据def loss_function(w,x,y,choice,b):if choice==1:return torch.abs(torch.sum(w@x)+b-y)else:#    print("fdasf:",torch.sum(w@x),y)#   print(torch.pow(torch.sum(w@x)-y,2))return torch.pow(torch.sum(w@x)-y+b,2)
index=0batch=32
learning_rating=0.03def SGDM(batch,step,beta,grad_s,grad_b_s):if step==0:grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_b#   print(loss_function(w,X[index],Y[index],2,b))loss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  grad/batch,loss/batch,grad_b/batchelse:grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_bloss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  (beta*grad_s+(1-beta)*grad)/batch,loss/batch,(beta*grad_b_s+(1-beta)*grad_b)/batchdef train(n):loss_list=[]setp=0global grad,grad_bgrad=0grad_b=0while n:n=n-1grad,loss,grad_b=SGDM(batch,setp,0.99,grad,grad_b)setp=setp+1# print(grad,loss,grad_b)w.data=w.data+learning_rating*grad*w.datab.data=b.data+learning_rating*grad_b# print("b",b)#print("grad_b",grad_b)#print("w:",w)#print("loss:",loss)#print("b:",b)loss_list.append(float(loss))#  b.data=b.data-(lear#  ning_rating*b.grad.data)#   print("b",b)print("w:",w)print("b:",b)print("loss:",loss)return loss_listdef SGD(batch):grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_bloss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  grad/batch,loss/batch,grad_b/batchdef train_s(n):loss_list=[]while n:if n//100==0:print(n)n=n-1grad,loss,grad_b=SGD(batch)# print(grad,loss,grad_b)w.data=w.data+learning_rating*grad*w.datab.data=b.data+learning_rating*grad_b# print("b",b)#print("w:",w)#print("loss:",loss)#print("b:",b)#  b.data=b.data-(learning_rating*b.grad.data)#   print("b",b)loss_list.append(float(loss))print("w:",w)print("b:",b)print("loss:",loss)return loss_listw=torch.tensor([1,1.0,1,1])b=torch.tensor([1.0])
w=Variable(w,requires_grad=True)b=Variable(b,requires_grad=True)epoch=10000
epoch_list=list(range(1,epoch+1))loss_list=train(epoch)plt.plot(epoch_list,loss_list,label='SGDM')#SGD
w=torch.tensor([1,1.0,1,1])b=torch.tensor([1.0])
w=Variable(w,requires_grad=True)b=Variable(b,requires_grad=True)
print(w)epoch_list=list(range(1,epoch+1))loss_list=train_s(epoch)plt.plot(epoch_list,loss_list,label='SGD')
plt.legend()
plt.show()

下面是一张跑出的实验图,事实上,我做了很多很多的实验,这是一件十分有趣的事情,在实验中,你可以看到这些优化器的特点,这很有趣,当然前提是这个优化器是你自己完全编程写的。
在这里插入图片描述


http://www.ppmy.cn/news/659212.html

相关文章

2022-2027年中国智能炒菜机行业发展监测及投资战略研究报告

【报告格式】电子版、纸介版 【出品单位】华经产业研究院 本报告由华经产业研究院出品,对中国智能炒菜机行业的发展现状、竞争格局及市场供需形势进行了具体分析,并从行业的政策环境、经济环境、社会环境及技术环境等方面分析行业面临的机遇及挑战。还…

高性能电工电子电拖及自动化技术实训与考核装置

ZN-800D高性能电工电子电拖及自动化技术实训与考核装置 一、概述: ZN-800D高性能电工电子电拖及自动化技术实训与考核装置是吸取国内外同类产品合理的实训方法,先进、科学的实训手段并加以消化、整合、提高而研制的针对高等院校、职业技术教育而开发的综合性实训考核装置。其…

1253. 重构 2 行二进制矩阵(力扣)

1253. 重构 2 行二进制矩阵(力扣) 题目第一种方式分析测试代码运行结果 第二种方式测试代码运行结果 题目 给你一个 2 行 n 列的二进制数组: 矩阵是一个二进制矩阵,这意味着矩阵中的每个元素不是 0 就是 1。 第 0 行的元素之和为…

浅谈基于分项计量的校园能源监管平台解决方案设计

张心志 关注acrelzxz 安科瑞电气股份有限公司 上海嘉定 201801 摘要:伴随着我国经济的飞速发展,国家机关办公建筑和大型公共建筑高耗能的问题日益突出,如何解决建筑能耗己成为一个国家总能耗的重要组成部分。学校是肩负着教育、科研和社会服…

windows store 下载软件出现错误代码: 0x80D03805

windows store 下载软件出现错误代码: 0x80D03805 8265 8260的网卡驱动安装后在Microsoft store会出现错误。 把网卡驱动卸载后 在重启 就可以更新了 不过更新好后 在把驱动装上。。。。 以下是原话 The latest ProSet/Driver package of v20.80.1 with driver v20.70.2.1 fo…

thinkpad 安装XP全过程

相信很多人对微软的Vista系统都很厌烦,我先前并不已为然。自从这次使用Vista系统后深刻体会到了这点。因此我将原有系统换成XP系统。本人记录下该过程以备不时之需:步骤:1. 按F1,进入到BISS,选择Config下的Serial ATA(SATA),将ahic模式修改成compatibili…

谈IBM ThinkPad随机软件的介绍说明

用IBM的笔记本的朋友非常的多,大多数人都喜欢删除原带的系统,自己来按装新的操作系统,可是重装系统后对于IBM随机带的软件可以说是一头雾水,因为 软件大多都为英文,而且数量又多,没办法许多朋友就用全部按装法,把只要是随机带的全部装到机器上,可是有一部分是完全没有必要装的,…

没有鼠标怎么还原计算机,系统恢复选项鼠标键盘不能用,鼠标没有右键功能

鼠标是操作电脑的必备工具,很多人为了更好的使用鼠标,都会对鼠标速度、指针移动速度等进行调节,一般可以通过控制面板中的鼠标选项来进行相关设置,然而有win7系统用户发现打开鼠标选项的时候,弹出一个窗口,…