Ao substituir a GRU por um LSTM, é necessário alterar a definição da camada recorrente e também o forward(), pois o LSTM utiliza dois estados internos: o hidden stateeo cell state. Assim, além de inicializar hidden, também é preciso inicializar cellcom a mesma dimensionalidade e passar ambos para a camada LSTM. A saída final continua sendo obtida a partir do último passo temporal, que é enviado para a camada linear de classificação.
implementação:
import torch
import torch.nn as nn
Exemplo de dicionário de configuração
args = {
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
class RNN(nn.Module):
def init(self, tam_entrada, tam_feature, tam_saida):
super(RNN, self).init()
# Salva as dimensões principais
self.tam_entrada = tam_entrada # tamanho de cada elemento da sequência
self.tam_feature = tam_feature # tamanho do estado oculto
self.tam_saida = tam_saida # número de classes
# Substituição da GRU pela LSTM
# batch_first=True => entrada no formato (batch, seq_len, features)
self.rnn = nn.LSTM(
input_size=self.tam_entrada,
hidden_size=self.tam_feature,
batch_first=True
)
# Camada linear para classificar a saída da LSTM
self.linear = nn.Linear(self.tam_feature, self.tam_saida)
# LogSoftmax para saída em log-probabilidades
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, nome):
"""
nome: tensor de uma sequência
Exemplo de shape antes do unsqueeze:
(seq_len, tam_entrada)
Como batch_first=True, a LSTM espera:
(batch, seq_len, tam_entrada)
Por isso usamos unsqueeze(0) para criar batch = 1.
"""
# Como estamos processando um nome por vez, batch_size = 1
batch_size = 1
# Inicializa o hidden state com zeros
# Shape: (num_layers, batch_size, tam_feature)
hidden = torch.zeros(1, batch_size, self.tam_feature).to(args['device'])
# Inicializa o cell state com zeros
# Mesmo shape do hidden
cell = torch.zeros(1, batch_size, self.tam_feature).to(args['device'])
# Adiciona dimensão de batch: (seq_len, tam_entrada) -> (1, seq_len, tam_entrada)
nome = nome.unsqueeze(0).to(args['device'])
# Aplica a LSTM
# saida: saída para cada passo temporal
# hidden: hidden state final
# cell: cell state final
saida, (hidden, cell) = self.rnn(nome, (hidden, cell))
# Pegamos a saída do último passo da sequência
# Como batch_first=True, saida tem shape:
# (batch, seq_len, tam_feature)
# então saida[:, -1, :] seleciona o último passo temporal
saida = self.linear(saida[:, -1, :])
# Aplica log softmax para classificação
saida = self.softmax(saida)
return saida