summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSam Scholten2025-10-23 15:06:25 +1000
committerSam Scholten2025-10-23 15:22:54 +1000
commit307bf648d8e3fe852d7daf2fa1567d1896e50f7e (patch)
treed15344eab2003fd0a12544cc1ed9fbfef3e871d9 /src
parent4a7026759e099e5c81cc9c77f19182a23d2f0275 (diff)
downloadtransivent-307bf648d8e3fe852d7daf2fa1567d1896e50f7e.tar.gz
transivent-307bf648d8e3fe852d7daf2fa1567d1896e50f7e.zip
Release v2.0.0v2.0.0
Major API refactoring with simplified public interface. - Added EventProcessor for high-level event processing workflow - New utility functions for data preprocessing - Additional example scripts for different use cases - Comprehensive test suite - Updated documentation with migration guide
Diffstat (limited to 'src')
-rw-r--r--src/transivent/__init__.py45
-rw-r--r--src/transivent/analysis.py389
-rw-r--r--src/transivent/event_detector.py83
-rw-r--r--src/transivent/event_processor.py659
-rw-r--r--src/transivent/utils.py130
5 files changed, 1170 insertions, 136 deletions
diff --git a/src/transivent/__init__.py b/src/transivent/__init__.py
index db3e824..a858d8c 100644
--- a/src/transivent/__init__.py
+++ b/src/transivent/__init__.py
@@ -1,36 +1,35 @@
"""
High-level analysis and plotting for transient events.
+
+This module provides tools for detecting and analyzing transient events in
+time-series data. Two main entry points are available:
+
+- detect(): For analyzing custom time-series data (arrays)
+- detect_from_wfm(): For analyzing Wfm binary files with XML sidecars
+
+For advanced usage, building blocks are available in submodules:
+- transivent.analysis: Background, noise, and event detection functions
+- transivent.event_detector: Low-level event detection algorithms
+- transivent.diffusion: Diffusion analysis tools (optional)
"""
-from .analysis import (
- analyze_thresholds,
- calculate_initial_background,
- calculate_smoothing_parameters,
- configure_logging,
- create_oscilloscope_plot,
- get_final_events,
- initialize_state,
- process_chunk,
- process_file,
-)
+from .analysis import detect, detect_from_wfm
from .event_detector import detect_events, merge_overlapping_events
from .event_plotter import EventPlotter
-from .io import get_waveform_params, rd, rd_chunked
+from .event_processor import extract_event_waveforms
+from .io import get_waveform_params, rd
__all__ = [
- "analyze_thresholds",
- "calculate_initial_background",
- "calculate_smoothing_parameters",
- "configure_logging",
- "create_oscilloscope_plot",
+ # Main entry points
+ "detect",
+ "detect_from_wfm",
+ # Building blocks
"detect_events",
+ "merge_overlapping_events",
+ "extract_event_waveforms",
+ # Visualization
"EventPlotter",
- "get_final_events",
+ # I/O utilities
"get_waveform_params",
- "initialize_state",
- "merge_overlapping_events",
- "process_chunk",
- "process_file",
"rd",
- "rd_chunked",
]
diff --git a/src/transivent/analysis.py b/src/transivent/analysis.py
index 1b5277b..caed827 100644
--- a/src/transivent/analysis.py
+++ b/src/transivent/analysis.py
@@ -25,38 +25,8 @@ from .event_plotter import EventPlotter
import xml.etree.ElementTree as ET
-@njit
-def _create_event_mask_numba(t: np.ndarray, events: np.ndarray) -> np.ndarray:
- """
- Create a boolean mask to exclude event regions using numba for speed.
-
- Parameters
- ----------
- t : np.ndarray
- Time array.
- events : np.ndarray
- Events array with shape (n_events, 2) where each row is [t_start, t_end].
-
- Returns
- -------
- np.ndarray
- Boolean mask where True means keep the sample, False means exclude.
- """
- mask = np.ones(len(t), dtype=np.bool_)
- if len(events) == 0:
- return mask
-
- for i in range(len(events)):
- t_start = events[i, 0]
- t_end = events[i, 1]
-
- start_idx = np.searchsorted(t, t_start, side="left")
- end_idx = np.searchsorted(t, t_end, side="left")
-
- if start_idx < end_idx:
- mask[start_idx:end_idx] = False
-
- return mask
+# Import utility functions to avoid duplication
+from .utils import create_event_mask_numba
def extract_preview_image(sidecar_path: str, output_path: str) -> Optional[str]:
@@ -458,7 +428,7 @@ def calculate_clean_background(
start_time = time.time()
# Fast masking with numba
- mask = _create_event_mask_numba(t, events_initial)
+ mask = create_event_mask_numba(t, events_initial)
mask_time = time.time()
logger.debug(
@@ -1236,3 +1206,356 @@ def process_file(
)
logger.warning("Plotting is disabled in chunked processing mode.")
+
+
+def detect(
+ t: np.ndarray,
+ x: np.ndarray,
+ sampling_interval: Optional[float] = None,
+ name: str = "Analysis",
+ detection_snr: float = 3.0,
+ min_event_keep_snr: float = 6.0,
+ min_event_t: float = 0.75e-6,
+ smooth_win_t: float = 10e-3,
+ widen_frac: float = 10.0,
+ signal_polarity: int = -1,
+ filter_type: str = "gaussian",
+ filter_order: int = 2,
+ max_plot_points: int = 10000,
+ envelope_mode_limit: float = 10e-3,
+ save_plots: bool = False,
+ plot_dir: Optional[str] = None,
+) -> Dict[str, Any]:
+ """
+ Detect transient events in custom time-series data.
+
+ This is the primary entry point for analyzing any time-series data
+ (CSV, NumPy arrays, HDF5, etc.). It returns all detected events and
+ analysis results.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array in seconds.
+ x : np.ndarray
+ Signal array (any units).
+ sampling_interval : Optional[float], default=None
+ Sampling interval in seconds. If None, calculated from t array.
+ name : str, default="Analysis"
+ Name for the analysis (used in plots and logging).
+ detection_snr : float, default=3.0
+ Detection SNR threshold (sigma units above background).
+ min_event_keep_snr : float, default=6.0
+ Minimum event keep SNR threshold (sigma units).
+ min_event_t : float, default=0.75e-6
+ Minimum event duration in seconds.
+ smooth_win_t : float, default=10e-3
+ Smoothing window for background estimation in seconds.
+ widen_frac : float, default=10.0
+ Fraction to widen detected events.
+ signal_polarity : int, default=-1
+ Signal polarity: -1 for negative spikes, +1 for positive spikes.
+ filter_type : str, default="gaussian"
+ Filter type: "savgol", "gaussian", "moving_average", "median".
+ filter_order : int, default=2
+ Order of Savitzky-Golay filter (if filter_type="savgol").
+ max_plot_points : int, default=10000
+ Maximum points for plot decimation.
+ envelope_mode_limit : float, default=10e-3
+ Envelope mode time limit.
+ save_plots : bool, default=False
+ Whether to save plots to disk.
+ plot_dir : Optional[str], default=None
+ Directory to save plots (used if save_plots=True).
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary containing:
+ - 'events': np.ndarray of [start_time, end_time] for each event
+ - 'bg_initial': Initial background estimate
+ - 'bg_clean': Clean background estimate (events masked)
+ - 'global_noise': Estimated noise level
+ - 'plot': OscilloscopePlot object (None if save_plots=False)
+ - 't': Time array
+ - 'x': Signal array
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from transivent import detect
+ >>>
+ >>> # Create sample data
+ >>> t = np.linspace(0, 1, 100000)
+ >>> x = np.random.randn(100000) * 0.1
+ >>>
+ >>> # Detect events
+ >>> results = detect(t, x, name="My Data")
+ >>> print(f"Found {len(results['events'])} events")
+ """
+ start_time = time.time()
+ logger.info(f"Detecting events in: {name}")
+
+ # Validate inputs
+ t = np.asarray(t, dtype=np.float32)
+ x = np.asarray(x, dtype=np.float32)
+
+ if len(t) != len(x):
+ raise ValueError(f"Time array length ({len(t)}) != signal array length ({len(x)})")
+
+ # Calculate sampling interval if not provided
+ if sampling_interval is None:
+ if len(t) < 2:
+ raise ValueError("Need at least 2 time points to calculate sampling interval")
+ sampling_interval = float(t[1] - t[0])
+ logger.info(f"Calculated sampling interval: {sampling_interval:.3e} s")
+ else:
+ logger.info(f"Using provided sampling interval: {sampling_interval:.3e} s")
+
+ # Calculate smoothing parameters
+ smooth_n, min_event_n = calculate_smoothing_parameters(
+ sampling_interval,
+ smooth_win_t,
+ None,
+ min_event_t,
+ detection_snr,
+ min_event_keep_snr,
+ widen_frac,
+ signal_polarity,
+ )
+
+ # Run analysis pipeline
+ logger.info("Running analysis pipeline...")
+
+ bg_initial = calculate_initial_background(t, x, smooth_n, filter_type)
+ global_noise = estimate_noise(x, bg_initial)
+
+ events_initial = detect_initial_events(
+ t, x, bg_initial, global_noise, detection_snr,
+ min_event_keep_snr, widen_frac, signal_polarity, min_event_n
+ )
+
+ bg_clean = calculate_clean_background(
+ t, x, events_initial, smooth_n, bg_initial, filter_type, filter_order
+ )
+
+ events = detect_final_events(
+ t, x, bg_clean, global_noise, detection_snr,
+ min_event_keep_snr, widen_frac, signal_polarity, min_event_n
+ )
+
+ analyze_events(t, x, bg_clean, events, global_noise, signal_polarity)
+
+ logger.success(f"Detection complete in {time.time() - start_time:.3f}s")
+ logger.success(f"Found {len(events)} events")
+
+ # Create plot if requested
+ plot = None
+ if save_plots or plot_dir is not None:
+ detection_threshold, keep_threshold = analyze_thresholds(
+ x, bg_clean, global_noise, detection_snr, min_event_keep_snr, signal_polarity
+ )
+
+ plot = create_oscilloscope_plot(
+ t, x, bg_initial, bg_clean, events,
+ detection_threshold, keep_threshold,
+ name, detection_snr, min_event_keep_snr,
+ max_plot_points, envelope_mode_limit, smooth_n,
+ global_noise=global_noise
+ )
+
+ if plot_dir is not None:
+ plot.save(os.path.join(plot_dir, f"{name}_trace.png"))
+ logger.info(f"Saved plot to {plot_dir}")
+
+ return {
+ "events": events,
+ "bg_initial": bg_initial,
+ "bg_clean": bg_clean,
+ "global_noise": global_noise,
+ "plot": plot,
+ "t": t,
+ "x": x,
+ }
+
+
+def detect_from_wfm(
+ name: str,
+ sampling_interval: float,
+ data_path: str,
+ detection_snr: float = 3.0,
+ min_event_keep_snr: float = 6.0,
+ min_event_t: float = 0.75e-6,
+ smooth_win_t: float = 10e-3,
+ widen_frac: float = 10.0,
+ signal_polarity: int = -1,
+ filter_type: str = "gaussian",
+ filter_order: int = 2,
+ max_plot_points: int = 10000,
+ envelope_mode_limit: float = 10e-3,
+ sidecar: Optional[str] = None,
+ crop: Optional[List[int]] = None,
+ save_plots: bool = True,
+ plot_dir: Optional[str] = None,
+ chunk_size: Optional[int] = None,
+) -> Dict[str, Any]:
+ """
+ Detect transient events in a Wfm binary file with XML sidecar.
+
+ This is the entry point for analyzing proprietary Wfm format files
+ that include XML metadata sidecars. It handles loading the file and
+ returns all detected events and analysis results.
+
+ Parameters
+ ----------
+ name : str
+ Filename of the Wfm binary file.
+ sampling_interval : float
+ Sampling interval in seconds (can be overridden by XML sidecar).
+ data_path : str
+ Path to the data directory.
+ detection_snr : float, default=3.0
+ Detection SNR threshold (sigma units above background).
+ min_event_keep_snr : float, default=6.0
+ Minimum event keep SNR threshold (sigma units).
+ min_event_t : float, default=0.75e-6
+ Minimum event duration in seconds.
+ smooth_win_t : float, default=10e-3
+ Smoothing window for background estimation in seconds.
+ widen_frac : float, default=10.0
+ Fraction to widen detected events.
+ signal_polarity : int, default=-1
+ Signal polarity: -1 for negative spikes, +1 for positive spikes.
+ filter_type : str, default="gaussian"
+ Filter type: "savgol", "gaussian", "moving_average", "median".
+ filter_order : int, default=2
+ Order of Savitzky-Golay filter (if filter_type="savgol").
+ max_plot_points : int, default=10000
+ Maximum points for plot decimation.
+ envelope_mode_limit : float, default=10e-3
+ Envelope mode time limit.
+ sidecar : Optional[str], default=None
+ XML sidecar filename. If None, auto-detected.
+ crop : Optional[List[int]], default=None
+ Crop indices [start, end] for signal.
+ save_plots : bool, default=True
+ Whether to save plots to disk.
+ plot_dir : Optional[str], default=None
+ Directory to save plots. If None, uses data_path_analysis/.
+ chunk_size : Optional[int], default=None
+ Chunk size for processing large files. If None, loads entire file.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary containing:
+ - 'events': np.ndarray of [start_time, end_time] for each event
+ - 'bg_initial': Initial background estimate
+ - 'bg_clean': Clean background estimate (events masked)
+ - 'global_noise': Estimated noise level
+ - 'plot': OscilloscopePlot object
+ - 't': Time array
+ - 'x': Signal array
+
+ Examples
+ --------
+ >>> from transivent import detect_from_wfm
+ >>>
+ >>> results = detect_from_wfm(
+ ... name="data.Wfm.bin",
+ ... sampling_interval=5e-7,
+ ... data_path="/path/to/data/",
+ ... detection_snr=3.0
+ ... )
+ >>> print(f"Found {len(results['events'])} events")
+ """
+ start_time = time.time()
+ logger.info(f"Detecting events in Wfm file: {name}")
+
+ # Setup plot directory
+ if plot_dir is None and save_plots:
+ analysis_dir = data_path[:-1] if data_path.endswith("/") else data_path
+ plot_dir = analysis_dir + "_analysis/"
+ if not os.path.exists(plot_dir):
+ os.makedirs(plot_dir)
+
+ # Extract preview image if available
+ sidecar_path = _get_xml_sidecar_path(name, data_path, sidecar)
+ if plot_dir and save_plots:
+ preview_path = os.path.join(plot_dir, f"{name}_preview.png")
+ extract_preview_image(sidecar_path, preview_path)
+
+ # Calculate smoothing parameters
+ smooth_n, min_event_n = calculate_smoothing_parameters(
+ sampling_interval,
+ smooth_win_t,
+ None,
+ min_event_t,
+ detection_snr,
+ min_event_keep_snr,
+ widen_frac,
+ signal_polarity,
+ )
+
+ # Load data
+ logger.info("Loading Wfm file...")
+ t, x = load_data(name, sampling_interval, data_path, sidecar, crop)
+
+ # Run analysis pipeline
+ logger.info("Running analysis pipeline...")
+
+ bg_initial = calculate_initial_background(t, x, smooth_n, filter_type)
+ global_noise = estimate_noise(x, bg_initial)
+
+ events_initial = detect_initial_events(
+ t, x, bg_initial, global_noise, detection_snr,
+ min_event_keep_snr, widen_frac, signal_polarity, min_event_n
+ )
+
+ bg_clean = calculate_clean_background(
+ t, x, events_initial, smooth_n, bg_initial, filter_type, filter_order
+ )
+
+ events = detect_final_events(
+ t, x, bg_clean, global_noise, detection_snr,
+ min_event_keep_snr, widen_frac, signal_polarity, min_event_n
+ )
+
+ analyze_events(t, x, bg_clean, events, global_noise, signal_polarity)
+
+ logger.success(f"Detection complete in {time.time() - start_time:.3f}s")
+ logger.success(f"Found {len(events)} events")
+
+ # Create and save plots
+ detection_threshold, keep_threshold = analyze_thresholds(
+ x, bg_clean, global_noise, detection_snr, min_event_keep_snr, signal_polarity
+ )
+
+ plot = create_oscilloscope_plot(
+ t, x, bg_initial, bg_clean, events,
+ detection_threshold, keep_threshold,
+ name, detection_snr, min_event_keep_snr,
+ max_plot_points, envelope_mode_limit, smooth_n,
+ global_noise=global_noise
+ )
+
+ if save_plots and plot_dir:
+ plot.save(os.path.join(plot_dir, f"{name}_trace.png"))
+
+ event_plotter = EventPlotter(
+ plot, events, bg_clean=bg_clean, global_noise=global_noise
+ )
+ event_plotter.plot_events_grid(max_events=16)
+ event_plotter.save(os.path.join(plot_dir, f"{name}_events.png"))
+
+ logger.info(f"Saved plots to {plot_dir}")
+
+ return {
+ "events": events,
+ "bg_initial": bg_initial,
+ "bg_clean": bg_clean,
+ "global_noise": global_noise,
+ "plot": plot,
+ "t": t,
+ "x": x,
+ }
diff --git a/src/transivent/event_detector.py b/src/transivent/event_detector.py
index 22d76d4..3726b53 100644
--- a/src/transivent/event_detector.py
+++ b/src/transivent/event_detector.py
@@ -4,6 +4,8 @@ import numpy as np
from loguru import logger
from numba import njit
+from .utils import validate_detection_inputs
+
# --- Constants ---
MEDIAN_TO_STD_FACTOR = (
1.4826 # Factor to convert median absolute deviation to standard deviation
@@ -235,7 +237,7 @@ def detect_events(
bg = np.asarray(bg, dtype=np.float32)
# Validate input data
- _validate_detection_inputs(
+ validate_detection_inputs(
time, signal, bg, snr_threshold, min_event_len, global_noise
)
@@ -269,86 +271,7 @@ def detect_events(
return events_array, np.float32(global_noise)
-def _validate_detection_inputs(
- time: np.ndarray,
- signal: np.ndarray,
- bg: np.ndarray,
- snr_threshold: np.float32,
- min_event_len: int,
- global_noise: np.float32,
-) -> None:
- """
- Validate inputs for event detection.
-
- Parameters
- ----------
- time : np.ndarray
- Time array.
- signal : np.ndarray
- Signal array.
- bg : np.ndarray
- Background array.
- snr_threshold : np.float32
- SNR threshold.
- min_event_len : int
- Minimum event length.
- global_noise : np.float32
- Global noise level.
-
- Raises
- ------
- ValueError
- If inputs are invalid.
- """
- # Check array lengths
- if not (len(time) == len(signal) == len(bg)):
- logger.warning(
- f"Validation Warning: Array length mismatch: time={len(time)}, signal={len(signal)}, bg={len(bg)}. "
- "This may lead to unexpected behaviour."
- )
-
- # Check for empty arrays
- if len(time) == 0:
- logger.warning(
- "Validation Warning: Input arrays are empty. This may lead to unexpected behaviour."
- )
-
- # Check time monotonicity with a small tolerance for floating-point comparisons
- if len(time) > 1:
- # Use a small epsilon for floating-point comparison
- # np.finfo(time.dtype).eps is the smallest representable positive number such that 1.0 + eps != 1.0
- # Multiplying by a small factor (e.g., 10) provides a reasonable tolerance.
- tolerance = np.finfo(time.dtype).eps * 10
- if not np.all(np.diff(time) > tolerance):
- # Log the problematic differences for debugging
- problematic_diffs = np.diff(time)[np.diff(time) <= tolerance]
- logger.warning(
- f"Validation Warning: Time array is not strictly monotonic increasing within tolerance {tolerance}. "
- f"Problematic diffs (first 10): {problematic_diffs[:10]}. This may lead to unexpected behaviour."
- )
-
- # Check parameter validity
- if snr_threshold <= 0:
- logger.warning(
- f"Validation Warning: SNR threshold must be positive, got {snr_threshold}. This may lead to unexpected behaviour."
- )
-
- if min_event_len <= 0:
- logger.warning(
- f"Validation Warning: Minimum event length must be positive, got {min_event_len}. This may lead to unexpected behaviour."
- )
-
- if global_noise <= 0:
- logger.warning(
- f"Validation Warning: Global noise must be positive, got {global_noise}. This may lead to unexpected behaviour."
- )
- # Check for NaN/inf values
- for name, arr in [("time", time), ("signal", signal), ("bg", bg)]:
- if not np.all(np.isfinite(arr)):
- logger.warning(
- f"Validation Warning: {name} array contains NaN or infinite values. This may lead to unexpected behaviour."
- )
def merge_overlapping_events(events: np.ndarray) -> np.ndarray:
diff --git a/src/transivent/event_processor.py b/src/transivent/event_processor.py
new file mode 100644
index 0000000..cb6b5f7
--- /dev/null
+++ b/src/transivent/event_processor.py
@@ -0,0 +1,659 @@
+"""
+Event extraction and diffusion processing functions for transient events.
+
+This module provides functions to extract event waveforms, calculate diffusion
+characteristics through Mean Square Displacement (MSD) and autocorrelation,
+and visualize the results. The functions are designed to be used individually
+for detailed control or together through convenience wrappers.
+"""
+
+from typing import Dict, List, Optional, Tuple, Union, Any
+
+import matplotlib.pyplot as plt
+import numpy as np
+from joblib import Parallel, delayed
+from loguru import logger
+from matplotlib.patches import Ellipse
+from matplotlib.transforms import ScaledTranslation
+from scipy.stats import norm
+
+
+def extract_event_waveforms(
+ t: np.ndarray,
+ x: np.ndarray,
+ events: np.ndarray,
+ bg_clean: Optional[np.ndarray] = None,
+ subtract_background: bool = True,
+) -> List[np.ndarray]:
+ """
+ Extract signal segments for each detected event.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ events : np.ndarray
+ Array of [start_time, end_time] for each event with shape (n_events, 2).
+ bg_clean : np.ndarray, optional
+ Clean background array. If provided and subtract_background=True,
+ the background will be subtracted from each event.
+ subtract_background : bool, default=True
+ Whether to subtract the background from each event.
+
+ Returns
+ -------
+ List[np.ndarray]
+ List of signal waveforms, one per event.
+
+ Raises
+ ------
+ ValueError
+ If input arrays have incompatible shapes or events array is invalid.
+ """
+ # Validate inputs
+ t = np.asarray(t, dtype=np.float32)
+ x = np.asarray(x, dtype=np.float32)
+ events = np.asarray(events, dtype=np.float32)
+
+ if len(t) != len(x):
+ raise ValueError(f"Time array length ({len(t)}) != signal array length ({len(x)})")
+
+ if events.ndim != 2 or events.shape[1] != 2:
+ raise ValueError("Events array must have shape (n_events, 2)")
+
+ if bg_clean is not None and len(bg_clean) != len(x):
+ raise ValueError("Background array must have same length as signal array")
+
+ waveforms = []
+ logger.info(f"Extracting waveforms for {len(events)} events")
+
+ for i, (t_start, t_end) in enumerate(events):
+ # Extract indices for this event
+ mask = (t >= t_start) & (t < t_end)
+
+ if not np.any(mask):
+ logger.warning(f"Event {i+1}: No data found in time range [{t_start:.6f}, {t_end:.6f}]")
+ continue
+
+ event_signal = x[mask].copy()
+
+ # Subtract background if requested and available
+ if subtract_background and bg_clean is not None:
+ event_signal -= bg_clean[mask]
+
+ waveforms.append(event_signal)
+
+ logger.info(f"Successfully extracted {len(waveforms)} event waveforms")
+ return waveforms
+
+
+def _calculate_msd_single_lag(data: np.ndarray, lag: int) -> Tuple[float, int]:
+ """
+ Calculate MSD for a single lag time.
+
+ This is a helper function for parallel processing.
+
+ Parameters
+ ----------
+ data : np.ndarray
+ Signal data.
+ lag : int
+ Lag time in samples.
+
+ Returns
+ -------
+ Tuple[float, int]
+ MSD value and count of displacement pairs.
+ """
+ if lag >= len(data):
+ return 0.0, 0
+
+ displacements = data[lag:] - data[:-lag]
+ msd = np.mean(displacements**2)
+ count = len(displacements)
+ return msd, count
+
+
+def calculate_msd_parallel(
+ data: np.ndarray,
+ dt: float = 1e-6,
+ max_lag: int = 1000,
+ n_jobs: int = -1,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Calculate Mean Square Displacement using parallel processing.
+
+ Based on the joblib implementation from heehun_diffusion_processing.py.
+
+ Parameters
+ ----------
+ data : np.ndarray
+ Signal data.
+ dt : float, default=1e-6
+ Time step in seconds.
+ max_lag : int, default=1000
+ Maximum lag time in samples.
+ n_jobs : int, default=-1
+ Number of parallel jobs. -1 uses all available cores.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
+ taus (lag times in seconds), msds (MSD values), counts (number of pairs).
+ """
+ data = np.asarray(data, dtype=np.float64)
+ logger.debug(f"Calculating MSD for {len(data)} points with max_lag={max_lag}")
+
+ # Parallel calculation of MSD for each lag
+ results = Parallel(n_jobs=n_jobs)(
+ delayed(_calculate_msd_single_lag)(data, i) for i in range(1, max_lag + 1)
+ )
+
+ msds, counts = zip(*results)
+ taus = np.arange(1, max_lag + 1) * dt
+
+ return taus, np.array(msds), np.array(counts)
+
+
+def calculate_acf(
+ x: np.ndarray, dt: float = 1e-6, max_lag: int = 100
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Calculate autocorrelation function.
+
+ Directly from heehun_diffusion_processing.py.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Input signal.
+ dt : float, default=1e-6
+ Time step in seconds.
+ max_lag : int, default=100
+ Maximum lag in samples.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ lags (time values), acf (autocorrelation values).
+ """
+ x = np.asarray(x, dtype=np.float64)
+ n = x.size
+
+ # Remove mean for proper autocorrelation calculation
+ x = x - x.mean()
+
+ acf = np.zeros(max_lag + 1)
+ for lag in range(max_lag + 1):
+ acf[lag] = np.dot(x[:n - lag], x[lag:]) / (n - lag)
+
+ lags = np.arange(len(acf)) * dt
+ return lags, acf
+
+
+def fit_diffusion_linear(
+ taus: np.ndarray,
+ msds: np.ndarray,
+ time_limit: float = 3e-5,
+ min_points: int = 10,
+) -> float:
+ """
+ Fit linear region of MSD to extract diffusion coefficient.
+
+ For normal diffusion, MSD = 2*D*t in 1D, so D = slope/2.
+
+ Parameters
+ ----------
+ taus : np.ndarray
+ Lag times in seconds.
+ msds : np.ndarray
+ MSD values.
+ time_limit : float, default=3e-5
+ Upper time limit for linear fit (seconds).
+ min_points : int, default=10
+ Minimum number of points required for fit.
+
+ Returns
+ -------
+ float
+ Diffusion coefficient (m²/s).
+
+ Raises
+ ------
+ ValueError
+ If insufficient data points for fitting.
+ """
+ # Mask for fitting region
+ mask = (taus <= time_limit) & (taus > 0)
+
+ if np.sum(mask) < min_points:
+ logger.warning(
+ f"Insufficient points for fit: {np.sum(mask)} < {min_points}. "
+ f"Consider increasing time_limit or decreasing max_lag."
+ )
+ return np.nan
+
+ taus_fit = taus[mask]
+ msds_fit = msds[mask]
+
+ # Linear fit: MSD = 2*D*τ
+ slope, intercept = np.polyfit(taus_fit, msds_fit, 1)
+ diffusion_coeff = slope / 2
+
+ logger.debug(
+ f"Linear fit: slope={slope:.3e}, intercept={intercept:.3e}, "
+ f"D={diffusion_coeff:.3e} m²/s"
+ )
+
+ return diffusion_coeff
+
+
+def calculate_diffusion_statistics(
+ diffusion_coeffs: np.ndarray, acf_values: np.ndarray
+) -> Dict[str, Any]:
+ """
+ Calculate statistical measures for diffusion analysis.
+
+ Parameters
+ ----------
+ diffusion_coeffs : np.ndarray
+ Array of diffusion coefficients.
+ acf_values : np.ndarray
+ Array of zero-lag ACF values.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary containing:
+ - 'mean_diffusion': Mean diffusion coefficient
+ - 'std_diffusion': Standard deviation of diffusion coefficients
+ - 'mean_acf': Mean ACF value
+ - 'std_acf': Standard deviation of ACF values
+ - 'covariance_matrix': Covariance matrix of log-transformed values
+ - 'eigenvalues': Eigenvalues of covariance matrix
+ - 'eigenvectors': Eigenvectors of covariance matrix
+ """
+ # Convert to log space for statistical calculations (as in original script)
+ log_D = np.log10(diffusion_coeffs)
+ log_acf = np.log10(acf_values)
+
+ # Remove any invalid values
+ valid_mask = np.isfinite(log_D) & np.isfinite(log_acf)
+ if not np.all(valid_mask):
+ logger.warning(
+ f"Found {np.sum(~valid_mask)} invalid values in diffusion data"
+ )
+ log_D = log_D[valid_mask]
+ log_acf = log_acf[valid_mask]
+
+ # Basic statistics
+ stats = {
+ "mean_diffusion": np.mean(diffusion_coeffs),
+ "std_diffusion": np.std(diffusion_coeffs),
+ "mean_acf": np.mean(acf_values),
+ "std_acf": np.std(acf_values),
+ }
+
+ # Covariance and eigenanalysis in log space
+ if len(log_D) > 1:
+ cov_matrix = np.cov(log_D, log_acf)
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
+
+ stats.update({
+ "covariance_matrix": cov_matrix,
+ "eigenvalues": eigenvalues,
+ "eigenvectors": eigenvectors,
+ })
+ else:
+ logger.warning("Insufficient data for covariance analysis")
+ stats.update({
+ "covariance_matrix": np.array([[np.nan, np.nan], [np.nan, np.nan]]),
+ "eigenvalues": np.array([np.nan, np.nan]),
+ "eigenvectors": np.array([[np.nan, np.nan], [np.nan, np.nan]]),
+ })
+
+ return stats
+
+
+def create_confidence_ellipse(
+ center: Tuple[float, float],
+ cov_matrix: np.ndarray,
+ n_std: float = 1.2,
+ ax: Optional[plt.Axes] = None,
+ **ellipse_kwargs,
+) -> Ellipse:
+ """
+ Create confidence ellipse for statistical visualization.
+
+ Parameters
+ ----------
+ center : Tuple[float, float]
+ Center coordinates (x, y) of the ellipse.
+ cov_matrix : np.ndarray
+ 2x2 covariance matrix.
+ n_std : float, default=1.2
+ Number of standard deviations for the ellipse size.
+ ax : plt.Axes, optional
+ Matplotlib axes to add the ellipse to. If None, creates new ellipse.
+ **ellipse_kwargs
+ Additional keyword arguments passed to Ellipse.
+
+ Returns
+ -------
+ Ellipse
+ Matplotlib Ellipse patch.
+
+ Notes
+ -----
+ The ellipse is created in log space to match the original visualization.
+ """
+ # Extract eigenvalues and eigenvectors
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
+
+ # Calculate ellipse dimensions
+ width = 2 * n_std * np.sqrt(eigenvalues[1])
+ height = 2 * n_std * np.sqrt(eigenvalues[0])
+ angle = np.degrees(np.arctan2(eigenvectors[1, 1], eigenvectors[0, 1]))
+
+ # Create ellipse
+ ellipse = Ellipse(
+ xy=(0, 0), # Will be transformed to correct position
+ width=width,
+ height=height,
+ angle=angle,
+ **ellipse_kwargs
+ )
+
+ # Apply transformation to position ellipse correctly
+ if ax is not None:
+ ell_offset = ScaledTranslation(center[0], center[1], ax.transScale)
+ ell_tform = ell_offset + ax.transLimits + ax.transAxes
+ ellipse.set_transform(ell_tform)
+
+ return ellipse
+
+
+def create_diffusion_scatter(
+ diffusion_coeffs: np.ndarray,
+ acf_values: np.ndarray,
+ label: Optional[str] = None,
+ color: Optional[str] = None,
+ marker: str = "o",
+ size: int = 8,
+ ax: Optional[plt.Axes] = None,
+ show_stats: bool = True,
+ **plot_kwargs,
+) -> plt.Axes:
+ """
+ Create scatter plot of diffusion coefficient vs ACF values.
+
+ Parameters
+ ----------
+ diffusion_coeffs : np.ndarray
+ Array of diffusion coefficients.
+ acf_values : np.ndarray
+ Array of zero-lag ACF values.
+ label : str, optional
+ Label for the data series.
+ color : str, optional
+ Color for the scatter plot.
+ marker : str, default="o"
+ Marker style.
+ size : int, default=8
+ Marker size.
+ ax : plt.Axes, optional
+ Matplotlib axes to plot on. If None, creates new figure.
+ show_stats : bool, default=True
+ Whether to show confidence ellipse and statistics.
+ **plot_kwargs
+ Additional keyword arguments passed to scatter.
+
+ Returns
+ -------
+ plt.Axes
+ Matplotlib axes with the plot.
+ """
+ if ax is None:
+ fig, ax = plt.subplots(figsize=(7, 6))
+
+ # Convert to appropriate units for display (m²/s from pm²/s)
+ D_display = diffusion_coeffs * 1e24 # pm²/s to m²/s
+ acf_display = acf_values * 1e24 # pm²/s to m²/s
+
+ # Plot scatter
+ ax.scatter(
+ D_display,
+ acf_display,
+ marker=marker,
+ s=size,
+ color=color,
+ label=label,
+ **plot_kwargs
+ )
+
+ # Add confidence ellipse and statistics if requested
+ if show_stats and label is not None and len(D_display) > 1:
+ stats = calculate_diffusion_statistics(diffusion_coeffs, acf_values)
+
+ # Create confidence ellipse in log space
+ center = (np.mean(np.log10(D_display)), np.mean(np.log10(acf_display)))
+ ellipse = create_confidence_ellipse(
+ center=center,
+ cov_matrix=stats["covariance_matrix"],
+ n_std=1.2,
+ ax=ax,
+ color=color if color else "blue",
+ alpha=0.15,
+ fill=True,
+ lw=2
+ )
+ ax.add_patch(ellipse)
+
+ # Formatting
+ ax.set_xscale("log")
+ ax.set_yscale("log")
+ ax.set_xlabel("Diffusion Coefficient (m²/s)")
+ ax.set_ylabel("ACF (0-lag)")
+
+ if label:
+ ax.legend()
+
+ return ax
+
+
+def plot_diffusion_comparison(
+ results_dict: Dict[str, Dict[str, np.ndarray]],
+ show_ellipses: bool = True,
+ colors: Optional[List[str]] = None,
+ markers: Optional[List[str]] = None,
+ figsize: Tuple[float, float] = (7, 6),
+ **plot_kwargs,
+) -> plt.Figure:
+ """
+ Plot multiple datasets with confidence ellipses.
+
+ Parameters
+ ----------
+ results_dict : Dict[str, Dict[str, np.ndarray]]
+ Dictionary mapping dataset names to their results.
+ Each entry should contain 'diffusion_coeffs' and 'acf_values'.
+ show_ellipses : bool, default=True
+ Whether to show confidence ellipses.
+ colors : List[str], optional
+ Colors for each dataset. If None, uses default colors.
+ markers : List[str], optional
+ Markers for each dataset. If None, uses default markers.
+ figsize : Tuple[float, float], default=(7, 6)
+ Figure size.
+ **plot_kwargs
+ Additional keyword arguments passed to scatter plots.
+
+ Returns
+ -------
+ plt.Figure
+ Matplotlib figure with the comparison plot.
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+
+ # Default colors and markers
+ if colors is None:
+ colors = ["#1f77b4", "gold", "green", "red", "purple", "orange"]
+ if markers is None:
+ markers = ["o", "^", "s", "d", "v", "p"]
+
+ # Plot each dataset
+ for i, (name, results) in enumerate(results_dict.items()):
+ color = colors[i % len(colors)] if colors else None
+ marker = markers[i % len(markers)] if markers else "o"
+
+ create_diffusion_scatter(
+ diffusion_coeffs=results["diffusion_coeffs"],
+ acf_values=results["acf_values"],
+ label=name,
+ color=color,
+ marker=marker,
+ ax=ax,
+ show_stats=show_ellipses,
+ **plot_kwargs
+ )
+
+ # Final formatting
+ ax.set_xscale("log")
+ ax.set_yscale("log")
+ ax.set_xlabel("Diffusion Coefficient (m²/s)")
+ ax.set_ylabel("ACF (0-lag)")
+ ax.legend()
+ ax.margins(x=0, y=0)
+
+ return fig
+
+
+def process_events_for_diffusion(
+ name: str,
+ sampling_interval: float,
+ data_path: str,
+ events: Optional[np.ndarray] = None,
+ t: Optional[np.ndarray] = None,
+ x: Optional[np.ndarray] = None,
+ bg_clean: Optional[np.ndarray] = None,
+ max_lag: int = 1000,
+ fit_time_limit: float = 3e-5,
+ n_jobs: int = -1,
+ subtract_background: bool = True,
+) -> Dict[str, Any]:
+ """
+ High-level wrapper for complete diffusion analysis.
+
+ Combines event extraction, MSD calculation, ACF calculation,
+ and statistical analysis into one function.
+
+ Parameters
+ ----------
+ name : str
+ Name of the data file (used for loading if t, x not provided).
+ sampling_interval : float
+ Sampling interval in seconds.
+ data_path : str
+ Path to data directory.
+ events : np.ndarray, optional
+ Pre-computed events array. If None, will detect events.
+ t : np.ndarray, optional
+ Time array. If None, will load from file.
+ x : np.ndarray, optional
+ Signal array. If None, will load from file.
+ bg_clean : np.ndarray, optional
+ Clean background array. If None, will calculate.
+ max_lag : int, default=1000
+ Maximum lag time in samples for MSD/ACF calculation.
+ fit_time_limit : float, default=3e-5
+ Upper time limit for linear fit (seconds).
+ n_jobs : int, default=-1
+ Number of parallel jobs for MSD calculation.
+ subtract_background : bool, default=True
+ Whether to subtract background from events.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary containing:
+ - 'diffusion_coeffs': Array of diffusion coefficients
+ - 'acf_values': Array of zero-lag ACF values
+ - 'event_count': Number of events processed
+ - 'statistics': Statistical analysis results
+ - 'msd_results': Full MSD results for each event (optional)
+ """
+ logger.info(f"Processing events for diffusion analysis: {name}")
+
+ # Load data if not provided
+ if t is None or x is None:
+ from .io import rd
+ t, x = rd(name, sampling_interval, data_path=data_path)
+
+ # Detect events if not provided
+ if events is None:
+ logger.warning("No events provided. You should run event detection first.")
+ return {
+ "diffusion_coeffs": np.array([]),
+ "acf_values": np.array([]),
+ "event_count": 0,
+ "statistics": {},
+ }
+
+ # Calculate background if not provided
+ if bg_clean is None and subtract_background:
+ logger.warning("No background provided. Events will not be background-subtracted.")
+ bg_clean = None
+
+ # Extract event waveforms
+ waveforms = extract_event_waveforms(
+ t, x, events, bg_clean=bg_clean, subtract_background=subtract_background
+ )
+
+ if not waveforms:
+ logger.warning("No valid events found for diffusion analysis")
+ return {
+ "diffusion_coeffs": np.array([]),
+ "acf_values": np.array([]),
+ "event_count": 0,
+ "statistics": {},
+ }
+
+ # Calculate diffusion coefficients
+ diffusion_coeffs = []
+ acf_values = []
+
+ logger.info(f"Processing {len(waveforms)} events for diffusion analysis")
+ for i, wf in enumerate(waveforms):
+ # MSD calculation
+ taus, msds, counts = calculate_msd_parallel(
+ wf, dt=sampling_interval, max_lag=max_lag, n_jobs=n_jobs
+ )
+ D = fit_diffusion_linear(taus, msds, time_limit=fit_time_limit)
+ diffusion_coeffs.append(D)
+
+ # ACF calculation
+ lags, acf = calculate_acf(wf, dt=sampling_interval, max_lag=max_lag)
+ acf_values.append(acf[0])
+
+ # Convert to arrays
+ diffusion_coeffs = np.array(diffusion_coeffs)
+ acf_values = np.array(acf_values)
+
+ # Calculate statistics
+ statistics = calculate_diffusion_statistics(diffusion_coeffs, acf_values)
+
+ result = {
+ "diffusion_coeffs": diffusion_coeffs,
+ "acf_values": acf_values,
+ "event_count": len(waveforms),
+ "statistics": statistics,
+ }
+
+ logger.success(
+ f"Processed {len(waveforms)} events: "
+ f"mean D = {statistics['mean_diffusion']:.3e} ± {statistics['std_diffusion']:.3e} m²/s"
+ )
+
+ return result \ No newline at end of file
diff --git a/src/transivent/utils.py b/src/transivent/utils.py
new file mode 100644
index 0000000..c58ff68
--- /dev/null
+++ b/src/transivent/utils.py
@@ -0,0 +1,130 @@
+"""
+Utility functions shared across modules.
+"""
+
+import numpy as np
+from numba import njit
+from typing import Tuple
+
+__all__ = [
+ "create_event_mask_numba",
+ "validate_detection_inputs",
+]
+
+
+@njit
+def create_event_mask_numba(t: np.ndarray, events: np.ndarray) -> np.ndarray:
+ """
+ Create a boolean mask to exclude event regions using numba for speed.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ events : np.ndarray
+ Events array with shape (n_events, 2) where each row is [t_start, t_end].
+
+ Returns
+ -------
+ np.ndarray
+ Boolean mask where True means keep the sample, False means exclude.
+ """
+ mask = np.ones(len(t), dtype=np.bool_)
+ if len(events) == 0:
+ return mask
+
+ for i in range(len(events)):
+ t_start = events[i, 0]
+ t_end = events[i, 1]
+
+ start_idx = np.searchsorted(t, t_start, side="left")
+ end_idx = np.searchsorted(t, t_end, side="left")
+
+ if start_idx < end_idx:
+ mask[start_idx:end_idx] = False
+
+ return mask
+
+
+def validate_detection_inputs(
+ time: np.ndarray,
+ signal: np.ndarray,
+ bg: np.ndarray,
+ snr_threshold: np.float32,
+ min_event_len: int,
+ global_noise: np.float32,
+) -> None:
+ """
+ Validate inputs for event detection.
+
+ Parameters
+ ----------
+ time : np.ndarray
+ Time array.
+ signal : np.ndarray
+ Signal array.
+ bg : np.ndarray
+ Background array.
+ snr_threshold : np.float32
+ SNR threshold.
+ min_event_len : int
+ Minimum event length.
+ global_noise : np.float32
+ Global noise level.
+
+ Raises
+ ------
+ ValueError
+ If inputs are invalid.
+ """
+ from loguru import logger
+
+ # Check array lengths
+ if not (len(time) == len(signal) == len(bg)):
+ logger.warning(
+ f"Validation Warning: Array length mismatch: time={len(time)}, signal={len(signal)}, bg={len(bg)}. "
+ "This may lead to unexpected behaviour."
+ )
+
+ # Check for empty arrays
+ if len(time) == 0:
+ logger.warning(
+ "Validation Warning: Input arrays are empty. This may lead to unexpected behaviour."
+ )
+
+ # Check time monotonicity with a small tolerance for floating-point comparisons
+ if len(time) > 1:
+ # Use a small epsilon for floating-point comparison
+ # np.finfo(time.dtype).eps is the smallest representable positive number such that 1.0 + eps != 1.0
+ # Multiplying by a small factor (e.g., 10) provides a reasonable tolerance.
+ tolerance = np.finfo(time.dtype).eps * 10
+ if not np.all(np.diff(time) > tolerance):
+ # Log the problematic differences for debugging
+ problematic_diffs = np.diff(time)[np.diff(time) <= tolerance]
+ logger.warning(
+ f"Validation Warning: Time array is not strictly monotonic increasing within tolerance {tolerance}. "
+ f"Problematic diffs (first 10): {problematic_diffs[:10]}. This may lead to unexpected behaviour."
+ )
+
+ # Check parameter validity
+ if snr_threshold <= 0:
+ logger.warning(
+ f"Validation Warning: SNR threshold must be positive, got {snr_threshold}. This may lead to unexpected behaviour."
+ )
+
+ if min_event_len <= 0:
+ logger.warning(
+ f"Validation Warning: Minimum event length must be positive, got {min_event_len}. This may lead to unexpected behaviour."
+ )
+
+ if global_noise <= 0:
+ logger.warning(
+ f"Validation Warning: Global noise must be positive, got {global_noise}. This may lead to unexpected behaviour."
+ )
+
+ # Check for NaN/inf values
+ for name, arr in [("time", time), ("signal", signal), ("bg", bg)]:
+ if not np.all(np.isfinite(arr)):
+ logger.warning(
+ f"Validation Warning: {name} array contains NaN or infinite values. This may lead to unexpected behaviour."
+ ) \ No newline at end of file