import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import readligo as rl
import template

# -- Read in the file
strain, time, dq = rl.loaddata('H-H1_LOSC_4_V1-817061888-4096.hdf5')
dt = time[1] - time[0]
fs = 1.0 / dt
print(list(dq.keys()))

# -- Plot the CBC injection channel
plt.figure()
plt.plot(dq['HW_CBC'] + 2, label='HW_CBC')
plt.plot(dq['DEFAULT'], label='Good Data')
plt.xlabel('Time since ' + str(time[0]) + ' (s)')
plt.axis([0, 4096, -1, 6])
plt.legend()

# -- Get the injection segment
inj_slice = rl.dq_channel_to_seglist(dq['HW_CBC'])[0]
inj_data = strain[inj_slice]
inj_time = time[inj_slice]

# -- Get the noise segment
noise_slice = slice(inj_slice.start-8*len(inj_data), inj_slice.start)
noise_data = strain[noise_slice]

# -- How long is the segment?
seg_time = len(inj_data) / fs
print("The injection segment is {0} s long".format(seg_time))

# -- Make a frequency domain template
temp, temp_freq = template.createTemplate(4096, seg_time, 10, 10)

# -- LIGO noise is very high below 25 Hz, so we won't search there
temp[temp_freq < 25] = 0

# -- Plot the template
plt.figure()
plt.loglog(temp_freq, abs(temp))
plt.axis([10, 1000, 1e-22, 1e-19])
plt.xlabel("Frequency (Hz)")
plt.ylabel("Template value (Strain/Hz)")
plt.grid()

# -- Window and FFT the data
window = np.blackman(inj_data.size)
data_fft = np.fft.rfft(inj_data*window)

# -- Take PSD of noise segment
Pxx, psd_freq = mlab.psd(noise_data, Fs=fs, NFFT=len(inj_data) )

#-- Multiply the data and the template, weighted by the PSD
integrand = data_fft*np.ma.conjugate(temp)/Pxx

#-- Zero pad before IFFT
num_zeros = len(inj_data) - len(data_fft)
padded_int = np.append( integrand, np.zeros(num_zeros) )

# -- Inverse Fourier Transform to get the matched filter result in time domain
z = 4*np.fft.ifft(padded_int)

# -- Calculate sigma for normalization
kernal =  (np.abs(temp))**2 / Pxx
df = psd_freq[1] - psd_freq[0]
sig_sqr = 4*kernal.sum()*df
sigma = np.sqrt(sig_sqr)

# -- Calculate the expected SNR, using the known distance in Mpc
expected_SNR = sigma / 25 

# -- Construct inverse window
inv_win = (1.0 / window)

# -- reject 20 seconds on edges due to window edge effects
inv_win[:20*4096] = 0
inv_win[-20*4096:] = 0

# -- Calculate the SNR, rho, for each time offset
rho = abs(z) / sigma * inv_win

# -- Plot rho as a function of time
plt.figure()
plt.plot(inj_time[::8]-time[0], rho[::8])
plt.xlabel("Seconds since GPS {0:.0f}".format(time[0]) )
plt.ylabel("SNR")

#-- Find which time off-set gives maximum value of SNR
snr = rho.max()
found_time = inj_time[ np.where(rho == snr) ]

# -- Report the results
print("\n  --- Printing Results ---")
print("Expected to find SNR {0:.1f}".format(expected_SNR))
print("Recovered SNR {0:.1f}".format(rho.max()))
print("Recovered time GPS {0:.1f}".format( found_time[0] ))
