Coverage for pySDC / projects / Resilience / collocation_adaptivity.py: 95%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-13 09:00 +0000

1import numpy as np 

2import matplotlib.pyplot as plt 

3from matplotlib.colors import TABLEAU_COLORS 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.projects.Resilience.vdp import run_vdp 

7from pySDC.projects.Resilience.advection import run_advection 

8from pySDC.projects.Resilience.heat import run_heat 

9from pySDC.projects.Resilience.hook import LogData 

10from pySDC.projects.Resilience.accuracy_check import get_accuracy_order 

11from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation 

12from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import ( 

13 EstimateEmbeddedErrorCollocation, 

14) 

15from pySDC.core.hooks import Hooks 

16from pySDC.implementations.hooks.log_errors import LogLocalErrorPostIter 

17from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter 

18 

19# define global parameters for running problems and plotting 

20CMAP = list(TABLEAU_COLORS.values()) 

21 

22 

23Tend = 0.015 

24base_params = { 

25 'step_params': {'maxiter': 99}, 

26 'sweeper_params': { 

27 'QI': 'LU', 

28 'num_nodes': 4, 

29 }, 

30 'level_params': {'restol': 1e-8, 'dt': Tend}, 

31} 

32 

33coll_params_inexact = { 

34 'num_nodes': [2, 3, 4], 

35 'restol': [1e-4, 1e-7, 1e-8], 

36} 

37coll_params_refinement = { 

38 'num_nodes': [1, 2, 3, 4], 

39} 

40coll_params_reduce = { 

41 'num_nodes': [4, 3, 2, 1], 

42} 

43coll_params_type = { 

44 # 'quad_type': ['RADAU-RIGHT', 'GAUSS'], 

45 'quad_type': ['GAUSS', 'RADAU-RIGHT', 'LOBATTO'], 

46} 

47 

48special_params = { 

49 'inexact': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_inexact}}, 

50 'refinement': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_refinement}}, 

51 'reduce': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_reduce}}, 

52 'standard': {}, 

53 'type': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_type}}, 

54} 

55 

56 

57# define a few hooks 

58class LogSweeperParams(Hooks): 

59 """ 

60 Log the sweeper parameters after every iteration to check if the adaptive collocation convergence controller is 

61 doing what it's supposed to. 

62 """ 

63 

64 def post_iteration(self, step, level_number): 

65 """ 

66 Args: 

67 step (pySDC.Step.step): the current step 

68 level_number (int): the current level number 

69 

70 Returns: 

71 None 

72 """ 

73 super().post_iteration(step, level_number) 

74 

75 L = step.levels[level_number] 

76 

77 self.add_to_stats( 

78 process=step.status.slot, 

79 time=L.time, 

80 level=L.level_index, 

81 iter=step.status.iter, 

82 sweep=L.status.sweep, 

83 type='sweeper_params', 

84 value=L.sweep.params.__dict__, 

85 ) 

86 self.add_to_stats( 

87 process=step.status.slot, 

88 time=L.time, 

89 level=L.level_index, 

90 iter=step.status.iter, 

91 sweep=L.status.sweep, 

92 type='coll_order', 

93 value=L.sweep.coll.order, 

94 ) 

95 

96 

97# plotting functions 

98def compare_adaptive_collocation(prob): 

99 """ 

100 Run a problem with various modes of adaptive collocation. 

101 

102 Args: 

103 prob (function): A problem from the resilience project to run 

104 

105 Returns: 

106 None 

107 """ 

108 fig, ax = plt.subplots() 

109 node_ax = ax.twinx() 

110 

111 for i in range(len(special_params.keys())): 

112 key = list(special_params.keys())[i] 

113 custom_description = {**base_params, 'convergence_controllers': special_params[key]} 

114 custom_controller_parameters = {'logger_level': 30} 

115 stats, _, _ = prob( 

116 Tend=Tend, 

117 custom_description=custom_description, 

118 custom_controller_params=custom_controller_parameters, 

119 hook_class=[LogData, LogSweeperParams], 

120 ) 

