TD5 - Wavelets

In this TD, we study the wavelet decomposition using the Daubechies 4 filters.

Hereafter, we compute a multiscale decomposition in the spirit of wavelets using Fourier.

We are looking for a linear operator

\[\begin{split} \begin{align} W: \,&\mathbb{C}^N \to \mathbb{C}^{N_1} \times \ldots \times \mathbb{C}^{N_2}\\ &u_0 \mapsto (u_0,\ldots,u_n)\,, \end{align} \end{split}\]

which is an isometry which means:

\[ \begin{equation} \forall u \in \mathbb{C}^{N_0}, \| u \|_2 = \| Wu\|_2 \end{equation} \]
from __future__ import division
#from scipy.misc import imsave
import numpy as np
import scipy as sp
import scipy.fftpack as ft
%pylab inline
%matplotlib inline
Populating the interactive namespace from numpy and matplotlib
# load an image
import imageio as imio
colored_image = imio.imread('grumpycat.png')
n = 220
image = np.sum(colored_image*[0.21, 0.72 ,0.07,0],axis=-1)
plt.figure(figsize = (10,10))
plt.subplot(1,2,1)
plt.gray()
cropped = image[n:n+1024,n:n + 1024]
plt.title(" Grey Level Image " + str(np.shape(cropped)))

plt.imshow(cropped)
print(np.mean(cropped),np.std(cropped),(np.min(cropped)))
print(image.shape)
image0 = np.copy(image)
118.3223597049713 46.019983171402444 28.049999999999997
(1720, 1720)
../../_images/TP5 - Wavelets - Exercices_3_1.png

What are the conditions on the filters for being admissible wavelet filters: The first condition on the low-pass filter \(h\) is

\[\begin{split} \begin{equation} \begin{cases}|\hat{h}(\omega)|^2 + |\hat{h}(\omega + \frac 12)|^2 = 2\\\hat{h}(0) = \sqrt{2}\,.\end{cases}\end{equation} \end{split}\]

The condition on the high-pass filter is the following:

\[\begin{split} \begin{equation} \begin{cases}|\hat{g}(\omega)|^2 + |\hat{g}(\omega + \frac 12)|^2 = 2\\\hat{g}(\omega) \hat{h}^*(\omega)+ \hat{g}(\omega + \frac 12)\hat{h}^*(\omega + \frac 12) = 0\,.\end{cases} \end{equation} \end{split}\]

It is possible to find this condition by writing the orthogonality condition of the scaling function and the wavelet function at scale 0 and scale 1. It is similar to the computation presented in the previous lecture.

Below are given the 4-Daubechies filters: the filter h corresponds to the low pass filter and the filter g corresponds to the high-pass filter. This filter is designed so that the first two moments of the wavelet function vanish. Thus, it encodes more efficiently parts of the image which have a behaviour close to a polynomial function of degree 2.

h = [0,.482962913145, .836516303738, .224143868042, -.129409522551]
h = h/norm(h) ## make sure the filter has norm 1.
u = power(-ones(len(h)-1),range(1,len(h))) # alternate +1/-1
g = concatenate(([0],h[-1:0:-1] * u))
print("low-pass filter associated with the scaling function",h)
print("high-pass filter associated with the wavelet function",g)
low-pass filter associated with the scaling function [ 0.          0.48296291  0.8365163   0.22414387 -0.12940952]
high-pass filter associated with the wavelet function [ 0.          0.12940952  0.22414387 -0.8365163   0.48296291]

We choose below periodic boundary condition for the filtering which is the simplest condition one can use. More involved treatment of the boundaries have been developed.

# Below is a circular convolution implemented. 
def cconv(x, h, d):
    if d == 1:
        return np.transpose(cconv(np.transpose(x), h, 0))
    y = np.zeros(x.shape)
    p = len(h)
    pc = int(round(float((p - 1) / 2 )))
    for i in range(0, p):
        y = y + h[i] * np.roll(x, pc-i, axis=0)
    return y

Q1: Test that the subsampling and upsampling operators defined below are adjoint operators.

That is, you test on random vectors that the following equality holds: \(\langle \text{subsampling}(x),y \rangle = \langle \text{upsampling}(y),x \rangle\).

