#!/usr/bin/env python3 """ Simple test for diffusion processing with clear events. """ import numpy as np import matplotlib.pyplot as plt from transivent import ( extract_event_waveforms, calculate_msd_parallel, calculate_acf, fit_diffusion_linear, plot_diffusion_comparison, process_events_for_diffusion, ) # Create simple test data with clear events np.random.seed(42) n_points = 10000 t = np.linspace(0, 0.01, n_points) # 10 ms x = np.random.normal(0, 0.01, n_points) # Low noise # Add clear step events events_list = [] for i in range(10): start = int((i * 0.0005 + 0.001) * n_points / 0.01) # Every 0.5 ms end = int((i * 0.0005 + 0.0013) * n_points / 0.01) # 300 µs duration x[start:end] += 1.0 # Clear step events_list.append([t[start], t[end]]) events = np.array(events_list) print(f"Created {len(events)} events") print(f"Event durations: {events[:, 1] - events[:, 0]}") # Test event extraction waveforms = extract_event_waveforms(t, x, events) print(f"Extracted {len(waveforms)} waveforms") # Process first few events for diffusion diffusion_coeffs = [] acf_values = [] for i, wf in enumerate(waveforms[:5]): # Add some random walk to make it interesting wf = wf + np.cumsum(np.random.normal(0, 0.01, len(wf))) # MSD calculation taus, msds, counts = calculate_msd_parallel(wf, dt=1e-6, max_lag=100, n_jobs=1) D = fit_diffusion_linear(taus, msds, time_limit=3e-5) if not np.isnan(D): diffusion_coeffs.append(D) # ACF calculation lags, acf = calculate_acf(wf, dt=1e-6, max_lag=100) acf_values.append(acf[0]) print(f"Calculated {len(diffusion_coeffs)} diffusion coefficients") print(f"Mean D: {np.mean(diffusion_coeffs):.3e}" if diffusion_coeffs else "No valid diffusion coefficients") # Create a simple plot if len(diffusion_coeffs) > 0: fig, ax = plt.subplots(figsize=(7, 6)) ax.scatter(diffusion_coeffs, acf_values, label="Test events") ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel("Diffusion Coefficient (m²/s)") ax.set_ylabel("ACF (0-lag)") ax.legend() plt.savefig("test_diffusion_plot.png", dpi=150, bbox_inches="tight") print("Saved test plot to: test_diffusion_plot.png") plt.show() # Test the high-level function results = process_events_for_diffusion( name="test", sampling_interval=1e-6, data_path="", t=t, x=x, events=events, max_lag=100, n_jobs=1, ) print(f"\nHigh-level function results:") print(f"Event count: {results['event_count']}") print(f"Diffusion coeffs: {len(results['diffusion_coeffs'])}") if results['statistics']['mean_diffusion']: print(f"Mean D: {results['statistics']['mean_diffusion']:.3e}")