diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_resample.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000..6b85878 --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,98 @@ +"""Test the resample module.""" + +import numpy as np +import pytest + +from transivent.resample import average_downsample, downsample_to_interval + + +def test_average_downsample_basic(): + """Test basic downsampling functionality.""" + # Create test data + t = np.linspace(0, 1, 1000) + x = np.ones(1000) + + # Downsample by factor of 10 + t_down, x_down = average_downsample(t, x, q=10) + + assert len(t_down) == 100 + assert len(x_down) == 100 + assert np.allclose(x_down, 1.0) # Mean should be preserved + assert np.allclose(t_down[1] - t_down[0], t[10] - t[0]) # Check time step + + +def test_average_downsample_q1(): + """Test that q=1 returns unchanged data.""" + t = np.linspace(0, 1, 100) + x = np.random.randn(100) + + t_out, x_out = average_downsample(t, x, q=1) + + np.testing.assert_array_equal(t_out, t) + np.testing.assert_array_equal(x_out, x) + + +def test_average_downsample_invalid_q(): + """Test error handling for invalid downsample factor.""" + t = np.linspace(0, 1, 100) + x = np.random.randn(100) + + with pytest.raises(ValueError, match="q must be >= 1"): + average_downsample(t, x, q=0) + + with pytest.raises(ValueError, match="q must be >= 1"): + average_downsample(t, x, q=-1) + + +def test_average_downsample_short_array(): + """Test downsampling when array is shorter than factor.""" + t = np.linspace(0, 1, 5) + x = np.ones(5) + + with pytest.raises(ValueError, match="Input length .* is less than downsample factor"): + average_downsample(t, x, q=10) + + +def test_downsample_to_interval(): + """Test interval-based downsampling.""" + # 10 kHz data (exact interval) + t = np.arange(10000) * 1e-4 + x = np.random.randn(10000) + + # Downsample to 1 kHz + t_down, x_down = downsample_to_interval(t, x, target_interval=1e-3) + + # Should downsample by factor of 10 + assert len(t_down) == 1000 + assert len(x_down) == 1000 + # Original dt was 1e-4, new dt should be 1e-3 + assert np.allclose(t_down[1] - t_down[0], 1e-3) + + +def test_downsample_to_interval_upsampling(): + """Test error for upsampling (not supported).""" + t = np.linspace(0, 1, 1000) # dt = 0.001 + x = np.random.randn(1000) + + with pytest.raises(ValueError, match="Upsampling not supported"): + downsample_to_interval(t, x, target_interval=5e-4) # Smaller interval + + +def test_average_downsample_preserves_amplitude(): + """Test that amplitude is preserved through averaging.""" + t = np.linspace(0, 1, 1000) + # Create a step function + x = np.concatenate([np.ones(500), 2 * np.ones(500)]) + + # Downsample by factor of 10 + t_down, x_down = average_downsample(t, x, q=10) + + # Check the transition is at the right place + transition_idx = 50 # 500 / 10 + # First part should be ~1, second part should be ~2 + assert np.allclose(x_down[:transition_idx], 1.0) + assert np.allclose(x_down[transition_idx:], 2.0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])
\ No newline at end of file |
