"""
Wavelet basis functions. (Python/NumPy implementation)
"""

## This code is written by Davide Albanese, <albanese@fbk.eu> and
## Marco Chierici <chierici@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## See: Practical Guide to Wavelet Analysis - C. Torrence and G. P. Compo.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.


from numpy import *
import math
import gslpy


__all__ = ["morletft", "paulft", "dogft"]


PI2 = 2 * pi


def normalization(s, dt):
   
    return sqrt((PI2 * s) / dt)


def morletft(s, w, w0, dt, norm = True):
    """Fourier tranformed morlet function.
    
    Input
      * *s*    - scales
      * *w*    - angular frequencies
      * *w0*   - omega0 (frequency)
      * *dt*   - time step
      * *norm* - normalization (True or False)
    Output
      * (normalized) fourier transformed morlet function
    """
    
    n = 1.0
    p = 0.75112554446494251 # pi**(-1.0/4.0)

    wavelet = empty((s.shape[0], w.shape[0]), dtype = complex128)
    wh = zeros_like(w)
    wh[w > 0] = w[w > 0]

    for i in range(s.shape[0]):
        if norm:
            n = normalization(s[i], dt)
        wavelet[i] = n * p * exp(-(s[i] * wh - w0)**2 / 2.0)
        
    return wavelet

    
def paulft(s, w, order, dt, norm = True):
    """Fourier tranformed paul function.
    
    Input
      * *s*     - scales
      * *w*     - angular frequencies
      * *order* - wavelet order
      * *dt*    - time step
      * *norm*  - normalization (True or False)
    Output
      * (normalized) fourier transformed paul function
    """
    
    n = 1.0
    p = 2.0**order / math.sqrt(order * gslpy.fact((2 * order) - 1))
   
    wavelet = empty((s.shape[0], w.shape[0]), dtype = complex128)
    wh = zeros_like(w)
    wh[w > 0] = w[w > 0]

    for i in range(s.shape[0]):
        if norm:
            n = normalization(s[i], dt)
        wavelet[i] = n * p * (s[i] * wh)**order * exp(-(s[i] * wh))
                
    return wavelet


def dogft(s, w, order, dt, norm = True):
    """Fourier tranformed DOG function.
    
    Input
      * *s*     - scales
      * *w*     - angular frequencies
      * *order* - wavelet order
      * *dt*    - time step
      * *norm*  - normalization (True or False)
    Output
      * (normalized) fourier transformed DOG function
    """


    n = 1.0
    p = -(0.0 + 1.0j)**order  / math.sqrt(gslpy.gamma(order + 0.5))
    
    wavelet = empty((s.shape[0], w.shape[0]), dtype = complex128)
    for i in range(s.shape[0]):
        if norm:
            n = normalization(s[i], dt)
        wavelet[i] = n * p * (s[i] * w)**order * exp(-((s[i] * w)**2 / 2.0))
        
    return wavelet
