上节课,我们修改了Dataset类,数据可以正常加载进来了,接下来,还需要微调一下模型,然后就可以开始训练了。

代码示例

1、squeeze()导致的报错

max_pool2d 之后,输出 tensor.shape = [10, 256, 1, 1],之后的 squeeze(),正常情况下会压缩 tensor 的最后两维,变成 [10, 256]。但如果是我们的训练数据是101条,batch_size = 10,最后一次就只剩下一条数据,max_pool2d 输出 tensor.shape = [1, 256, 1, 1],再经过 squeeze(),会变成 [256],后续程序会报错,所以我们需要改为只压缩最后两维。

内容不可见,请联系管理员开通权限。

2、模型输出值修改

一般分类场景,直接输出 linear 层的结果,也是可以的,在做交叉熵损失计算的时候,模型会自动转化。但是在这个项目里面,我们把 linear 层返回的结果,包一层 sigmiod,返回值的范围就在0-1之间,相当于是一个概率值,后面有别的用处。

内容不可见,请联系管理员开通权限。

3、添加验证模块

内容不可见,请联系管理员开通权限。

4、评估函数修改

batch比较小时,标签可能出现空缺,比如[2, 5],classification_report 无法判断总的标签数量。

内容不可见,请联系管理员开通权限。

5、边训练边验证

内容不可见,请联系管理员开通权限。

现在,模型的训练流程就可以跑通了,接下来,还是老办法,把代码和数据,传到 Kaggle 上训练,这个过程已经讲过很多次了,就不在重复演示,大家课后自己完成。

本文链接:http://www.ichenhua.cn/edu/note/626

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!