Integrator Benchmark Example

Test Hiten integrators against SciPy on various ODE problems and provides

Testing Hiten integrators against SciPy on various ODE problems
  1#!/usr/bin/env python3
  2"""Integrator benchmark script.
  3
  4Tests Hiten integrators against SciPy on various ODE problems and provides
  5detailed accuracy, performance, and error analysis.
  6"""
  7
  8import os
  9import sys
 10import time
 11import warnings
 12from dataclasses import dataclass
 13from typing import List
 14
 15import matplotlib.pyplot as plt
 16import numpy as np
 17from scipy.integrate import solve_ivp
 18
 19sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', "src"))
 20
 21from hiten.algorithms.dynamics.rhs import create_rhs_system
 22from hiten.algorithms.integrators import (AdaptiveRK, RungeKutta)
 23
 24warnings.filterwarnings('ignore', category=UserWarning)
 25
 26
 27@dataclass
 28class TestResult:
 29    """Store results from a single integrator test."""
 30    integrator_name: str
 31    problem_name: str
 32    final_error: float
 33    max_error: float
 34    relative_error: float
 35    computation_time: float
 36    n_steps: int
 37    energy_error: float = None
 38    converged: bool = True
 39    error_message: str = ""
 40
 41
 42# Test problems as simple functions
 43def harmonic_oscillator(t, y):
 44    """Harmonic oscillator: x'' + x = 0."""
 45    return np.array([y[1], -y[0]])
 46
 47def van_der_pol(t, y):
 48    """Van der Pol oscillator: x'' - μ(1-x²)x' + x = 0."""
 49    mu = 1.0
 50    return np.array([y[1], mu * (1 - y[0]**2) * y[1] - y[0]])
 51
 52def lorenz(t, y):
 53    """Lorenz system: chaotic attractor."""
 54    sigma, rho, beta = 10.0, 28.0, 8.0/3.0
 55    return np.array([
 56        sigma * (y[1] - y[0]),
 57        y[0] * (rho - y[2]) - y[1],
 58        y[0] * y[1] - beta * y[2]
 59    ])
 60
 61def kepler(t, y):
 62    """Kepler problem: 2D central force motion."""
 63    x, y_pos, vx, vy = y
 64    mu = 1.0
 65    r = np.sqrt(x**2 + y_pos**2)
 66    r3 = r**3
 67    return np.array([vx, vy, -mu * x / r3, -mu * y_pos / r3])
 68
 69def duffing(t, y):
 70    """Duffing oscillator: x'' + δx' + αx + βx³ = γcos(ωt)."""
 71    delta, alpha, beta = 0.1, -1.0, 1.0
 72    gamma, omega = 0.3, 1.2
 73    return np.array([
 74        y[1],
 75        -delta * y[1] - alpha * y[0] - beta * y[0]**3 + 
 76        gamma * np.cos(omega * t)
 77    ])
 78
 79# Test problem configurations
 80TEST_PROBLEMS = [
 81    {
 82        "name": "Harmonic Oscillator",
 83        "rhs": harmonic_oscillator,
 84        "y0": np.array([1.0, 0.0]),
 85        "t_span": (0.0, 4*np.pi),
 86        "exact": lambda t: np.column_stack([np.cos(t), -np.sin(t)]),
 87        "energy": lambda y: 0.5 * (y[0]**2 + y[1]**2)
 88    },
 89    {
 90        "name": "Van der Pol",
 91        "rhs": van_der_pol,
 92        "y0": np.array([2.0, 0.0]),
 93        "t_span": (0.0, 20.0),
 94        "exact": None,
 95        "energy": None
 96    },
 97    {
 98        "name": "Lorenz",
 99        "rhs": lorenz,
100        "y0": np.array([1.0, 1.0, 1.0]),
101        "t_span": (0.0, 20.0),
102        "exact": None,
103        "energy": None
104    },
105    {
106        "name": "Kepler",
107        "rhs": kepler,
108        "y0": np.array([1.0, 0.0, 0.0, 0.8]),
109        "t_span": (0.0, 10.0),
110        "exact": None,
111        "energy": lambda y: 0.5 * (y[2]**2 + y[3]**2) - 1.0 / np.sqrt(y[0]**2 + y[1]**2)
112    },
113    {
114        "name": "Duffing",
115        "rhs": duffing,
116        "y0": np.array([1.0, 0.0]),
117        "t_span": (0.0, 20.0),
118        "exact": None,
119        "energy": None
120    }
121]
122
123
124def run_integrator_test(integrator, system, problem_config, t_eval: np.ndarray) -> TestResult:
125    """Run a single integrator test and return results."""
126    
127    # Warm-up JIT compilation (excluded from timing)
128    try:
129        if t_eval.size >= 2:
130            dt = t_eval[1] - t_eval[0]
131            warmup_t = np.array([t_eval[0], t_eval[0] + max(dt, 1e-8)], dtype=float)
132        else:
133            warmup_t = np.array([t_eval[0], t_eval[0] + 1e-8], dtype=float)
134        _ = integrator.integrate(system, problem_config["y0"], warmup_t)
135    except Exception:
136        pass
137
138    start_time = time.perf_counter()
139    
140    try:
141        solution = integrator.integrate(system, problem_config["y0"], t_eval)
142        computation_time = time.perf_counter() - start_time
143        
144        y_solution = solution.states
145        
146        # Calculate errors
147        if problem_config["exact"] is not None:
148            y_exact = problem_config["exact"](t_eval)
149            errors = np.abs(y_solution - y_exact)
150            final_error = np.linalg.norm(errors[-1])
151            max_error = np.max(errors)
152            relative_error = max_error / (np.max(np.abs(y_exact)) + 1e-16)
153        else:
154            # Use SciPy as reference
155            scipy_sol = solve_ivp(
156                system.rhs, problem_config["t_span"], problem_config["y0"], 
157                t_eval=t_eval, rtol=1e-12, atol=1e-14
158            )
159            y_ref = scipy_sol.y.T
160            errors = np.abs(y_solution - y_ref)
161            final_error = np.linalg.norm(errors[-1])
162            max_error = np.max(errors)
163            relative_error = max_error / (np.max(np.abs(y_ref)) + 1e-16)
164        
165        # Calculate energy error if applicable
166        energy_error = None
167        if problem_config["energy"] is not None:
168            initial_energy = problem_config["energy"](problem_config["y0"])
169            final_energy = problem_config["energy"](y_solution[-1])
170            energy_error = abs(final_energy - initial_energy) / abs(initial_energy)
171        
172        # Count steps (approximate for fixed-step methods)
173        n_steps = len(t_eval) - 1
174        
175        return TestResult(
176            integrator_name=str(integrator),
177            problem_name=problem_config["name"],
178            final_error=final_error,
179            max_error=max_error,
180            relative_error=relative_error,
181            computation_time=computation_time,
182            n_steps=n_steps,
183            energy_error=energy_error,
184            converged=True
185        )
186        
187    except Exception as e:
188        computation_time = time.perf_counter() - start_time
189        return TestResult(
190            integrator_name=str(integrator),
191            problem_name=problem_config["name"],
192            final_error=np.inf,
193            max_error=np.inf,
194            relative_error=np.inf,
195            computation_time=computation_time,
196            n_steps=0,
197            converged=False,
198            error_message=str(e)
199        )
200
201
202def run_benchmark():
203    """Run comprehensive tests on all integrators and problems."""
204    
205    # Define integrators to test
206    integrators = {
207        # Fixed-step RK methods
208        "RK4": RungeKutta(order=4),
209        "RK6": RungeKutta(order=6),
210        "RK8": RungeKutta(order=8),
211        
212        # Adaptive RK methods
213        "RK45": AdaptiveRK(order=5, rtol=1e-8, atol=1e-10),
214        "DOP853": AdaptiveRK(order=8, rtol=1e-8, atol=1e-10),
215    }
216    
217    # SciPy reference
218    scipy_integrators = {
219        "SciPy-RK45": "RK45",
220        "SciPy-DOP853": "DOP853",
221        "SciPy-RK23": "RK23",
222        "SciPy-BDF": "BDF",
223        "SciPy-LSODA": "LSODA"
224    }
225    
226    # Time grid for evaluation
227    t_eval = np.linspace(0, 10, 1001)
228    
229    all_results = []
230    
231    print("Running comprehensive integrator tests...")
232    print("=" * 60)
233    
234    for problem_config in TEST_PROBLEMS:
235        print(f"\nTesting problem: {problem_config['name']}")
236        print("-" * 40)
237        # Build a single dynamical system instance per problem for all tests
238        system = create_rhs_system(problem_config["rhs"], dim=len(problem_config["y0"]), name=problem_config["name"])
239        
240        # Test Hiten integrators
241        for name, integrator in integrators.items():
242            result = run_integrator_test(integrator, system, problem_config, t_eval)
243            all_results.append(result)
244            
245            if result.converged:
246                print(f"{name:15s}: "
247                      f"Final Error: {result.final_error:.2e}, "
248                      f"Max Error: {result.max_error:.2e}, "
249                      f"Time: {result.computation_time:.4f}s")
250                
251                if result.energy_error is not None:
252                    print(f"{'':15s} Energy Error: {result.energy_error:.2e}")
253            else:
254                print(f"{name:15s}: FAILED - {result.error_message}")
255        
256        # Test SciPy integrators
257        for name, method in scipy_integrators.items():
258            start_time = time.perf_counter()
259            
260            try:
261                scipy_sol = solve_ivp(
262                    system.rhs, problem_config["t_span"], problem_config["y0"],
263                    t_eval=t_eval, method=method, rtol=1e-8, atol=1e-10
264                )
265                computation_time = time.perf_counter() - start_time
266                
267                if scipy_sol.success:
268                    y_solution = scipy_sol.y.T
269                    
270                    # Calculate errors
271                    if problem_config["exact"] is not None:
272                        y_exact = problem_config["exact"](t_eval)
273                        errors = np.abs(y_solution - y_exact)
274                        final_error = np.linalg.norm(errors[-1])
275                        max_error = np.max(errors)
276                        relative_error = max_error / (np.max(np.abs(y_exact)) + 1e-16)
277                    else:
278                        # Use highest accuracy SciPy as reference
279                        ref_sol = solve_ivp(
280                            system.rhs, problem_config["t_span"], problem_config["y0"],
281                            t_eval=t_eval, method="DOP853", rtol=1e-12, atol=1e-14
282                        )
283                        y_ref = ref_sol.y.T
284                        errors = np.abs(y_solution - y_ref)
285                        final_error = np.linalg.norm(errors[-1])
286                        max_error = np.max(errors)
287                        relative_error = max_error / (np.max(np.abs(y_ref)) + 1e-16)
288                    
289                    # Energy error
290                    energy_error = None
291                    if problem_config["energy"] is not None:
292                        initial_energy = problem_config["energy"](problem_config["y0"])
293                        final_energy = problem_config["energy"](y_solution[-1])
294                        energy_error = abs(final_energy - initial_energy) / abs(initial_energy)
295                    
296                    result = TestResult(
297                        integrator_name=name,
298                        problem_name=problem_config["name"],
299                        final_error=final_error,
300                        max_error=max_error,
301                        relative_error=relative_error,
302                        computation_time=computation_time,
303                        n_steps=len(scipy_sol.t) - 1,
304                        energy_error=energy_error,
305                        converged=True
306                    )
307                    all_results.append(result)
308                    
309                    print(f"{name:15s}: "
310                          f"Final Error: {result.final_error:.2e}, "
311                          f"Max Error: {result.max_error:.2e}, "
312                          f"Time: {result.computation_time:.4f}s")
313                    
314                    if result.energy_error is not None:
315                        print(f"{'':15s} Energy Error: {result.energy_error:.2e}")
316                else:
317                    print(f"{name:15s}: FAILED - {scipy_sol.message}")
318                    
319            except Exception as e:
320                print(f"{name:15s}: ERROR - {e}")
321    
322    return all_results
323
324
325def print_summary_table(results: List[TestResult]):
326    """Print a summary table of all results."""
327    
328    print("\n" + "=" * 100)
329    print("SUMMARY TABLE")
330    print("=" * 100)
331    print(f"{'Integrator':<15} {'Problem':<20} {'Final Error':<12} {'Max Error':<12} {'Time (s)':<10} {'Energy Error':<12}")
332    print("-" * 100)
333    
334    # Group by problem for better readability
335    problems = {}
336    for result in results:
337        if result.problem_name not in problems:
338            problems[result.problem_name] = []
339        problems[result.problem_name].append(result)
340    
341    for problem_name in sorted(problems.keys()):
342        problem_results = problems[problem_name]
343        for result in sorted(problem_results, key=lambda x: x.final_error):
344            energy_str = f"{result.energy_error:.2e}" if result.energy_error is not None else "N/A"
345            print(f"{result.integrator_name:<15} {result.problem_name:<20} "
346                  f"{result.final_error:<12.2e} {result.max_error:<12.2e} "
347                  f"{result.computation_time:<10.4f} {energy_str:<12}")
348
349
350def create_performance_plots(results: List[TestResult]):
351    """Create performance comparison plots."""
352    
353    # Create results directory
354    os.makedirs('_debug/results/plots', exist_ok=True)
355    
356    # Extract data for plotting
357    integrators = list(set(r.integrator_name for r in results if r.converged))
358    problems = list(set(r.problem_name for r in results if r.converged))
359    
360    # Create accuracy vs speed plot
361    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
362    
363    colors = plt.cm.tab10(np.linspace(0, 1, len(integrators)))
364    integrator_colors = dict(zip(integrators, colors))
365    
366    # Plot 1: Final Error vs Time
367    for integrator in integrators:
368        integrator_results = [r for r in results if r.integrator_name == integrator and r.converged]
369        if integrator_results:
370            times = [r.computation_time for r in integrator_results]
371            errors = [r.final_error for r in integrator_results]
372            ax1.loglog(times, errors, 'o', color=integrator_colors[integrator], 
373                      label=integrator, markersize=8)
374    
375    ax1.set_xlabel('Computation Time (s)')
376    ax1.set_ylabel('Final Error')
377    ax1.set_title('Accuracy vs Speed')
378    ax1.legend()
379    ax1.grid(True, alpha=0.3)
380    
381    # Plot 2: Max Error vs Time
382    for integrator in integrators:
383        integrator_results = [r for r in results if r.integrator_name == integrator and r.converged]
384        if integrator_results:
385            times = [r.computation_time for r in integrator_results]
386            errors = [r.max_error for r in integrator_results]
387            ax2.loglog(times, errors, 'o', color=integrator_colors[integrator], 
388                      label=integrator, markersize=8)
389    
390    ax2.set_xlabel('Computation Time (s)')
391    ax2.set_ylabel('Max Error')
392    ax2.set_title('Max Error vs Speed')
393    ax2.legend()
394    ax2.grid(True, alpha=0.3)
395    
396    plt.tight_layout()
397    plt.savefig('_debug/results/plots/performance_comparison.png', dpi=300, bbox_inches='tight')
398    plt.close()
399    
400    print("Performance plots saved to _debug/results/plots/")
401
402
403def main():
404    """Main function to run all tests."""
405    
406    print("Hiten Integrators vs SciPy Comprehensive Benchmark")
407    print("=" * 60)
408    
409    # Run comprehensive tests
410    results = run_benchmark()
411    
412    # Print summary
413    print_summary_table(results)
414    
415    # Create plots
416    create_performance_plots(results)
417    
418    print("\nBenchmark completed successfully!")
419
420
421if __name__ == "__main__":
422    main()

