Step-7: pySDC with external libraries

pySDC can be used with external libraries, in particular for spatial discretization, parallelization and solving of linear and/or nonlinear systems. In the following, we show a few examples of pySDC + X.

Part A: pySDC and FEniCS

In this example, pySDC is coupled with the FEniCS framework for using finite elements in space. This implies significant changes to the algorithm, depending on whether or not the mass matrix should be inverted. SDC, MLSDC and PFASST can be used without changes when the right-hand side of the ODE is defined with the inverse of the mass matrix. Otherwise, the mass matrix has to be used for e.g. the tau-correction. This example tests different variants of this methodology for SDC, MLSDC and PFASST.

Important things to note:

  • This example shows that even core routines like the BaseTransfer can be overwritten if needed.

  • It is also valuable to check out the data type and transfer classes required to work with FEniCS. Both can be found in the implementations folder.

Full code: pySDC/tutorial/step_7/A_pySDC_with_FEniCS.py

from pathlib import Path
import numpy as np

from pySDC.helpers.stats_helper import get_sorted

from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.problem_classes.HeatEquation_1D_FEniCS_matrix_forced import (
    fenics_heat_mass,
    fenics_heat,
    fenics_heat_mass_timebc,
)
from pySDC.implementations.sweeper_classes.imex_1st_order_mass import imex_1st_order_mass, imex_1st_order
from pySDC.implementations.transfer_classes.TransferFenicsMesh import mesh_to_mesh_fenics


def setup(t0=None, ml=None):
    """
    Helper routine to set up parameters

    Args:
        t0 (float): initial time
        ml (bool): use single or multiple levels

    Returns:
        description and controller_params parameter dictionaries
    """

    # initialize level parameters
    level_params = dict()
    level_params['restol'] = 5e-10
    level_params['dt'] = 0.2

    # initialize step parameters
    step_params = dict()
    step_params['maxiter'] = 20

    # initialize sweeper parameters
    sweeper_params = dict()
    sweeper_params['quad_type'] = 'RADAU-RIGHT'
    if ml:
        # Note that coarsening in the nodes actually HELPS MLSDC to converge (M=1 is exact on the coarse level)
        sweeper_params['num_nodes'] = [3, 1]
    else:
        sweeper_params['num_nodes'] = [3]

    problem_params = dict()
    problem_params['nu'] = 0.1
    problem_params['t0'] = t0  # ugly, but necessary to set up this ProblemClass
    problem_params['c_nvars'] = [128]
    problem_params['family'] = 'CG'
    problem_params['c'] = 1.0
    if ml:
        # We can do rather aggressive coarsening here. As long as we have 1 node on the coarse level, all is "well" (ie.
        # MLSDC does not take more iterations than SDC, but also not less). If we just coarsen in the refinement (and
        # not in the nodes and order, the mass inverse approach is way better, ie. halves the number of iterations!
        problem_params['order'] = [4, 1]
        problem_params['refinements'] = [1, 0]
    else:
        problem_params['order'] = [4]
        problem_params['refinements'] = [1]

    # initialize controller parameters
    controller_params = dict()
    controller_params['logger_level'] = 30

    base_transfer_params = dict()
    base_transfer_params['finter'] = True

    # Fill description dictionary for easy hierarchy creation
    description = dict()
    description['problem_class'] = None
    description['problem_params'] = problem_params
    description['sweeper_class'] = None
    description['sweeper_params'] = sweeper_params
    description['level_params'] = level_params
    description['step_params'] = step_params
    description['space_transfer_class'] = mesh_to_mesh_fenics
    description['base_transfer_params'] = base_transfer_params

    return description, controller_params


