import torch
import torch.nn as nn
import time
import json
==========================================
1. Preparação de Dados (Auto-regressivo)
==========================================
def get_batch(data, block_size, batch_size, device):
# Gera índices aleatórios para o início dos blocos
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
# Y é o X deslocado uma posição para a direita
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x.to(device), y.to(device)
==========================================
2. Loop de Treinamento Otimizado
==========================================
def train_model(model, train_data, val_data, config):
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
model.to(config['device'])
for iter in range(config['max_iters']):
model.train()
xb, yb = get_batch(train_data, config['block_size'], config['batch_size'], config['device'])
# Mixed Precision para velocidade
with torch.cuda.amp.autocast():
logits = model(xb)
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T))
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Gradient Clipping: Impede a explosão do gradiente
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if iter % 100 == 0:
ppl = torch.exp(loss) # Cálculo da Perplexidade
print(f"Iter {iter}: Loss {loss.item():.4f} | PPL {ppl.item():.2f}")
==========================================
3. Interface Interativa (CLI para Deploy)
==========================================
@torch.no_grad()
def interactive_chat(model, tokenizer, max_new_tokens=50):
model.eval()
print("\n--- Model GPT-Goiás Online (Digite 'sair' para encerrar) ---")
while True:
prompt = input("Você: ")
if prompt.lower() == 'sair': break
# Converte texto para IDs
context = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=config['device']).unsqueeze(0)
# Geração de resposta (sem gradiente para poupar memória)
generated_ids = model.generate(context, max_new_tokens=max_new_tokens, temperature=0.7)
resposta = tokenizer.decode(generated_ids[0].tolist())
print(f"IA: {resposta}\n")
==========================================
4. Salvamento de Artefatos (State Dict + Config)
==========================================
def save_deployment_artifacts(model, config, path="model_deploy.pt"):
checkpoint = {
'model_state': model.state_dict(),
'config': config,
'vocab_size': config['vocab_size']
}
torch.save(checkpoint, path)
with open("config.json", "w") as f:
json.dump(config, f)