def subsampling(x, d):
    p = 2
    if d == 0:
        y = x[::p, :]
    elif d == 1:
        y = x[:, ::p]
    return y


def upsampling(x, d):
    p = 2
    s = x.shape
    if d == 0:
        y = np.zeros((p * s[0], s[1]))
        y[::p, :] = x
    elif d == 1:
        y = np.zeros((s[0], p * s[1]))
        y[:, ::p] = x
    return y

def reverse(x):
    return x[::-1]

N = 16
EPSILON = 1e-12
x = np.random.rand(N, N)
y = np.random.rand(int(N / 2), N)
d = 0
#np.vdot returns the dot product of two vectors.
print(np.vdot(subsampling(x, d), y))
print(np.vdot(upsampling(y, d), x))
print(np.abs(np.vdot(subsampling(x, d), y) - np.vdot(upsampling(y, 0), x)) < EPSILON)
34.52433096707686
34.52433096707686
True

In order to extend the wavelet transform in 2D on images, simply perform the wavelet transform iteratively on the dimension of the image.

First choose for instance, the lines of the image and perform wavelet transform thus obtaining \(low scale(0)\) and \(high scale(0)\) on the lines (0 is the index for the lines and 1 for the columns).

Then, you apply again the wavelet transform in the other direction, let’s say the columns, for each part of the transformed signal.

Thus, you obtain: \(high scale(1)(low scale(0))\) and \(high scale(1)(high scale(0))\) and \(low scale(1)(high scale(0))\) and \(low scale(1)(low scale(0))\).

You thus have four parts for the 1 step wavelet transform for a given image.

It is usually concatenated in an array of this type: \(\begin{pmatrix} LL & HL \\ LH & HH\end{pmatrix}\) where \(L\) and \(H\) represent respectively low and high scale.

Q2: Implement below the one step of W transform of the lecture. Coarse and Fine scale transforms and test it on your image.

#https://pywavelets.readthedocs.io/en/latest/
#http://gwyddion.net/documentation/user-guide-en/wavelet-transform.html

def coarse(image,d): #coarse means "harsh"
    # d is the dimension along which the operation is performed.
    # the formula to be implemented is the convolution of x by h along the dimension d.
    coarse = cconv(image, h, d)
    coarse = subsampling(coarse,d)
    return coarse

def fine(image,d):
    # d is the dimension along which the operation is performed.
    # the formula to be implemented is the convolution of x by g along the dimension d.
    fine = cconv(image, g, d)
    fine = subsampling(fine, d)
    return fine

def W_0(image,wavelet_coeff=None):
    # start with the first dimension index 0 to get coarse and fine description along this axis
    #lowscale and highscale
    l0 = coarse(image, 0)
    h0 = fine(image, 0)
    # process again the previous results to get coarse and fine description along the axis of index 1.
    # you thus should have 4 sub-images of the same size which is the initial image size divided by 4 (divide columns and rows by 2).
    LL = coarse(l0, 1)
    HL = fine(l0, 1)
    LH = coarse(h0, 1)
    HH = fine(h0, 1)
    # concatenate your results in an array called results
    L = np.concatenate((LL, LH), axis=0)
    H = np.concatenate((HL, HH), axis=0)
    result = np.concatenate((L, H), axis=1)
    return result