def run_variants(variant=None, ml=None, num_procs=None):
    """
    Main routine to run the different implementations of the heat equation with FEniCS

    Args:
        variant (str): specifies the variant
        ml (bool): use single or multiple levels
        num_procs (int): number of processors in time
    """
    Tend = 1.0
    t0 = 0.0

    description, controller_params = setup(t0=t0, ml=ml)

    if variant == 'mass':
        # Note that we need to reduce the tolerance for the residual here, since otherwise the error will be too high
        description['level_params']['restol'] /= 500
        description['problem_class'] = fenics_heat_mass
        description['sweeper_class'] = imex_1st_order_mass
    elif variant == 'mass_inv':
        description['problem_class'] = fenics_heat
        description['sweeper_class'] = imex_1st_order
    elif variant == 'mass_timebc':
        # Can increase the tolerance here, errors are higher anyway
        description['level_params']['restol'] *= 20
        description['problem_class'] = fenics_heat_mass_timebc
        description['sweeper_class'] = imex_1st_order_mass
    else:
        raise NotImplementedError('Variant %s is not implemented' % variant)

    # quickly generate block of steps
    controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description)

    # get initial values on finest level
    P = controller.MS[0].levels[0].prob
    uinit = P.u_exact(0.0)

    # call main function to get things done...
    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)

    # compute exact solution and compare
    uex = P.u_exact(Tend)
    err = abs(uex - uend) / abs(uex)

    Path("data").mkdir(parents=True, exist_ok=True)
    f = open('data/step_7_A_out.txt', 'a')

    out = f'Variant {variant} with ml={ml} and num_procs={num_procs} -- error at time {Tend}: {err}'
    f.write(out + '\n')
    print(out)

    # filter statistics by type (number of iterations)
    iter_counts = get_sorted(stats, type='niter', sortby='time')

    niters = np.array([item[1] for item in iter_counts])
    out = '   Mean number of iterations: %4.2f' % np.mean(niters)
    f.write(out + '\n')
    print(out)
    out = '   Range of values for number of iterations: %2i ' % np.ptp(niters)
    f.write(out + '\n')
    print(out)
    out = '   Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters)))
    f.write(out + '\n')
    print(out)
    out = '   Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
    f.write(out + '\n')
    print(out)

    timing = get_sorted(stats, type='timing_run', sortby='time')
    out = f'Time to solution: {timing[0][1]:6.4f} sec.'
    f.write(out + '\n')
    print(out)

    if num_procs == 1:
        assert np.mean(niters) <= 6.0, 'Mean number of iterations is too high, got %s' % np.mean(niters)
        if variant == 'mass' or variant == 'mass_inv':
            assert err <= 1.14e-08, 'Error is too high, got %s' % err
        else:
            assert err <= 3.25e-07, 'Error is too high, got %s' % err
    else:
        assert np.mean(niters) <= 11.6, 'Mean number of iterations is too high, got %s' % np.mean(niters)
        assert err <= 1.14e-08, 'Error is too high, got %s' % err

    f.write('\n')
    print()
    f.close()


def main():
    run_variants(variant='mass_inv', ml=False, num_procs=1)
    run_variants(variant='mass', ml=False, num_procs=1)
    run_variants(variant='mass_timebc', ml=False, num_procs=1)
    run_variants(variant='mass_inv', ml=True, num_procs=1)
    run_variants(variant='mass', ml=True, num_procs=1)
    run_variants(variant='mass_timebc', ml=True, num_procs=1)
    run_variants(variant='mass_inv', ml=True, num_procs=5)

    # WARNING: all other variants do NOT work, either because of FEniCS restrictions (weak forms with different meshes
    # will not work together) or because of inconsistent use of the mass matrix (locality condition for the tau
    # correction is not satisfied, mass matrix does not permute with restriction).
    # run_pfasst_variants(variant='mass', ml=True, num_procs=5)


if __name__ == "__main__":
    main()

Results:

Variant mass_inv with ml=False and num_procs=1 -- error at time 1.0: 1.1387407230222816e-08
   Mean number of iterations: 6.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 1.6663 sec.

Variant mass with ml=False and num_procs=1 -- error at time 1.0: 1.1387594756569534e-08
   Mean number of iterations: 6.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 1.2271 sec.

Variant mass_timebc with ml=False and num_procs=1 -- error at time 1.0: 3.2473562155116167e-07
   Mean number of iterations: 6.00
   Range of values for number of iterations:  3 
   Position of max/min number of iterations:  3 --  0
   Std and var for number of iterations: 1.10 -- 1.20
Time to solution: 1.2417 sec.

Variant mass_inv with ml=True and num_procs=1 -- error at time 1.0: 1.138768636885694e-08
   Mean number of iterations: 6.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 2.2374 sec.

