diff options
Diffstat (limited to 'tests/test_diffusion_simple.py')
| -rw-r--r-- | tests/test_diffusion_simple.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/tests/test_diffusion_simple.py b/tests/test_diffusion_simple.py new file mode 100644 index 0000000..96e77df --- /dev/null +++ b/tests/test_diffusion_simple.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +""" +Simple test for diffusion processing with clear events. +""" + +import numpy as np +import matplotlib.pyplot as plt + +from transivent import ( + extract_event_waveforms, + calculate_msd_parallel, + calculate_acf, + fit_diffusion_linear, + plot_diffusion_comparison, + process_events_for_diffusion, +) + +# Create simple test data with clear events +np.random.seed(42) +n_points = 10000 +t = np.linspace(0, 0.01, n_points) # 10 ms +x = np.random.normal(0, 0.01, n_points) # Low noise + +# Add clear step events +events_list = [] +for i in range(10): + start = int((i * 0.0005 + 0.001) * n_points / 0.01) # Every 0.5 ms + end = int((i * 0.0005 + 0.0013) * n_points / 0.01) # 300 µs duration + x[start:end] += 1.0 # Clear step + events_list.append([t[start], t[end]]) + +events = np.array(events_list) + +print(f"Created {len(events)} events") +print(f"Event durations: {events[:, 1] - events[:, 0]}") + +# Test event extraction +waveforms = extract_event_waveforms(t, x, events) +print(f"Extracted {len(waveforms)} waveforms") + +# Process first few events for diffusion +diffusion_coeffs = [] +acf_values = [] + +for i, wf in enumerate(waveforms[:5]): + # Add some random walk to make it interesting + wf = wf + np.cumsum(np.random.normal(0, 0.01, len(wf))) + + # MSD calculation + taus, msds, counts = calculate_msd_parallel(wf, dt=1e-6, max_lag=100, n_jobs=1) + D = fit_diffusion_linear(taus, msds, time_limit=3e-5) + if not np.isnan(D): + diffusion_coeffs.append(D) + + # ACF calculation + lags, acf = calculate_acf(wf, dt=1e-6, max_lag=100) + acf_values.append(acf[0]) + +print(f"Calculated {len(diffusion_coeffs)} diffusion coefficients") +print(f"Mean D: {np.mean(diffusion_coeffs):.3e}" if diffusion_coeffs else "No valid diffusion coefficients") + +# Create a simple plot +if len(diffusion_coeffs) > 0: + fig, ax = plt.subplots(figsize=(7, 6)) + ax.scatter(diffusion_coeffs, acf_values, label="Test events") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Diffusion Coefficient (m²/s)") + ax.set_ylabel("ACF (0-lag)") + ax.legend() + plt.savefig("test_diffusion_plot.png", dpi=150, bbox_inches="tight") + print("Saved test plot to: test_diffusion_plot.png") + plt.show() + +# Test the high-level function +results = process_events_for_diffusion( + name="test", + sampling_interval=1e-6, + data_path="", + t=t, + x=x, + events=events, + max_lag=100, + n_jobs=1, +) + +print(f"\nHigh-level function results:") +print(f"Event count: {results['event_count']}") +print(f"Diffusion coeffs: {len(results['diffusion_coeffs'])}") +if results['statistics']['mean_diffusion']: + print(f"Mean D: {results['statistics']['mean_diffusion']:.3e}")
\ No newline at end of file |
