Coverage for pySDC/projects/Performance/controller_MPI_scorep.py: 0%
283 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import numpy as np
2from mpi4py import MPI
4from pySDC.core.controller import Controller
5from pySDC.core.errors import ControllerError
6from pySDC.core.step import Step
7from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9import scorep.user as spu
12class controller_MPI(Controller):
13 """
15 PFASST controller, running parallel version of PFASST in blocks (MG-style)
17 """
19 def __init__(self, controller_params, description, comm):
20 """
21 Initialization routine for PFASST controller
23 Args:
24 controller_params: parameter set for the controller and the step class
25 description: all the parameters to set up the rest (levels, problems, transfer, ...)
26 comm: MPI communicator
27 """
29 # call parent's initialization routine
30 super(controller_MPI, self).__init__(controller_params)
32 # create single step per processor
33 self.S = Step(description)
35 # pass communicator for future use
36 self.comm = comm
37 # add request handler for status send
38 self.req_status = None
40 num_procs = self.comm.Get_size()
41 rank = self.comm.Get_rank()
43 # insert data on time communicator to the steps (helpful here and there)
44 self.S.status.time_size = num_procs
46 if self.params.dump_setup and rank == 0:
47 self.dump_setup(step=self.S, controller_params=controller_params, description=description)
49 num_levels = len(self.S.levels)
51 # add request handle container for isend
52 self.req_send = [None] * num_levels
54 if num_procs > 1 and num_levels > 1:
55 for L in self.S.levels:
56 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update:
57 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k")
59 if num_levels == 1 and self.params.predict_type is not None:
60 self.logger.warning(
61 'you have specified a predictor type but only a single level.. ' 'predictor will be ignored'
62 )
64 def run(self, u0, t0, Tend):
65 """
66 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST
68 Args:
69 u0: initial values
70 t0: starting time
71 Tend: ending time
73 Returns:
74 end values on the finest level
75 stats object containing statistics for each step, each level and each iteration
76 """
78 # reset stats to prevent double entries from old runs
79 self.hooks.reset_stats()
81 # find active processes and put into new communicator
82 rank = self.comm.Get_rank()
83 num_procs = self.comm.Get_size()
84 all_dt = self.comm.allgather(self.S.dt)
85 all_time = [t0 + sum(all_dt[0:i]) for i in range(num_procs)]
86 time = all_time[rank]
87 all_active = all_time < Tend - 10 * np.finfo(float).eps
89 if not any(all_active):
90 raise ControllerError('Nothing to do, check t0, dt and Tend')
92 active = all_active[rank]
93 if not all(all_active):
94 comm_active = self.comm.Split(active)
95 rank = comm_active.Get_rank()
96 num_procs = comm_active.Get_size()
97 else:
98 comm_active = self.comm
100 self.S.status.slot = rank
102 # initialize block of steps with u0
103 self.restart_block(num_procs, time, u0)
104 uend = u0
106 # call post-setup hook
107 self.hooks.post_setup(step=None, level_number=None)
109 # call pre-run hook
110 self.hooks.pre_run(step=self.S, level_number=0)
112 comm_active.Barrier()
114 # while any process still active...
115 while active:
116 while not self.S.status.done:
117 name = f'REGION -- {self.S.status.stage} -- {self.S.status.slot}'
118 spu.region_begin(name)
119 self.pfasst(comm_active, num_procs)
120 spu.region_end(name)
122 time += self.S.dt
124 # broadcast uend, set new times and fine active processes
125 tend = comm_active.bcast(time, root=num_procs - 1)
126 uend = self.S.levels[0].uend.bcast(root=num_procs - 1, comm=comm_active)
127 all_dt = comm_active.allgather(self.S.dt)
128 all_time = [tend + sum(all_dt[0:i]) for i in range(num_procs)]
129 time = all_time[rank]
130 all_active = all_time < Tend - 10 * np.finfo(float).eps
131 active = all_active[rank]
132 if not all(all_active):
133 comm_active_new = comm_active.Split(active)
134 comm_active.Free()
135 comm_active = comm_active_new
136 rank = comm_active.Get_rank()
137 num_procs = comm_active.Get_size()
138 self.S.status.slot = rank
140 # initialize block of steps with u0
141 self.restart_block(num_procs, time, uend)
143 # call post-run hook
144 self.hooks.post_run(step=self.S, level_number=0)
146 comm_active.Free()
148 return uend, self.hooks.return_stats()
150 def restart_block(self, size, time, u0):
151 """
152 Helper routine to reset/restart block of (active) steps
154 Args:
155 size: number of active time steps
156 time: current time
157 u0: initial value to distribute across the steps
159 Returns:
160 block of (all) steps
161 """
163 # store link to previous step
164 self.S.prev = self.S.status.slot - 1
165 self.S.next = self.S.status.slot + 1
167 # resets step
168 self.S.reset_step()
169 # determine whether I am the first and/or last in line
170 self.S.status.first = self.S.prev == -1
171 self.S.status.last = self.S.next == size
172 # intialize step with u0
173 self.S.init_step(u0)
174 # reset some values
175 self.S.status.done = False
176 self.S.status.iter = 0
177 self.S.status.stage = 'SPREAD'
178 for l in self.S.levels:
179 l.tag = None
180 self.req_status = None
181 self.req_send = [None] * len(self.S.levels)
182 self.S.status.prev_done = False
184 self.S.status.time_size = size
186 for lvl in self.S.levels:
187 lvl.status.time = time
188 lvl.status.sweep = 1
190 @staticmethod
191 def recv(target, source, tag=None, comm=None):
192 """
193 Receive function
195 Args:
196 target: level which will receive the values
197 source: level which initiated the send
198 tag: identifier to check if this message is really for me
199 comm: communicator
200 """
201 target.u[0].recv(source=source, tag=tag, comm=comm)
202 # re-evaluate f on left interval boundary
203 target.f[0] = target.prob.eval_f(target.u[0], target.time)
205 def predictor(self, comm):
206 """
207 Predictor function, extracted from the stepwise implementation (will be also used by matrix sweppers)
209 Args:
210 comm: communicator
211 """
213 if self.params.predict_type is None:
214 pass
216 elif self.params.predict_type == 'fine_only':
217 # do a fine sweep only
218 self.S.levels[0].sweep.update_nodes()
220 elif self.params.predict_type == 'libpfasst_style':
221 # restrict to coarsest level
222 for l in range(1, len(self.S.levels)):
223 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
225 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
226 if not self.S.status.first:
227 self.logger.debug(
228 'recv data predict: process %s, stage %s, time, %s, source %s, tag %s'
229 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, self.S.status.iter)
230 )
231 self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
232 self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
234 # do the sweep with new values
235 self.S.levels[-1].sweep.update_nodes()
236 self.S.levels[-1].sweep.compute_end_point()
238 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
239 if not self.S.status.last:
240 self.logger.debug(
241 'send data predict: process %s, stage %s, time, %s, target %s, tag %s'
242 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, self.S.status.iter)
243 )
244 self.S.levels[-1].uend.send(dest=self.S.next, tag=self.S.status.iter, comm=comm)
245 self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True)
247 # go back to fine level, sweeping
248 for l in range(len(self.S.levels) - 1, 0, -1):
249 # prolong values
250 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
251 # on middle levels: do sweep as usual
252 if l - 1 > 0:
253 self.S.levels[l - 1].sweep.update_nodes()
255 # end with a fine sweep
256 self.S.levels[0].sweep.update_nodes()
258 elif self.params.predict_type == 'pfasst_burnin':
259 # restrict to coarsest level
260 for l in range(1, len(self.S.levels)):
261 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
263 for p in range(self.S.status.slot + 1):
264 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
265 if not p == 0 and not self.S.status.first:
266 self.logger.debug(
267 'recv data predict: process %s, stage %s, time, %s, source %s, tag %s, phase %s'
268 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, self.S.status.iter, p)
269 )
270 self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
271 self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
273 # do the sweep with new values
274 self.S.levels[-1].sweep.update_nodes()
275 self.S.levels[-1].sweep.compute_end_point()
277 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
278 if not self.S.status.last:
279 self.logger.debug(
280 'send data predict: process %s, stage %s, time, %s, target %s, tag %s, phase %s'
281 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, self.S.status.iter, p)
282 )
283 self.S.levels[-1].uend.send(dest=self.S.next, tag=self.S.status.iter, comm=comm)
284 self.hooks.post_comm(
285 step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot)
286 )
288 # interpolate back to finest level
289 for l in range(len(self.S.levels) - 1, 0, -1):
290 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
292 # end this with a fine sweep
293 self.S.levels[0].sweep.update_nodes()
295 elif self.params.predict_type == 'fmg':
296 # TODO: implement FMG predictor
297 raise NotImplementedError('FMG predictor is not yet implemented')
299 else:
300 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
302 def pfasst(self, comm, num_procs):
303 """
304 Main function including the stages of SDC, MLSDC and PFASST (the "controller")
306 For the workflow of this controller, check out one of our PFASST talks
308 Args:
309 comm: communicator
310 num_procs (int): number of parallel processes
311 """
313 stage = self.S.status.stage
315 self.logger.debug(stage + ' - process ' + str(self.S.status.slot))
317 if stage == 'SPREAD':
318 # (potentially) serial spreading phase
320 # first stage: spread values
321 self.hooks.pre_step(step=self.S, level_number=0)
323 # call predictor from sweeper
324 self.S.levels[0].sweep.predict()
326 # update stage
327 if len(self.S.levels) > 1: # MLSDC or PFASST with predict
328 self.S.status.stage = 'PREDICT'
329 else:
330 self.S.status.stage = 'IT_CHECK'
332 elif stage == 'PREDICT':
333 # call predictor (serial)
335 self.hooks.pre_predict(step=self.S, level_number=0)
337 self.predictor(comm)
339 self.hooks.post_predict(step=self.S, level_number=0)
341 # update stage
342 # self.hooks.pre_iteration(step=self.S, level_number=0)
343 self.S.status.stage = 'IT_CHECK'
345 elif stage == 'IT_CHECK':
346 # check whether to stop iterating (parallel)
348 self.hooks.pre_comm(step=self.S, level_number=0)
350 if self.req_send[0] is not None:
351 self.req_send[0].wait()
352 self.S.levels[0].sweep.compute_end_point()
354 if not self.S.status.last and self.params.fine_comm:
355 self.logger.debug(
356 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
357 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, 0, self.S.status.iter)
358 )
359 self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
361 if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
362 self.logger.debug(
363 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
364 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, 0, self.S.status.iter)
365 )
366 self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
368 self.hooks.post_comm(step=self.S, level_number=0)
370 self.S.levels[0].sweep.compute_residual()
371 self.S.status.done = CheckConvergence.check_convergence(self.S)
373 if self.params.all_to_done:
374 self.hooks.pre_comm(step=self.S, level_number=0)
375 self.S.status.done = comm.allreduce(sendobj=self.S.status.done, op=MPI.LAND)
376 self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True)
378 else:
379 self.hooks.pre_comm(step=self.S, level_number=0)
381 # check if an open request of the status send is pending
382 if self.req_status is not None:
383 self.req_status.wait()
385 # recv status
386 if not self.S.status.first and not self.S.status.prev_done:
387 self.S.status.prev_done = comm.recv(source=self.S.prev, tag=99)
388 self.logger.debug(
389 'recv status: status %s, process %s, time %s, target %s, tag %s, iter %s'
390 % (
391 self.S.status.prev_done,
392 self.S.status.slot,
393 self.S.time,
394 self.S.next,
395 99,
396 self.S.status.iter,
397 )
398 )
399 self.S.status.done = self.S.status.done and self.S.status.prev_done
401 # send status forward
402 if not self.S.status.last:
403 self.logger.debug(
404 'isend status: status %s, process %s, time %s, target %s, tag %s, iter %s'
405 % (self.S.status.done, self.S.status.slot, self.S.time, self.S.next, 99, self.S.status.iter)
406 )
407 self.req_status = comm.isend(self.S.status.done, dest=self.S.next, tag=99)
409 self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True)
411 if self.S.status.iter > 0:
412 self.hooks.post_iteration(step=self.S, level_number=0)
414 # if not readys, keep doing stuff
415 if not self.S.status.done:
416 # increment iteration count here (and only here)
417 self.S.status.iter += 1
419 self.hooks.pre_iteration(step=self.S, level_number=0)
420 if len(self.S.levels) > 1: # MLSDC or PFASST
421 self.S.status.stage = 'IT_UP'
422 else:
423 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like)
424 self.S.status.stage = 'IT_FINE'
425 else:
426 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
428 else:
429 # Need to finish alll pending isend requests. These will occur for the first active process, since
430 # in the last iteration the wait statement will not be called ("send and forget")
431 for req in self.req_send:
432 if req is not None:
433 req.wait()
434 if self.req_status is not None:
435 self.req_status.wait()
437 self.hooks.post_step(step=self.S, level_number=0)
438 self.S.status.stage = 'DONE'
440 elif stage == 'IT_FINE':
441 nsweeps = self.S.levels[0].params.nsweeps
443 self.S.levels[0].status.sweep = 0
445 # do fine sweep
446 for k in range(nsweeps):
447 self.S.levels[0].status.sweep += 1
449 self.hooks.pre_comm(step=self.S, level_number=0)
451 if self.req_send[0] is not None:
452 self.req_send[0].wait()
453 self.S.levels[0].sweep.compute_end_point()
455 if not self.S.status.last and self.params.fine_comm:
456 self.logger.debug(
457 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
458 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, 0, self.S.status.iter)
459 )
460 self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
462 if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
463 self.logger.debug(
464 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
465 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, 0, self.S.status.iter)
466 )
467 self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
469 self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=(k == nsweeps - 1))
471 self.hooks.pre_sweep(step=self.S, level_number=0)
472 self.S.levels[0].sweep.update_nodes()
473 self.S.levels[0].sweep.compute_residual()
474 self.hooks.post_sweep(step=self.S, level_number=0)
476 # update stage
477 self.S.status.stage = 'IT_CHECK'
479 elif stage == 'IT_UP':
480 # go up the hierarchy from finest to coarsest level (parallel)
482 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1])
484 # sweep and send on middle levels (not on finest, not on coarsest, though)
485 for l in range(1, len(self.S.levels) - 1):
486 nsweeps = self.S.levels[l].params.nsweeps
488 for _ in range(nsweeps):
489 self.hooks.pre_comm(step=self.S, level_number=l)
491 if self.req_send[l] is not None:
492 self.req_send[l].wait()
493 self.S.levels[l].sweep.compute_end_point()
495 if not self.S.status.last and self.params.fine_comm:
496 self.logger.debug(
497 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
498 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, l, self.S.status.iter)
499 )
500 self.req_send[l] = self.S.levels[l].uend.isend(
501 dest=self.S.next, tag=self.S.status.iter, comm=comm
502 )
504 if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
505 self.logger.debug(
506 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
507 % (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, l, self.S.status.iter)
508 )
509 self.recv(target=self.S.levels[l], source=self.S.prev, tag=self.S.status.iter, comm=comm)
511 self.hooks.post_comm(step=self.S, level_number=l)
513 self.hooks.pre_sweep(step=self.S, level_number=l)
514 self.S.levels[l].sweep.update_nodes()
515 self.S.levels[l].sweep.compute_residual()
516 self.hooks.post_sweep(step=self.S, level_number=l)
518 # transfer further up the hierarchy
519 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
521 # update stage
522 self.S.status.stage = 'IT_COARSE'
524 elif stage == 'IT_COARSE':
525 # sweeps on coarsest level (serial/blocking)
527 # receive from previous step (if not first)
528 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
529 if not self.S.status.first and not self.S.status.prev_done:
530 self.logger.debug(
531 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
532 % (
533 self.S.status.slot,
534 self.S.status.stage,
535 self.S.time,
536 self.S.prev,
537 len(self.S.levels) - 1,
538 self.S.status.iter,
539 )
540 )
541 self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
542 self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
544 # do the sweep
545 self.hooks.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
546 assert self.S.levels[-1].params.nsweeps == 1, (
547 'ERROR: this controller can only work with one sweep on the coarse level, got %s'
548 % self.S.levels[-1].params.nsweeps
549 )
550 self.S.levels[-1].sweep.update_nodes()
551 self.S.levels[-1].sweep.compute_residual()
552 self.hooks.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
553 self.S.levels[-1].sweep.compute_end_point()
555 # send to next step
556 self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
557 if not self.S.status.last:
558 self.logger.debug(
559 'send data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
560 % (
561 self.S.status.slot,
562 self.S.status.stage,
563 self.S.time,
564 self.S.next,
565 len(self.S.levels) - 1,
566 self.S.status.iter,
567 )
568 )
569 self.S.levels[-1].uend.send(dest=self.S.next, tag=self.S.status.iter, comm=comm)
570 self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True)
572 # update stage
573 if len(self.S.levels) > 1: # MLSDC or PFASST
574 self.S.status.stage = 'IT_DOWN'
575 else:
576 self.S.status.stage = 'IT_CHECK' # MSSDC
578 elif stage == 'IT_DOWN':
579 # prolong corrections down to finest level (parallel)
581 # receive and sweep on middle levels (except for coarsest level)
582 for l in range(len(self.S.levels) - 1, 0, -1):
583 # prolong values
584 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
586 # on middle levels: do sweep as usual
587 if l - 1 > 0:
588 nsweeps = self.S.levels[l - 1].params.nsweeps
590 for k in range(nsweeps):
591 self.hooks.pre_comm(step=self.S, level_number=l - 1)
593 if self.req_send[l - 1] is not None:
594 self.req_send[l - 1].wait()
595 self.S.levels[l - 1].sweep.compute_end_point()
597 if not self.S.status.last and self.params.fine_comm:
598 self.logger.debug(
599 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
600 % (
601 self.S.status.slot,
602 self.S.status.stage,
603 self.S.time,
604 self.S.next,
605 l - 1,
606 self.S.status.iter,
607 )
608 )
609 self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(
610 dest=self.S.next, tag=self.S.status.iter, comm=comm
611 )
613 if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
614 self.logger.debug(
615 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
616 % (
617 self.S.status.slot,
618 self.S.status.stage,
619 self.S.time,
620 self.S.prev,
621 l - 1,
622 self.S.status.iter,
623 )
624 )
625 self.recv(
626 target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm
627 )
629 self.hooks.post_comm(step=self.S, level_number=l - 1, add_to_stats=(k == nsweeps - 1))
631 self.hooks.pre_sweep(step=self.S, level_number=l - 1)
632 self.S.levels[l - 1].sweep.update_nodes()
633 self.S.levels[l - 1].sweep.compute_residual()
634 self.hooks.post_sweep(step=self.S, level_number=l - 1)
636 # update stage
637 self.S.status.stage = 'IT_FINE'
639 else:
640 raise ControllerError('Weird stage, got %s' % self.S.status.stage)