Coverage for pySDC/projects/Resilience/collocation_adaptivity.py: 95%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import numpy as np 

2import matplotlib.pyplot as plt 

3from matplotlib.colors import TABLEAU_COLORS 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.projects.Resilience.vdp import run_vdp 

7from pySDC.projects.Resilience.advection import run_advection 

8from pySDC.projects.Resilience.heat import run_heat 

9from pySDC.projects.Resilience.hook import LogData 

10from pySDC.projects.Resilience.accuracy_check import get_accuracy_order 

11from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation 

12from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import ( 

13 EstimateEmbeddedErrorCollocation, 

14) 

15from pySDC.core.hooks import Hooks 

16from pySDC.implementations.hooks.log_errors import LogLocalErrorPostIter 

17from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter 

18 

19 

20# define global parameters for running problems and plotting 

21CMAP = list(TABLEAU_COLORS.values()) 

22 

23 

24Tend = 0.015 

25base_params = { 

26 'step_params': {'maxiter': 99}, 

27 'sweeper_params': { 

28 'QI': 'LU', 

29 'num_nodes': 4, 

30 }, 

31 'level_params': {'restol': 1e-8, 'dt': Tend}, 

32} 

33 

34coll_params_inexact = { 

35 'num_nodes': [2, 3, 4], 

36 'restol': [1e-4, 1e-7, 1e-8], 

37} 

38coll_params_refinement = { 

39 'num_nodes': [1, 2, 3, 4], 

40} 

41coll_params_reduce = { 

42 'num_nodes': [4, 3, 2, 1], 

43} 

44coll_params_type = { 

45 # 'quad_type': ['RADAU-RIGHT', 'GAUSS'], 

46 'quad_type': ['GAUSS', 'RADAU-RIGHT', 'LOBATTO'], 

47} 

48 

49special_params = { 

50 'inexact': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_inexact}}, 

51 'refinement': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_refinement}}, 

52 'reduce': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_reduce}}, 

53 'standard': {}, 

54 'type': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_type}}, 

55} 

56 

57 

58# define a few hooks 

59class LogSweeperParams(Hooks): 

60 """ 

61 Log the sweeper parameters after every iteration to check if the adaptive collocation convergence controller is 

62 doing what it's supposed to. 

63 """ 

64 

65 def post_iteration(self, step, level_number): 

66 """ 

67 Args: 

68 step (pySDC.Step.step): the current step 

69 level_number (int): the current level number 

70 

71 Returns: 

72 None 

73 """ 

74 super().post_iteration(step, level_number) 

75 

76 L = step.levels[level_number] 

77 

78 self.add_to_stats( 

79 process=step.status.slot, 

80 time=L.time, 

81 level=L.level_index, 

82 iter=step.status.iter, 

83 sweep=L.status.sweep, 

84 type='sweeper_params', 

85 value=L.sweep.params.__dict__, 

86 ) 

87 self.add_to_stats( 

88 process=step.status.slot, 

89 time=L.time, 

90 level=L.level_index, 

91 iter=step.status.iter, 

92 sweep=L.status.sweep, 

93 type='coll_order', 

94 value=L.sweep.coll.order, 

95 ) 

96 

97 

98# plotting functions 

99def compare_adaptive_collocation(prob): 

100 """ 

101 Run a problem with various modes of adaptive collocation. 

102 

103 Args: 

104 prob (function): A problem from the resilience project to run 

105 

106 Returns: 

107 None 

108 """ 

109 fig, ax = plt.subplots() 

110 node_ax = ax.twinx() 

111 

112 for i in range(len(special_params.keys())): 

113 key = list(special_params.keys())[i] 

114 custom_description = {**base_params, 'convergence_controllers': special_params[key]} 

115 custom_controller_parameters = {'logger_level': 30} 

116 stats, _, _ = prob( 

117 Tend=Tend, 

118 custom_description=custom_description, 

119 custom_controller_params=custom_controller_parameters, 

120 hook_class=[LogData, LogSweeperParams], 

121 ) 