print(W_0(image))
[[ 2.01740000e+02  2.01740000e+02  2.01740000e+02 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 2.01740000e+02  2.01740000e+02  2.01740000e+02 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 2.01740000e+02  2.01740000e+02  2.01740000e+02 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [-1.00485917e-14 -1.00485917e-14 -1.00485917e-14 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-1.00485917e-14 -1.00485917e-14 -1.00485917e-14 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-1.00485917e-14 -1.00485917e-14 -1.00485917e-14 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]]

Q3: Implement the full fast wavelet transform for a given number of scales.

def WaveletTransform(image,iterations = 1):
    ## iterate the 1step wavelet transform implemented above.
    # first step is W_0
    if iterations == 0:
        return image
    waveletTransform = W_0(image)
    size = waveletTransform.shape[0]
    waveletTmp = waveletTransform[0:size, 0:size]
    waveletTransform[0:size, 0:size] = WaveletTransform(waveletTmp,iterations -1)
    return waveletTransform

Q4: Below are functions to obtain a better plot of the results. Using the functions below, plot the wavelet transform.

def rescale(f,a=0,b=1):
    ## function that rescale f between a and b.
    rang = f.max() - f.min()
    g = (f - f.min()).copy()
    if rang > 0:
        g = g / rang
    return a + g*(b-a)

def rescaleWav(A):
    ## function dedicated to a proper rescale of wavelet representation (particular case of the previous one).
    sup = abs(A).max()
    B = A.copy()
    if sup > 0:
        B = .5 + .5 * A / sup
    return B
    
def PlotWavelet(wavelet, iterations=1):
    ## Plot the wavelet representation, need to specify the number of iterations (or scales at which the FWT is computed).
    n = wavelet.shape[1]
    Jmax = int(np.log2(n)) - 1
    Jmin = Jmax - iterations + 1
    
    U = wavelet.copy()
    for j in np.arange(Jmax, Jmin - 1, -1):
        U[:2 ** j:,    2 ** j:2 **
            (j + 1):] = rescaleWav(U[:2 ** j:, 2 ** j:2 ** (j + 1):])
        U[2 ** j:2 ** (j + 1):, :2 **
          j:] = rescaleWav(U[2 ** j:2 ** (j + 1):, :2 ** j:])
        U[2 ** j:2 ** (j + 1):, 2 ** j:2 ** (j + 1):] = (
            rescaleWav(U[2 ** j:2 ** (j + 1):, 2 ** j:2 ** (j + 1):]))
    # coarse scale
    U[:2 ** j:, :2 ** j:] = rescale(U[:2 ** j:, :2 ** j:])
    # plot underlying image
    plt.imshow(U)
    # display crosses
    for j in np.arange(Jmax, Jmin - 1, -1):
        plt.plot([0, 2 ** (j + 1)], [2 ** j, 2 ** j], 'b')
        plt.plot([2 ** j, 2 ** j], [0, 2 ** (j + 1)], 'b')
    return U

plt.figure(figsize=(10,10))
PlotWavelet(WaveletTransform(cropped, 1))
array([[0.39351842, 0.32848455, 0.32848455, ..., 0.49981076, 0.49954092,
        0.39466087],
       [0.32848455, 0.32848455, 0.32848455, ..., 0.5       , 0.5       ,
        0.5       ],
       [0.32848455, 0.32848455, 0.32848455, ..., 0.5       , 0.5       ,
        0.5       ],
       ...,
       [0.50051351, 0.5       , 0.5       , ..., 0.50367352, 0.50012416,
        0.49946431],
       [0.49859707, 0.5       , 0.5       , ..., 0.50454897, 0.49586309,
        0.50330812],
       [0.38217186, 0.5       , 0.5       , ..., 0.49936489, 0.49926823,
        0.63154009]])
../../_images/TP5 - Wavelets - Exercices_17_1.png

Q5: The wavelet transform is an orthogonal transformation so its inverse is given by its transpose (also called adjoint). Use the two operators defined below to implement the inverse of the wavelet transform.

def AdjointCoarse(image,d):
    adjoint_h = reverse(h)
    return cconv(upsampling(image,d),adjoint_h,d)
def AdjointFine(image,d):
    adjoint_g = reverse(g)
    return cconv(upsampling(image,d),adjoint_g,d)
# Hereafter, you just implement the inverse of a one-step wavelet transform.
def IW_0(wavelet):
    # you can use the temporary variables defined below.
    n = np.shape(wavelet)[0]//2
    temp_hc_vc = wavelet[0:n,0:n]
    temp_hc_vf = wavelet[0:n,n:]
    temp_hf_vc = wavelet[n:,0:n]
    temp_hf_vf = wavelet[n:,n:]
    
    # implement here the adjoint on each directions and then on the full vector.
    LL = AdjointCoarse(temp_hc_vc,1)
    LH = AdjointCoarse(temp_hf_vc, 1)
    HL = AdjointFine(temp_hc_vf, 1)
    HH = AdjointFine(temp_hf_vf, 1)
    result_type = AdjointCoarse(LL + HL, 0) + AdjointFine(LH + HH, 0)
    return result_type

# once the function above is implemented, the full wavelet inversion reads like this:
def IWT(wavelet,iterations = 3):
    image = np.copy(wavelet)
    for i in reverse(range(iterations)):
        cut = np.shape(wavelet)[0]//(2**i)
        image[0:cut,0:cut] = IW_0(image[0:cut,0:cut])
    return image

Q6: Test on an image that the implemented inverse is correct. That is, check on an image \(x\) that \(IWT(\text{WaveletTransform(x)}) = x\) up to numerical errors.

waveletTransform = WaveletTransform(cropped, 1) 
iwv = IWT(waveletTransform,1)

# Plot
plt.figure(figsize=(10,10))

# Original
plt.subplot(1, 3, 1)
plt.title("Grey Level Image " + str(np.shape(cropped)))
plt.imshow(cropped)

# Wavelet transform
plt.subplot(1, 3, 2)
plt.title("Wavelet transform")
PlotWavelet(waveletTransform)

# Inverted
plt.subplot(1, 3, 3)
plt.title("Inverse")
plt.imshow(iwv)

waveletTransform = WaveletTransform(cropped, 2) 
iwv = IWT(waveletTransform,1)

# Plot
plt.figure(figsize=(10,10))

# Original
plt.subplot(1, 3, 1)
plt.title("Grey Level Image " + str(np.shape(cropped)))
plt.imshow(cropped)

# Wavelet transform
plt.subplot(1, 3, 2)
plt.title("Wavelet transform")
PlotWavelet(waveletTransform)

# Inverted
plt.subplot(1, 3, 3)
plt.title("Inverse")
plt.imshow(iwv)

print('diff between original and recovered:', np.abs(cropped-iwv).sum())
diff between original and recovered: 134730229.40767026
../../_images/TP5 - Wavelets - Exercices_22_1.png ../../_images/TP5 - Wavelets - Exercices_22_2.png

Funny how it acts with a different number of iterations, the inverse is done and the image of the top-left corner in the wavelet transform disapear.

Use of wavelets for compression.

Q7: Plot the histogram of the wavelet coefficients just to see the repartition between small and large coefficients by modifying the code below. You should see that most of the coefficients are close to \(0\) and few are large coefficients. Thus, the idea is that lossy compression might perform well on this data.

temp = waveletTransform.flatten()
import pylab as pl
pl.hist(temp,bins = linspace(-2,200,100))
temp.sort()
print(len(temp))
print(temp[500])
1048576
-143.2512625397739
../../_images/TP5 - Wavelets - Exercices_26_1.png

Q8: Consider the 20% largest coefficients of the wavelet transform and put all the remaining coefficients to \(0\) and show the resulting approximating image.

Do the same with Fourier transform by calling the \(fft2\) function and compare visually and compare using the function implemented below which is the signal to noise ratio function.

import scipy.fftpack as sf

def snr(x, y):
    s =  np.linalg.norm(x - y)
    if s == 0:
        return "Equal inputs"
    return 20 * np.log10(np.linalg.norm(x) /s)

iwv = real(sf.fft2(waveletTransform))

# Plot
plt.figure(figsize=(10,10))

# Original
plt.subplot(1, 3, 1)
plt.title("Grey Level Image " + str(np.shape(cropped)))
plt.imshow(cropped)

# Wavelet transform
plt.subplot(1, 3, 2)
plt.title("Wavelet transform")
PlotWavelet(waveletTransform)

# FFT2
plt.subplot(1, 3, 3)
plt.title("FFT2")
plt.imshow(iwv)
<matplotlib.image.AxesImage at 0x7fbac291abd0>
../../_images/TP5 - Wavelets - Exercices_28_1.png

Last note that one can look at the wavelet function by plotting the image of a Dirac mass. Run and play with the code below.

n = 32
dirac = np.zeros((n,n))
dirac[20,20] = 1
ondelette = IW_0(dirac)
plt.imshow(ondelette)
<matplotlib.image.AxesImage at 0x7fbac28f5990>
../../_images/TP5 - Wavelets - Exercices_30_1.png