在学习《动手学深度学习》时,实现下面代码时,报出raise NotImplementedError错误。



import collections
import torch
from d2l import torch as d2l
import math
from torch import nn

class Seq2SeqEncoder(d2l.Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super(Seq2SeqEncoder,self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
    def forward(self, X, *args):
        X = self.embedding(X)
        X = X.permute(1,0,2)
        output,state = self.rnn(X)
        return output,state

encoder = Seq2SeqEncoder(10,8,16,2)
encoder.eval()
X = torch.zeros((4,7),dtype=torch.long)
output,state = encoder(X)
print(output.shape)


class Seq2SeqDecoder(d2l.Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super(Seq2SeqDecoder,self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense = nn.Linear(num_hiddens,vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def farward(self,X,state):
        X = self.embedding(X).permute(1,0,2)
        context = state[-1].repeat(X.shape[0],1,1)
        X_and_context = torch.cat((X,context),2)
        output,state = self.rnn(X_and_context,state)
        output = self.dense(output).permute(1,0,2)
        return output,state

decoder = Seq2SeqDecoder(10,8,16,2)
print(decoder.eval())
state = decoder.init_state(encoder(X))
output,state = decoder(X,state)
print(output.shape)

在这里插入图片描述
原因是类Seq2SeqDecoder在继承d2l.Decoder类时,需要重写父类的方法,而我把forward写成了farward。因此,出现了报错。

在深度学习中,子类继承父类时,需要重写父类的方法。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