122 

123 plot_residual(stats, ax, node_ax, label=key, color=CMAP[i]) 

124 

125 

126def plot_residual(stats, ax, node_ax, **kwargs): 

127 """ 

128 Plot residual and nodes vs. iteration. 

129 Also a test is performed to see if we can reproduce previously obtained results. 

130 

131 Args: 

132 stats (pySDC.stats): The stats object of the run 

133 ax (Matplotlib.pyplot.axes): Somewhere to plot 

134 node_ax (Matplotlib.pyplot.axes): Somewhere to plot 

135 

136 Returns: 

137 None 

138 """ 

139 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter') 

140 residual = get_sorted(stats, type='residual_post_iteration', sortby='iter') 

141 

142 # determine when the number of collocation nodes increased 

143 nodes = [me[1]['num_nodes'] for me in sweeper_params] 

144 

145 # test if the expected outcome was achieved 

146 label = kwargs['label'] 

147 expect = { 

148 'inexact': [2, 2, 3, 3, 4], 

149 'refinement': [1, 2, 2, 2, 2, 3, 3, 3, 4], 

150 'reduce': [4, 4, 4, 4, 4, 3, 2, 2, 1], 

151 'standard': [4, 4, 4, 4, 4], 

152 'type': [4, 4, 4, 4, 4, 4], 

153 } 

154 assert np.allclose( 

155 nodes, expect[label] 

156 ), f"Unexpected distribution of nodes vs. iteration in {label}! Expected {expect[label]}, got {nodes}" 

157 

158 ax.plot([me[0] for me in residual], [me[1] for me in residual], **kwargs) 

159 ax.set_yscale('log') 

160 ax.legend(frameon=False) 

161 ax.set_xlabel(r'$k$') 

162 ax.set_ylabel(r'residual') 

163 

164 node_ax.plot([me[0] for me in sweeper_params], nodes, **kwargs, ls='--') 

165 node_ax.set_ylabel(r'nodes') 

166 

167 

168def check_order(prob, coll_name, ax, k_ax): 

169 """ 

170 Make plot of the order of the collocation problems and check if they are as expected. 

171 

172 Args: 

173 prob (function): A problem from the resilience project to run 

174 coll_name (str): The name of the collocation refinement strategy 

175 ax (Matplotlib.pyplot.axes): Somewhere to plot 

176 k_ax (Matplotlib.pyplot.axes): Somewhere to plot 

177 

178 Returns: 

179 None 

180 """ 

181 dt_range = [2.0 ** (-i) for i in range(2, 11)] 

182 

183 res = [] 

184 

185 label_keys = { 

186 'type': 'quad_type', 

187 } 

188 

189 for i in range(len(dt_range)): 

190 new_params = { 

191 'level_params': {'restol': 1e-9, 'dt': dt_range[i]}, 

192 'sweeper_params': {'num_nodes': 2, 'QI': 'IE'}, 

193 } 

194 custom_description = {**base_params, 'convergence_controllers': special_params[coll_name], **new_params} 

195 custom_controller_parameters = {'logger_level': 30} 

196 stats, _, _ = prob( 

197 Tend=dt_range[i], 

198 custom_description=custom_description, 

199 custom_controller_params=custom_controller_parameters, 

200 hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter, LogEmbeddedErrorEstimatePostIter], 

201 ) 

202 

203 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter') 

204 converged_solution = [ 

205 sweeper_params[i][1] != sweeper_params[i + 1][1] for i in range(len(sweeper_params) - 1) 

206 ] + [True] 

207 idx = np.arange(len(converged_solution))[converged_solution] 

208 labels = [sweeper_params[i][1][label_keys.get(coll_name, 'num_nodes')] for i in idx] 

209 e_loc = np.array([me[1] for me in get_sorted(stats, type='e_local_post_iteration', sortby='iter')])[ 

210 converged_solution 

211 ] 

212 