Variant mass with ml=True and num_procs=1 -- error at time 1.0: 1.1387216566052821e-08
   Mean number of iterations: 6.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 1.9052 sec.

Variant mass_timebc with ml=True and num_procs=1 -- error at time 1.0: 3.2473561574763597e-07
   Mean number of iterations: 6.00
   Range of values for number of iterations:  3 
   Position of max/min number of iterations:  3 --  0
   Std and var for number of iterations: 1.10 -- 1.20
Time to solution: 1.8955 sec.

Variant mass_inv with ml=True and num_procs=5 -- error at time 1.0: 1.1150087179536389e-08
   Mean number of iterations: 11.60
   Range of values for number of iterations:  9 
   Position of max/min number of iterations:  4 --  0
   Std and var for number of iterations: 3.26 -- 10.64
Time to solution: 3.6634 sec.

Part B: mpi4py-fft for parallel Fourier transforms

The most prominent parallel solver is, probably, the FFT. While many implementations or wrappers for Python exist, we decided to use mpi4py-fft, which provided the easiest installation, a simple API and good parallel scaling. As an example, we here test the nonlinear Schrödinger equation, using the IMEX sweeper to treat the nonlinear parts explicitly. The code allows to work both in real and spectral space, while the latter is usually faster. This example tests SDC, MLSDC and PFASST.

Important things to note:

  • The code runs both in serial using just python B_pySDC_with_mpi4pyfft.py and also in parallel using mpirun -np 2 python B_pySDC_with_mpi4pyfft.py.

  • The nonlinear Schrödinger example is not expected to work well with PFASST. In fact, SDC and MLSDC converge for larger time-steps, but PFASST does not.

Full code: pySDC/tutorial/step_7/B_pySDC_with_mpi4pyfft.py

import numpy as np
from pathlib import Path
from mpi4py import MPI

from pySDC.helpers.stats_helper import get_sorted

from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import nonlinearschroedinger_imex
from pySDC.implementations.transfer_classes.TransferMesh_MPIFFT import fft_to_fft


def run_simulation(spectral=None, ml=None, num_procs=None):
    """
    A test program to do SDC, MLSDC and PFASST runs for the 2D NLS equation

    Args:
        spectral (bool): run in real or spectral space
        ml (bool): single or multiple levels
        num_procs (int): number of parallel processors
    """

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    # initialize level parameters
    level_params = dict()
    level_params['restol'] = 1e-08
    level_params['dt'] = 1e-01 / 2
    level_params['nsweeps'] = [1]

    # initialize sweeper parameters
    sweeper_params = dict()
    sweeper_params['quad_type'] = 'RADAU-RIGHT'
    sweeper_params['num_nodes'] = [3]
    sweeper_params['QI'] = ['LU']  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
    sweeper_params['initial_guess'] = 'zero'

    # initialize problem parameters
    problem_params = dict()
    if ml:
        problem_params['nvars'] = [(128, 128), (32, 32)]
    else:
        problem_params['nvars'] = [(128, 128)]
    problem_params['spectral'] = spectral
    problem_params['comm'] = comm

    # initialize step parameters
    step_params = dict()
    step_params['maxiter'] = 50

    # initialize controller parameters
    controller_params = dict()
    controller_params['logger_level'] = 30 if rank == 0 else 99
    # controller_params['predict_type'] = 'fine_only'

    # fill description dictionary for easy step instantiation
    description = dict()
    description['problem_params'] = problem_params  # pass problem parameters
    description['problem_class'] = nonlinearschroedinger_imex
    description['sweeper_class'] = imex_1st_order
    description['sweeper_params'] = sweeper_params  # pass sweeper parameters
    description['level_params'] = level_params  # pass level parameters
    description['step_params'] = step_params  # pass step parameters
    description['space_transfer_class'] = fft_to_fft

    # set time parameters
    t0 = 0.0
    Tend = 1.0

    f = None
    if rank == 0:
        Path("data").mkdir(parents=True, exist_ok=True)
        f = open('data/step_7_B_out.txt', 'a')
        out = f'Running with ml={ml} and num_procs={num_procs}...'
        f.write(out + '\n')
        print(out)

    # instantiate controller
    controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description)

    # get initial values on finest level
    P = controller.MS[0].levels[0].prob
    uinit = P.u_exact(t0)

    # call main function to get things done...
    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
    uex = P.u_exact(Tend)
    err = abs(uex - uend)

    if rank == 0:
        # filter statistics by type (number of iterations)
        iter_counts = get_sorted(stats, type='niter', sortby='time')

        niters = np.array([item[1] for item in iter_counts])
        out = (
            f'   Min/Mean/Max number of iterations: '
            f'{np.min(niters):4.2f} / {np.mean(niters):4.2f} / {np.max(niters):4.2f}'
        )
        f.write(out + '\n')
        print(out)
        out = '   Range of values for number of iterations: %2i ' % np.ptp(niters)
        f.write(out + '\n')
        print(out)
        out = '   Position of max/min number of iterations: %2i -- %2i' % (
            int(np.argmax(niters)),
            int(np.argmin(niters)),
        )
        f.write(out + '\n')
        print(out)
        out = '   Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
        f.write(out + '\n')
        print(out)

        out = f'Error: {err:6.4e}'
        f.write(out + '\n')
        print(out)

        timing = get_sorted(stats, type='timing_run', sortby='time')
        out = f'Time to solution: {timing[0][1]:6.4f} sec.'
        f.write(out + '\n')
        print(out)

        assert err <= 1.133e-05, 'Error is too high, got %s' % err
        if ml:
            if num_procs > 1:
                maxmean = 12.5
            else:
                maxmean = 6.6
        else:
            maxmean = 12.7
        assert np.mean(niters) <= maxmean, 'Mean number of iterations is too high, got %s' % np.mean(niters)

        f.write('\n')
        print()
        f.close()


