import scipy as sp
import numpy as np
import scipy.sparse as sparse
import scipy.sparse.linalg as sla
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(8,8))
ax = fig.gca(projection='3d')
n=16
I = np.arange(0,n)
J = np.arange(0,n)
I, J = np.meshgrid(I, J)
omega = 2.0/3.0
eps = 0.01
L = 1 - (2*omega/(1+eps))*(np.sin(I*np.pi/(2*n))**2 + eps * np.sin(J*np.pi/(2*n))**2)
surf = ax.plot_surface(I, J, L, rstride=1, cstride=1, cmap=plt.cm.coolwarm,
linewidth=0, antialiased=False)
ax.set_zlim(-1.01, 1.01)
fig.colorbar(surf, shrink=0.5, aspect=5)