213 e_em_raw = [ 

214 me[1] for me in get_sorted(stats, type='error_embedded_estimate_collocation_post_iteration', sortby='iter') 

215 ] 

216 e_em = np.array((e_em_raw + [None] if coll_name == 'refinement' else [None] + e_em_raw)) 

217 coll_order = np.array([me[1] for me in get_sorted(stats, type='coll_order', sortby='iter')])[converged_solution] 

218 

219 res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order, e_em)] 

220 

221 # assemble sth we can compute the order from 

222 result = {'dt': [me[0] for me in res]} 

223 embedded_errors = {'dt': [me[0] for me in res]} 

224 num_sols = len(res[0][1]) 

225 for i in range(num_sols): 

226 result[i] = [me[1][i] for me in res] 

227 embedded_errors[i] = [me[5][i] for me in res] 

228 

229 label = res[0][3][i] 

230 expected_order = res[0][4][i] + 1 

231 

232 ax.scatter(result['dt'], embedded_errors[i], color=CMAP[i]) 

233 

234 for me in [result, embedded_errors]: 

235 if None in me[i]: 

236 continue 

237 order = get_accuracy_order(me, key=i, thresh=1e-9) 

238 assert np.isclose( 

239 np.mean(order), expected_order, atol=0.3 

240 ), f"Expected order: {expected_order}, got {np.mean(order):.2f}!" 

241 ax.loglog(result['dt'], result[i], label=f'{label} nodes: order: {np.mean(order):.1f}', color=CMAP[i]) 

242 

243 if i > 0: 

244 extra_iter = [me[2][i - 1] for me in res] 

245 k_ax.plot(result['dt'], extra_iter, ls='--', color=CMAP[i]) 

246 ax.legend(frameon=False) 

247 ax.set_xlabel(r'$\Delta t$') 

248 ax.set_ylabel(r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots)') 

249 k_ax.set_ylabel(r'extra iterations') 

250 

251 

252def order_stuff(prob): 

253 fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True) 

254 k_axs = [] 

255 modes = ['type', 'refinement', 'reduce'] 

256 for i in range(len(modes)): 

257 k_axs += [axs.flatten()[i].twinx()] 

258 check_order(prob, modes[i], axs.flatten()[i], k_axs[-1]) 

259 axs.flatten()[i].set_title(modes[i]) 

260 

261 for i in range(2): 

262 k_axs[i].set_ylabel('') 

263 

264 for ax in axs[1:]: 

265 ax.set_xlabel('') 

266 ax.set_ylabel('') 

267 fig.tight_layout() 

268 

269 

270def adaptivity_collocation(plotting=False): 

271 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation 

272 

273 e_tol = 1e-7 

274 

275 adaptive_coll_params = { 

276 'num_nodes': [2, 3], 

277 } 

278 

279 convergence_controllers = {} 

280 convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol} 

281 

282 step_params = {} 

283 step_params['maxiter'] = 99 

284 

285 level_params = {} 

286 level_params['restol'] = 1e-8 

287 

288 description = {} 

289 description['convergence_controllers'] = convergence_controllers 

290 description['step_params'] = step_params 

291 description['level_params'] = level_params 

292 

293 controller_params = {'logger_level': 30} 

294 

295 stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params) 

296 

297 e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False) 

298 assert ( 

299 max([me[1] for me in e_em]) <= e_tol 

300 ), "Exceeded threshold for local tolerance when using collocation based adaptivity" 

301 assert ( 

302 min([me[1] for me in e_em][1:-1]) >= e_tol / 10 

303 ), "Over resolved problem when using collocation based adaptivity" 

304 

305 if plotting: 

306 from pySDC.projects.Resilience.vdp import plot_step_sizes 

307 

308 fig, ax = plt.subplots() 

309 

310 plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation') 

311 

312 

313def main(plotting=False): 

314 adaptivity_collocation(plotting) 

315 order_stuff(run_advection) 

316 compare_adaptive_collocation(run_vdp) 

317 

318 

319if __name__ == "__main__": 

320 main(True) 

321 plt.show()