Coverage for pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py: 100%
43 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import numpy as np
2import torch
3import torch.nn as nn
4import torch.optim as optim
5from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC
6from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor
9def train_at_collocation_nodes():
10 """
11 For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.
12 If successful, the initial guess would already be the exact solution and we would need no SDC iterations.
14 What we find is that we can train the network to predict the solution to one very specific problem rather well.
15 See the error during training for what happens when we ask the network to solve for exactly what it just trained.
16 However, if we train for something else, i.e. solving to a different step size in this case, we can only use the
17 model to predict the solution of what it's been trained for last and it loses the ability to solve for previously
18 learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and
19 is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and
20 PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that
21 actually do something!
23 The output shows the training loss during training and, after each of three training sessions is complete, the error
24 of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the
25 error of all learned predictions after training is complete.
26 """
27 out = ''
28 errors_mid_training = []
29 errors_post_training = []
31 # instantiate the pySDC problem and a model for PyTorch
32 problem = Heat1DFDTensor()
33 model = HeatEquationModel(problem)
35 # setup neural network
36 lr = 0.001
37 num_epochs = 250
38 criterion = nn.MSELoss()
39 optimizer = optim.Adam(model.parameters(), lr=lr)
41 # setup initial conditions
42 t = 0
43 initial_condition = problem.u_exact(t)
45 # train the model to predict the solution at certain collocation nodes
46 collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2
47 for dt in collocation_nodes:
49 # get target condition from implicit Euler step
50 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
52 # do the training
53 for epoch in range(num_epochs):
54 predicted_state = model(initial_condition, t, dt)
55 loss = criterion(predicted_state.float(), target_condition.float())
57 optimizer.zero_grad()
58 loss.backward()
59 optimizer.step()
61 if (epoch + 1) % 50 == 0:
62 out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n'
64 # evaluate model to compute error
65 model_prediction = model(initial_condition, t, dt)
66 errors_mid_training += [abs(target_condition - model_prediction)]
67 out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n'
69 # compare model and problem
70 for dt in collocation_nodes:
71 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
72 model_prediction = model(initial_condition, t, dt)
73 errors_post_training += [abs(target_condition - model_prediction)]
74 out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n'
76 print(out)
77 with open('data/step_7_D_out.txt', 'w') as file:
78 file.write(out)
80 # test that the training went as expected
81 assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected'
82 assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected'
84 # save the model to use it throughout pySDC
85 torch.save(model.state_dict(), 'data/heat_equation_model.pth')
88if __name__ == '__main__':
89 train_at_collocation_nodes()