def main():
    """
    Little helper routine to run the whole thing

    Note: This can also be run with "mpirun -np 2 python B_pySDC_with_mpi4pyfft.py"
    """
    run_simulation(spectral=False, ml=False, num_procs=1)
    run_simulation(spectral=True, ml=False, num_procs=1)
    run_simulation(spectral=False, ml=True, num_procs=1)
    run_simulation(spectral=True, ml=True, num_procs=1)
    run_simulation(spectral=False, ml=True, num_procs=10)
    run_simulation(spectral=True, ml=True, num_procs=10)


if __name__ == "__main__":
    main()

Results:

Running with ml=False and num_procs=1...
   Min/Mean/Max number of iterations: 9.00 / 12.70 / 16.00
   Range of values for number of iterations:  7 
   Position of max/min number of iterations:  0 -- 19
   Std and var for number of iterations: 2.10 -- 4.41
Error: 1.1321e-05
Time to solution: 1.4778 sec.

Running with ml=False and num_procs=1...
   Min/Mean/Max number of iterations: 8.00 / 11.40 / 15.00
   Range of values for number of iterations:  7 
   Position of max/min number of iterations:  0 -- 19
   Std and var for number of iterations: 2.03 -- 4.14
Error: 4.1749e-06
Time to solution: 1.2048 sec.

Running with ml=True and num_procs=1...
   Min/Mean/Max number of iterations: 5.00 / 6.60 / 8.00
   Range of values for number of iterations:  3 
   Position of max/min number of iterations:  0 -- 16
   Std and var for number of iterations: 1.07 -- 1.14
Error: 1.1316e-05
Time to solution: 1.4295 sec.

Running with ml=True and num_procs=1...
   Min/Mean/Max number of iterations: 4.00 / 5.95 / 8.00
   Range of values for number of iterations:  4 
   Position of max/min number of iterations:  0 -- 19
   Std and var for number of iterations: 1.02 -- 1.05
Error: 4.1744e-06
Time to solution: 1.2943 sec.

Running with ml=True and num_procs=10...
   Min/Mean/Max number of iterations: 7.00 / 12.45 / 18.00
   Range of values for number of iterations: 11 
   Position of max/min number of iterations:  9 -- 10
   Std and var for number of iterations: 3.11 -- 9.65
Error: 1.1306e-05
Time to solution: 2.9191 sec.

