summaryrefslogtreecommitdiff
path: root/tests/test_diffusion_simple.py
blob: 96e77dfea9a97324a541b54bbe8259ec87201978 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
"""
Simple test for diffusion processing with clear events.
"""

import numpy as np
import matplotlib.pyplot as plt

from transivent import (
    extract_event_waveforms,
    calculate_msd_parallel,
    calculate_acf,
    fit_diffusion_linear,
    plot_diffusion_comparison,
    process_events_for_diffusion,
)

# Create simple test data with clear events
np.random.seed(42)
n_points = 10000
t = np.linspace(0, 0.01, n_points)  # 10 ms
x = np.random.normal(0, 0.01, n_points)  # Low noise

# Add clear step events
events_list = []
for i in range(10):
    start = int((i * 0.0005 + 0.001) * n_points / 0.01)  # Every 0.5 ms
    end = int((i * 0.0005 + 0.0013) * n_points / 0.01)  # 300 µs duration
    x[start:end] += 1.0  # Clear step
    events_list.append([t[start], t[end]])

events = np.array(events_list)

print(f"Created {len(events)} events")
print(f"Event durations: {events[:, 1] - events[:, 0]}")

# Test event extraction
waveforms = extract_event_waveforms(t, x, events)
print(f"Extracted {len(waveforms)} waveforms")

# Process first few events for diffusion
diffusion_coeffs = []
acf_values = []

for i, wf in enumerate(waveforms[:5]):
    # Add some random walk to make it interesting
    wf = wf + np.cumsum(np.random.normal(0, 0.01, len(wf)))
    
    # MSD calculation
    taus, msds, counts = calculate_msd_parallel(wf, dt=1e-6, max_lag=100, n_jobs=1)
    D = fit_diffusion_linear(taus, msds, time_limit=3e-5)
    if not np.isnan(D):
        diffusion_coeffs.append(D)
    
    # ACF calculation
    lags, acf = calculate_acf(wf, dt=1e-6, max_lag=100)
    acf_values.append(acf[0])

print(f"Calculated {len(diffusion_coeffs)} diffusion coefficients")
print(f"Mean D: {np.mean(diffusion_coeffs):.3e}" if diffusion_coeffs else "No valid diffusion coefficients")

# Create a simple plot
if len(diffusion_coeffs) > 0:
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.scatter(diffusion_coeffs, acf_values, label="Test events")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("Diffusion Coefficient (m²/s)")
    ax.set_ylabel("ACF (0-lag)")
    ax.legend()
    plt.savefig("test_diffusion_plot.png", dpi=150, bbox_inches="tight")
    print("Saved test plot to: test_diffusion_plot.png")
    plt.show()

# Test the high-level function
results = process_events_for_diffusion(
    name="test",
    sampling_interval=1e-6,
    data_path="",
    t=t,
    x=x,
    events=events,
    max_lag=100,
    n_jobs=1,
)

print(f"\nHigh-level function results:")
print(f"Event count: {results['event_count']}")
print(f"Diffusion coeffs: {len(results['diffusion_coeffs'])}")
if results['statistics']['mean_diffusion']:
    print(f"Mean D: {results['statistics']['mean_diffusion']:.3e}")