1. Learn
  2. /
  3. Cursuri
  4. /
  5. Introducere în Deep Learning cu PyTorch

Connected

exercițiu

Scrierea unei bucle de antrenament

În scikit-learn, bucla de antrenament este inclusă în metoda .fit(), în timp ce în PyTorch aceasta se configurează manual. Deși acest lucru oferă mai multă flexibilitate, necesită o implementare personalizată.

În acest exercițiu, vei crea o buclă pentru a antrena un model de predicție a salariilor.

Funcția show_results() este furnizată pentru a te ajuta să vizualizezi câteva predicții exemplu.

Importurile de pachete disponibile sunt: pandas ca pd, torch, torch.nn ca nn, torch.optim ca optim, precum și DataLoader și TensorDataset din torch.utils.data.

Au fost create următoarele variabile: num_epochs, care conține numărul de epoci (setat la 5); dataloader, care conține dataloader-ul; model, care conține rețeaua neuronală; criterion, care conține funcția de pierdere, nn.MSELoss(); optimizer, care conține optimizatorul SGD.

Instrucțiuni 1/3

undefined XP
    1
    2
    3
  • Scrie o buclă for care iterează peste dataloader; aceasta trebuie să fie imbricată într-o buclă for care iterează peste un interval egal cu numărul de epoci.
  • Setează gradienții optimizatorului la zero.