Running with ml=True and num_procs=10...
   Min/Mean/Max number of iterations: 6.00 / 11.50 / 17.00
   Range of values for number of iterations: 11 
   Position of max/min number of iterations:  9 -- 10
   Std and var for number of iterations: 3.04 -- 9.25
Error: 4.1688e-06
Time to solution: 2.7466 sec.

Part C: Time-parallel pySDC with space-parallel PETSc

With rather unfavorable scaling properties, parallel-in-time methods are only really useful when spatial parallelization is maxed out. To work with spatial parallelization, this part shows how to (1) include and work with an external library and (2) set up space- and time-parallel runs. We use again the forced heat equation as our testbed and PETSc for the space-parallel data structures and linear solver. See implementations/datatype_classes/petsc_dmda_grid.py and implementations/problem_classes/HeatEquation_2D_PETSc_forced.py for details on the PETSc bindings.

Important things to note:

  • We need processors in space and time, which can be achieved by comm.Split and coloring. The space-communicator is then passed to the problem class.

  • Below, we run the code 3 times: with 1 and 2 processors in space as well as 4 processors (2 in time and 2 in space). Do not expect scaling due to the CI environment.

Full code: pySDC/tutorial/step_7/C_pySDC_with_PETSc.py

import sys
from pathlib import Path

import numpy as np
from mpi4py import MPI

from pySDC.helpers.stats_helper import get_sorted

from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
from pySDC.implementations.problem_classes.HeatEquation_2D_PETSc_forced import heat2d_petsc_forced
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.transfer_classes.TransferPETScDMDA import mesh_to_mesh_petsc_dmda