Event Integrator Benchmark Example

Tests Hiten’s event-capable integrator (DOP853) against SciPy solvers on simple problems with known event times. Reports detection accuracy and speed, and saves a comparison plot.

Testing Hiten’s event-capable integrator (DOP853) against SciPy solvers on simple problems with known event times
  1"""Event-handling integrator benchmark.
  2
  3Compares Hiten's event-capable integrator (DOP853) against SciPy solvers on
  4simple problems with known event times. Reports detection accuracy and speed,
  5and saves a comparison plot.
  6"""
  7
  8import os
  9import sys
 10import time
 11import warnings
 12from dataclasses import dataclass
 13from typing import Callable, Dict, List, Optional, Tuple
 14
 15import matplotlib.pyplot as plt
 16import numpy as np
 17from numba import njit, types
 18from scipy.integrate import solve_ivp
 19
 20# Make project src importable when running from repository root
 21sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
 22
 23from hiten.algorithms.dynamics.rhs import create_rhs_system
 24from hiten.algorithms.integrators import RungeKutta
 25from hiten.algorithms.integrators.configs import _EventConfig
 26
 27warnings.filterwarnings('ignore', category=UserWarning)
 28
 29
 30@dataclass
 31class EventBenchmarkResult:
 32    integrator_name: str
 33    problem_name: str
 34    t_event: Optional[float]
 35    y_event: Optional[np.ndarray]
 36    t_error: float
 37    y_error: Optional[float]
 38    computation_time: float
 39    converged: bool = True
 40    error_message: str = ""
 41
 42
 43# --- Precompiled global RHS and event functions (stable identities) ---
 44@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
 45def g_y1_jit(t: float, y: np.ndarray) -> float:
 46    return y[0] - 1.0
 47
 48
 49@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
 50def g_x0_jit(t: float, y: np.ndarray) -> float:
 51    return y[0]
 52
 53
 54@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
 55def rhs_inc_jit(t: float, y: np.ndarray) -> np.ndarray:
 56    out = np.empty(1, dtype=np.float64)
 57    out[0] = 1.0
 58    return out
 59
 60
 61@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
 62def rhs_dec_jit(t: float, y: np.ndarray) -> np.ndarray:
 63    out = np.empty(1, dtype=np.float64)
 64    out[0] = -1.0
 65    return out
 66
 67
 68@njit(types.float64[:](types.float64, types.float64[:]), cache=True, fastmath=True)
 69def rhs_ho_jit(t: float, y: np.ndarray) -> np.ndarray:
 70    out = np.empty(2, dtype=np.float64)
 71    out[0] = y[1]
 72    out[1] = -y[0]
 73    return out
 74
 75
 76# --- Event test problems using precompiled functions ---
 77def _problems() -> List[Dict]:
 78    problems: List[Dict] = []
 79
 80    problems.append({
 81        "name": "y' = 1, hit y=1 (inc)",
 82        "rhs": rhs_inc_jit,
 83        "y0": np.array([0.2], dtype=float),
 84        "t_span": (0.0, 5.0),
 85        "event": g_y1_jit,
 86        "direction": +1,
 87        "t_expected": lambda y0: 1.0 - float(y0[0]),
 88        "dim": 1,
 89    })
 90
 91    problems.append({
 92        "name": "y' = -1, hit y=1 (dec)",
 93        "rhs": rhs_dec_jit,
 94        "y0": np.array([2.0], dtype=float),
 95        "t_span": (0.0, 5.0),
 96        "event": g_y1_jit,
 97        "direction": -1,
 98        "t_expected": lambda y0: float(y0[0]) - 1.0,
 99        "dim": 1,
100    })
101
102    problems.append({
103        "name": "Harmonic oscillator, hit x=0",
104        "rhs": rhs_ho_jit,
105        "y0": np.array([1.0, 0.0], dtype=float),
106        "t_span": (0.0, 5.0),
107        "event": g_x0_jit,
108        "direction": -1,
109        "t_expected": lambda y0: 0.5 * np.pi,
110        "dim": 2,
111    })
112
113    return problems
114
115
116# --- Warm-up helpers (avoid JIT bias in timing) ---
117def _make_warmup_times(t0: float) -> np.ndarray:
118    return np.array([t0, t0 + 1.0e1], dtype=float)
119
120
121def _is_fixed_step(integrator) -> bool:
122    try:
123        return hasattr(integrator, "_integrate_fixed_rk")
124    except Exception:
125        return False
126
127
128def _warmup_hiten_event(integrator, system, y0: np.ndarray, event_fn: Callable[[float, np.ndarray], float], direction: int) -> None:
129    try:
130        warmup_t = _make_warmup_times(0.0)
131        if _is_fixed_step(integrator):
132            # Use a small grid to trigger step-based event path and refinement
133            t_grid = np.linspace(warmup_t[0], warmup_t[-1], 64)
134            _ = integrator.integrate(
135                system,
136                y0,
137                t_grid,
138                event_fn=event_fn,
139                event_cfg=_EventConfig(direction=direction, terminal=True),
140            )
141        else:
142            _ = integrator.integrate(
143                system,
144                y0,
145                warmup_t,
146                event_fn=event_fn,
147                event_cfg=_EventConfig(direction=direction, terminal=True),
148            )
149    except Exception:
150        pass
151
152
153def _warmup_scipy_event(method: str, rhs: Callable[[float, np.ndarray], np.ndarray], y0: np.ndarray, t0: float, event_fn: Callable[[float, np.ndarray], float], direction: int) -> None:
154    try:
155        def ev(t, y):
156            return event_fn(t, y)
157        ev.terminal = True
158        ev.direction = float(direction)
159        warmup_t = _make_warmup_times(t0)
160        _ = solve_ivp(rhs, (warmup_t[0], warmup_t[-1]), y0, method=method, events=ev, rtol=1e-6, atol=1e-8)
161    except Exception:
162        pass
163
164
165def run_hiten_event(
166    integrator,
167    system,
168    y0: np.ndarray,
169    t_span: Tuple[float, float],
170    event_fn: Callable[[float, np.ndarray], float],
171    direction: int,
172    t_expected: float,
173    problem_name: str,
174) -> EventBenchmarkResult:
175    _warmup_hiten_event(integrator, system, y0, event_fn, direction)
176    start = time.perf_counter()
177    try:
178        # Adaptive integrators only need endpoints; fixed-step benefits from a grid
179        if _is_fixed_step(integrator):
180            # Use a moderately fine grid for fixed-step methods
181            n_grid = 4097
182            t_eval = np.linspace(t_span[0], t_span[1], n_grid, dtype=float)
183        else:
184            t_eval = np.array([t_span[0], t_span[1]], dtype=float)
185        sol = integrator.integrate(
186            system,
187            y0,
188            t_eval,
189            event_fn=event_fn,
190            event_cfg=_EventConfig(direction=direction, terminal=True),
191        )
192        dt = time.perf_counter() - start
193        # Event path returns two nodes [t0, t_hit] or [t0, tmax] when no hit
194        t_event = float(sol.times[-1])
195        y_event = sol.states[-1].copy()
196        t_err = abs(t_event - t_expected)
197        y_err = float(np.linalg.norm(y_event - y_event))  # always 0 vs itself (placeholder)
198        return EventBenchmarkResult(
199            integrator_name=str(integrator),
200            problem_name=problem_name,
201            t_event=t_event,
202            y_event=y_event,
203            t_error=t_err,
204            y_error=y_err,
205            computation_time=dt,
206            converged=True,
207        )
208    except Exception as e:
209        dt = time.perf_counter() - start
210        return EventBenchmarkResult(
211            integrator_name=str(integrator),
212            problem_name=problem_name,
213            t_event=None,
214            y_event=None,
215            t_error=np.inf,
216            y_error=None,
217            computation_time=dt,
218            converged=False,
219            error_message=str(e),
220        )
221
222
223def run_scipy_event(
224    name: str,
225    method: str,
226    rhs: Callable[[float, np.ndarray], np.ndarray],
227    y0: np.ndarray,
228    t_span: Tuple[float, float],
229    event_fn: Callable[[float, np.ndarray], float],
230    direction: int,
231    t_expected: float,
232    problem_name: str,
233) -> EventBenchmarkResult:
234    _warmup_scipy_event(method, rhs, y0, t_span[0], event_fn, direction)
235    def ev(t, y):
236        return event_fn(t, y)
237    ev.terminal = True
238    ev.direction = float(direction)
239    start = time.perf_counter()
240    try:
241        sol = solve_ivp(
242            rhs,
243            t_span,
244            y0,
245            method=method,
246            events=ev,
247            rtol=1e-8,
248            atol=1e-10,
249        )
250        dt = time.perf_counter() - start
251        if sol.t_events and len(sol.t_events[0]) > 0:
252            t_event = float(sol.t_events[0][0])
253            y_event = sol.y_events[0][0].copy() if hasattr(sol, 'y_events') and sol.y_events and len(sol.y_events[0]) > 0 else None
254            t_err = abs(t_event - t_expected)
255            y_err = float(np.linalg.norm(y_event - y_event)) if y_event is not None else None
256            return EventBenchmarkResult(
257                integrator_name=name,
258                problem_name=problem_name,
259                t_event=t_event,
260                y_event=y_event,
261                t_error=t_err,
262                y_error=y_err,
263                computation_time=dt,
264                converged=True,
265            )
266        else:
267            return EventBenchmarkResult(
268                integrator_name=name,
269                problem_name=problem_name,
270                t_event=None,
271                y_event=None,
272                t_error=np.inf,
273                y_error=None,
274                computation_time=dt,
275                converged=False,
276                error_message=str(sol.message),
277            )
278    except Exception as e:
279        dt = time.perf_counter() - start
280        return EventBenchmarkResult(
281            integrator_name=name,
282            problem_name=problem_name,
283            t_event=None,
284            y_event=None,
285            t_error=np.inf,
286            y_error=None,
287            computation_time=dt,
288            converged=False,
289            error_message=str(e),
290        )
291
292
293def run_event_benchmark() -> List[EventBenchmarkResult]:
294    results: List[EventBenchmarkResult] = []
295
296    # Hiten integrators (adaptive and fixed-step)
297    hiten_integrators = {
298        "HITEN-RK45": RungeKutta(order=45, rtol=1e-8, atol=1e-10),
299        "HITEN-DOP853": RungeKutta(order=853, rtol=1e-8, atol=1e-10),
300        "HITEN-RK4-fixed": RungeKutta(order=4),
301        "HITEN-RK6-fixed": RungeKutta(order=6),
302        "HITEN-RK8-fixed": RungeKutta(order=8),
303    }
304
305    # SciPy integrators with event support
306    scipy_integrators = {
307        "SciPy-RK45": "RK45",
308        "SciPy-DOP853": "DOP853",
309        "SciPy-Radau": "Radau",
310        "SciPy-BDF": "BDF",
311        "SciPy-LSODA": "LSODA",
312    }
313
314    print("Event Integrator Benchmark (Hiten vs SciPy)")
315    print("=" * 60)
316
317    for prob in _problems():
318        print(f"\nProblem: {prob['name']}")
319        print("-" * 40)
320        system = create_rhs_system(prob["rhs"], dim=prob["dim"], name=prob["name"])  
321        t_expected = float(prob["t_expected"](prob["y0"]))
322
323        # Hiten
324        for name, integrator in hiten_integrators.items():
325            res = run_hiten_event(
326                integrator=integrator,
327                system=system,
328                y0=prob["y0"],
329                t_span=prob["t_span"],
330                event_fn=prob["event"],
331                direction=prob["direction"],
332                t_expected=t_expected,
333                problem_name=prob["name"],
334            )
335            results.append(res)
336            if res.converged:
337                print(f"{str(integrator):15s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, time={res.computation_time:.4f}s")
338            else:
339                print(f"{str(integrator):15s}: FAILED - {res.error_message}")
340
341        # SciPy
342        for name, method in scipy_integrators.items():
343            res = run_scipy_event(
344                name=name,
345                method=method,
346                rhs=prob["rhs"],
347                y0=prob["y0"],
348                t_span=prob["t_span"],
349                event_fn=prob["event"],
350                direction=prob["direction"],
351                t_expected=t_expected,
352                problem_name=prob["name"],
353            )
354            results.append(res)
355            if res.converged:
356                print(f"{name:15s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, time={res.computation_time:.4f}s")
357            else:
358                print(f"{name:15s}: FAILED - {res.error_message}")
359
360    return results
361
362
363def print_summary_table(results: List[EventBenchmarkResult]) -> None:
364    print("\n" + "=" * 100)
365    print("SUMMARY TABLE (Events)")
366    print("=" * 100)
367    print(f"{'Integrator':<15} {'Problem':<30} {'t_hit':<18} {'|dt|':<12} {'Time (s)':<10}")
368    print("-" * 100)
369
370    groups: Dict[str, List[EventBenchmarkResult]] = {}
371    for r in results:
372        groups.setdefault(r.problem_name, []).append(r)
373
374    for problem_name in sorted(groups.keys()):
375        for r in sorted(groups[problem_name], key=lambda x: (not x.converged, x.t_error)):
376            t_hit_str = f"{r.t_event:.10f}" if r.t_event is not None else "N/A"
377            dt_str = f"{r.t_error:.2e}" if np.isfinite(r.t_error) else "inf"
378            print(f"{r.integrator_name:<15} {r.problem_name:<30} {t_hit_str:<18} {dt_str:<12} {r.computation_time:<10.4f}")
379
380
381def create_performance_plot(results: List[EventBenchmarkResult]) -> None:
382    os.makedirs('_debug/results/plots', exist_ok=True)
383
384    # Plot |dt| vs time, grouped by integrator names
385    names = list({r.integrator_name for r in results})
386    colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(names))))
387    cmap: Dict[str, np.ndarray] = dict(zip(names, colors))
388
389    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
390    for name in names:
391        subset = [r for r in results if r.integrator_name == name and r.converged]
392        if not subset:
393            continue
394        times = [r.computation_time for r in subset]
395        terrs = [r.t_error for r in subset]
396        ax.loglog(times, terrs, 'o', label=name, color=cmap[name], markersize=8)
397
398    ax.set_xlabel('Computation Time (s)')
399    ax.set_ylabel('Absolute Event Time Error |dt|')
400    ax.set_title('Event Detection: Accuracy vs Speed')
401    ax.grid(True, which='both', alpha=0.3)
402    ax.legend()
403
404    plt.tight_layout()
405    out_path = '_debug/results/plots/events_performance_comparison.png'
406    plt.savefig(out_path, dpi=300, bbox_inches='tight')
407    plt.close()
408    print(f"Performance plot saved to {out_path}")
409
410
411def main() -> None:
412    results = run_event_benchmark()
413    print_summary_table(results)
414    create_performance_plot(results)
415    print("\nEvent benchmark completed.")
416
417
418if __name__ == "__main__":
419    main()
420
421