121 

122 plot_residual(stats, ax, node_ax, label=key, color=CMAP[i]) 

123 

124 

125def plot_residual(stats, ax, node_ax, **kwargs): 

126 """ 

127 Plot residual and nodes vs. iteration. 

128 Also a test is performed to see if we can reproduce previously obtained results. 

129 

130 Args: 

131 stats (pySDC.stats): The stats object of the run 

132 ax (Matplotlib.pyplot.axes): Somewhere to plot 

133 node_ax (Matplotlib.pyplot.axes): Somewhere to plot 

134 

135 Returns: 

136 None 

137 """ 

138 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter') 

139 residual = get_sorted(stats, type='residual_post_iteration', sortby='iter') 

140 

141 # determine when the number of collocation nodes increased 

142 nodes = [me[1]['num_nodes'] for me in sweeper_params] 

143 

144 # test if the expected outcome was achieved 

145 label = kwargs['label'] 

146 expect = { 

147 'inexact': [2, 2, 3, 3, 4], 

148 'refinement': [1, 2, 2, 2, 2, 3, 3, 3, 4], 

149 'reduce': [4, 4, 4, 4, 4, 3, 2, 2, 1], 

150 'standard': [4, 4, 4, 4, 4], 

151 'type': [4, 4, 4, 4, 4, 4], 

152 } 

153 assert np.allclose( 

154 nodes, expect[label] 

155 ), f"Unexpected distribution of nodes vs. iteration in {label}! Expected {expect[label]}, got {nodes}" 

156 

157 ax.plot([me[0] for me in residual], [me[1] for me in residual], **kwargs) 

158 ax.set_yscale('log') 

159 ax.legend(frameon=False) 

160 ax.set_xlabel(r'$k$') 

161 ax.set_ylabel(r'residual') 

162 

163 node_ax.plot([me[0] for me in sweeper_params], nodes, **kwargs, ls='--') 

164 node_ax.set_ylabel(r'nodes') 

165 

166 

167def check_order(prob, coll_name, ax, k_ax): 

168 """ 

169 Make plot of the order of the collocation problems and check if they are as expected. 

170 

171 Args: 

172 prob (function): A problem from the resilience project to run 

173 coll_name (str): The name of the collocation refinement strategy 

174 ax (Matplotlib.pyplot.axes): Somewhere to plot 

175 k_ax (Matplotlib.pyplot.axes): Somewhere to plot 

176 

177 Returns: 

178 None 

179 """ 

180 dt_range = [2.0 ** (-i) for i in range(2, 11)] 

181 

182 res = [] 

183 

184 label_keys = { 

185 'type': 'quad_type', 

186 } 

187 

188 for i in range(len(dt_range)): 

189 new_params = { 

190 'level_params': {'restol': 1e-9, 'dt': dt_range[i]}, 

191 'sweeper_params': {'num_nodes': 2, 'QI': 'IE'}, 

192 } 

193 custom_description = {**base_params, 'convergence_controllers': special_params[coll_name], **new_params} 

194 custom_controller_parameters = {'logger_level': 30} 

195 stats, _, _ = prob( 

196 Tend=dt_range[i], 

197 custom_description=custom_description, 

198 custom_controller_params=custom_controller_parameters, 

199 hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter, LogEmbeddedErrorEstimatePostIter], 

200 ) 

201 

202 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter') 

203 converged_solution = [ 

204 sweeper_params[i][1] != sweeper_params[i + 1][1] for i in range(len(sweeper_params) - 1) 

205 ] + [True] 

206 idx = np.arange(len(converged_solution))[converged_solution] 

207 labels = [sweeper_params[i][1][label_keys.get(coll_name, 'num_nodes')] for i in idx] 

208 e_loc = np.array([me[1] for me in get_sorted(stats, type='e_local_post_iteration', sortby='iter')])[ 

209 converged_solution 

210 ] 

211 

