diff options
| author | Sam Scholten | 2025-10-23 15:06:25 +1000 |
|---|---|---|
| committer | Sam Scholten | 2025-10-23 15:22:54 +1000 |
| commit | 307bf648d8e3fe852d7daf2fa1567d1896e50f7e (patch) | |
| tree | d15344eab2003fd0a12544cc1ed9fbfef3e871d9 /src | |
| parent | 4a7026759e099e5c81cc9c77f19182a23d2f0275 (diff) | |
| download | transivent-307bf648d8e3fe852d7daf2fa1567d1896e50f7e.tar.gz transivent-307bf648d8e3fe852d7daf2fa1567d1896e50f7e.zip | |
Release v2.0.0v2.0.0
Major API refactoring with simplified public interface.
- Added EventProcessor for high-level event processing workflow
- New utility functions for data preprocessing
- Additional example scripts for different use cases
- Comprehensive test suite
- Updated documentation with migration guide
Diffstat (limited to 'src')
| -rw-r--r-- | src/transivent/__init__.py | 45 | ||||
| -rw-r--r-- | src/transivent/analysis.py | 389 | ||||
| -rw-r--r-- | src/transivent/event_detector.py | 83 | ||||
| -rw-r--r-- | src/transivent/event_processor.py | 659 | ||||
| -rw-r--r-- | src/transivent/utils.py | 130 |
5 files changed, 1170 insertions, 136 deletions
diff --git a/src/transivent/__init__.py b/src/transivent/__init__.py index db3e824..a858d8c 100644 --- a/src/transivent/__init__.py +++ b/src/transivent/__init__.py @@ -1,36 +1,35 @@ """ High-level analysis and plotting for transient events. + +This module provides tools for detecting and analyzing transient events in +time-series data. Two main entry points are available: + +- detect(): For analyzing custom time-series data (arrays) +- detect_from_wfm(): For analyzing Wfm binary files with XML sidecars + +For advanced usage, building blocks are available in submodules: +- transivent.analysis: Background, noise, and event detection functions +- transivent.event_detector: Low-level event detection algorithms +- transivent.diffusion: Diffusion analysis tools (optional) """ -from .analysis import ( - analyze_thresholds, - calculate_initial_background, - calculate_smoothing_parameters, - configure_logging, - create_oscilloscope_plot, - get_final_events, - initialize_state, - process_chunk, - process_file, -) +from .analysis import detect, detect_from_wfm from .event_detector import detect_events, merge_overlapping_events from .event_plotter import EventPlotter -from .io import get_waveform_params, rd, rd_chunked +from .event_processor import extract_event_waveforms +from .io import get_waveform_params, rd __all__ = [ - "analyze_thresholds", - "calculate_initial_background", - "calculate_smoothing_parameters", - "configure_logging", - "create_oscilloscope_plot", + # Main entry points + "detect", + "detect_from_wfm", + # Building blocks "detect_events", + "merge_overlapping_events", + "extract_event_waveforms", + # Visualization "EventPlotter", - "get_final_events", + # I/O utilities "get_waveform_params", - "initialize_state", - "merge_overlapping_events", - "process_chunk", - "process_file", "rd", - "rd_chunked", ] diff --git a/src/transivent/analysis.py b/src/transivent/analysis.py index 1b5277b..caed827 100644 --- a/src/transivent/analysis.py +++ b/src/transivent/analysis.py @@ -25,38 +25,8 @@ from .event_plotter import EventPlotter import xml.etree.ElementTree as ET -@njit -def _create_event_mask_numba(t: np.ndarray, events: np.ndarray) -> np.ndarray: - """ - Create a boolean mask to exclude event regions using numba for speed. - - Parameters - ---------- - t : np.ndarray - Time array. - events : np.ndarray - Events array with shape (n_events, 2) where each row is [t_start, t_end]. - - Returns - ------- - np.ndarray - Boolean mask where True means keep the sample, False means exclude. - """ - mask = np.ones(len(t), dtype=np.bool_) - if len(events) == 0: - return mask - - for i in range(len(events)): - t_start = events[i, 0] - t_end = events[i, 1] - - start_idx = np.searchsorted(t, t_start, side="left") - end_idx = np.searchsorted(t, t_end, side="left") - - if start_idx < end_idx: - mask[start_idx:end_idx] = False - - return mask +# Import utility functions to avoid duplication +from .utils import create_event_mask_numba def extract_preview_image(sidecar_path: str, output_path: str) -> Optional[str]: @@ -458,7 +428,7 @@ def calculate_clean_background( start_time = time.time() # Fast masking with numba - mask = _create_event_mask_numba(t, events_initial) + mask = create_event_mask_numba(t, events_initial) mask_time = time.time() logger.debug( @@ -1236,3 +1206,356 @@ def process_file( ) logger.warning("Plotting is disabled in chunked processing mode.") + + +def detect( + t: np.ndarray, + x: np.ndarray, + sampling_interval: Optional[float] = None, + name: str = "Analysis", + detection_snr: float = 3.0, + min_event_keep_snr: float = 6.0, + min_event_t: float = 0.75e-6, + smooth_win_t: float = 10e-3, + widen_frac: float = 10.0, + signal_polarity: int = -1, + filter_type: str = "gaussian", + filter_order: int = 2, + max_plot_points: int = 10000, + envelope_mode_limit: float = 10e-3, + save_plots: bool = False, + plot_dir: Optional[str] = None, +) -> Dict[str, Any]: + """ + Detect transient events in custom time-series data. + + This is the primary entry point for analyzing any time-series data + (CSV, NumPy arrays, HDF5, etc.). It returns all detected events and + analysis results. + + Parameters + ---------- + t : np.ndarray + Time array in seconds. + x : np.ndarray + Signal array (any units). + sampling_interval : Optional[float], default=None + Sampling interval in seconds. If None, calculated from t array. + name : str, default="Analysis" + Name for the analysis (used in plots and logging). + detection_snr : float, default=3.0 + Detection SNR threshold (sigma units above background). + min_event_keep_snr : float, default=6.0 + Minimum event keep SNR threshold (sigma units). + min_event_t : float, default=0.75e-6 + Minimum event duration in seconds. + smooth_win_t : float, default=10e-3 + Smoothing window for background estimation in seconds. + widen_frac : float, default=10.0 + Fraction to widen detected events. + signal_polarity : int, default=-1 + Signal polarity: -1 for negative spikes, +1 for positive spikes. + filter_type : str, default="gaussian" + Filter type: "savgol", "gaussian", "moving_average", "median". + filter_order : int, default=2 + Order of Savitzky-Golay filter (if filter_type="savgol"). + max_plot_points : int, default=10000 + Maximum points for plot decimation. + envelope_mode_limit : float, default=10e-3 + Envelope mode time limit. + save_plots : bool, default=False + Whether to save plots to disk. + plot_dir : Optional[str], default=None + Directory to save plots (used if save_plots=True). + + Returns + ------- + Dict[str, Any] + Dictionary containing: + - 'events': np.ndarray of [start_time, end_time] for each event + - 'bg_initial': Initial background estimate + - 'bg_clean': Clean background estimate (events masked) + - 'global_noise': Estimated noise level + - 'plot': OscilloscopePlot object (None if save_plots=False) + - 't': Time array + - 'x': Signal array + + Examples + -------- + >>> import numpy as np + >>> from transivent import detect + >>> + >>> # Create sample data + >>> t = np.linspace(0, 1, 100000) + >>> x = np.random.randn(100000) * 0.1 + >>> + >>> # Detect events + >>> results = detect(t, x, name="My Data") + >>> print(f"Found {len(results['events'])} events") + """ + start_time = time.time() + logger.info(f"Detecting events in: {name}") + + # Validate inputs + t = np.asarray(t, dtype=np.float32) + x = np.asarray(x, dtype=np.float32) + + if len(t) != len(x): + raise ValueError(f"Time array length ({len(t)}) != signal array length ({len(x)})") + + # Calculate sampling interval if not provided + if sampling_interval is None: + if len(t) < 2: + raise ValueError("Need at least 2 time points to calculate sampling interval") + sampling_interval = float(t[1] - t[0]) + logger.info(f"Calculated sampling interval: {sampling_interval:.3e} s") + else: + logger.info(f"Using provided sampling interval: {sampling_interval:.3e} s") + + # Calculate smoothing parameters + smooth_n, min_event_n = calculate_smoothing_parameters( + sampling_interval, + smooth_win_t, + None, + min_event_t, + detection_snr, + min_event_keep_snr, + widen_frac, + signal_polarity, + ) + + # Run analysis pipeline + logger.info("Running analysis pipeline...") + + bg_initial = calculate_initial_background(t, x, smooth_n, filter_type) + global_noise = estimate_noise(x, bg_initial) + + events_initial = detect_initial_events( + t, x, bg_initial, global_noise, detection_snr, + min_event_keep_snr, widen_frac, signal_polarity, min_event_n + ) + + bg_clean = calculate_clean_background( + t, x, events_initial, smooth_n, bg_initial, filter_type, filter_order + ) + + events = detect_final_events( + t, x, bg_clean, global_noise, detection_snr, + min_event_keep_snr, widen_frac, signal_polarity, min_event_n + ) + + analyze_events(t, x, bg_clean, events, global_noise, signal_polarity) + + logger.success(f"Detection complete in {time.time() - start_time:.3f}s") + logger.success(f"Found {len(events)} events") + + # Create plot if requested + plot = None + if save_plots or plot_dir is not None: + detection_threshold, keep_threshold = analyze_thresholds( + x, bg_clean, global_noise, detection_snr, min_event_keep_snr, signal_polarity + ) + + plot = create_oscilloscope_plot( + t, x, bg_initial, bg_clean, events, + detection_threshold, keep_threshold, + name, detection_snr, min_event_keep_snr, + max_plot_points, envelope_mode_limit, smooth_n, + global_noise=global_noise + ) + + if plot_dir is not None: + plot.save(os.path.join(plot_dir, f"{name}_trace.png")) + logger.info(f"Saved plot to {plot_dir}") + + return { + "events": events, + "bg_initial": bg_initial, + "bg_clean": bg_clean, + "global_noise": global_noise, + "plot": plot, + "t": t, + "x": x, + } + + +def detect_from_wfm( + name: str, + sampling_interval: float, + data_path: str, + detection_snr: float = 3.0, + min_event_keep_snr: float = 6.0, + min_event_t: float = 0.75e-6, + smooth_win_t: float = 10e-3, + widen_frac: float = 10.0, + signal_polarity: int = -1, + filter_type: str = "gaussian", + filter_order: int = 2, + max_plot_points: int = 10000, + envelope_mode_limit: float = 10e-3, + sidecar: Optional[str] = None, + crop: Optional[List[int]] = None, + save_plots: bool = True, + plot_dir: Optional[str] = None, + chunk_size: Optional[int] = None, +) -> Dict[str, Any]: + """ + Detect transient events in a Wfm binary file with XML sidecar. + + This is the entry point for analyzing proprietary Wfm format files + that include XML metadata sidecars. It handles loading the file and + returns all detected events and analysis results. + + Parameters + ---------- + name : str + Filename of the Wfm binary file. + sampling_interval : float + Sampling interval in seconds (can be overridden by XML sidecar). + data_path : str + Path to the data directory. + detection_snr : float, default=3.0 + Detection SNR threshold (sigma units above background). + min_event_keep_snr : float, default=6.0 + Minimum event keep SNR threshold (sigma units). + min_event_t : float, default=0.75e-6 + Minimum event duration in seconds. + smooth_win_t : float, default=10e-3 + Smoothing window for background estimation in seconds. + widen_frac : float, default=10.0 + Fraction to widen detected events. + signal_polarity : int, default=-1 + Signal polarity: -1 for negative spikes, +1 for positive spikes. + filter_type : str, default="gaussian" + Filter type: "savgol", "gaussian", "moving_average", "median". + filter_order : int, default=2 + Order of Savitzky-Golay filter (if filter_type="savgol"). + max_plot_points : int, default=10000 + Maximum points for plot decimation. + envelope_mode_limit : float, default=10e-3 + Envelope mode time limit. + sidecar : Optional[str], default=None + XML sidecar filename. If None, auto-detected. + crop : Optional[List[int]], default=None + Crop indices [start, end] for signal. + save_plots : bool, default=True + Whether to save plots to disk. + plot_dir : Optional[str], default=None + Directory to save plots. If None, uses data_path_analysis/. + chunk_size : Optional[int], default=None + Chunk size for processing large files. If None, loads entire file. + + Returns + ------- + Dict[str, Any] + Dictionary containing: + - 'events': np.ndarray of [start_time, end_time] for each event + - 'bg_initial': Initial background estimate + - 'bg_clean': Clean background estimate (events masked) + - 'global_noise': Estimated noise level + - 'plot': OscilloscopePlot object + - 't': Time array + - 'x': Signal array + + Examples + -------- + >>> from transivent import detect_from_wfm + >>> + >>> results = detect_from_wfm( + ... name="data.Wfm.bin", + ... sampling_interval=5e-7, + ... data_path="/path/to/data/", + ... detection_snr=3.0 + ... ) + >>> print(f"Found {len(results['events'])} events") + """ + start_time = time.time() + logger.info(f"Detecting events in Wfm file: {name}") + + # Setup plot directory + if plot_dir is None and save_plots: + analysis_dir = data_path[:-1] if data_path.endswith("/") else data_path + plot_dir = analysis_dir + "_analysis/" + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + + # Extract preview image if available + sidecar_path = _get_xml_sidecar_path(name, data_path, sidecar) + if plot_dir and save_plots: + preview_path = os.path.join(plot_dir, f"{name}_preview.png") + extract_preview_image(sidecar_path, preview_path) + + # Calculate smoothing parameters + smooth_n, min_event_n = calculate_smoothing_parameters( + sampling_interval, + smooth_win_t, + None, + min_event_t, + detection_snr, + min_event_keep_snr, + widen_frac, + signal_polarity, + ) + + # Load data + logger.info("Loading Wfm file...") + t, x = load_data(name, sampling_interval, data_path, sidecar, crop) + + # Run analysis pipeline + logger.info("Running analysis pipeline...") + + bg_initial = calculate_initial_background(t, x, smooth_n, filter_type) + global_noise = estimate_noise(x, bg_initial) + + events_initial = detect_initial_events( + t, x, bg_initial, global_noise, detection_snr, + min_event_keep_snr, widen_frac, signal_polarity, min_event_n + ) + + bg_clean = calculate_clean_background( + t, x, events_initial, smooth_n, bg_initial, filter_type, filter_order + ) + + events = detect_final_events( + t, x, bg_clean, global_noise, detection_snr, + min_event_keep_snr, widen_frac, signal_polarity, min_event_n + ) + + analyze_events(t, x, bg_clean, events, global_noise, signal_polarity) + + logger.success(f"Detection complete in {time.time() - start_time:.3f}s") + logger.success(f"Found {len(events)} events") + + # Create and save plots + detection_threshold, keep_threshold = analyze_thresholds( + x, bg_clean, global_noise, detection_snr, min_event_keep_snr, signal_polarity + ) + + plot = create_oscilloscope_plot( + t, x, bg_initial, bg_clean, events, + detection_threshold, keep_threshold, + name, detection_snr, min_event_keep_snr, + max_plot_points, envelope_mode_limit, smooth_n, + global_noise=global_noise + ) + + if save_plots and plot_dir: + plot.save(os.path.join(plot_dir, f"{name}_trace.png")) + + event_plotter = EventPlotter( + plot, events, bg_clean=bg_clean, global_noise=global_noise + ) + event_plotter.plot_events_grid(max_events=16) + event_plotter.save(os.path.join(plot_dir, f"{name}_events.png")) + + logger.info(f"Saved plots to {plot_dir}") + + return { + "events": events, + "bg_initial": bg_initial, + "bg_clean": bg_clean, + "global_noise": global_noise, + "plot": plot, + "t": t, + "x": x, + } diff --git a/src/transivent/event_detector.py b/src/transivent/event_detector.py index 22d76d4..3726b53 100644 --- a/src/transivent/event_detector.py +++ b/src/transivent/event_detector.py @@ -4,6 +4,8 @@ import numpy as np from loguru import logger from numba import njit +from .utils import validate_detection_inputs + # --- Constants --- MEDIAN_TO_STD_FACTOR = ( 1.4826 # Factor to convert median absolute deviation to standard deviation @@ -235,7 +237,7 @@ def detect_events( bg = np.asarray(bg, dtype=np.float32) # Validate input data - _validate_detection_inputs( + validate_detection_inputs( time, signal, bg, snr_threshold, min_event_len, global_noise ) @@ -269,86 +271,7 @@ def detect_events( return events_array, np.float32(global_noise) -def _validate_detection_inputs( - time: np.ndarray, - signal: np.ndarray, - bg: np.ndarray, - snr_threshold: np.float32, - min_event_len: int, - global_noise: np.float32, -) -> None: - """ - Validate inputs for event detection. - - Parameters - ---------- - time : np.ndarray - Time array. - signal : np.ndarray - Signal array. - bg : np.ndarray - Background array. - snr_threshold : np.float32 - SNR threshold. - min_event_len : int - Minimum event length. - global_noise : np.float32 - Global noise level. - - Raises - ------ - ValueError - If inputs are invalid. - """ - # Check array lengths - if not (len(time) == len(signal) == len(bg)): - logger.warning( - f"Validation Warning: Array length mismatch: time={len(time)}, signal={len(signal)}, bg={len(bg)}. " - "This may lead to unexpected behaviour." - ) - - # Check for empty arrays - if len(time) == 0: - logger.warning( - "Validation Warning: Input arrays are empty. This may lead to unexpected behaviour." - ) - - # Check time monotonicity with a small tolerance for floating-point comparisons - if len(time) > 1: - # Use a small epsilon for floating-point comparison - # np.finfo(time.dtype).eps is the smallest representable positive number such that 1.0 + eps != 1.0 - # Multiplying by a small factor (e.g., 10) provides a reasonable tolerance. - tolerance = np.finfo(time.dtype).eps * 10 - if not np.all(np.diff(time) > tolerance): - # Log the problematic differences for debugging - problematic_diffs = np.diff(time)[np.diff(time) <= tolerance] - logger.warning( - f"Validation Warning: Time array is not strictly monotonic increasing within tolerance {tolerance}. " - f"Problematic diffs (first 10): {problematic_diffs[:10]}. This may lead to unexpected behaviour." - ) - - # Check parameter validity - if snr_threshold <= 0: - logger.warning( - f"Validation Warning: SNR threshold must be positive, got {snr_threshold}. This may lead to unexpected behaviour." - ) - - if min_event_len <= 0: - logger.warning( - f"Validation Warning: Minimum event length must be positive, got {min_event_len}. This may lead to unexpected behaviour." - ) - - if global_noise <= 0: - logger.warning( - f"Validation Warning: Global noise must be positive, got {global_noise}. This may lead to unexpected behaviour." - ) - # Check for NaN/inf values - for name, arr in [("time", time), ("signal", signal), ("bg", bg)]: - if not np.all(np.isfinite(arr)): - logger.warning( - f"Validation Warning: {name} array contains NaN or infinite values. This may lead to unexpected behaviour." - ) def merge_overlapping_events(events: np.ndarray) -> np.ndarray: diff --git a/src/transivent/event_processor.py b/src/transivent/event_processor.py new file mode 100644 index 0000000..cb6b5f7 --- /dev/null +++ b/src/transivent/event_processor.py @@ -0,0 +1,659 @@ +""" +Event extraction and diffusion processing functions for transient events. + +This module provides functions to extract event waveforms, calculate diffusion +characteristics through Mean Square Displacement (MSD) and autocorrelation, +and visualize the results. The functions are designed to be used individually +for detailed control or together through convenience wrappers. +""" + +from typing import Dict, List, Optional, Tuple, Union, Any + +import matplotlib.pyplot as plt +import numpy as np +from joblib import Parallel, delayed +from loguru import logger +from matplotlib.patches import Ellipse +from matplotlib.transforms import ScaledTranslation +from scipy.stats import norm + + +def extract_event_waveforms( + t: np.ndarray, + x: np.ndarray, + events: np.ndarray, + bg_clean: Optional[np.ndarray] = None, + subtract_background: bool = True, +) -> List[np.ndarray]: + """ + Extract signal segments for each detected event. + + Parameters + ---------- + t : np.ndarray + Time array. + x : np.ndarray + Signal array. + events : np.ndarray + Array of [start_time, end_time] for each event with shape (n_events, 2). + bg_clean : np.ndarray, optional + Clean background array. If provided and subtract_background=True, + the background will be subtracted from each event. + subtract_background : bool, default=True + Whether to subtract the background from each event. + + Returns + ------- + List[np.ndarray] + List of signal waveforms, one per event. + + Raises + ------ + ValueError + If input arrays have incompatible shapes or events array is invalid. + """ + # Validate inputs + t = np.asarray(t, dtype=np.float32) + x = np.asarray(x, dtype=np.float32) + events = np.asarray(events, dtype=np.float32) + + if len(t) != len(x): + raise ValueError(f"Time array length ({len(t)}) != signal array length ({len(x)})") + + if events.ndim != 2 or events.shape[1] != 2: + raise ValueError("Events array must have shape (n_events, 2)") + + if bg_clean is not None and len(bg_clean) != len(x): + raise ValueError("Background array must have same length as signal array") + + waveforms = [] + logger.info(f"Extracting waveforms for {len(events)} events") + + for i, (t_start, t_end) in enumerate(events): + # Extract indices for this event + mask = (t >= t_start) & (t < t_end) + + if not np.any(mask): + logger.warning(f"Event {i+1}: No data found in time range [{t_start:.6f}, {t_end:.6f}]") + continue + + event_signal = x[mask].copy() + + # Subtract background if requested and available + if subtract_background and bg_clean is not None: + event_signal -= bg_clean[mask] + + waveforms.append(event_signal) + + logger.info(f"Successfully extracted {len(waveforms)} event waveforms") + return waveforms + + +def _calculate_msd_single_lag(data: np.ndarray, lag: int) -> Tuple[float, int]: + """ + Calculate MSD for a single lag time. + + This is a helper function for parallel processing. + + Parameters + ---------- + data : np.ndarray + Signal data. + lag : int + Lag time in samples. + + Returns + ------- + Tuple[float, int] + MSD value and count of displacement pairs. + """ + if lag >= len(data): + return 0.0, 0 + + displacements = data[lag:] - data[:-lag] + msd = np.mean(displacements**2) + count = len(displacements) + return msd, count + + +def calculate_msd_parallel( + data: np.ndarray, + dt: float = 1e-6, + max_lag: int = 1000, + n_jobs: int = -1, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Calculate Mean Square Displacement using parallel processing. + + Based on the joblib implementation from heehun_diffusion_processing.py. + + Parameters + ---------- + data : np.ndarray + Signal data. + dt : float, default=1e-6 + Time step in seconds. + max_lag : int, default=1000 + Maximum lag time in samples. + n_jobs : int, default=-1 + Number of parallel jobs. -1 uses all available cores. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + taus (lag times in seconds), msds (MSD values), counts (number of pairs). + """ + data = np.asarray(data, dtype=np.float64) + logger.debug(f"Calculating MSD for {len(data)} points with max_lag={max_lag}") + + # Parallel calculation of MSD for each lag + results = Parallel(n_jobs=n_jobs)( + delayed(_calculate_msd_single_lag)(data, i) for i in range(1, max_lag + 1) + ) + + msds, counts = zip(*results) + taus = np.arange(1, max_lag + 1) * dt + + return taus, np.array(msds), np.array(counts) + + +def calculate_acf( + x: np.ndarray, dt: float = 1e-6, max_lag: int = 100 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate autocorrelation function. + + Directly from heehun_diffusion_processing.py. + + Parameters + ---------- + x : np.ndarray + Input signal. + dt : float, default=1e-6 + Time step in seconds. + max_lag : int, default=100 + Maximum lag in samples. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + lags (time values), acf (autocorrelation values). + """ + x = np.asarray(x, dtype=np.float64) + n = x.size + + # Remove mean for proper autocorrelation calculation + x = x - x.mean() + + acf = np.zeros(max_lag + 1) + for lag in range(max_lag + 1): + acf[lag] = np.dot(x[:n - lag], x[lag:]) / (n - lag) + + lags = np.arange(len(acf)) * dt + return lags, acf + + +def fit_diffusion_linear( + taus: np.ndarray, + msds: np.ndarray, + time_limit: float = 3e-5, + min_points: int = 10, +) -> float: + """ + Fit linear region of MSD to extract diffusion coefficient. + + For normal diffusion, MSD = 2*D*t in 1D, so D = slope/2. + + Parameters + ---------- + taus : np.ndarray + Lag times in seconds. + msds : np.ndarray + MSD values. + time_limit : float, default=3e-5 + Upper time limit for linear fit (seconds). + min_points : int, default=10 + Minimum number of points required for fit. + + Returns + ------- + float + Diffusion coefficient (m²/s). + + Raises + ------ + ValueError + If insufficient data points for fitting. + """ + # Mask for fitting region + mask = (taus <= time_limit) & (taus > 0) + + if np.sum(mask) < min_points: + logger.warning( + f"Insufficient points for fit: {np.sum(mask)} < {min_points}. " + f"Consider increasing time_limit or decreasing max_lag." + ) + return np.nan + + taus_fit = taus[mask] + msds_fit = msds[mask] + + # Linear fit: MSD = 2*D*τ + slope, intercept = np.polyfit(taus_fit, msds_fit, 1) + diffusion_coeff = slope / 2 + + logger.debug( + f"Linear fit: slope={slope:.3e}, intercept={intercept:.3e}, " + f"D={diffusion_coeff:.3e} m²/s" + ) + + return diffusion_coeff + + +def calculate_diffusion_statistics( + diffusion_coeffs: np.ndarray, acf_values: np.ndarray +) -> Dict[str, Any]: + """ + Calculate statistical measures for diffusion analysis. + + Parameters + ---------- + diffusion_coeffs : np.ndarray + Array of diffusion coefficients. + acf_values : np.ndarray + Array of zero-lag ACF values. + + Returns + ------- + Dict[str, Any] + Dictionary containing: + - 'mean_diffusion': Mean diffusion coefficient + - 'std_diffusion': Standard deviation of diffusion coefficients + - 'mean_acf': Mean ACF value + - 'std_acf': Standard deviation of ACF values + - 'covariance_matrix': Covariance matrix of log-transformed values + - 'eigenvalues': Eigenvalues of covariance matrix + - 'eigenvectors': Eigenvectors of covariance matrix + """ + # Convert to log space for statistical calculations (as in original script) + log_D = np.log10(diffusion_coeffs) + log_acf = np.log10(acf_values) + + # Remove any invalid values + valid_mask = np.isfinite(log_D) & np.isfinite(log_acf) + if not np.all(valid_mask): + logger.warning( + f"Found {np.sum(~valid_mask)} invalid values in diffusion data" + ) + log_D = log_D[valid_mask] + log_acf = log_acf[valid_mask] + + # Basic statistics + stats = { + "mean_diffusion": np.mean(diffusion_coeffs), + "std_diffusion": np.std(diffusion_coeffs), + "mean_acf": np.mean(acf_values), + "std_acf": np.std(acf_values), + } + + # Covariance and eigenanalysis in log space + if len(log_D) > 1: + cov_matrix = np.cov(log_D, log_acf) + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + + stats.update({ + "covariance_matrix": cov_matrix, + "eigenvalues": eigenvalues, + "eigenvectors": eigenvectors, + }) + else: + logger.warning("Insufficient data for covariance analysis") + stats.update({ + "covariance_matrix": np.array([[np.nan, np.nan], [np.nan, np.nan]]), + "eigenvalues": np.array([np.nan, np.nan]), + "eigenvectors": np.array([[np.nan, np.nan], [np.nan, np.nan]]), + }) + + return stats + + +def create_confidence_ellipse( + center: Tuple[float, float], + cov_matrix: np.ndarray, + n_std: float = 1.2, + ax: Optional[plt.Axes] = None, + **ellipse_kwargs, +) -> Ellipse: + """ + Create confidence ellipse for statistical visualization. + + Parameters + ---------- + center : Tuple[float, float] + Center coordinates (x, y) of the ellipse. + cov_matrix : np.ndarray + 2x2 covariance matrix. + n_std : float, default=1.2 + Number of standard deviations for the ellipse size. + ax : plt.Axes, optional + Matplotlib axes to add the ellipse to. If None, creates new ellipse. + **ellipse_kwargs + Additional keyword arguments passed to Ellipse. + + Returns + ------- + Ellipse + Matplotlib Ellipse patch. + + Notes + ----- + The ellipse is created in log space to match the original visualization. + """ + # Extract eigenvalues and eigenvectors + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + + # Calculate ellipse dimensions + width = 2 * n_std * np.sqrt(eigenvalues[1]) + height = 2 * n_std * np.sqrt(eigenvalues[0]) + angle = np.degrees(np.arctan2(eigenvectors[1, 1], eigenvectors[0, 1])) + + # Create ellipse + ellipse = Ellipse( + xy=(0, 0), # Will be transformed to correct position + width=width, + height=height, + angle=angle, + **ellipse_kwargs + ) + + # Apply transformation to position ellipse correctly + if ax is not None: + ell_offset = ScaledTranslation(center[0], center[1], ax.transScale) + ell_tform = ell_offset + ax.transLimits + ax.transAxes + ellipse.set_transform(ell_tform) + + return ellipse + + +def create_diffusion_scatter( + diffusion_coeffs: np.ndarray, + acf_values: np.ndarray, + label: Optional[str] = None, + color: Optional[str] = None, + marker: str = "o", + size: int = 8, + ax: Optional[plt.Axes] = None, + show_stats: bool = True, + **plot_kwargs, +) -> plt.Axes: + """ + Create scatter plot of diffusion coefficient vs ACF values. + + Parameters + ---------- + diffusion_coeffs : np.ndarray + Array of diffusion coefficients. + acf_values : np.ndarray + Array of zero-lag ACF values. + label : str, optional + Label for the data series. + color : str, optional + Color for the scatter plot. + marker : str, default="o" + Marker style. + size : int, default=8 + Marker size. + ax : plt.Axes, optional + Matplotlib axes to plot on. If None, creates new figure. + show_stats : bool, default=True + Whether to show confidence ellipse and statistics. + **plot_kwargs + Additional keyword arguments passed to scatter. + + Returns + ------- + plt.Axes + Matplotlib axes with the plot. + """ + if ax is None: + fig, ax = plt.subplots(figsize=(7, 6)) + + # Convert to appropriate units for display (m²/s from pm²/s) + D_display = diffusion_coeffs * 1e24 # pm²/s to m²/s + acf_display = acf_values * 1e24 # pm²/s to m²/s + + # Plot scatter + ax.scatter( + D_display, + acf_display, + marker=marker, + s=size, + color=color, + label=label, + **plot_kwargs + ) + + # Add confidence ellipse and statistics if requested + if show_stats and label is not None and len(D_display) > 1: + stats = calculate_diffusion_statistics(diffusion_coeffs, acf_values) + + # Create confidence ellipse in log space + center = (np.mean(np.log10(D_display)), np.mean(np.log10(acf_display))) + ellipse = create_confidence_ellipse( + center=center, + cov_matrix=stats["covariance_matrix"], + n_std=1.2, + ax=ax, + color=color if color else "blue", + alpha=0.15, + fill=True, + lw=2 + ) + ax.add_patch(ellipse) + + # Formatting + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Diffusion Coefficient (m²/s)") + ax.set_ylabel("ACF (0-lag)") + + if label: + ax.legend() + + return ax + + +def plot_diffusion_comparison( + results_dict: Dict[str, Dict[str, np.ndarray]], + show_ellipses: bool = True, + colors: Optional[List[str]] = None, + markers: Optional[List[str]] = None, + figsize: Tuple[float, float] = (7, 6), + **plot_kwargs, +) -> plt.Figure: + """ + Plot multiple datasets with confidence ellipses. + + Parameters + ---------- + results_dict : Dict[str, Dict[str, np.ndarray]] + Dictionary mapping dataset names to their results. + Each entry should contain 'diffusion_coeffs' and 'acf_values'. + show_ellipses : bool, default=True + Whether to show confidence ellipses. + colors : List[str], optional + Colors for each dataset. If None, uses default colors. + markers : List[str], optional + Markers for each dataset. If None, uses default markers. + figsize : Tuple[float, float], default=(7, 6) + Figure size. + **plot_kwargs + Additional keyword arguments passed to scatter plots. + + Returns + ------- + plt.Figure + Matplotlib figure with the comparison plot. + """ + fig, ax = plt.subplots(figsize=figsize) + + # Default colors and markers + if colors is None: + colors = ["#1f77b4", "gold", "green", "red", "purple", "orange"] + if markers is None: + markers = ["o", "^", "s", "d", "v", "p"] + + # Plot each dataset + for i, (name, results) in enumerate(results_dict.items()): + color = colors[i % len(colors)] if colors else None + marker = markers[i % len(markers)] if markers else "o" + + create_diffusion_scatter( + diffusion_coeffs=results["diffusion_coeffs"], + acf_values=results["acf_values"], + label=name, + color=color, + marker=marker, + ax=ax, + show_stats=show_ellipses, + **plot_kwargs + ) + + # Final formatting + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Diffusion Coefficient (m²/s)") + ax.set_ylabel("ACF (0-lag)") + ax.legend() + ax.margins(x=0, y=0) + + return fig + + +def process_events_for_diffusion( + name: str, + sampling_interval: float, + data_path: str, + events: Optional[np.ndarray] = None, + t: Optional[np.ndarray] = None, + x: Optional[np.ndarray] = None, + bg_clean: Optional[np.ndarray] = None, + max_lag: int = 1000, + fit_time_limit: float = 3e-5, + n_jobs: int = -1, + subtract_background: bool = True, +) -> Dict[str, Any]: + """ + High-level wrapper for complete diffusion analysis. + + Combines event extraction, MSD calculation, ACF calculation, + and statistical analysis into one function. + + Parameters + ---------- + name : str + Name of the data file (used for loading if t, x not provided). + sampling_interval : float + Sampling interval in seconds. + data_path : str + Path to data directory. + events : np.ndarray, optional + Pre-computed events array. If None, will detect events. + t : np.ndarray, optional + Time array. If None, will load from file. + x : np.ndarray, optional + Signal array. If None, will load from file. + bg_clean : np.ndarray, optional + Clean background array. If None, will calculate. + max_lag : int, default=1000 + Maximum lag time in samples for MSD/ACF calculation. + fit_time_limit : float, default=3e-5 + Upper time limit for linear fit (seconds). + n_jobs : int, default=-1 + Number of parallel jobs for MSD calculation. + subtract_background : bool, default=True + Whether to subtract background from events. + + Returns + ------- + Dict[str, Any] + Dictionary containing: + - 'diffusion_coeffs': Array of diffusion coefficients + - 'acf_values': Array of zero-lag ACF values + - 'event_count': Number of events processed + - 'statistics': Statistical analysis results + - 'msd_results': Full MSD results for each event (optional) + """ + logger.info(f"Processing events for diffusion analysis: {name}") + + # Load data if not provided + if t is None or x is None: + from .io import rd + t, x = rd(name, sampling_interval, data_path=data_path) + + # Detect events if not provided + if events is None: + logger.warning("No events provided. You should run event detection first.") + return { + "diffusion_coeffs": np.array([]), + "acf_values": np.array([]), + "event_count": 0, + "statistics": {}, + } + + # Calculate background if not provided + if bg_clean is None and subtract_background: + logger.warning("No background provided. Events will not be background-subtracted.") + bg_clean = None + + # Extract event waveforms + waveforms = extract_event_waveforms( + t, x, events, bg_clean=bg_clean, subtract_background=subtract_background + ) + + if not waveforms: + logger.warning("No valid events found for diffusion analysis") + return { + "diffusion_coeffs": np.array([]), + "acf_values": np.array([]), + "event_count": 0, + "statistics": {}, + } + + # Calculate diffusion coefficients + diffusion_coeffs = [] + acf_values = [] + + logger.info(f"Processing {len(waveforms)} events for diffusion analysis") + for i, wf in enumerate(waveforms): + # MSD calculation + taus, msds, counts = calculate_msd_parallel( + wf, dt=sampling_interval, max_lag=max_lag, n_jobs=n_jobs + ) + D = fit_diffusion_linear(taus, msds, time_limit=fit_time_limit) + diffusion_coeffs.append(D) + + # ACF calculation + lags, acf = calculate_acf(wf, dt=sampling_interval, max_lag=max_lag) + acf_values.append(acf[0]) + + # Convert to arrays + diffusion_coeffs = np.array(diffusion_coeffs) + acf_values = np.array(acf_values) + + # Calculate statistics + statistics = calculate_diffusion_statistics(diffusion_coeffs, acf_values) + + result = { + "diffusion_coeffs": diffusion_coeffs, + "acf_values": acf_values, + "event_count": len(waveforms), + "statistics": statistics, + } + + logger.success( + f"Processed {len(waveforms)} events: " + f"mean D = {statistics['mean_diffusion']:.3e} ± {statistics['std_diffusion']:.3e} m²/s" + ) + + return result
\ No newline at end of file diff --git a/src/transivent/utils.py b/src/transivent/utils.py new file mode 100644 index 0000000..c58ff68 --- /dev/null +++ b/src/transivent/utils.py @@ -0,0 +1,130 @@ +""" +Utility functions shared across modules. +""" + +import numpy as np +from numba import njit +from typing import Tuple + +__all__ = [ + "create_event_mask_numba", + "validate_detection_inputs", +] + + +@njit +def create_event_mask_numba(t: np.ndarray, events: np.ndarray) -> np.ndarray: + """ + Create a boolean mask to exclude event regions using numba for speed. + + Parameters + ---------- + t : np.ndarray + Time array. + events : np.ndarray + Events array with shape (n_events, 2) where each row is [t_start, t_end]. + + Returns + ------- + np.ndarray + Boolean mask where True means keep the sample, False means exclude. + """ + mask = np.ones(len(t), dtype=np.bool_) + if len(events) == 0: + return mask + + for i in range(len(events)): + t_start = events[i, 0] + t_end = events[i, 1] + + start_idx = np.searchsorted(t, t_start, side="left") + end_idx = np.searchsorted(t, t_end, side="left") + + if start_idx < end_idx: + mask[start_idx:end_idx] = False + + return mask + + +def validate_detection_inputs( + time: np.ndarray, + signal: np.ndarray, + bg: np.ndarray, + snr_threshold: np.float32, + min_event_len: int, + global_noise: np.float32, +) -> None: + """ + Validate inputs for event detection. + + Parameters + ---------- + time : np.ndarray + Time array. + signal : np.ndarray + Signal array. + bg : np.ndarray + Background array. + snr_threshold : np.float32 + SNR threshold. + min_event_len : int + Minimum event length. + global_noise : np.float32 + Global noise level. + + Raises + ------ + ValueError + If inputs are invalid. + """ + from loguru import logger + + # Check array lengths + if not (len(time) == len(signal) == len(bg)): + logger.warning( + f"Validation Warning: Array length mismatch: time={len(time)}, signal={len(signal)}, bg={len(bg)}. " + "This may lead to unexpected behaviour." + ) + + # Check for empty arrays + if len(time) == 0: + logger.warning( + "Validation Warning: Input arrays are empty. This may lead to unexpected behaviour." + ) + + # Check time monotonicity with a small tolerance for floating-point comparisons + if len(time) > 1: + # Use a small epsilon for floating-point comparison + # np.finfo(time.dtype).eps is the smallest representable positive number such that 1.0 + eps != 1.0 + # Multiplying by a small factor (e.g., 10) provides a reasonable tolerance. + tolerance = np.finfo(time.dtype).eps * 10 + if not np.all(np.diff(time) > tolerance): + # Log the problematic differences for debugging + problematic_diffs = np.diff(time)[np.diff(time) <= tolerance] + logger.warning( + f"Validation Warning: Time array is not strictly monotonic increasing within tolerance {tolerance}. " + f"Problematic diffs (first 10): {problematic_diffs[:10]}. This may lead to unexpected behaviour." + ) + + # Check parameter validity + if snr_threshold <= 0: + logger.warning( + f"Validation Warning: SNR threshold must be positive, got {snr_threshold}. This may lead to unexpected behaviour." + ) + + if min_event_len <= 0: + logger.warning( + f"Validation Warning: Minimum event length must be positive, got {min_event_len}. This may lead to unexpected behaviour." + ) + + if global_noise <= 0: + logger.warning( + f"Validation Warning: Global noise must be positive, got {global_noise}. This may lead to unexpected behaviour." + ) + + # Check for NaN/inf values + for name, arr in [("time", time), ("signal", signal), ("bg", bg)]: + if not np.all(np.isfinite(arr)): + logger.warning( + f"Validation Warning: {name} array contains NaN or infinite values. This may lead to unexpected behaviour." + )
\ No newline at end of file |