Symplectic Event Integrator Benchmark Example

Tests Hiten’s extended symplectic integrators with event detection enabled against SciPy solvers on simple problems with known event times. Reports detection accuracy and speed, and saves a comparison plot.

Testing Hiten’s extended symplectic integrators with event detection enabled against SciPy solvers on simple problems with known event times
  1#!/usr/bin/env python3
  2"""Symplectic event-handling benchmark and example.
  3
  4This script mirrors ``event_integrator_benchmark.py`` but focuses on
  5Hiten's extended symplectic integrators with event detection enabled.
  6It constructs a truncated pendulum Hamiltonian, integrates it with
  7event-enabled symplectic schemes, compares the detected event times
  8against SciPy solvers, and saves an accuracy vs speed plot.
  9"""
 10
 11from __future__ import annotations
 12
 13import os
 14import sys
 15import time
 16import warnings
 17from dataclasses import dataclass
 18from typing import Dict, List, Optional, Tuple
 19
 20import matplotlib.pyplot as plt
 21import numpy as np
 22from numba import njit, types
 23from numba.typed import List as NumbaList
 24from scipy.integrate import solve_ivp
 25
 26# Make project src importable when running from repository root
 27sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
 28
 29from hiten.algorithms.dynamics.hamiltonian import create_hamiltonian_system
 30from hiten.algorithms.integrators import ExtendedSymplectic
 31from hiten.algorithms.integrators.configs import _EventConfig
 32from hiten.algorithms.integrators.symplectic import (
 33    N_SYMPLECTIC_DOF,
 34    N_VARS_POLY,
 35    P_POLY_INDICES,
 36    Q_POLY_INDICES,
 37)
 38from hiten.algorithms.polynomial.base import (
 39    _create_encode_dict_from_clmo,
 40    _encode_multiindex,
 41    _init_index_tables,
 42)
 43
 44warnings.filterwarnings('ignore', category=UserWarning)
 45
 46
 47@dataclass
 48class EventResult:
 49    solver_name: str
 50    problem_name: str
 51    t_event: Optional[float]
 52    y_event: Optional[np.ndarray]
 53    t_error: float
 54    y_error: Optional[float]
 55    computation_time: float
 56    converged: bool = True
 57    error_message: str = ""
 58
 59
 60@dataclass
 61class SymplecticEventProblem:
 62    name: str
 63    description: str
 64    y0_extended: np.ndarray
 65    y0_canonical: np.ndarray
 66    t_span: Tuple[float, float]
 67    grid_size: int
 68    direction: int
 69
 70
 71@njit(types.float64(types.float64, types.float64[:]), cache=True, fastmath=True)
 72def event_q1_jit(t: float, y: np.ndarray) -> float:
 73    """Event function detecting q1 = 0 crossings."""
 74    return y[0]
 75
 76
 77def truncated_pendulum_rhs(_t: float, y: np.ndarray) -> List[float]:
 78    """Return dynamics for the 2D truncated pendulum system."""
 79    q, p = y
 80    sin_taylor = q - (q ** 3) / 6.0 + (q ** 5) / 120.0
 81    return [p, -sin_taylor]
 82
 83
 84def embed_state(qp_state: np.ndarray) -> np.ndarray:
 85    """Embed a 2D (q, p) state into the 6D symplectic phase space."""
 86    state = np.zeros(2 * N_SYMPLECTIC_DOF, dtype=np.float64)
 87    state[0] = qp_state[0]
 88    state[N_SYMPLECTIC_DOF] = qp_state[1]
 89    return state
 90
 91
 92def build_truncated_pendulum_system(max_degree: int = 6):
 93    """Construct a polynomial Hamiltonian system for a truncated pendulum."""
 94    psi_tables, clmo_tables = _init_index_tables(max_degree)
 95    encode_dict_list = _create_encode_dict_from_clmo(clmo_tables)
 96
 97    H_blocks = [
 98        np.zeros(psi_tables[N_VARS_POLY, deg], dtype=np.complex128)
 99        for deg in range(max_degree + 1)
100    ]
101
102    idx_p1 = P_POLY_INDICES[0]
103    idx_q1 = Q_POLY_INDICES[0]
104    mono = np.zeros(N_VARS_POLY, dtype=np.int64)
105
106    # Kinetic term: p1^2 / 2 (degree 2)
107    mono[:] = 0
108    mono[idx_p1] = 2
109    encoded = _encode_multiindex(mono, 2, encode_dict_list)
110    if encoded != -1:
111        H_blocks[2][encoded] = 0.5
112
113    # Potential offset: -1 (degree 0)
114    mono[:] = 0
115    encoded = _encode_multiindex(mono, 0, encode_dict_list)
116    if encoded != -1:
117        H_blocks[0][encoded] = -1.0
118
119    # Quadratic potential: +q1^2 / 2 (degree 2)
120    mono[:] = 0
121    mono[idx_q1] = 2
122    encoded = _encode_multiindex(mono, 2, encode_dict_list)
123    if encoded != -1:
124        H_blocks[2][encoded] += 0.5
125
126    # Quartic correction: -q1^4 / 24 (degree 4)
127    if max_degree >= 4:
128        mono[:] = 0
129        mono[idx_q1] = 4
130        encoded = _encode_multiindex(mono, 4, encode_dict_list)
131        if encoded != -1:
132            H_blocks[4][encoded] = -1.0 / 24.0
133
134    # Sextic correction: +q1^6 / 720 (degree 6)
135    if max_degree >= 6:
136        mono[:] = 0
137        mono[idx_q1] = 6
138        encoded = _encode_multiindex(mono, 6, encode_dict_list)
139        if encoded != -1:
140            H_blocks[6][encoded] = 1.0 / 720.0
141
142    H_blocks_typed = NumbaList()
143    for arr in H_blocks:
144        H_blocks_typed.append(arr.copy())
145
146    system = create_hamiltonian_system(
147        H_blocks=H_blocks_typed,
148        degree=max_degree,
149        psi_table=psi_tables,
150        clmo_table=clmo_tables,
151        encode_dict_list=encode_dict_list,
152        n_dof=N_SYMPLECTIC_DOF,
153        name="Truncated Pendulum Hamiltonian",
154    )
155
156    return system
157
158
159def make_event_problems() -> List[SymplecticEventProblem]:
160    """Configure example problems for event detection."""
161    angle = np.deg2rad(45.0)
162    base_state = np.zeros(2 * N_SYMPLECTIC_DOF, dtype=np.float64)
163    base_state_neg = base_state.copy()
164
165    # Positive release, expect crossing with negative direction
166    base_state[0] = angle
167    problem_pos = SymplecticEventProblem(
168        name="Pendulum release (+45 deg)",
169        description="Pendulum released from +45 degrees, expect q1 -> 0 with negative crossing.",
170        y0_extended=base_state.copy(),
171        y0_canonical=np.array([angle, 0.0], dtype=np.float64),
172        t_span=(0.0, 5.0),
173        grid_size=4097,
174        direction=-1,
175    )
176
177    # Negative release, expect crossing with positive direction
178    base_state_neg[0] = -angle
179    problem_neg = SymplecticEventProblem(
180        name="Pendulum release (-45 deg)",
181        description="Pendulum released from -45 degrees, expect q1 -> 0 with positive crossing.",
182        y0_extended=base_state_neg.copy(),
183        y0_canonical=np.array([-angle, 0.0], dtype=np.float64),
184        t_span=(0.0, 5.0),
185        grid_size=4097,
186        direction=+1,
187    )
188
189    return [problem_pos, problem_neg]
190
191
192def compute_reference_event(problem: SymplecticEventProblem) -> Tuple[float, np.ndarray]:
193    """High-accuracy reference using SciPy Radau."""
194    event = lambda t, y: y[0]
195    event.terminal = True
196    event.direction = float(problem.direction)
197
198    sol = solve_ivp(
199        truncated_pendulum_rhs,
200        problem.t_span,
201        problem.y0_canonical,
202        method="Radau",
203        events=event,
204        rtol=1e-12,
205        atol=1e-14,
206    )
207
208    if sol.status == 1 and sol.t_events and len(sol.t_events[0]) > 0:
209        t_event = float(sol.t_events[0][0])
210        y_event = sol.y_events[0][0].copy()
211        return t_event, y_event
212
213    raise RuntimeError(f"Reference solver failed to detect event for problem '{problem.name}'")
214
215
216def _warmup_symplectic_event(integrator, system, problem: SymplecticEventProblem) -> None:
217    """Trigger compilation overhead outside timed region."""
218    try:
219        warmup_grid = np.linspace(
220            problem.t_span[0],
221            problem.t_span[0] + 1.0e-2,
222            8,
223            dtype=np.float64,
224        )
225        integrator.integrate(
226            system,
227            problem.y0_extended,
228            warmup_grid,
229            event_fn=event_q1_jit,
230            event_cfg=_EventConfig(direction=problem.direction, terminal=True),
231        )
232    except Exception:
233        pass
234
235
236def _warmup_scipy_event(method: str, problem: SymplecticEventProblem) -> None:
237    """Trigger SciPy solver overhead outside timed region."""
238    try:
239        event = lambda t, y: y[0]
240        event.terminal = True
241        event.direction = float(problem.direction)
242        warmup_t_span = (problem.t_span[0], problem.t_span[0] + 1.0e-2)
243        solve_ivp(
244            truncated_pendulum_rhs,
245            warmup_t_span,
246            problem.y0_canonical,
247            method=method,
248            events=event,
249            rtol=1e-6,
250            atol=1e-8,
251        )
252    except Exception:
253        pass
254
255
256def run_symplectic_event(
257    solver_name: str,
258    integrator,
259    system,
260    problem: SymplecticEventProblem,
261    t_expected: float,
262    y_expected: np.ndarray,
263) -> EventResult:
264    """Execute symplectic integration with event detection."""
265    time_grid = np.linspace(
266        problem.t_span[0], problem.t_span[1], problem.grid_size, dtype=np.float64
267    )
268    event_cfg = _EventConfig(direction=problem.direction, terminal=True, xtol=1.0e-12, gtol=1.0e-12)
269
270    _warmup_symplectic_event(integrator, system, problem)
271
272    start = time.perf_counter()
273    try:
274        sol = integrator.integrate(
275            system,
276            problem.y0_extended,
277            time_grid,
278            event_fn=event_q1_jit,
279            event_cfg=event_cfg,
280        )
281        elapsed = time.perf_counter() - start
282
283        t_event = float(sol.times[-1])
284        y_event = sol.states[-1].copy()
285        t_error = abs(t_event - t_expected)
286        y_error = float(np.linalg.norm(y_event - y_expected)) if y_expected is not None else None
287
288        return EventResult(
289            solver_name=solver_name,
290            problem_name=problem.name,
291            t_event=t_event,
292            y_event=y_event,
293            t_error=t_error,
294            y_error=y_error,
295            computation_time=elapsed,
296            converged=True,
297        )
298    except Exception as exc:
299        elapsed = time.perf_counter() - start
300        return EventResult(
301            solver_name=solver_name,
302            problem_name=problem.name,
303            t_event=None,
304            y_event=None,
305            t_error=np.inf,
306            y_error=None,
307            computation_time=elapsed,
308            converged=False,
309            error_message=str(exc),
310        )
311
312
313def run_scipy_event(
314    solver_name: str,
315    method: str,
316    solver_opts: Dict,
317    problem: SymplecticEventProblem,
318    t_expected: float,
319    y_expected: np.ndarray,
320) -> EventResult:
321    """Run SciPy solve_ivp with event detection on the 2D system."""
322    event = lambda t, y: y[0]
323    event.terminal = True
324    event.direction = float(problem.direction)
325
326    _warmup_scipy_event(method, problem)
327
328    start = time.perf_counter()
329    try:
330        sol = solve_ivp(
331            truncated_pendulum_rhs,
332            problem.t_span,
333            problem.y0_canonical,
334            method=method,
335            events=event,
336            **solver_opts,
337        )
338        elapsed = time.perf_counter() - start
339
340        if sol.status == 1 and sol.t_events and len(sol.t_events[0]) > 0:
341            t_event = float(sol.t_events[0][0])
342            qp_event = sol.y_events[0][0].copy()
343            y_event = embed_state(qp_event)
344            t_error = abs(t_event - t_expected)
345            y_error = float(np.linalg.norm(y_event - y_expected)) if y_expected is not None else None
346            return EventResult(
347                solver_name=solver_name,
348                problem_name=problem.name,
349                t_event=t_event,
350                y_event=y_event,
351                t_error=t_error,
352                y_error=y_error,
353                computation_time=elapsed,
354                converged=True,
355            )
356
357        message = sol.message if hasattr(sol, 'message') else "event not detected"
358        return EventResult(
359            solver_name=solver_name,
360            problem_name=problem.name,
361            t_event=None,
362            y_event=None,
363            t_error=np.inf,
364            y_error=None,
365            computation_time=elapsed,
366            converged=False,
367            error_message=message,
368        )
369    except Exception as exc:
370        elapsed = time.perf_counter() - start
371        return EventResult(
372            solver_name=solver_name,
373            problem_name=problem.name,
374            t_event=None,
375            y_event=None,
376            t_error=np.inf,
377            y_error=None,
378            computation_time=elapsed,
379            converged=False,
380            error_message=str(exc),
381        )
382
383
384def run_benchmark() -> List[EventResult]:
385    """Execute the symplectic event benchmark."""
386    system = build_truncated_pendulum_system()
387    problems = make_event_problems()
388
389    symplectic_schemes = {
390        "Symplectic4": ExtendedSymplectic(order=4, c_omega_heuristic=15.0),
391        "Symplectic6": ExtendedSymplectic(order=6, c_omega_heuristic=20.0),
392        "Symplectic8": ExtendedSymplectic(order=8, c_omega_heuristic=25.0),
393    }
394
395    scipy_solvers = {
396        "SciPy-RK45": ("RK45", {"rtol": 1.0e-8, "atol": 1.0e-10}),
397        "SciPy-DOP853": ("DOP853", {"rtol": 1.0e-9, "atol": 1.0e-11}),
398        "SciPy-BDF": ("BDF", {"rtol": 1.0e-8, "atol": 1.0e-10}),
399    }
400
401    results: List[EventResult] = []
402
403    print("Symplectic Event Integrator Benchmark")
404    print("=" * 70)
405
406    for problem in problems:
407        print(f"\nProblem: {problem.name}")
408        print(problem.description)
409        print("-" * 70)
410
411        t_ref, y_ref_qp = compute_reference_event(problem)
412        y_ref_ext = embed_state(y_ref_qp)
413
414        print(f"Reference event time (Radau): {t_ref:.10f} s")
415
416        reference_result = EventResult(
417            solver_name="SciPy-Radau (reference)",
418            problem_name=problem.name,
419            t_event=t_ref,
420            y_event=y_ref_ext,
421            t_error=0.0,
422            y_error=0.0,
423            computation_time=0.0,
424            converged=True,
425        )
426        results.append(reference_result)
427
428        for solver_name, integrator in symplectic_schemes.items():
429            res = run_symplectic_event(
430                solver_name,
431                integrator,
432                system,
433                problem,
434                t_ref,
435                y_ref_ext,
436            )
437            results.append(res)
438            if res.converged:
439                print(
440                    f"{solver_name:20s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, "
441                    f"time={res.computation_time:.4f}s"
442                )
443            else:
444                print(f"{solver_name:20s}: FAILED - {res.error_message}")
445
446        for solver_name, (method, opts) in scipy_solvers.items():
447            res = run_scipy_event(
448                solver_name,
449                method,
450                opts,
451                problem,
452                t_ref,
453                y_ref_ext,
454            )
455            results.append(res)
456            if res.converged:
457                print(
458                    f"{solver_name:20s}: t_hit={res.t_event:.10f}, |dt|={res.t_error:.2e}, "
459                    f"time={res.computation_time:.4f}s"
460                )
461            else:
462                print(f"{solver_name:20s}: FAILED - {res.error_message}")
463
464    return results
465
466
467def print_summary_table(results: List[EventResult]) -> None:
468    """Print formatted summary of all runs."""
469    print("\n" + "=" * 110)
470    print("SUMMARY TABLE (Symplectic Events)")
471    print("=" * 110)
472    print(f"{'Integrator':<22} {'Problem':<32} {'t_hit':<18} {'|dt|':<12} {'|dy|':<12} {'Time (s)':<10}")
473    print("-" * 110)
474
475    grouped: Dict[str, List[EventResult]] = {}
476    for res in results:
477        grouped.setdefault(res.problem_name, []).append(res)
478
479    for problem_name in sorted(grouped):
480        for res in sorted(grouped[problem_name], key=lambda r: (not r.converged, r.t_error)):
481            if res.converged:
482                t_hit_str = f"{res.t_event:.10f}" if res.t_event is not None else "N/A"
483                dt_str = f"{res.t_error:.2e}" if np.isfinite(res.t_error) else "inf"
484                dy_str = (
485                    f"{res.y_error:.2e}"
486                    if res.y_error is not None and np.isfinite(res.y_error)
487                    else "n/a"
488                )
489                print(
490                    f"{res.solver_name:<22} {res.problem_name:<32} {t_hit_str:<18} "
491                    f"{dt_str:<12} {dy_str:<12} {res.computation_time:<10.4f}"
492                )
493            else:
494                print(
495                    f"{res.solver_name:<22} {res.problem_name:<32} {'FAILED':<18} "
496                    f"{'inf':<12} {'n/a':<12} {res.computation_time:<10.4f}"
497                )
498
499
500def create_performance_plot(results: List[EventResult]) -> None:
501    """Save log-log plot of |dt| vs computation time."""
502    os.makedirs(os.path.join('_debug', 'results', 'plots'), exist_ok=True)
503
504    names = sorted({res.solver_name for res in results if res.converged and res.t_error >= 0.0})
505    if not names:
506        print("No successful runs to plot.")
507        return
508
509    colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(names))))
510    cmap = dict(zip(names, colors))
511
512    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
513    for name in names:
514        subset = [r for r in results if r.solver_name == name and r.converged and np.isfinite(r.t_error)]
515        if not subset:
516            continue
517        times = [r.computation_time for r in subset]
518        terrs = [r.t_error for r in subset]
519        ax.loglog(times, terrs, 'o', label=name, color=cmap[name], markersize=8)
520
521    ax.set_xlabel('Computation Time (s)')
522    ax.set_ylabel('Absolute Event Time Error |dt|')
523    ax.set_title('Symplectic Event Detection: Accuracy vs Speed')
524    ax.grid(True, which='both', alpha=0.3)
525    ax.legend()
526
527    plt.tight_layout()
528    out_path = os.path.join('_debug', 'results', 'plots', 'symplectic_events_performance.png')
529    plt.savefig(out_path, dpi=300, bbox_inches='tight')
530    plt.close()
531    print(f"Performance plot saved to {out_path}")
532
533
534def main() -> None:
535    results = run_benchmark()
536    print_summary_table(results)
537    create_performance_plot(results)
538    print("\nSymplectic event benchmark completed.")
539
540
541if __name__ == "__main__":
542    main()
543