212 e_em_raw = [ 

213 me[1] for me in get_sorted(stats, type='error_embedded_estimate_collocation_post_iteration', sortby='iter') 

214 ] 

215 e_em = np.array((e_em_raw + [None] if coll_name == 'refinement' else [None] + e_em_raw)) 

216 coll_order = np.array([me[1] for me in get_sorted(stats, type='coll_order', sortby='iter')])[converged_solution] 

217 

218 res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order, e_em)] 

219 

220 # assemble sth we can compute the order from 

221 result = {'dt': [me[0] for me in res]} 

222 embedded_errors = {'dt': [me[0] for me in res]} 

223 num_sols = len(res[0][1]) 

224 for i in range(num_sols): 

225 result[i] = [me[1][i] for me in res] 

226 embedded_errors[i] = [me[5][i] for me in res] 

227 

228 label = res[0][3][i] 

229 expected_order = res[0][4][i] + 1 

230 

231 ax.scatter(result['dt'], embedded_errors[i], color=CMAP[i]) 

232 

233 for me in [result, embedded_errors]: 

234 if None in me[i]: 

235 continue 

236 order = get_accuracy_order(me, key=i, thresh=1e-9) 

237 assert np.isclose( 

238 np.mean(order), expected_order, atol=0.3 

239 ), f"Expected order: {expected_order}, got {np.mean(order):.2f}!" 

240 ax.loglog(result['dt'], result[i], label=f'{label} nodes: order: {np.mean(order):.1f}', color=CMAP[i]) 

241 

242 if i > 0: 

243 extra_iter = [me[2][i - 1] for me in res] 

244 k_ax.plot(result['dt'], extra_iter, ls='--', color=CMAP[i]) 

245 ax.legend(frameon=False) 

246 ax.set_xlabel(r'$\Delta t$') 

247 ax.set_ylabel(r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots)') 

248 k_ax.set_ylabel(r'extra iterations') 

249 

250 

251def order_stuff(prob): 

252 fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True) 

253 k_axs = [] 

254 modes = ['type', 'refinement', 'reduce'] 

255 for i in range(len(modes)): 

256 k_axs += [axs.flatten()[i].twinx()] 

257 check_order(prob, modes[i], axs.flatten()[i], k_axs[-1]) 

258 axs.flatten()[i].set_title(modes[i]) 

259 

260 for i in range(2): 

261 k_axs[i].set_ylabel('') 

262 

263 for ax in axs[1:]: 

264 ax.set_xlabel('') 

265 ax.set_ylabel('') 

266 fig.tight_layout() 

267 

268 

269def adaptivity_collocation(plotting=False): 

270 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation 

271 

272 e_tol = 1e-7 

273 

274 adaptive_coll_params = { 

275 'num_nodes': [2, 3], 

276 } 

277 

278 convergence_controllers = {} 

279 convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol} 

280 

281 step_params = {} 

282 step_params['maxiter'] = 99 

283 

284 level_params = {} 

285 level_params['restol'] = 1e-8 

286 

287 description = {} 

288 description['convergence_controllers'] = convergence_controllers 

289 description['step_params'] = step_params 

290 description['level_params'] = level_params 

291 

292 controller_params = {'logger_level': 30} 

293 

294 stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params) 

295 

296 e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False) 

297 assert ( 

298 max([me[1] for me in e_em]) <= e_tol 

299 ), "Exceeded threshold for local tolerance when using collocation based adaptivity" 

300 assert ( 

301 min([me[1] for me in e_em][1:-1]) >= e_tol / 10 

302 ), "Over resolved problem when using collocation based adaptivity" 

303 

304 if plotting: 

305 from pySDC.projects.Resilience.vdp import plot_step_sizes 

306 

307 fig, ax = plt.subplots() 

308 

309 plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation') 

310 

311 

312def main(plotting=False): 

313 adaptivity_collocation(plotting) 

314 order_stuff(run_advection) 

315 compare_adaptive_collocation(run_vdp) 

316 

317 

318if __name__ == "__main__": 

319 main(True) 

320 plt.show()