Coverage for pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2import torch 

3import torch.nn as nn 

4import torch.optim as optim 

5from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC 

6from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor 

7 

8 

9def train_at_collocation_nodes(): 

10 """ 

11 For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC. 

12 If successful, the initial guess would already be the exact solution and we would need no SDC iterations. 

13 

14 What we find is that we can train the network to predict the solution to one very specific problem rather well. 

15 See the error during training for what happens when we ask the network to solve for exactly what it just trained. 

16 However, if we train for something else, i.e. solving to a different step size in this case, we can only use the 

17 model to predict the solution of what it's been trained for last and it loses the ability to solve for previously 

18 learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and 

19 is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and 

20 PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that 

21 actually do something! 

22 

23 The output shows the training loss during training and, after each of three training sessions is complete, the error 

24 of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the 

25 error of all learned predictions after training is complete. 

26 """ 

27 out = '' 

28 errors_mid_training = [] 

29 errors_post_training = [] 

30 

31 # instantiate the pySDC problem and a model for PyTorch 

32 problem = Heat1DFDTensor() 

33 model = HeatEquationModel(problem) 

34 

35 # setup neural network 

36 lr = 0.001 

37 num_epochs = 250 

38 criterion = nn.MSELoss() 

39 optimizer = optim.Adam(model.parameters(), lr=lr) 

40 

41 # setup initial conditions 

42 t = 0 

43 initial_condition = problem.u_exact(t) 

44 

45 # train the model to predict the solution at certain collocation nodes 

46 collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2 

47 for dt in collocation_nodes: 

48 

49 # get target condition from implicit Euler step 

50 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t) 

51 

52 # do the training 

53 for epoch in range(num_epochs): 

54 predicted_state = model(initial_condition, t, dt) 

55 loss = criterion(predicted_state.float(), target_condition.float()) 

56 

57 optimizer.zero_grad() 

58 loss.backward() 

59 optimizer.step() 

60 

61 if (epoch + 1) % 50 == 0: 

62 out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n' 

63 

64 # evaluate model to compute error 

65 model_prediction = model(initial_condition, t, dt) 

66 errors_mid_training += [abs(target_condition - model_prediction)] 

67 out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n' 

68 

69 # compare model and problem 

70 for dt in collocation_nodes: 

71 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t) 

72 model_prediction = model(initial_condition, t, dt) 

73 errors_post_training += [abs(target_condition - model_prediction)] 

74 out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n' 

75 

76 print(out) 

77 with open('data/step_7_D_out.txt', 'w') as file: 

78 file.write(out) 

79 

80 # test that the training went as expected 

81 assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected' 

82 assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected' 

83 

84 # save the model to use it throughout pySDC 

85 torch.save(model.state_dict(), 'data/heat_equation_model.pth') 

86 

87 

88if __name__ == '__main__': 

89 train_at_collocation_nodes()