13李沐动手学深度学习v2/权重衰退从0开始实现
·
总结
- 什么东西需要计算梯度,
requires_grad=True - 什么时候开始计算梯度,
with torch.enable_grad(): - 正则项(惩罚项):λ\lambdaλ越大,www选择范围越小,降低模型复杂度,避免过拟合
- 一般lambda=1e-3=0.001,lambda不会寻到1等等大的值
# 权重衰退是广泛应用的正则化技术
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
人工数据集
y=0.05+∑i=1d0.01xi+ϵ,whereϵ∼η(0,0.012),ϵ是偏差y=0.05+\sum\limits^d_{i=1}0.01x_i+\epsilon, where \epsilon\sim \eta(0,0.01^2), \epsilon是偏差y=0.05+i=1∑d0.01xi+ϵ,whereϵ∼η(0,0.012),ϵ是偏差
# 数据
n_train,n_test,num_inputs,batch_size=20,100,200,5
# true_b=0.05
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
# 20个训练数据样本,模型容量大(num_inputs=100)+数据量小=容易发生过拟合
train_data=d2l.synthetic_data(true_w,true_b,n_train)
train_iter=d2l.load_array(train_data,batch_size)
# 5个测试数据样本,实际上是验证集
test_data=d2l.synthetic_data(true_w,true_b,n_test)
test_iter=d2l.load_array(test_data,batch_size,is_train=False)
# 参数初始化
def init_params():
# !requires_grad=True作用,需要对这个参数w计算梯度
w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)
# requires_grad=True作用,需要对这个参数w计算梯度
b=torch.zeros(1,requires_grad=True)
return [w,b]
# 惩罚项
def l2_penalty(w):
return torch.sum(w.pow(2))/2
# 训练
def train(lambd):
w,b=init_params()
# 模型和损失函数
net,loss=lambda X:d2l.linreg(X,w,b),d2l.squared_loss
# 超参数
num_epochs,lr=100,0.003
# 展示 xlabel, x轴代表什么;ylabel,y轴代表什么;yscale,y轴缩放类型;xlim,x轴范围限制;legend,铭文,图例
animator=d2l.Animator(xlabel='epochs',ylabel='loss',
yscale='log',
xlim=[5, num_epochs],
legend=['train', 'test'])
# 开始训练
for epoch in range(num_epochs):
for X,y in train_iter:
# 上下文管理器
# 先调用with后面的`troch.enable_grad()`的`__enter__()`方法,执行完with内部再调用troch.enable_grad()`的` __exit__()`
with torch.enable_grad():
# 增加了L2范数惩罚项。广播机制,w被复制batchs_size次
l=loss(net(X),y)+lambd*l2_penalty(w)
# sum()不影响梯度,因为是梯度是求偏导
# !书写损失函数时就需要梯度,后向传播之前就关闭梯度
l.sum().backward()
d2l.sgd([w,b],lr,batch_size)
# 运行完5个epoch之后展示1次
if(epoch+1)%5==0:
animator.add(epoch + 1,
(d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
# 均方损失
print('w的L2范数:',torch.norm(w).item())
# 无正则
train(lambd=0)
# 有正则项且lambda=3
train(lambd=3)
# 有正则项且lambda=10
train(lambd=10)
w的L2范数: 0.02175123244524002
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zv80bQYp-1662210118936)(output_6_1.svg)]](https://i-blog.csdnimg.cn/blog_migrate/62b1d2f5b8ca85fbae1e05f4a3ad40fc.png)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ndb2YX0H-1662210118939)(output_6_2.svg)]](https://i-blog.csdnimg.cn/blog_migrate/081994279b826b7785836fc754e8132e.png)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g6UcJPbB-1662210118940)(output_6_3.svg)]](https://i-blog.csdnimg.cn/blog_migrate/525969ebefb4adea636d624b74d62368.png)
更多推荐
所有评论(0)