def main():
    """
    Program to demonstrate usage of PETSc data structures and spatial parallelization,
    combined with parallelization in time.
    """
    # set MPI communicator
    comm = MPI.COMM_WORLD

    world_rank = comm.Get_rank()
    world_size = comm.Get_size()

    # split world communicator to create space-communicators
    if len(sys.argv) >= 2:
        color = int(world_rank / int(sys.argv[1]))
    else:
        color = int(world_rank / 1)
    space_comm = comm.Split(color=color)
    space_rank = space_comm.Get_rank()

    # split world communicator to create time-communicators
    if len(sys.argv) >= 2:
        color = int(world_rank % int(sys.argv[1]))
    else:
        color = int(world_rank / world_size)
    time_comm = comm.Split(color=color)
    time_rank = time_comm.Get_rank()

    # initialize level parameters
    level_params = dict()
    level_params['restol'] = 1e-08
    level_params['dt'] = 0.125
    level_params['nsweeps'] = [1]

    # initialize sweeper parameters
    sweeper_params = dict()
    sweeper_params['quad_type'] = 'RADAU-RIGHT'
    sweeper_params['num_nodes'] = [3]
    sweeper_params['QI'] = ['LU']  # For the IMEX sweeper, the LU-trick can be activated for the implicit part
    sweeper_params['initial_guess'] = 'zero'

    # initialize problem parameters
    problem_params = dict()
    problem_params['nu'] = 1.0  # diffusion coefficient
    problem_params['freq'] = 2  # frequency for the test value
    problem_params['cnvars'] = [(65, 65)]  # number of degrees of freedom for the coarsest level
    problem_params['refine'] = [1, 0]  # number of refinements
    problem_params['comm'] = space_comm  # pass space-communicator to problem class
    problem_params['sol_tol'] = 1e-12  # set tolerance to PETSc' linear solver

    # initialize step parameters
    step_params = dict()
    step_params['maxiter'] = 50

    # initialize space transfer parameters
    space_transfer_params = dict()
    space_transfer_params['rorder'] = 2
    space_transfer_params['iorder'] = 2
    space_transfer_params['periodic'] = False

    # initialize controller parameters
    controller_params = dict()
    controller_params['logger_level'] = 20 if space_rank == 0 else 99  # set level depending on rank
    controller_params['dump_setup'] = False

    # fill description dictionary for easy step instantiation
    description = dict()
    description['problem_class'] = heat2d_petsc_forced  # pass problem class
    description['problem_params'] = problem_params  # pass problem parameters
    description['sweeper_class'] = imex_1st_order  # pass sweeper (see part B)
    description['sweeper_params'] = sweeper_params  # pass sweeper parameters
    description['level_params'] = level_params  # pass level parameters
    description['step_params'] = step_params  # pass step parameters
    description['space_transfer_class'] = mesh_to_mesh_petsc_dmda  # pass spatial transfer class
    description['space_transfer_params'] = space_transfer_params  # pass parameters for spatial transfer

    # set time parameters
    t0 = 0.0
    Tend = 0.25

    # instantiate controller
    controller = controller_MPI(controller_params=controller_params, description=description, comm=time_comm)

    # get initial values on finest level
    P = controller.S.levels[0].prob
    uinit = P.u_exact(t0)

    # call main function to get things done...
    uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)

    # compute exact solution and compare
    uex = P.u_exact(Tend)
    err = abs(uex - uend)

    # filter statistics by type (number of iterations)
    iter_counts = get_sorted(stats, type='niter', sortby='time')

    niters = np.array([item[1] for item in iter_counts])

    # limit output to space-rank 0 (as before when setting the logger level)
    if space_rank == 0:
        if len(sys.argv) == 3:
            fname = str(sys.argv[2])
        else:
            fname = 'step_7_C_out.txt'
        Path("data").mkdir(parents=True, exist_ok=True)
        f = open('data/' + fname, 'a+')

        out = 'This is time-rank %i...' % time_rank
        f.write(out + '\n')
        print(out)

        # compute and print statistics
        for item in iter_counts:
            out = 'Number of iterations for time %4.2f: %2i' % item
            f.write(out + '\n')
            print(out)

        out = '   Mean number of iterations: %4.2f' % np.mean(niters)
        f.write(out + '\n')
        print(out)
        out = '   Range of values for number of iterations: %2i ' % np.ptp(niters)
        f.write(out + '\n')
        print(out)
        out = '   Position of max/min number of iterations: %2i -- %2i' % (
            int(np.argmax(niters)),
            int(np.argmin(niters)),
        )
        f.write(out + '\n')
        print(out)
        out = '   Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
        f.write(out + '\n')
        print(out)

        timing = get_sorted(stats, type='timing_run', sortby='time')

        out = 'Time to solution: %6.4f sec.' % timing[0][1]
        f.write(out + '\n')
        print(out)
        out = 'Error vs. PDE solution: %6.4e' % err
        f.write(out + '\n')
        print(out)

        f.close()

    assert err < 2e-04, 'ERROR: did not match error tolerance, got %s' % err
    assert np.mean(niters) <= 12, 'ERROR: number of iterations is too high, got %s' % np.mean(niters)

    space_comm.Free()
    time_comm.Free()


if __name__ == "__main__":
    main()

Results:

1 processor in time, 1 processor in space

This is time-rank 0...
Number of iterations for time 0.00: 12
Number of iterations for time 0.12: 12
   Mean number of iterations: 12.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 2.3038 sec.
Error vs. PDE solution: 1.9479e-04

1 processor in time, 2 processors in space

This is time-rank 0...
Number of iterations for time 0.00: 12
Number of iterations for time 0.12: 12
   Mean number of iterations: 12.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 0.8779 sec.
Error vs. PDE solution: 1.9479e-04

2 processor in time, 2 processors in space

This is time-rank 0...
Number of iterations for time 0.00: 12
   Mean number of iterations: 12.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 0.7735 sec.
Error vs. PDE solution: 1.9479e-04
This is time-rank 1...
Number of iterations for time 0.12: 12
   Mean number of iterations: 12.00
   Range of values for number of iterations:  0 
   Position of max/min number of iterations:  0 --  0
   Std and var for number of iterations: 0.00 -- 0.00
Time to solution: 0.7726 sec.
Error vs. PDE solution: 1.9479e-04

Part D: pySDC and PyTorch

PyTorch is a library for machine learning. The data structure is called tensor and allows to run on CPUs as well as GPUs in addition to access to various machine learning methods. Since the potential for use in pySDC is very large, we have started on a datatype that allows to use PyTorch tensors throughout pySDC.

