Coverage for pySDC/tutorial/step_6/playground_parallelization.py: 0%
6 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1import sys
2from pathlib import Path
4from mpi4py import MPI
6from pySDC.helpers.stats_helper import get_sorted
7from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
8from pySDC.tutorial.step_6.A_run_non_MPI_controller import set_parameters_ml
10if __name__ == "__main__":
11 """
12 A simple test program to do MPI-parallel PFASST runs
13 """
15 # set MPI communicator
16 comm = MPI.COMM_WORLD
18 # get parameters from Part A
19 description, controller_params, t0, Tend = set_parameters_ml()
21 # instantiate controllers
22 controller = controller_MPI(controller_params=controller_params, description=description, comm=comm)
23 # get initial values on finest level
24 P = controller.S.levels[0].prob
25 uinit = P.u_exact(t0)
27 # call main functions to get things done...
28 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
30 # filter statistics by type (number of iterations)
31 iter_counts = get_sorted(stats, type='niter', sortby='time')
33 # combine statistics into list of statistics
34 iter_counts_list = comm.gather(iter_counts, root=0)
36 rank = comm.Get_rank()
37 size = comm.Get_size()
39 if rank == 0:
40 # we'd need to deal with variable file names here (for testing purpose only)
41 if len(sys.argv) == 2:
42 fname = sys.argv[1]
43 else:
44 fname = 'step_6_B_out.txt'
46 Path("data").mkdir(parents=True, exist_ok=True)
47 f = open('data/' + fname, 'a')
48 out = 'Working with %2i processes...' % size
49 f.write(out + '\n')
50 print(out)
52 # compute exact solutions and compare with both results
53 uex = P.u_exact(Tend)
54 err = abs(uex - uend)
56 out = 'Error vs. exact solution: %12.8e' % err
57 f.write(out + '\n')
58 print(out)
60 # build one list of statistics instead of list of lists, the sort by time
61 iter_counts_gather = [item for sublist in iter_counts_list for item in sublist]
62 iter_counts = sorted(iter_counts_gather, key=lambda tup: tup[0])
64 # compute and print statistics
65 for item in iter_counts:
66 out = 'Number of iterations for time %4.2f: %1i ' % (item[0], item[1])
67 f.write(out + '\n')
68 print(out)
70 f.write('\n')
71 print()
73 assert all(item[1] <= 8 for item in iter_counts), "ERROR: weird iteration counts, got %s" % iter_counts