Skip to content
Snippets Groups Projects
Commit aa3e455d authored by Weiwei Chen's avatar Weiwei Chen
Browse files

[mod]: create ShortTimeFFTFactory class to unit the forward and backward fft in the same instance.

parent c7f7bb0f
No related branches found
No related tags found
No related merge requests found
Pipeline #246141 passed
......@@ -11,6 +11,25 @@ from scipy.signal import ShortTimeFFT
import h5reader
class ShortTimeFFTFactory(object):
def __init__(self, segment_len, sample_rate, fft_mode='twosided'):
hop_len = int(segment_len/2)
dual_win = hann(segment_len, sym=True)
# print(f" NOLA: {check_NOLA(dual_win, segment_len, hop_len)}")
self.ShortTimeFFT_factory = ShortTimeFFT.from_dual(
dual_win, hop_len, sample_rate, fft_mode=fft_mode)
def forward_fft(self, samples):
ffted_samples = self.ShortTimeFFT_factory.stft(samples)
return ffted_samples
def backward_fft(self, samples):
ffted_samples = self.ShortTimeFFT_factory.istft(samples)
return ffted_samples
def generate_complex_samples(shape):
'''
Generate complex samples in given shape
......@@ -70,38 +89,7 @@ def create_interpolater(iv, matrix, k=3):
interpolater = interpolate.make_interp_spline(iv, matrix, k=k)
return interpolater
def perform_short_time_fft(samples, segment_len, sample_rate, direction = 1,
fft_mode='twosided'):
'''
perfrom windowed and segmented fft.
arguments:
samples: time series to be ffted
segment_len: length of each segment
sample_rate: sample rate
direction: 1 for forward, -1 for backward
fft_mode: 'twosided' or ...
returns:
ffted samples in dimension: [-1, channel, time]
'''
hop_len = int(segment_len/2)
dual_win = hann(segment_len, sym=True)
# print(f" NOLA: {check_NOLA(dual_win, segment_len, hop_len)}")
ShortTimeFFT_generator = ShortTimeFFT.from_dual(
dual_win, hop_len, sample_rate, fft_mode=fft_mode)
# print(f"invertible : {ShortTimeFFT_generator.invertible}")
if direction == 1:
ffted_samples = ShortTimeFFT_generator.stft(samples)
else:
ffted_samples = ShortTimeFFT_generator.istft(samples)
return ffted_samples
def generate_interpolated_acm(dataset, acm_path,
frequencies, fft_frequencies):
def generate_interpolated_acm(dataset, acm_path, frequencies, fft_frequencies):
'''
Interpolate the acm through frequencies.
......@@ -127,8 +115,7 @@ def generate_interpolated_acm(dataset, acm_path,
return interpolated_acm
def generate_channelized_samples(length, element_num, segment_len,
sample_rate, origin='same'):
def generate_channelized_samples(length, element_num, fft_function, origin='same'):
'''
Generate channelized samples
......@@ -136,9 +123,8 @@ def generate_channelized_samples(length, element_num, segment_len,
arguments:
length: length of the samples
element_num: number of elements
segment_len: length of each segment for short time FFT
sample_rate: sample rate
same_origin: weather use the same samples for all element
fft_function: function for the forward fft
origin: weather use the same samples for all element
returns:
channelized samples
'''
......@@ -147,16 +133,13 @@ def generate_channelized_samples(length, element_num, segment_len,
if origin =='same':
random_complex = generate_complex_samples([length,])
# shape: (segment_len, -1)
fourier_samples = perform_short_time_fft(
random_complex, segment_len, sample_rate, fft_mode='twosided')
fourier_samples = fft_function(random_complex)
fourier_samples_matrix = np.tile(fourier_samples, [element_num, 1, 1])
elif origin == 'different':
random_complex = generate_complex_samples([element_num, length,])
fourier_samples_matrix = perform_short_time_fft(
random_complex, segment_len, sample_rate, fft_mode='twosided')
fourier_samples_matrix = fft_function(random_complex)
elif hasattr(origin, '__iter__'):
fourier_samples_matrix = perform_short_time_fft(
origin, segment_len, sample_rate, fft_mode='twosided')
fourier_samples_matrix = fft_function(origin)
return fourier_samples_matrix
......@@ -186,21 +169,18 @@ def weight_samples_using_acm(acm, channel_len, input_fourier_samples_matrix):
return fourier_samples_matrix
def convert_channelized_data_to_time_series(fourier_samples_matrix, full_length,
segment_len, element_num, sample_rate):
fft_function):
'''
Convert channelized data to time series
arguments:
fourier_samples_matrix: fourier samples input
full_length: full length of the samples
segment_len: length of each segment for short time FFT
element_num: number of elements
sample_rate: sample rate
fft_function: function for the backward fft
returns:
time series
'''
time_series = perform_short_time_fft(fourier_samples_matrix,
segment_len, sample_rate, fft_mode='twosided', direction = -1)
time_series = fft_function(fourier_samples_matrix)
return time_series[:, :full_length]
def main():
......@@ -240,16 +220,16 @@ def main():
# signal
channel_len = len(fine_frequencies)
segment_len = fft_size
short_time_fft_factory = ShortTimeFFTFactory(segment_len, sample_rate)
interpolated_acm = generate_interpolated_acm(acm_matrix,
signal_path, frequencies, fine_frequencies)
channelized_samples_signal = generate_channelized_samples(output_length,
element_num, segment_len, sample_rate, origin='same')
element_num, short_time_fft_factory.forward_fft, origin='same')
channelized_correlated_signal = weight_samples_using_acm(
interpolated_acm, channel_len, channelized_samples_signal)
correlated_signal = convert_channelized_data_to_time_series(
channelized_correlated_signal, output_length, segment_len,
element_num, sample_rate)
channelized_correlated_signal, output_length,
short_time_fft_factory.backward_fft)
# verification
acm_0 = (channelized_correlated_signal[:, 0, :]
......@@ -265,15 +245,16 @@ def main():
# noise
channel_len = len(fine_frequencies)
segment_len = fft_size
short_time_fft_factory = ShortTimeFFTFactory(segment_len, sample_rate)
interpolated_acm = generate_interpolated_acm(acm_matrix,
t_noise_path, frequencies, fine_frequencies)
channelized_samples_noise = generate_channelized_samples(output_length,
element_num, segment_len, sample_rate, origin='different')
element_num, short_time_fft_factory.forward_fft, origin='different')
channelized_correlated_noise = weight_samples_using_acm(
interpolated_acm, channel_len, channelized_samples_noise)
correlated_noise = convert_channelized_data_to_time_series(
channelized_correlated_noise, output_length, segment_len,
element_num, sample_rate)
channelized_correlated_noise, output_length,
short_time_fft_factory.backward_fft)
# verification
acm_1 = (channelized_correlated_noise[:, 0, :]
......@@ -298,6 +279,7 @@ def main():
rfi_channel_len = fft_size
segment_len = fft_size
short_time_fft_factory = ShortTimeFFTFactory(segment_len, sample_rate)
sinusoidal_freq = frequencies_rfi[0]
complex_sinusoidal = generate_complex_sinusoidal(
[element_num, output_length], sinusoidal_freq, sample_rate)
......@@ -307,12 +289,13 @@ def main():
rfi_path, None, np.tile(frequencies_rfi, fft_size))
channelized_samples_rfi = generate_channelized_samples(output_length,
element_num, segment_len, sample_rate, origin=complex_sinusoidal)
element_num, short_time_fft_factory.forward_fft,
origin=complex_sinusoidal)
channelized_correlated_rfi = weight_samples_using_acm(
interpolated_acm_rfi, rfi_channel_len, channelized_samples_rfi)
correlated_rfi = convert_channelized_data_to_time_series(
channelized_correlated_rfi, output_length, segment_len,
element_num, sample_rate)
channelized_correlated_rfi, output_length,
short_time_fft_factory.backward_fft)
acm_rfi_reconstructed = (channelized_correlated_rfi[:, 5, :]
@ channelized_correlated_rfi[:, 5, :].conj().T / output_length)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment