import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse as sp
from matplotlib import cm


def finite_difference_elliptic_homogeneous_2D(f, N):
    h = 1 / N

    # Definimos las submatrices
    A = sp.diags([-1, 4, -1], [-1, 0, 1], shape=(N - 1, N - 1), dtype=np.float64)
    B = -sp.eye(N - 1)

    # Definimos la distribución de las submatrices
    F_A = sp.eye(N - 1)
    F_B = sp.diags([1, 1], [-1, 1], shape=(N - 1, N - 1), dtype=np.float64)

    # Usamos la delta de Kronecker para montar la matriz
    M = sp.kron(F_A, A) + sp.kron(F_B, B)

    # Comprimimos a formato CSR (para cálculos)
    M = M.tocsr()

    # Calculamos los términos independientes
    points = np.linspace(0, 1, N + 1)
    inner_points = points[1:-1]

    X, Y = np.meshgrid(inner_points, inner_points, indexing='ij')

    b = h ** 2 * f(X, Y).ravel()

    # Resolvemos y montamos la matriz solución
    u_vec = sp.linalg.spsolve(M, b)

    U_inner = u_vec.reshape((N - 1, N - 1))

    U = np.zeros([N + 1, N + 1])
    U[1:-1, 1:-1] = U_inner

    X, Y = np.meshgrid(points, points)

    return X, Y, U


if __name__ == "__main__":
    f = lambda x, y: 8 * np.pi ** 2 * np.sin(2 * np.pi * x) * np.sin(2 * np.pi * y)
    X, Y, U = finite_difference_elliptic_homogeneous_2D(f, 100)

    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    surf = ax.plot_surface(X, Y, U, cmap=cm.viridis,
                           linewidth=0, antialiased=True)

    ax.set_title(r'Solución de la Ecuación de Poisson $\nabla^2 u = f$')
    ax.set_xlabel('Eje X')
    ax.set_ylabel('Eje Y')
    ax.set_zlabel(r'$u(x,y)$')
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.show()