Dikkat mekanizmalı RNN modelini eğitme ve test etme
PyBooks ekibi, daha önce dikkat mekanizması olmadan kelime tahmini için bir RNN modeli geliştirmişti. Bu ilk model rnn_model olarak anılıyor, eğitildi ve örneği önceden yüklendi. Görevin, yeni RNNWithAttentionModel modelini eğitmek ve tahminlerini önceki rnn_model ile karşılaştırmak.
Senin için aşağıdakiler önceden yüklendi:
inputs: tensörlerden oluşan girdi dizi listesitargets: her girdi dizisi için hedef kelimeleri içeren tensöroptimizer: Adam optimizasyon fonksiyonucriterion: CrossEntropyLoss fonksiyonupad_sequences: toplu işlem için girdi dizilerini doldurma fonksiyonuattention_model: önceki egzersizde tanımlanan model sınıfırnn_model: PyBooks ekibinin eğittiği RNN modeli
Bu egzersiz
PyTorch ile Metin için Deep Learning
kursunun bir parçasıdırEgzersiz talimatları
- Test verisiyle denemeden önce RNN modelini evaluation moduna al.
- Uygun girdiyi RNN modeline vererek RNN çıktısını al.
- RNN çıktısından en yüksek tahmin skoruna sahip kelimeyi çıkar.
- Benzer şekilde attention modeli için de attention çıktısından en yüksek tahmin skoruna sahip kelimeyi çıkar.
Uygulamalı interaktif egzersiz
Bu örnek kodu tamamlayarak bu egzersizi bitirin.
for epoch in range(epochs):
attention_model.train()
optimizer.zero_grad()
padded_inputs = pad_sequences(inputs)
outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
for input_seq, target in zip(input_data, target_data):
input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
# Set the RNN model to evaluation mode
rnn_model.____()
# Get the RNN output by passing the appropriate input
rnn_output = ____(____)
# Extract the word with the highest prediction score
rnn_prediction = ix_to_word[torch.____(____).item()]
attention_model.eval()
attention_output = attention_model(input_test)
# Extract the word with the highest prediction score
attention_prediction = ix_to_word[torch.____(____).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
print(f"Target: {ix_to_word[target]}")
print(f"RNN prediction: {rnn_prediction}")
print(f"RNN with Attention prediction: {attention_prediction}")