summaryrefslogtreecommitdiff
path: root/examples/example_custom_data.py
blob: 5625d807d5ba4fb4f35168770627b2698821013a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Example: Using transivent with custom data formats.

This example demonstrates how to use transivent's building blocks
directly with your own time-series data, without requiring any
proprietary file formats.
"""

import numpy as np
from transivent import (
    calculate_initial_background,
    calculate_clean_background,
    detect_initial_events,
    detect_final_events,
    merge_overlapping_events,
    estimate_noise,
    analyze_thresholds,
    create_oscilloscope_plot,
    EventPlotter,
)

# Generate synthetic data (replace with your actual data loading)
print("Generating synthetic data...")
np.random.seed(42)

# Time array (must be in seconds)
duration = 0.05  # 50 ms
sampling_rate = 2_000_000  # 2 MHz
n_points = int(duration * sampling_rate)
t = np.linspace(0, duration, n_points)

# Signal with some spikes
x = np.random.randn(n_points) * 0.1  # Background noise

# Add some synthetic spikes
spike_times = [0.01, 0.023, 0.037, 0.045]
spike_amplitudes = [-1.5, -2.0, -0.8, -1.2]  # Negative spikes
spike_width = 50  # samples

for spike_t, amp in zip(spike_times, spike_amplitudes):
    spike_idx = int(spike_t * sampling_rate)
    spike_start = max(0, spike_idx - spike_width // 2)
    spike_end = min(n_points, spike_idx + spike_width // 2)
    x[spike_start:spike_end] += amp * np.exp(-0.5 * ((np.arange(spike_start, spike_end) - spike_idx) / (spike_width / 4))**2)

print(f"Data shape: {t.shape}, Sampling interval: {t[1] - t[0]:.2e} s")

# Analysis parameters
sampling_interval = t[1] - t[0]
smooth_n = 101  # Smoothing window in samples
detection_snr = 3.0
min_event_keep_snr = 5.0
signal_polarity = -1  # Negative spikes
min_event_n = 10  # Minimum event length in samples
widen_frac = 0.5

# Step 1: Calculate initial background
print("\nStep 1: Calculating initial background...")
bg_initial = calculate_initial_background(t, x, smooth_n, filter_type="gaussian")

# Step 2: Estimate noise
print("Step 2: Estimating noise level...")
global_noise = estimate_noise(x, bg_initial)
print(f"Estimated noise: {global_noise:.3f}")

# Step 3: Initial event detection
print("\nStep 3: Initial event detection...")
events_initial = detect_initial_events(
    t, x, bg_initial, global_noise, detection_snr, 
    min_event_keep_snr, widen_frac=widen_frac, signal_polarity=signal_polarity,
    min_event_n=min_event_n
)
print(f"Found {len(events_initial)} initial events")

# Step 4: Calculate clean background (masking events)
print("\nStep 4: Calculating clean background...")
bg_clean = calculate_clean_background(
    t, x, events_initial, smooth_n, bg_initial, filter_type="gaussian"
)

# Step 5: Final event detection with clean background
print("\nStep 5: Final event detection...")
events = detect_final_events(
    t, x, bg_clean, global_noise, detection_snr,
    min_event_keep_snr, widen_frac=widen_frac, signal_polarity=signal_polarity,
    min_event_n=min_event_n
)

# Step 6: Merge any overlapping events
events = merge_overlapping_events(events)
print(f"Final event count: {len(events)}")

# Print event details
print("\nDetected events:")
for i, (start, end) in enumerate(events):
    duration_us = (end - start) * 1e6
    print(f"  Event {i+1}: {start:.6f}s to {end:.6f}s (duration: {duration_us:.2f} µs)")

# Step 7: Visualization
print("\nStep 7: Creating visualizations...")

# Analyze thresholds for plotting
detection_threshold, keep_threshold = analyze_thresholds(
    x, bg_clean, global_noise, detection_snr, min_event_keep_snr, signal_polarity
)

# Create main plot
plot = create_oscilloscope_plot(
    t, x, bg_initial, bg_clean, events,
    detection_threshold, keep_threshold,
    name="Custom Data Example", detection_snr=detection_snr,
    min_event_keep_snr=min_event_keep_snr,
    max_plot_points=10000, envelope_mode_limit=10e-3, smooth_n=smooth_n,
    global_noise=global_noise
)

# Create event plots
if len(events) > 0:
    event_plotter = EventPlotter(
        plot, events, bg_clean=bg_clean, global_noise=global_noise
    )
    event_plotter.plot_events_grid(max_events=16)
    event_plotter.save("custom_data_events.png")

# Save plots
plot.save("custom_data_trace.png")
print("\nPlots saved:")
print("  - custom_data_trace.png: Full trace with events")
if len(events) > 0:
    print("  - custom_data_events.png: Individual event plots")

# Show plots (uncomment to display)
# import matplotlib.pyplot as plt
# plt.show()

print("\nDone! The transivent building blocks work with any time-series data.")