This example trains a network to predict the results of implicit Euler solves for the heat equation. It is too simple to do anything useful, but demonstrates how to use tensors in pySDC and then apply the enormous PyTorch infrastructure. This is work in progress in very early stages! The tensor datatype is the simplest possible implementation, rather than an efficient one. If you want to work on this, your input is appreciated!

Full code: pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC
from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor


def train_at_collocation_nodes():
    """
    For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.
    If successful, the initial guess would already be the exact solution and we would need no SDC iterations.

    What we find is that we can train the network to predict the solution to one very specific problem rather well.
    See the error during training for what happens when we ask the network to solve for exactly what it just trained.
    However, if we train for something else, i.e. solving to a different step size in this case, we can only use the
    model to predict the solution of what it's been trained for last and it loses the ability to solve for previously
    learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and
    is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and
    PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that
    actually do something!

    The output shows the training loss during training and, after each of three training sessions is complete, the error
    of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the
    error of all learned predictions after training is complete.
    """
    out = ''
    errors_mid_training = []
    errors_post_training = []

    # instantiate the pySDC problem and a model for PyTorch
    problem = Heat1DFDTensor()
    model = HeatEquationModel(problem)

    # setup neural network
    lr = 0.001
    num_epochs = 250
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # setup initial conditions
    t = 0
    initial_condition = problem.u_exact(t)

    # train the model to predict the solution at certain collocation nodes
    collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2
    for dt in collocation_nodes:

        # get target condition from implicit Euler step
        target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)

        # do the training
        for epoch in range(num_epochs):
            predicted_state = model(initial_condition, t, dt)
            loss = criterion(predicted_state.float(), target_condition.float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (epoch + 1) % 50 == 0:
                out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n'

        # evaluate model to compute error
        model_prediction = model(initial_condition, t, dt)
        errors_mid_training += [abs(target_condition - model_prediction)]
        out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n'

    # compare model and problem
    for dt in collocation_nodes:
        target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
        model_prediction = model(initial_condition, t, dt)
        errors_post_training += [abs(target_condition - model_prediction)]
        out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n'

    print(out)
    with open('data/step_7_D_out.txt', 'w') as file:
        file.write(out)

    # test that the training went as expected
    assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected'
    assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected'

    # save the model to use it throughout pySDC
    torch.save(model.state_dict(), 'data/heat_equation_model.pth')


if __name__ == '__main__':
    train_at_collocation_nodes()

Results:

Training for dt=1.55e-03: Epoch [  50/ 250], Loss: 6.0419e-03
Training for dt=1.55e-03: Epoch [ 100/ 250], Loss: 2.2858e-05
Training for dt=1.55e-03: Epoch [ 150/ 250], Loss: 1.0228e-07
Training for dt=1.55e-03: Epoch [ 200/ 250], Loss: 6.6946e-10
Training for dt=1.55e-03: Epoch [ 250/ 250], Loss: 2.5570e-12
Error of prediction at 1.55e-03 during training: 6.42e-06
Training for dt=6.45e-03: Epoch [  50/ 250], Loss: 8.0754e-05
Training for dt=6.45e-03: Epoch [ 100/ 250], Loss: 7.3524e-07
Training for dt=6.45e-03: Epoch [ 150/ 250], Loss: 4.0323e-09
Training for dt=6.45e-03: Epoch [ 200/ 250], Loss: 2.4311e-11
Training for dt=6.45e-03: Epoch [ 250/ 250], Loss: 8.8606e-14
Error of prediction at 6.45e-03 during training: 6.98e-07
Training for dt=1.00e-02: Epoch [  50/ 250], Loss: 2.0748e-05
Training for dt=1.00e-02: Epoch [ 100/ 250], Loss: 1.1005e-07
Training for dt=1.00e-02: Epoch [ 150/ 250], Loss: 6.0699e-10
Training for dt=1.00e-02: Epoch [ 200/ 250], Loss: 1.0267e-12
Training for dt=1.00e-02: Epoch [ 250/ 250], Loss: 1.7684e-14
Error of prediction at 1.00e-02 during training: 4.03e-07
Error of prediction at 1.55e-03 after training: 4.30e-01
Error of prediction at 6.45e-03 after training: 1.14e-01
Error of prediction at 1.00e-02 after training: 4.03e-07