0
respostas

[Projeto] questão é verificar se você entendeu a diferença entre usar: RNNCell/GRU e usarLSTM

Ao substituir a GRU por um LSTM, é necessário alterar a definição da camada recorrente e também o forward(), pois o LSTM utiliza dois estados internos: o hidden stateeo cell state. Assim, além de inicializar hidden, também é preciso inicializar cellcom a mesma dimensionalidade e passar ambos para a camada LSTM. A saída final continua sendo obtida a partir do último passo temporal, que é enviado para a camada linear de classificação.

implementação:
import torch
import torch.nn as nn

Exemplo de dicionário de configuração

args = {
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

class RNN(nn.Module):
def init(self, tam_entrada, tam_feature, tam_saida):
super(RNN, self).init()

    # Salva as dimensões principais
    self.tam_entrada = tam_entrada   # tamanho de cada elemento da sequência
    self.tam_feature = tam_feature   # tamanho do estado oculto
    self.tam_saida   = tam_saida     # número de classes

    # Substituição da GRU pela LSTM
    # batch_first=True => entrada no formato (batch, seq_len, features)
    self.rnn = nn.LSTM(
        input_size=self.tam_entrada,
        hidden_size=self.tam_feature,
        batch_first=True
    )

    # Camada linear para classificar a saída da LSTM
    self.linear = nn.Linear(self.tam_feature, self.tam_saida)

    # LogSoftmax para saída em log-probabilidades
    self.softmax = nn.LogSoftmax(dim=-1)

def forward(self, nome):
    """
    nome: tensor de uma sequência
    Exemplo de shape antes do unsqueeze:
        (seq_len, tam_entrada)

    Como batch_first=True, a LSTM espera:
        (batch, seq_len, tam_entrada)

    Por isso usamos unsqueeze(0) para criar batch = 1.
    """

    # Como estamos processando um nome por vez, batch_size = 1
    batch_size = 1

    # Inicializa o hidden state com zeros
    # Shape: (num_layers, batch_size, tam_feature)
    hidden = torch.zeros(1, batch_size, self.tam_feature).to(args['device'])

    # Inicializa o cell state com zeros
    # Mesmo shape do hidden
    cell = torch.zeros(1, batch_size, self.tam_feature).to(args['device'])

    # Adiciona dimensão de batch: (seq_len, tam_entrada) -> (1, seq_len, tam_entrada)
    nome = nome.unsqueeze(0).to(args['device'])

    # Aplica a LSTM
    # saida: saída para cada passo temporal
    # hidden: hidden state final
    # cell: cell state final
    saida, (hidden, cell) = self.rnn(nome, (hidden, cell))

    # Pegamos a saída do último passo da sequência
    # Como batch_first=True, saida tem shape:
    # (batch, seq_len, tam_feature)
    # então saida[:, -1, :] seleciona o último passo temporal
    saida = self.linear(saida[:, -1, :])

    # Aplica log softmax para classificação
    saida = self.softmax(saida)

    return saida