Em Redes Neurais MLP(Multilayer Perceptron) o Backpropagation calcula os erros de cada unidade na camada oculta atual L
, somando os erros de todas as unidades da camada seguinte L+1
multiplicando pelos pesos que fazem a conexão entre a unidade atual da camada oculta com as unidades da camada seguinte L+1
. Mais isso numa MLP. E nas MLPs não existe instantes de tempo, e não existem estados ocultos.
Porém eu sei que o Backpropagation Through Time (BPTT) de RNN é diferente, e que é bem mais complexo que o Backpropagation tradicional de uma MLP.
Por isso, eu fico um pouco confuso quando tempo usar o que eu sei sobre o algoritmo do Backpropagation tradicional de uma MLP para tentar entender como o Backpropagation Through Time (BPTT) de RNN funciona. Não consigo encontrar muitos pontos em comum pra ajudar eu a entender o BPTT.
Por exemplo, fazendo perguntas para o ChatGPT, eu tive a leve impressão de que, o Backpropagation Through Time (BPTT) usados nas RNNs poderia ser semelhante ao Backpropagation tradicional de uma MLP que usa treinamento FullBatch(pois ele faz uma acumulação de gradientes) e o Backpropagation Through Time (BPTT) parece fazer uma acumulação de gradientes também, só que para cada instante de tempo. Porém eu não consigo ter certeza.
Também, pelo que entendi, o Backpropagation Through Time (BPTT) só atualiza os pesos quando ele termina de processar o ultimo instante de tempo(ou seja, a ultima amostra) da sequencia.
Eu reescrevi com minhas palavras algumas coisas que o Chat GPT me falou abaixo:
Para atualizar os pesos das entradas, nós fazemos exatamente igual numa MLP tradicional, usando o gradiente do erro em relação a cada peso.
Porém, para atualizar os pesos de um estado oculto t
nós precisamos levar em conta o gradiente acumulado de todos os instantes de tempo futuros: t+1
, t+2
etc.... Isso por que, o erro total da rede é calculado no ultimo instante de tempo(ou seja, na ultima amostra), e então, ele é retropropagado para traz. Ou seja, esse processo começou no ultimo instante de tempo T e vai até o instante de tempo 1(o primeiro instante de tempo). E isso tudo significa que, para cada instante de tempo t
, o gradiente do erro em relação ao estado oculto Ht
depende não só do erro atual, mas também do erro propagado dos instantes futuros( t+1
, t+2
, etc...., T
). O BPTT acumula gradientes ao longo do tempo, mas essa acumulação não é uma simples soma dos gradientes de cada instante. Em vez disso, os gradientes são propagados de forma recursiva, levando em conta a dependência temporal entre os estados ocultos. Isso é o que diferencia o BPTT de um backpropagation tradicional em uma MLP com fullbatch. Esse comportamento de considerar todos os instantes futuros ao calcular o gradiente se aplica tanto aos pesos associados ao estado oculto quanto aos pesos associados às entradas da rede. Então, o gradiente de W(dos pesos) acumula as contribuições ao longo de todos os instantes de tempo. Isso ocorre por que, os instantes seguintes(t+1
, t+2
, etc...) vão depender mesmo que indiretamente do instante atual t
, por causa da recorrência dos estados ocultos explicada aqui.
ENTÂO: Quando chegamos ao último instante de tempo(na ultima amostra do dataset), o BPTT começa a ser aplicado do último instante de tempo até o primeiro instante de tempo, propagando os gradientes de volta no tempo. Durante essa propagação:
Gradientes dos pesos das entradas(pesos W): São acumulados ao longo de todos os instantes de tempo.
Gradientes dos pesos dos estados ocultos(pesos U): São calculados a cada passo de retropropagação e acumulados também ao longo do tempo.
E TAMBÈM, Após acumular os gradientes para todos os instantes de tempo, ocorre uma única atualização dos pesos:
- Os pesos das entradas(pesos W) e os pesos dos estados ocultos(pesos U) são atualizados simultaneamente, utilizando os gradientes acumulados ao longo de todos os instantes.
Porem eu ainda não sei de como o Backpropagation Through Time (BPTT) realmente funciona.
Algumas coisas que eu li nas respostas do modelo soaram um pouco confusas, e eu não tenho certeza se podem haver erros. Então não confio totalmente no que o modelo diz.
DUVIDAS:
Como funciona o Backpropagation Through Time (BPTT) ?
Qual o passo a passo que ele segue ?
Será que as coisas que o ChatGPT disse faz sentido?