summaryrefslogtreecommitdiff
path: root/tests/test_diffusion_simple.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_diffusion_simple.py')
-rw-r--r--tests/test_diffusion_simple.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/tests/test_diffusion_simple.py b/tests/test_diffusion_simple.py
new file mode 100644
index 0000000..96e77df
--- /dev/null
+++ b/tests/test_diffusion_simple.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+"""
+Simple test for diffusion processing with clear events.
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from transivent import (
+ extract_event_waveforms,
+ calculate_msd_parallel,
+ calculate_acf,
+ fit_diffusion_linear,
+ plot_diffusion_comparison,
+ process_events_for_diffusion,
+)
+
+# Create simple test data with clear events
+np.random.seed(42)
+n_points = 10000
+t = np.linspace(0, 0.01, n_points) # 10 ms
+x = np.random.normal(0, 0.01, n_points) # Low noise
+
+# Add clear step events
+events_list = []
+for i in range(10):
+ start = int((i * 0.0005 + 0.001) * n_points / 0.01) # Every 0.5 ms
+ end = int((i * 0.0005 + 0.0013) * n_points / 0.01) # 300 µs duration
+ x[start:end] += 1.0 # Clear step
+ events_list.append([t[start], t[end]])
+
+events = np.array(events_list)
+
+print(f"Created {len(events)} events")
+print(f"Event durations: {events[:, 1] - events[:, 0]}")
+
+# Test event extraction
+waveforms = extract_event_waveforms(t, x, events)
+print(f"Extracted {len(waveforms)} waveforms")
+
+# Process first few events for diffusion
+diffusion_coeffs = []
+acf_values = []
+
+for i, wf in enumerate(waveforms[:5]):
+ # Add some random walk to make it interesting
+ wf = wf + np.cumsum(np.random.normal(0, 0.01, len(wf)))
+
+ # MSD calculation
+ taus, msds, counts = calculate_msd_parallel(wf, dt=1e-6, max_lag=100, n_jobs=1)
+ D = fit_diffusion_linear(taus, msds, time_limit=3e-5)
+ if not np.isnan(D):
+ diffusion_coeffs.append(D)
+
+ # ACF calculation
+ lags, acf = calculate_acf(wf, dt=1e-6, max_lag=100)
+ acf_values.append(acf[0])
+
+print(f"Calculated {len(diffusion_coeffs)} diffusion coefficients")
+print(f"Mean D: {np.mean(diffusion_coeffs):.3e}" if diffusion_coeffs else "No valid diffusion coefficients")
+
+# Create a simple plot
+if len(diffusion_coeffs) > 0:
+ fig, ax = plt.subplots(figsize=(7, 6))
+ ax.scatter(diffusion_coeffs, acf_values, label="Test events")
+ ax.set_xscale("log")
+ ax.set_yscale("log")
+ ax.set_xlabel("Diffusion Coefficient (m²/s)")
+ ax.set_ylabel("ACF (0-lag)")
+ ax.legend()
+ plt.savefig("test_diffusion_plot.png", dpi=150, bbox_inches="tight")
+ print("Saved test plot to: test_diffusion_plot.png")
+ plt.show()
+
+# Test the high-level function
+results = process_events_for_diffusion(
+ name="test",
+ sampling_interval=1e-6,
+ data_path="",
+ t=t,
+ x=x,
+ events=events,
+ max_lag=100,
+ n_jobs=1,
+)
+
+print(f"\nHigh-level function results:")
+print(f"Event count: {results['event_count']}")
+print(f"Diffusion coeffs: {len(results['diffusion_coeffs'])}")
+if results['statistics']['mean_diffusion']:
+ print(f"Mean D: {results['statistics']['mean_diffusion']:.3e}") \ No newline at end of file