Skip to content

Commit

Permalink
Upgrade to rand v0.6, and added more thorough integration tests (#23)
Browse files Browse the repository at this point in the history
* Upgrade to rand v0.8, and added more thorough integration tests

* Ran fmt, and downgraded to rand 0.6
  • Loading branch information
ejmahler committed Dec 24, 2020
1 parent f94ae37 commit a67f46a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ strength_reduce = "^0.2.1"
transpose = "0.2"

[dev-dependencies]
rand = "0.5"
rand = "0.6"
6 changes: 3 additions & 3 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use num_traits::Zero;

use std::sync::Arc;

use rand::distributions::{Distribution, Normal};
use rand::{SeedableRng, StdRng};
use rand::distributions::{Distribution, Uniform};
use rand::{rngs::StdRng, SeedableRng};

use crate::algorithm::{butterflies, DFT};
use crate::FFT;
Expand All @@ -18,7 +18,7 @@ const RNG_SEED: [u8; 32] = [

pub fn random_signal(length: usize) -> Vec<Complex<f32>> {
let mut sig = Vec::with_capacity(length);
let normal_dist = Normal::new(0.0, 10.0);
let normal_dist = Uniform::new(0.0, 10.0);
let mut rng: StdRng = SeedableRng::from_seed(RNG_SEED);
for _ in 0..length {
sig.push(Complex {
Expand Down
140 changes: 101 additions & 39 deletions tests/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
//! for a variety of lengths, and test that our FFT algorithm matches our
//! DFT calculation for those signals.

use std::f32;
use std::sync::Arc;

use rustfft::num_complex::Complex;
use num_traits::Float;
use rustfft::num_traits::Zero;
use rustfft::{
algorithm::{Bluesteins, Radix4},
num_complex::Complex,
FFTnum, FFTplanner, FFT,
};

use rand::distributions::{Distribution, Normal};
use rand::{SeedableRng, StdRng};
use rustfft::algorithm::DFT;
use rustfft::{FFTplanner, FFT};
use rand::distributions::{uniform::SampleUniform, Distribution, Uniform};
use rand::{rngs::StdRng, SeedableRng};

/// The seed for the random number generator used to generate
/// random signals. It's defined here so that we have deterministic
Expand All @@ -23,82 +26,141 @@ const RNG_SEED: [u8; 32] = [

/// Returns true if the mean difference in the elements of the two vectors
/// is small
fn compare_vectors(vec1: &[Complex<f32>], vec2: &[Complex<f32>]) -> bool {
fn compare_vectors<T: FFTnum + Float>(vec1: &[Complex<T>], vec2: &[Complex<T>]) -> bool {
assert_eq!(vec1.len(), vec2.len());
let mut sse = 0f32;
let mut sse = T::zero();
for (&a, &b) in vec1.iter().zip(vec2.iter()) {
sse = sse + (a - b).norm();
}
return (sse / vec1.len() as f32) < 0.1f32;
return (sse / T::from_usize(vec1.len()).unwrap()) < T::from_f32(0.1).unwrap();
}

fn fft_matches_dft(signal: Vec<Complex<f32>>, inverse: bool) -> bool {
let mut signal_dft = signal.clone();
let mut signal_fft = signal.clone();
fn fft_matches_control<T: FFTnum + Float>(control: Arc<dyn FFT<T>>, input: &[Complex<T>]) -> bool {
let mut control_input = input.to_vec();
let mut test_input = input.to_vec();

let mut spectrum_dft = vec![Zero::zero(); signal.len()];
let mut spectrum_fft = vec![Zero::zero(); signal.len()];
let mut control_output = vec![Zero::zero(); control.len()];
let mut test_output = vec![Zero::zero(); control.len()];

let mut planner = FFTplanner::new(inverse);
let fft = planner.plan_fft(signal.len());
let mut planner = FFTplanner::new(control.is_inverse());
let fft = planner.plan_fft(control.len());
assert_eq!(
fft.len(),
signal.len(),
control.len(),
"FFTplanner created FFT of wrong length"
);
assert_eq!(
fft.is_inverse(),
inverse,
control.is_inverse(),
"FFTplanner created FFT of wrong direction"
);

fft.process(&mut signal_fft, &mut spectrum_fft);
control.process(&mut control_input, &mut control_output);
fft.process(&mut test_input, &mut test_output);

let dft = DFT::new(signal.len(), inverse);
dft.process(&mut signal_dft, &mut spectrum_dft);

return compare_vectors(&spectrum_dft[..], &spectrum_fft[..]);
return compare_vectors(&test_output, &control_output);
}

fn random_signal(length: usize) -> Vec<Complex<f32>> {
fn random_signal<T: FFTnum + SampleUniform>(length: usize) -> Vec<Complex<T>> {
let mut sig = Vec::with_capacity(length);
let normal_dist = Normal::new(0.0, 10.0);
let dist: Uniform<T> = Uniform::new(T::zero(), T::from_f64(10.0).unwrap());
let mut rng: StdRng = SeedableRng::from_seed(RNG_SEED);
for _ in 0..length {
sig.push(Complex {
re: (normal_dist.sample(&mut rng) as f32),
im: (normal_dist.sample(&mut rng) as f32),
re: (dist.sample(&mut rng)),
im: (dist.sample(&mut rng)),
});
}
return sig;
}

// A cache that makes setup for integration tests faster
struct ControlCache<T: FFTnum> {
fft_cache: Vec<Arc<dyn FFT<T>>>,
}
impl<T: FFTnum> ControlCache<T> {
pub fn new(max_outer_len: usize, inverse: bool) -> Self {
let max_inner_len = (max_outer_len * 2 - 1).checked_next_power_of_two().unwrap();
let max_power = max_inner_len.trailing_zeros() as usize;

Self {
fft_cache: (0..=max_power)
.map(|i| {
let len = 1 << i;
Arc::new(Radix4::new(len, inverse)) as Arc<dyn FFT<_>>
})
.collect(),
}
}

pub fn plan_fft(&self, len: usize) -> Arc<dyn FFT<T>> {
let inner_fft_len = (len * 2 - 1).checked_next_power_of_two().unwrap();
let inner_fft_index = inner_fft_len.trailing_zeros() as usize;
let inner_fft = Arc::clone(&self.fft_cache[inner_fft_index]);
Arc::new(Bluesteins::new(len, inner_fft))
}
}

const TEST_MAX: usize = 1001;

/// Integration tests that verify our FFT output matches the direct DFT calculation
/// for random signals.
#[test]
fn test_fft() {
for len in 1..100 {
fn test_planned_fft_forward_f32() {
let is_inverse = false;
let cache: ControlCache<f32> = ControlCache::new(TEST_MAX, is_inverse);

for len in 1..TEST_MAX {
let control = cache.plan_fft(len);
assert_eq!(control.len(), len);
assert_eq!(control.is_inverse(), is_inverse);

let signal = random_signal(len);
assert!(fft_matches_dft(signal, false), "length = {}", len);
assert!(fft_matches_control(control, &signal), "length = {}", len);
}
}

#[test]
fn test_planned_fft_inverse_f32() {
let is_inverse = true;
let cache: ControlCache<f32> = ControlCache::new(TEST_MAX, is_inverse);

for len in 1..TEST_MAX {
let control = cache.plan_fft(len);
assert_eq!(control.len(), len);
assert_eq!(control.is_inverse(), is_inverse);

//test some specific lengths > 100
for &len in &[256, 768] {
let signal = random_signal(len);
assert!(fft_matches_dft(signal, false), "length = {}", len);
assert!(fft_matches_control(control, &signal), "length = {}", len);
}
}

#[test]
fn test_fft_inverse() {
for len in 1..100 {
fn test_planned_fft_forward_f64() {
let is_inverse = false;
let cache: ControlCache<f64> = ControlCache::new(TEST_MAX, is_inverse);

for len in 1..TEST_MAX {
let control = cache.plan_fft(len);
assert_eq!(control.len(), len);
assert_eq!(control.is_inverse(), is_inverse);

let signal = random_signal(len);
assert!(fft_matches_dft(signal, true), "length = {}", len);
assert!(fft_matches_control(control, &signal), "length = {}", len);
}
}

#[test]
fn test_planned_fft_inverse_f64() {
let is_inverse = true;
let cache: ControlCache<f64> = ControlCache::new(TEST_MAX, is_inverse);

for len in 1..TEST_MAX {
let control = cache.plan_fft(len);
assert_eq!(control.len(), len);
assert_eq!(control.is_inverse(), is_inverse);

//test some specific lengths > 100
for &len in &[256, 768] {
let signal = random_signal(len);
assert!(fft_matches_dft(signal, true), "length = {}", len);
assert!(fft_matches_control(control, &signal), "length = {}", len);
}
}

0 comments on commit a67f46a

Please sign in to comment.