Integrator Benchmark Example
Test Hiten integrators against SciPy on various ODE problems and provides
Testing Hiten integrators against SciPy on various ODE problems
1#!/usr/bin/env python3
2"""Integrator benchmark script.
3
4Tests Hiten integrators against SciPy on various ODE problems and provides
5detailed accuracy, performance, and error analysis.
6"""
7
8import os
9import sys
10import time
11import warnings
12from dataclasses import dataclass
13from typing import List
14
15import matplotlib.pyplot as plt
16import numpy as np
17from scipy.integrate import solve_ivp
18
19sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', "src"))
20
21from hiten.algorithms.dynamics.rhs import create_rhs_system
22from hiten.algorithms.integrators import (AdaptiveRK, RungeKutta)
23
24warnings.filterwarnings('ignore', category=UserWarning)
25
26
27@dataclass
28class TestResult:
29 """Store results from a single integrator test."""
30 integrator_name: str
31 problem_name: str
32 final_error: float
33 max_error: float
34 relative_error: float
35 computation_time: float
36 n_steps: int
37 energy_error: float = None
38 converged: bool = True
39 error_message: str = ""
40
41
42# Test problems as simple functions
43def harmonic_oscillator(t, y):
44 """Harmonic oscillator: x'' + x = 0."""
45 return np.array([y[1], -y[0]])
46
47def van_der_pol(t, y):
48 """Van der Pol oscillator: x'' - μ(1-x²)x' + x = 0."""
49 mu = 1.0
50 return np.array([y[1], mu * (1 - y[0]**2) * y[1] - y[0]])
51
52def lorenz(t, y):
53 """Lorenz system: chaotic attractor."""
54 sigma, rho, beta = 10.0, 28.0, 8.0/3.0
55 return np.array([
56 sigma * (y[1] - y[0]),
57 y[0] * (rho - y[2]) - y[1],
58 y[0] * y[1] - beta * y[2]
59 ])
60
61def kepler(t, y):
62 """Kepler problem: 2D central force motion."""
63 x, y_pos, vx, vy = y
64 mu = 1.0
65 r = np.sqrt(x**2 + y_pos**2)
66 r3 = r**3
67 return np.array([vx, vy, -mu * x / r3, -mu * y_pos / r3])
68
69def duffing(t, y):
70 """Duffing oscillator: x'' + δx' + αx + βx³ = γcos(ωt)."""
71 delta, alpha, beta = 0.1, -1.0, 1.0
72 gamma, omega = 0.3, 1.2
73 return np.array([
74 y[1],
75 -delta * y[1] - alpha * y[0] - beta * y[0]**3 +
76 gamma * np.cos(omega * t)
77 ])
78
79# Test problem configurations
80TEST_PROBLEMS = [
81 {
82 "name": "Harmonic Oscillator",
83 "rhs": harmonic_oscillator,
84 "y0": np.array([1.0, 0.0]),
85 "t_span": (0.0, 4*np.pi),
86 "exact": lambda t: np.column_stack([np.cos(t), -np.sin(t)]),
87 "energy": lambda y: 0.5 * (y[0]**2 + y[1]**2)
88 },
89 {
90 "name": "Van der Pol",
91 "rhs": van_der_pol,
92 "y0": np.array([2.0, 0.0]),
93 "t_span": (0.0, 20.0),
94 "exact": None,
95 "energy": None
96 },
97 {
98 "name": "Lorenz",
99 "rhs": lorenz,
100 "y0": np.array([1.0, 1.0, 1.0]),
101 "t_span": (0.0, 20.0),
102 "exact": None,
103 "energy": None
104 },
105 {
106 "name": "Kepler",
107 "rhs": kepler,
108 "y0": np.array([1.0, 0.0, 0.0, 0.8]),
109 "t_span": (0.0, 10.0),
110 "exact": None,
111 "energy": lambda y: 0.5 * (y[2]**2 + y[3]**2) - 1.0 / np.sqrt(y[0]**2 + y[1]**2)
112 },
113 {
114 "name": "Duffing",
115 "rhs": duffing,
116 "y0": np.array([1.0, 0.0]),
117 "t_span": (0.0, 20.0),
118 "exact": None,
119 "energy": None
120 }
121]
122
123
124def run_integrator_test(integrator, system, problem_config, t_eval: np.ndarray) -> TestResult:
125 """Run a single integrator test and return results."""
126
127 # Warm-up JIT compilation (excluded from timing)
128 try:
129 if t_eval.size >= 2:
130 dt = t_eval[1] - t_eval[0]
131 warmup_t = np.array([t_eval[0], t_eval[0] + max(dt, 1e-8)], dtype=float)
132 else:
133 warmup_t = np.array([t_eval[0], t_eval[0] + 1e-8], dtype=float)
134 _ = integrator.integrate(system, problem_config["y0"], warmup_t)
135 except Exception:
136 pass
137
138 start_time = time.perf_counter()
139
140 try:
141 solution = integrator.integrate(system, problem_config["y0"], t_eval)
142 computation_time = time.perf_counter() - start_time
143
144 y_solution = solution.states
145
146 # Calculate errors
147 if problem_config["exact"] is not None:
148 y_exact = problem_config["exact"](t_eval)
149 errors = np.abs(y_solution - y_exact)
150 final_error = np.linalg.norm(errors[-1])
151 max_error = np.max(errors)
152 relative_error = max_error / (np.max(np.abs(y_exact)) + 1e-16)
153 else:
154 # Use SciPy as reference
155 scipy_sol = solve_ivp(
156 system.rhs, problem_config["t_span"], problem_config["y0"],
157 t_eval=t_eval, rtol=1e-12, atol=1e-14
158 )
159 y_ref = scipy_sol.y.T
160 errors = np.abs(y_solution - y_ref)
161 final_error = np.linalg.norm(errors[-1])
162 max_error = np.max(errors)
163 relative_error = max_error / (np.max(np.abs(y_ref)) + 1e-16)
164
165 # Calculate energy error if applicable
166 energy_error = None
167 if problem_config["energy"] is not None:
168 initial_energy = problem_config["energy"](problem_config["y0"])
169 final_energy = problem_config["energy"](y_solution[-1])
170 energy_error = abs(final_energy - initial_energy) / abs(initial_energy)
171
172 # Count steps (approximate for fixed-step methods)
173 n_steps = len(t_eval) - 1
174
175 return TestResult(
176 integrator_name=str(integrator),
177 problem_name=problem_config["name"],
178 final_error=final_error,
179 max_error=max_error,
180 relative_error=relative_error,
181 computation_time=computation_time,
182 n_steps=n_steps,
183 energy_error=energy_error,
184 converged=True
185 )
186
187 except Exception as e:
188 computation_time = time.perf_counter() - start_time
189 return TestResult(
190 integrator_name=str(integrator),
191 problem_name=problem_config["name"],
192 final_error=np.inf,
193 max_error=np.inf,
194 relative_error=np.inf,
195 computation_time=computation_time,
196 n_steps=0,
197 converged=False,
198 error_message=str(e)
199 )
200
201
202def run_benchmark():
203 """Run comprehensive tests on all integrators and problems."""
204
205 # Define integrators to test
206 integrators = {
207 # Fixed-step RK methods
208 "RK4": RungeKutta(order=4),
209 "RK6": RungeKutta(order=6),
210 "RK8": RungeKutta(order=8),
211
212 # Adaptive RK methods
213 "RK45": AdaptiveRK(order=5, rtol=1e-8, atol=1e-10),
214 "DOP853": AdaptiveRK(order=8, rtol=1e-8, atol=1e-10),
215 }
216
217 # SciPy reference
218 scipy_integrators = {
219 "SciPy-RK45": "RK45",
220 "SciPy-DOP853": "DOP853",
221 "SciPy-RK23": "RK23",
222 "SciPy-BDF": "BDF",
223 "SciPy-LSODA": "LSODA"
224 }
225
226 # Time grid for evaluation
227 t_eval = np.linspace(0, 10, 1001)
228
229 all_results = []
230
231 print("Running comprehensive integrator tests...")
232 print("=" * 60)
233
234 for problem_config in TEST_PROBLEMS:
235 print(f"\nTesting problem: {problem_config['name']}")
236 print("-" * 40)
237 # Build a single dynamical system instance per problem for all tests
238 system = create_rhs_system(problem_config["rhs"], dim=len(problem_config["y0"]), name=problem_config["name"])
239
240 # Test Hiten integrators
241 for name, integrator in integrators.items():
242 result = run_integrator_test(integrator, system, problem_config, t_eval)
243 all_results.append(result)
244
245 if result.converged:
246 print(f"{name:15s}: "
247 f"Final Error: {result.final_error:.2e}, "
248 f"Max Error: {result.max_error:.2e}, "
249 f"Time: {result.computation_time:.4f}s")
250
251 if result.energy_error is not None:
252 print(f"{'':15s} Energy Error: {result.energy_error:.2e}")
253 else:
254 print(f"{name:15s}: FAILED - {result.error_message}")
255
256 # Test SciPy integrators
257 for name, method in scipy_integrators.items():
258 start_time = time.perf_counter()
259
260 try:
261 scipy_sol = solve_ivp(
262 system.rhs, problem_config["t_span"], problem_config["y0"],
263 t_eval=t_eval, method=method, rtol=1e-8, atol=1e-10
264 )
265 computation_time = time.perf_counter() - start_time
266
267 if scipy_sol.success:
268 y_solution = scipy_sol.y.T
269
270 # Calculate errors
271 if problem_config["exact"] is not None:
272 y_exact = problem_config["exact"](t_eval)
273 errors = np.abs(y_solution - y_exact)
274 final_error = np.linalg.norm(errors[-1])
275 max_error = np.max(errors)
276 relative_error = max_error / (np.max(np.abs(y_exact)) + 1e-16)
277 else:
278 # Use highest accuracy SciPy as reference
279 ref_sol = solve_ivp(
280 system.rhs, problem_config["t_span"], problem_config["y0"],
281 t_eval=t_eval, method="DOP853", rtol=1e-12, atol=1e-14
282 )
283 y_ref = ref_sol.y.T
284 errors = np.abs(y_solution - y_ref)
285 final_error = np.linalg.norm(errors[-1])
286 max_error = np.max(errors)
287 relative_error = max_error / (np.max(np.abs(y_ref)) + 1e-16)
288
289 # Energy error
290 energy_error = None
291 if problem_config["energy"] is not None:
292 initial_energy = problem_config["energy"](problem_config["y0"])
293 final_energy = problem_config["energy"](y_solution[-1])
294 energy_error = abs(final_energy - initial_energy) / abs(initial_energy)
295
296 result = TestResult(
297 integrator_name=name,
298 problem_name=problem_config["name"],
299 final_error=final_error,
300 max_error=max_error,
301 relative_error=relative_error,
302 computation_time=computation_time,
303 n_steps=len(scipy_sol.t) - 1,
304 energy_error=energy_error,
305 converged=True
306 )
307 all_results.append(result)
308
309 print(f"{name:15s}: "
310 f"Final Error: {result.final_error:.2e}, "
311 f"Max Error: {result.max_error:.2e}, "
312 f"Time: {result.computation_time:.4f}s")
313
314 if result.energy_error is not None:
315 print(f"{'':15s} Energy Error: {result.energy_error:.2e}")
316 else:
317 print(f"{name:15s}: FAILED - {scipy_sol.message}")
318
319 except Exception as e:
320 print(f"{name:15s}: ERROR - {e}")
321
322 return all_results
323
324
325def print_summary_table(results: List[TestResult]):
326 """Print a summary table of all results."""
327
328 print("\n" + "=" * 100)
329 print("SUMMARY TABLE")
330 print("=" * 100)
331 print(f"{'Integrator':<15} {'Problem':<20} {'Final Error':<12} {'Max Error':<12} {'Time (s)':<10} {'Energy Error':<12}")
332 print("-" * 100)
333
334 # Group by problem for better readability
335 problems = {}
336 for result in results:
337 if result.problem_name not in problems:
338 problems[result.problem_name] = []
339 problems[result.problem_name].append(result)
340
341 for problem_name in sorted(problems.keys()):
342 problem_results = problems[problem_name]
343 for result in sorted(problem_results, key=lambda x: x.final_error):
344 energy_str = f"{result.energy_error:.2e}" if result.energy_error is not None else "N/A"
345 print(f"{result.integrator_name:<15} {result.problem_name:<20} "
346 f"{result.final_error:<12.2e} {result.max_error:<12.2e} "
347 f"{result.computation_time:<10.4f} {energy_str:<12}")
348
349
350def create_performance_plots(results: List[TestResult]):
351 """Create performance comparison plots."""
352
353 # Create results directory
354 os.makedirs('_debug/results/plots', exist_ok=True)
355
356 # Extract data for plotting
357 integrators = list(set(r.integrator_name for r in results if r.converged))
358 problems = list(set(r.problem_name for r in results if r.converged))
359
360 # Create accuracy vs speed plot
361 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
362
363 colors = plt.cm.tab10(np.linspace(0, 1, len(integrators)))
364 integrator_colors = dict(zip(integrators, colors))
365
366 # Plot 1: Final Error vs Time
367 for integrator in integrators:
368 integrator_results = [r for r in results if r.integrator_name == integrator and r.converged]
369 if integrator_results:
370 times = [r.computation_time for r in integrator_results]
371 errors = [r.final_error for r in integrator_results]
372 ax1.loglog(times, errors, 'o', color=integrator_colors[integrator],
373 label=integrator, markersize=8)
374
375 ax1.set_xlabel('Computation Time (s)')
376 ax1.set_ylabel('Final Error')
377 ax1.set_title('Accuracy vs Speed')
378 ax1.legend()
379 ax1.grid(True, alpha=0.3)
380
381 # Plot 2: Max Error vs Time
382 for integrator in integrators:
383 integrator_results = [r for r in results if r.integrator_name == integrator and r.converged]
384 if integrator_results:
385 times = [r.computation_time for r in integrator_results]
386 errors = [r.max_error for r in integrator_results]
387 ax2.loglog(times, errors, 'o', color=integrator_colors[integrator],
388 label=integrator, markersize=8)
389
390 ax2.set_xlabel('Computation Time (s)')
391 ax2.set_ylabel('Max Error')
392 ax2.set_title('Max Error vs Speed')
393 ax2.legend()
394 ax2.grid(True, alpha=0.3)
395
396 plt.tight_layout()
397 plt.savefig('_debug/results/plots/performance_comparison.png', dpi=300, bbox_inches='tight')
398 plt.close()
399
400 print("Performance plots saved to _debug/results/plots/")
401
402
403def main():
404 """Main function to run all tests."""
405
406 print("Hiten Integrators vs SciPy Comprehensive Benchmark")
407 print("=" * 60)
408
409 # Run comprehensive tests
410 results = run_benchmark()
411
412 # Print summary
413 print_summary_table(results)
414
415 # Create plots
416 create_performance_plots(results)
417
418 print("\nBenchmark completed successfully!")
419
420
421if __name__ == "__main__":
422 main()
Event Integrator Benchmark Example
Tests Hiten’s event-capable integrator (DOP853) against SciPy solvers on simple problems with known event times. Reports detection accuracy and speed, and saves a comparison plot.
Testing Hiten’s event-capable integrator (DOP853) against SciPy solvers on simple problems with known event times
1"""Event-handling integrator benchmark.
2
3Compares Hiten's event-capable integrator (DOP853) against SciPy solvers on
4simple problems with known event times. Reports detection accuracy and speed,
5and saves a comparison plot.
6"""
7
8import os
9import sys
10import time
11import warnings
12from dataclasses import dataclass
13from typing import Callable, Dict, List, Optional, Tuple
14
15import matplotlib.pyplot as plt
16import numpy as np
17from numba import njit, types
18from scipy.integrate import solve_ivp
19
20# Make project src importable when running from repository root
21sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
22
23from hiten.algorithms.dynamics.rhs import create_rhs_system
24from hiten.algorithms.integrators import RungeKutta
25from hiten.algorithms.integrators.configs import _EventConfig
26
27warnings.filterwarnings('ignore', category=UserWarning)
28
29
30@dataclass
31class EventBenchmarkResult:
32 integrator_name: str
33 problem_name: str
34 t_event: Optional[float]
35 y_event: Optional[np.ndarray]
36 t_error: float
37 y_error: Optional[float]
38 computation_time: float
39 converged: bool = True
40 error_message: str = ""
41
42
43# --- Precompiled global RHS and event functions (stable identities) ---
44@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
45def g_y1_jit(t: float, y: np.ndarray) -> float:
46 return y[0] - 1.0
47
48
49@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
50def g_x0_jit(t: float, y: np.ndarray) -> float:
51 return y[0]
52
53
54@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
55def rhs_inc_jit(t: float, y: np.ndarray) -> np.ndarray:
56 out = np.empty(1, dtype=np.float64)
57 out[0] = 1.0
58 return out
59
60
61@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
62def rhs_dec_jit(t: float, y: np.ndarray) -> np.ndarray:
63 out = np.empty(1, dtype=np.float64)
64 out[0] = -1.0
65 return out
66
67
68@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
69def rhs_ho_jit(t: float, y: np.ndarray) -> np.ndarray:
70 out = np.empty(2, dtype=np.float64)
71 out[0] = y[1]
72 out[1] = -y[0]
73 return out
74
75
76# --- Event test problems using precompiled functions ---
77def _problems() -> List[Dict]:
78 problems: List[Dict] = []
79
80 problems.append({
81 "name": "y' = 1, hit y=1 (inc)",
82 "rhs": rhs_inc_jit,
83 "y0": np.array([0.2], dtype=float),
84 "t_span": (0.0, 5.0),
85 "event": g_y1_jit,
86 "direction": +1,
87 "t_expected": lambda y0: 1.0 - float(y0[0]),
88 "dim": 1,
89 })
90
91 problems.append({
92 "name": "y' = -1, hit y=1 (dec)",
93 "rhs": rhs_dec_jit,
94 "y0": np.array([2.0], dtype=float),
95 "t_span": (0.0, 5.0),
96 "event": g_y1_jit,
97 "direction": -1,
98 "t_expected": lambda y0: float(y0[0]) - 1.0,
99 "dim": 1,
100 })
101
102 problems.append({
103 "name": "Harmonic oscillator, hit x=0",
104 "rhs": rhs_ho_jit,
105 "y0": np.array([1.0, 0.0], dtype=float),
106 "t_span": (0.0, 5.0),
107 "event": g_x0_jit,
108 "direction": -1,
109 "t_expected": lambda y0: 0.5 * np.pi,
110 "dim": 2,
111 })
112
113 return problems
114
115
116# --- Warm-up helpers (avoid JIT bias in timing) ---
117def _make_warmup_times(t0: float) -> np.ndarray:
118 return np.array([t0, t0 + 1.0e1], dtype=float)
119
120
121def _is_fixed_step(integrator) -> bool:
122 try:
123 return hasattr(integrator, "_integrate_fixed_rk")
124 except Exception:
125 return False
126
127
128def _warmup_hiten_event(integrator, system, y0: np.ndarray, event_fn: Callable[[float, np.ndarray], float], direction: int) -> None:
129 try:
130 warmup_t = _make_warmup_times(0.0)
131 if _is_fixed_step(integrator):
132 # Use a small grid to trigger step-based event path and refinement
133 t_grid = np.linspace(warmup_t[0], warmup_t[-1], 64)
134 _ = integrator.integrate(
135 system,
136 y0,
137 t_grid,
138 event_fn=event_fn,
139 event_cfg=_EventConfig(direction=direction, terminal=True),
140 )
141 else:
142 _ = integrator.integrate(
143 system,
144 y0,
145 warmup_t,
146 event_fn=event_fn,
147 event_cfg=_EventConfig(direction=direction, terminal=True),
148 )
149 except Exception:
150 pass
151
152
153def _warmup_scipy_event(method: str, rhs: Callable[[float, np.ndarray], np.ndarray], y0: np.ndarray, t0: float, event_fn: Callable[[float, np.ndarray], float], direction: int) -> None:
154 try:
155 def ev(t, y):
156 return event_fn(t, y)
157 ev.terminal = True
158 ev.direction = float(direction)
159 warmup_t = _make_warmup_times(t0)
160 _ = solve_ivp(rhs, (warmup_t[0], warmup_t[-1]), y0, method=method, events=ev, rtol=1e-6, atol=1e-8)
161 except Exception:
162 pass
163
164
165def run_hiten_event(
166 integrator,
167 system,
168 y0: np.ndarray,
169 t_span: Tuple[float, float],
170 event_fn: Callable[[float, np.ndarray], float],
171 direction: int,
172 t_expected: float,
173 problem_name: str,
174) -> EventBenchmarkResult:
175 _warmup_hiten_event(integrator, system, y0, event_fn, direction)
176 start = time.perf_counter()
177 try:
178 # Adaptive integrators only need endpoints; fixed-step benefits from a grid
179 if _is_fixed_step(integrator):
180 # Use a moderately fine grid for fixed-step methods
181 n_grid = 4097
182 t_eval = np.linspace(t_span[0], t_span[1], n_grid, dtype=float)
183 else:
184 t_eval = np.array([t_span[0], t_span[1]], dtype=float)
185 sol = integrator.integrate(
186 system,
187 y0,
188 t_eval,
189 event_fn=event_fn,
190 event_cfg=_EventConfig(direction=direction, terminal=True),
191 )
192 dt = time.perf_counter() - start
193 # Event path returns two nodes [t0, t_hit] or [t0, tmax] when no hit
194 t_event = float(sol.times[-1])
195 y_event = sol.states[-1].copy()
196 t_err = abs(t_event - t_expected)
197 y_err = float(np.linalg.norm(y_event - y_event)) # always 0 vs itself (placeholder)
198 return EventBenchmarkResult(
199 integrator_name=str(integrator),
200 problem_name=problem_name,
201 t_event=t_event,
202 y_event=y_event,
203 t_error=t_err,
204 y_error=y_err,
205 computation_time=dt,
206 converged=True,
207 )
208 except Exception as e:
209 dt = time.perf_counter() - start
210 return EventBenchmarkResult(
211 integrator_name=str(integrator),
212 problem_name=problem_name,
213 t_event=None,
214 y_event=None,
215 t_error=np.inf,
216 y_error=None,
217 computation_time=dt,
218 converged=False,
219 error_message=str(e),
220 )
221
222
223def run_scipy_event(
224 name: str,
225 method: str,
226 rhs: Callable[[float, np.ndarray], np.ndarray],
227 y0: np.ndarray,
228 t_span: Tuple[float, float],
229 event_fn: Callable[[float, np.ndarray], float],
230 direction: int,
231 t_expected: float,
232 problem_name: str,
233) -> EventBenchmarkResult:
234 _warmup_scipy_event(method, rhs, y0, t_span[0], event_fn, direction)
235 def ev(t, y):
236 return event_fn(t, y)
237 ev.terminal = True
238 ev.direction = float(direction)
239 start = time.perf_counter()
240 try:
241 sol = solve_ivp(
242 rhs,
243 t_span,
244 y0,
245 method=method,
246 events=ev,
247 rtol=1e-8,
248 atol=1e-10,
249 )
250 dt = time.perf_counter() - start
251 if sol.t_events and len(sol.t_events[0]) > 0:
252 t_event = float(sol.t_events[0][0])
253 y_event = sol.y_events[0][0].copy() if hasattr(sol, 'y_events') and sol.y_events and len(sol.y_events[0]) > 0 else None
254 t_err = abs(t_event - t_expected)
255 y_err = float(np.linalg.norm(y_event - y_event)) if y_event is not None else None
256 return EventBenchmarkResult(
257 integrator_name=name,
258 problem_name=problem_name,
259 t_event=t_event,
260 y_event=y_event,
261 t_error=t_err,
262 y_error=y_err,
263 computation_time=dt,
264 converged=True,
265 )
266 else:
267 return EventBenchmarkResult(
268 integrator_name=name,
269 problem_name=problem_name,
270 t_event=None,
271 y_event=None,
272 t_error=np.inf,
273 y_error=None,
274 computation_time=dt,
275 converged=False,
276 error_message=str(sol.message),
277 )
278 except Exception as e:
279 dt = time.perf_counter() - start
280 return EventBenchmarkResult(
281 integrator_name=name,
282 problem_name=problem_name,
283 t_event=None,
284 y_event=None,
285 t_error=np.inf,
286 y_error=None,
287 computation_time=dt,
288 converged=False,
289 error_message=str(e),
290 )
291
292
293def run_event_benchmark() -> List[EventBenchmarkResult]:
294 results: List[EventBenchmarkResult] = []
295
296 # Hiten integrators (adaptive and fixed-step)
297 hiten_integrators = {
298 "HITEN-RK45": RungeKutta(order=45, rtol=1e-8, atol=1e-10),
299 "HITEN-DOP853": RungeKutta(order=853, rtol=1e-8, atol=1e-10),
300 "HITEN-RK4-fixed": RungeKutta(order=4),
301 "HITEN-RK6-fixed": RungeKutta(order=6),
302 "HITEN-RK8-fixed": RungeKutta(order=8),
303 }
304
305 # SciPy integrators with event support
306 scipy_integrators = {
307 "SciPy-RK45": "RK45",
308 "SciPy-DOP853": "DOP853",
309 "SciPy-Radau": "Radau",
310 "SciPy-BDF": "BDF",
311 "SciPy-LSODA": "LSODA",
312 }
313
314 print("Event Integrator Benchmark (Hiten vs SciPy)")
315 print("=" * 60)
316
317 for prob in _problems():
318 print(f"\nProblem: {prob['name']}")
319 print("-" * 40)
320 system = create_rhs_system(prob["rhs"], dim=prob["dim"], name=prob["name"])
321 t_expected = float(prob["t_expected"](prob["y0"]))
322
323 # Hiten
324 for name, integrator in hiten_integrators.items():
325 res = run_hiten_event(
326 integrator=integrator,
327 system=system,
328 y0=prob["y0"],
329 t_span=prob["t_span"],
330 event_fn=prob["event"],
331 direction=prob["direction"],
332 t_expected=t_expected,
333 problem_name=prob["name"],
334 )
335 results.append(res)
336 if res.converged:
337 print(f"{str(integrator):15s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, time={res.computation_time:.4f}s")
338 else:
339 print(f"{str(integrator):15s}: FAILED - {res.error_message}")
340
341 # SciPy
342 for name, method in scipy_integrators.items():
343 res = run_scipy_event(
344 name=name,
345 method=method,
346 rhs=prob["rhs"],
347 y0=prob["y0"],
348 t_span=prob["t_span"],
349 event_fn=prob["event"],
350 direction=prob["direction"],
351 t_expected=t_expected,
352 problem_name=prob["name"],
353 )
354 results.append(res)
355 if res.converged:
356 print(f"{name:15s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, time={res.computation_time:.4f}s")
357 else:
358 print(f"{name:15s}: FAILED - {res.error_message}")
359
360 return results
361
362
363def print_summary_table(results: List[EventBenchmarkResult]) -> None:
364 print("\n" + "=" * 100)
365 print("SUMMARY TABLE (Events)")
366 print("=" * 100)
367 print(f"{'Integrator':<15} {'Problem':<30} {'t_hit':<18} {'|dt|':<12} {'Time (s)':<10}")
368 print("-" * 100)
369
370 groups: Dict[str, List[EventBenchmarkResult]] = {}
371 for r in results:
372 groups.setdefault(r.problem_name, []).append(r)
373
374 for problem_name in sorted(groups.keys()):
375 for r in sorted(groups[problem_name], key=lambda x: (not x.converged, x.t_error)):
376 t_hit_str = f"{r.t_event:.10f}" if r.t_event is not None else "N/A"
377 dt_str = f"{r.t_error:.2e}" if np.isfinite(r.t_error) else "inf"
378 print(f"{r.integrator_name:<15} {r.problem_name:<30} {t_hit_str:<18} {dt_str:<12} {r.computation_time:<10.4f}")
379
380
381def create_performance_plot(results: List[EventBenchmarkResult]) -> None:
382 os.makedirs('_debug/results/plots', exist_ok=True)
383
384 # Plot |dt| vs time, grouped by integrator names
385 names = list({r.integrator_name for r in results})
386 colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(names))))
387 cmap: Dict[str, np.ndarray] = dict(zip(names, colors))
388
389 fig, ax = plt.subplots(1, 1, figsize=(8, 6))
390 for name in names:
391 subset = [r for r in results if r.integrator_name == name and r.converged]
392 if not subset:
393 continue
394 times = [r.computation_time for r in subset]
395 terrs = [r.t_error for r in subset]
396 ax.loglog(times, terrs, 'o', label=name, color=cmap[name], markersize=8)
397
398 ax.set_xlabel('Computation Time (s)')
399 ax.set_ylabel('Absolute Event Time Error |dt|')
400 ax.set_title('Event Detection: Accuracy vs Speed')
401 ax.grid(True, which='both', alpha=0.3)
402 ax.legend()
403
404 plt.tight_layout()
405 out_path = '_debug/results/plots/events_performance_comparison.png'
406 plt.savefig(out_path, dpi=300, bbox_inches='tight')
407 plt.close()
408 print(f"Performance plot saved to {out_path}")
409
410
411def main() -> None:
412 results = run_event_benchmark()
413 print_summary_table(results)
414 create_performance_plot(results)
415 print("\nEvent benchmark completed.")
416
417
418if __name__ == "__main__":
419 main()
420
421
Symplectic Event Integrator Benchmark Example
Tests Hiten’s extended symplectic integrators with event detection enabled against SciPy solvers on simple problems with known event times. Reports detection accuracy and speed, and saves a comparison plot.
Testing Hiten’s extended symplectic integrators with event detection enabled against SciPy solvers on simple problems with known event times
1#!/usr/bin/env python3
2"""Symplectic event-handling benchmark and example.
3
4This script mirrors ``event_integrator_benchmark.py`` but focuses on
5Hiten's extended symplectic integrators with event detection enabled.
6It constructs a truncated pendulum Hamiltonian, integrates it with
7event-enabled symplectic schemes, compares the detected event times
8against SciPy solvers, and saves an accuracy vs speed plot.
9"""
10
11from __future__ import annotations
12
13import os
14import sys
15import time
16import warnings
17from dataclasses import dataclass
18from typing import Dict, List, Optional, Tuple
19
20import matplotlib.pyplot as plt
21import numpy as np
22from numba import njit, types
23from numba.typed import List as NumbaList
24from scipy.integrate import solve_ivp
25
26# Make project src importable when running from repository root
27sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
28
29from hiten.algorithms.dynamics.hamiltonian import create_hamiltonian_system
30from hiten.algorithms.integrators import ExtendedSymplectic
31from hiten.algorithms.integrators.configs import _EventConfig
32from hiten.algorithms.integrators.symplectic import (
33 N_SYMPLECTIC_DOF,
34 N_VARS_POLY,
35 P_POLY_INDICES,
36 Q_POLY_INDICES,
37)
38from hiten.algorithms.polynomial.base import (
39 _create_encode_dict_from_clmo,
40 _encode_multiindex,
41 _init_index_tables,
42)
43
44warnings.filterwarnings('ignore', category=UserWarning)
45
46
47@dataclass
48class EventResult:
49 solver_name: str
50 problem_name: str
51 t_event: Optional[float]
52 y_event: Optional[np.ndarray]
53 t_error: float
54 y_error: Optional[float]
55 computation_time: float
56 converged: bool = True
57 error_message: str = ""
58
59
60@dataclass
61class SymplecticEventProblem:
62 name: str
63 description: str
64 y0_extended: np.ndarray
65 y0_canonical: np.ndarray
66 t_span: Tuple[float, float]
67 grid_size: int
68 direction: int
69
70
71@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
72def event_q1_jit(t: float, y: np.ndarray) -> float:
73 """Event function detecting q1 = 0 crossings."""
74 return y[0]
75
76
77def truncated_pendulum_rhs(_t: float, y: np.ndarray) -> List[float]:
78 """Return dynamics for the 2D truncated pendulum system."""
79 q, p = y
80 sin_taylor = q - (q ** 3) / 6.0 + (q ** 5) / 120.0
81 return [p, -sin_taylor]
82
83
84def embed_state(qp_state: np.ndarray) -> np.ndarray:
85 """Embed a 2D (q, p) state into the 6D symplectic phase space."""
86 state = np.zeros(2 * N_SYMPLECTIC_DOF, dtype=np.float64)
87 state[0] = qp_state[0]
88 state[N_SYMPLECTIC_DOF] = qp_state[1]
89 return state
90
91
92def build_truncated_pendulum_system(max_degree: int = 6):
93 """Construct a polynomial Hamiltonian system for a truncated pendulum."""
94 psi_tables, clmo_tables = _init_index_tables(max_degree)
95 encode_dict_list = _create_encode_dict_from_clmo(clmo_tables)
96
97 H_blocks = [
98 np.zeros(psi_tables[N_VARS_POLY, deg], dtype=np.complex128)
99 for deg in range(max_degree + 1)
100 ]
101
102 idx_p1 = P_POLY_INDICES[0]
103 idx_q1 = Q_POLY_INDICES[0]
104 mono = np.zeros(N_VARS_POLY, dtype=np.int64)
105
106 # Kinetic term: p1^2 / 2 (degree 2)
107 mono[:] = 0
108 mono[idx_p1] = 2
109 encoded = _encode_multiindex(mono, 2, encode_dict_list)
110 if encoded != -1:
111 H_blocks[2][encoded] = 0.5
112
113 # Potential offset: -1 (degree 0)
114 mono[:] = 0
115 encoded = _encode_multiindex(mono, 0, encode_dict_list)
116 if encoded != -1:
117 H_blocks[0][encoded] = -1.0
118
119 # Quadratic potential: +q1^2 / 2 (degree 2)
120 mono[:] = 0
121 mono[idx_q1] = 2
122 encoded = _encode_multiindex(mono, 2, encode_dict_list)
123 if encoded != -1:
124 H_blocks[2][encoded] += 0.5
125
126 # Quartic correction: -q1^4 / 24 (degree 4)
127 if max_degree >= 4:
128 mono[:] = 0
129 mono[idx_q1] = 4
130 encoded = _encode_multiindex(mono, 4, encode_dict_list)
131 if encoded != -1:
132 H_blocks[4][encoded] = -1.0 / 24.0
133
134 # Sextic correction: +q1^6 / 720 (degree 6)
135 if max_degree >= 6:
136 mono[:] = 0
137 mono[idx_q1] = 6
138 encoded = _encode_multiindex(mono, 6, encode_dict_list)
139 if encoded != -1:
140 H_blocks[6][encoded] = 1.0 / 720.0
141
142 H_blocks_typed = NumbaList()
143 for arr in H_blocks:
144 H_blocks_typed.append(arr.copy())
145
146 system = create_hamiltonian_system(
147 H_blocks=H_blocks_typed,
148 degree=max_degree,
149 psi_table=psi_tables,
150 clmo_table=clmo_tables,
151 encode_dict_list=encode_dict_list,
152 n_dof=N_SYMPLECTIC_DOF,
153 name="Truncated Pendulum Hamiltonian",
154 )
155
156 return system
157
158
159def make_event_problems() -> List[SymplecticEventProblem]:
160 """Configure example problems for event detection."""
161 angle = np.deg2rad(45.0)
162 base_state = np.zeros(2 * N_SYMPLECTIC_DOF, dtype=np.float64)
163 base_state_neg = base_state.copy()
164
165 # Positive release, expect crossing with negative direction
166 base_state[0] = angle
167 problem_pos = SymplecticEventProblem(
168 name="Pendulum release (+45 deg)",
169 description="Pendulum released from +45 degrees, expect q1 -> 0 with negative crossing.",
170 y0_extended=base_state.copy(),
171 y0_canonical=np.array([angle, 0.0], dtype=np.float64),
172 t_span=(0.0, 5.0),
173 grid_size=4097,
174 direction=-1,
175 )
176
177 # Negative release, expect crossing with positive direction
178 base_state_neg[0] = -angle
179 problem_neg = SymplecticEventProblem(
180 name="Pendulum release (-45 deg)",
181 description="Pendulum released from -45 degrees, expect q1 -> 0 with positive crossing.",
182 y0_extended=base_state_neg.copy(),
183 y0_canonical=np.array([-angle, 0.0], dtype=np.float64),
184 t_span=(0.0, 5.0),
185 grid_size=4097,
186 direction=+1,
187 )
188
189 return [problem_pos, problem_neg]
190
191
192def compute_reference_event(problem: SymplecticEventProblem) -> Tuple[float, np.ndarray]:
193 """High-accuracy reference using SciPy Radau."""
194 event = lambda t, y: y[0]
195 event.terminal = True
196 event.direction = float(problem.direction)
197
198 sol = solve_ivp(
199 truncated_pendulum_rhs,
200 problem.t_span,
201 problem.y0_canonical,
202 method="Radau",
203 events=event,
204 rtol=1e-12,
205 atol=1e-14,
206 )
207
208 if sol.status == 1 and sol.t_events and len(sol.t_events[0]) > 0:
209 t_event = float(sol.t_events[0][0])
210 y_event = sol.y_events[0][0].copy()
211 return t_event, y_event
212
213 raise RuntimeError(f"Reference solver failed to detect event for problem '{problem.name}'")
214
215
216def _warmup_symplectic_event(integrator, system, problem: SymplecticEventProblem) -> None:
217 """Trigger compilation overhead outside timed region."""
218 try:
219 warmup_grid = np.linspace(
220 problem.t_span[0],
221 problem.t_span[0] + 1.0e-2,
222 8,
223 dtype=np.float64,
224 )
225 integrator.integrate(
226 system,
227 problem.y0_extended,
228 warmup_grid,
229 event_fn=event_q1_jit,
230 event_cfg=_EventConfig(direction=problem.direction, terminal=True),
231 )
232 except Exception:
233 pass
234
235
236def _warmup_scipy_event(method: str, problem: SymplecticEventProblem) -> None:
237 """Trigger SciPy solver overhead outside timed region."""
238 try:
239 event = lambda t, y: y[0]
240 event.terminal = True
241 event.direction = float(problem.direction)
242 warmup_t_span = (problem.t_span[0], problem.t_span[0] + 1.0e-2)
243 solve_ivp(
244 truncated_pendulum_rhs,
245 warmup_t_span,
246 problem.y0_canonical,
247 method=method,
248 events=event,
249 rtol=1e-6,
250 atol=1e-8,
251 )
252 except Exception:
253 pass
254
255
256def run_symplectic_event(
257 solver_name: str,
258 integrator,
259 system,
260 problem: SymplecticEventProblem,
261 t_expected: float,
262 y_expected: np.ndarray,
263) -> EventResult:
264 """Execute symplectic integration with event detection."""
265 time_grid = np.linspace(
266 problem.t_span[0], problem.t_span[1], problem.grid_size, dtype=np.float64
267 )
268 event_cfg = _EventConfig(direction=problem.direction, terminal=True, xtol=1.0e-12, gtol=1.0e-12)
269
270 _warmup_symplectic_event(integrator, system, problem)
271
272 start = time.perf_counter()
273 try:
274 sol = integrator.integrate(
275 system,
276 problem.y0_extended,
277 time_grid,
278 event_fn=event_q1_jit,
279 event_cfg=event_cfg,
280 )
281 elapsed = time.perf_counter() - start
282
283 t_event = float(sol.times[-1])
284 y_event = sol.states[-1].copy()
285 t_error = abs(t_event - t_expected)
286 y_error = float(np.linalg.norm(y_event - y_expected)) if y_expected is not None else None
287
288 return EventResult(
289 solver_name=solver_name,
290 problem_name=problem.name,
291 t_event=t_event,
292 y_event=y_event,
293 t_error=t_error,
294 y_error=y_error,
295 computation_time=elapsed,
296 converged=True,
297 )
298 except Exception as exc:
299 elapsed = time.perf_counter() - start
300 return EventResult(
301 solver_name=solver_name,
302 problem_name=problem.name,
303 t_event=None,
304 y_event=None,
305 t_error=np.inf,
306 y_error=None,
307 computation_time=elapsed,
308 converged=False,
309 error_message=str(exc),
310 )
311
312
313def run_scipy_event(
314 solver_name: str,
315 method: str,
316 solver_opts: Dict,
317 problem: SymplecticEventProblem,
318 t_expected: float,
319 y_expected: np.ndarray,
320) -> EventResult:
321 """Run SciPy solve_ivp with event detection on the 2D system."""
322 event = lambda t, y: y[0]
323 event.terminal = True
324 event.direction = float(problem.direction)
325
326 _warmup_scipy_event(method, problem)
327
328 start = time.perf_counter()
329 try:
330 sol = solve_ivp(
331 truncated_pendulum_rhs,
332 problem.t_span,
333 problem.y0_canonical,
334 method=method,
335 events=event,
336 **solver_opts,
337 )
338 elapsed = time.perf_counter() - start
339
340 if sol.status == 1 and sol.t_events and len(sol.t_events[0]) > 0:
341 t_event = float(sol.t_events[0][0])
342 qp_event = sol.y_events[0][0].copy()
343 y_event = embed_state(qp_event)
344 t_error = abs(t_event - t_expected)
345 y_error = float(np.linalg.norm(y_event - y_expected)) if y_expected is not None else None
346 return EventResult(
347 solver_name=solver_name,
348 problem_name=problem.name,
349 t_event=t_event,
350 y_event=y_event,
351 t_error=t_error,
352 y_error=y_error,
353 computation_time=elapsed,
354 converged=True,
355 )
356
357 message = sol.message if hasattr(sol, 'message') else "event not detected"
358 return EventResult(
359 solver_name=solver_name,
360 problem_name=problem.name,
361 t_event=None,
362 y_event=None,
363 t_error=np.inf,
364 y_error=None,
365 computation_time=elapsed,
366 converged=False,
367 error_message=message,
368 )
369 except Exception as exc:
370 elapsed = time.perf_counter() - start
371 return EventResult(
372 solver_name=solver_name,
373 problem_name=problem.name,
374 t_event=None,
375 y_event=None,
376 t_error=np.inf,
377 y_error=None,
378 computation_time=elapsed,
379 converged=False,
380 error_message=str(exc),
381 )
382
383
384def run_benchmark() -> List[EventResult]:
385 """Execute the symplectic event benchmark."""
386 system = build_truncated_pendulum_system()
387 problems = make_event_problems()
388
389 symplectic_schemes = {
390 "Symplectic4": ExtendedSymplectic(order=4, c_omega_heuristic=15.0),
391 "Symplectic6": ExtendedSymplectic(order=6, c_omega_heuristic=20.0),
392 "Symplectic8": ExtendedSymplectic(order=8, c_omega_heuristic=25.0),
393 }
394
395 scipy_solvers = {
396 "SciPy-RK45": ("RK45", {"rtol": 1.0e-8, "atol": 1.0e-10}),
397 "SciPy-DOP853": ("DOP853", {"rtol": 1.0e-9, "atol": 1.0e-11}),
398 "SciPy-BDF": ("BDF", {"rtol": 1.0e-8, "atol": 1.0e-10}),
399 }
400
401 results: List[EventResult] = []
402
403 print("Symplectic Event Integrator Benchmark")
404 print("=" * 70)
405
406 for problem in problems:
407 print(f"\nProblem: {problem.name}")
408 print(problem.description)
409 print("-" * 70)
410
411 t_ref, y_ref_qp = compute_reference_event(problem)
412 y_ref_ext = embed_state(y_ref_qp)
413
414 print(f"Reference event time (Radau): {t_ref:.10f} s")
415
416 reference_result = EventResult(
417 solver_name="SciPy-Radau (reference)",
418 problem_name=problem.name,
419 t_event=t_ref,
420 y_event=y_ref_ext,
421 t_error=0.0,
422 y_error=0.0,
423 computation_time=0.0,
424 converged=True,
425 )
426 results.append(reference_result)
427
428 for solver_name, integrator in symplectic_schemes.items():
429 res = run_symplectic_event(
430 solver_name,
431 integrator,
432 system,
433 problem,
434 t_ref,
435 y_ref_ext,
436 )
437 results.append(res)
438 if res.converged:
439 print(
440 f"{solver_name:20s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, "
441 f"time={res.computation_time:.4f}s"
442 )
443 else:
444 print(f"{solver_name:20s}: FAILED - {res.error_message}")
445
446 for solver_name, (method, opts) in scipy_solvers.items():
447 res = run_scipy_event(
448 solver_name,
449 method,
450 opts,
451 problem,
452 t_ref,
453 y_ref_ext,
454 )
455 results.append(res)
456 if res.converged:
457 print(
458 f"{solver_name:20s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, "
459 f"time={res.computation_time:.4f}s"
460 )
461 else:
462 print(f"{solver_name:20s}: FAILED - {res.error_message}")
463
464 return results
465
466
467def print_summary_table(results: List[EventResult]) -> None:
468 """Print formatted summary of all runs."""
469 print("\n" + "=" * 110)
470 print("SUMMARY TABLE (Symplectic Events)")
471 print("=" * 110)
472 print(f"{'Integrator':<22} {'Problem':<32} {'t_hit':<18} {'|dt|':<12} {'|dy|':<12} {'Time (s)':<10}")
473 print("-" * 110)
474
475 grouped: Dict[str, List[EventResult]] = {}
476 for res in results:
477 grouped.setdefault(res.problem_name, []).append(res)
478
479 for problem_name in sorted(grouped):
480 for res in sorted(grouped[problem_name], key=lambda r: (not r.converged, r.t_error)):
481 if res.converged:
482 t_hit_str = f"{res.t_event:.10f}" if res.t_event is not None else "N/A"
483 dt_str = f"{res.t_error:.2e}" if np.isfinite(res.t_error) else "inf"
484 dy_str = (
485 f"{res.y_error:.2e}"
486 if res.y_error is not None and np.isfinite(res.y_error)
487 else "n/a"
488 )
489 print(
490 f"{res.solver_name:<22} {res.problem_name:<32} {t_hit_str:<18} "
491 f"{dt_str:<12} {dy_str:<12} {res.computation_time:<10.4f}"
492 )
493 else:
494 print(
495 f"{res.solver_name:<22} {res.problem_name:<32} {'FAILED':<18} "
496 f"{'inf':<12} {'n/a':<12} {res.computation_time:<10.4f}"
497 )
498
499
500def create_performance_plot(results: List[EventResult]) -> None:
501 """Save log-log plot of |dt| vs computation time."""
502 os.makedirs(os.path.join('_debug', 'results', 'plots'), exist_ok=True)
503
504 names = sorted({res.solver_name for res in results if res.converged and res.t_error >= 0.0})
505 if not names:
506 print("No successful runs to plot.")
507 return
508
509 colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(names))))
510 cmap = dict(zip(names, colors))
511
512 fig, ax = plt.subplots(1, 1, figsize=(8, 6))
513 for name in names:
514 subset = [r for r in results if r.solver_name == name and r.converged and np.isfinite(r.t_error)]
515 if not subset:
516 continue
517 times = [r.computation_time for r in subset]
518 terrs = [r.t_error for r in subset]
519 ax.loglog(times, terrs, 'o', label=name, color=cmap[name], markersize=8)
520
521 ax.set_xlabel('Computation Time (s)')
522 ax.set_ylabel('Absolute Event Time Error |dt|')
523 ax.set_title('Symplectic Event Detection: Accuracy vs Speed')
524 ax.grid(True, which='both', alpha=0.3)
525 ax.legend()
526
527 plt.tight_layout()
528 out_path = os.path.join('_debug', 'results', 'plots', 'symplectic_events_performance.png')
529 plt.savefig(out_path, dpi=300, bbox_inches='tight')
530 plt.close()
531 print(f"Performance plot saved to {out_path}")
532
533
534def main() -> None:
535 results = run_benchmark()
536 print_summary_table(results)
537 create_performance_plot(results)
538 print("\nSymplectic event benchmark completed.")
539
540
541if __name__ == "__main__":
542 main()
543