x= self.seq(x)

假设经过这一步,输出了需要分类的特征。

 head1 = nn.Linear(linear_size, n_classes)
 head2 = nn.Linear(linear_size, n_classes)
 .........

然后,就可以做分类了:

x = self.seq(x)
head1_out = head1(x)
head2_out = head2(x)
......

问题:

  1. 那可不可以同时复用head1的而不用new这么多head?
    若是你的head层的权重是不被调整更新的那可以,(但应该把它过滤掉,设置为不随训练更新,参考transformer的位置编码)

  2. head多了,那就一个个手动改代码么?
    可以做成pipeline

# head_num 为头的数量,这个示例是 linear_size 和 n_classes是相同的,你做个列表,就可以自动提取不相同了。
nn.ModuleList([nn.Linear(linear_size3, n_classes) for _ in range(head_num)])
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