1import numpy as np

2import torch

3import torch.nn as nn

4import torch.optim as optim

5from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC

6from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor

7

8

9def train_at_collocation_nodes():

10 """

11 For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.

12 If successful, the initial guess would already be the exact solution and we would need no SDC iterations.

13

14 What we find is that we can train the network to predict the solution to one very specific problem rather well.

15 See the error during training for what happens when we ask the network to solve for exactly what it just trained.

16 However, if we train for something else, i.e. solving to a different step size in this case, we can only use the

17 model to predict the solution of what it's been trained for last and it loses the ability to solve for previously

18 learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and

19 is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and

20 PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that

21 actually do something!

22

23 The output shows the training loss during training and, after each of three training sessions is complete, the error

24 of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the

25 error of all learned predictions after training is complete.

26 """

27 out = ''

28 errors_mid_training = []

29 errors_post_training = []

30

31 # instantiate the pySDC problem and a model for PyTorch

32 problem = Heat1DFDTensor()

33 model = HeatEquationModel(problem)

34

35 # setup neural network

36 lr = 0.001

37 num_epochs = 250

38 criterion = nn.MSELoss()

39 optimizer = optim.Adam(model.parameters(), lr=lr)

40

41 # setup initial conditions

42 t = 0

43 initial_condition = problem.u_exact(t)

44

45 # train the model to predict the solution at certain collocation nodes

46 collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2

47 for dt in collocation_nodes:

48

49 # get target condition from implicit Euler step

50 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)

51

52 # do the training

53 for epoch in range(num_epochs):

54 predicted_state = model(initial_condition, t, dt)

55 loss = criterion(predicted_state.float(), target_condition.float())

56

57 optimizer.zero_grad()

58 loss.backward()

59 optimizer.step()

60

61 if (epoch + 1) % 50 == 0:

62 out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n'

63

64 # evaluate model to compute error

65 model_prediction = model(initial_condition, t, dt)

66 errors_mid_training += [abs(target_condition - model_prediction)]

67 out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n'

68

69 # compare model and problem

70 for dt in collocation_nodes:

71 target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)

72 model_prediction = model(initial_condition, t, dt)

73 errors_post_training += [abs(target_condition - model_prediction)]

74 out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n'

75

76 print(out)

77 with open('data/step_7_D_out.txt', 'w') as file:

78 file.write(out)

79

80 # test that the training went as expected

81 assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected'

82 assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected'

83

84 # save the model to use it throughout pySDC

85 torch.save(model.state_dict(), 'data/heat_equation_model.pth')

86

87

88if __name__ == '__main__':

89 train_at_collocation_nodes()