import numpy as np
import matplotlib.pyplot as plt


def finite_difference_elliptic_homogeneous_1D(f, N):
    h = 1/N

    x = np.linspace(0, 1, N+1)

    A = 2*np.diag(np.ones(N-1)) - np.diag(np.ones(N-2), 1) - np.diag(np.ones(N-2), -1)
    b = f(x[1:-1])

    u_num = np.zeros(N+1)
    u_num[1:-1] = h**2 * np.linalg.solve(A, b)

    return x, u_num

def finite_difference_elliptic_1D(f, a, b, N):
    h = 1/N
    def fh(x):
        fx = f(x)
        fx[0] += a/h**2
        fx[-1] += b/h**2
        return fx
    x, u_num = finite_difference_elliptic_homogeneous_1D(fh, N)
    u_num[0] = a
    u_num[-1] = b
    return x, u_num

if __name__ == "__main__":
    u = lambda x: np.sin(2*np.pi*x)
    f = lambda x: 4*np.pi**2*np.sin(2*np.pi*x)

    Ns = 2**np.arange(2, 10)
    hs = 1/Ns
    errors = np.empty_like(Ns, dtype=float)
    for i, N in enumerate(Ns):
        x, u_num = finite_difference_elliptic_1D(f, 1, 1, N)
        errors[i] = np.max(np.abs(u_num-u(x)))


    fig, ax = plt.subplots()

    ax.loglog(hs, errors, "or")
    ax.plot(hs, errors, "-r")

    ax.set_title("Error de discretización en función de la malla")
    ax.set_xlabel("h")
    ax.set_ylabel(r"$||e_h||_{h, \infty}$")

    p, logC = np.polyfit(np.log(hs), np.log(errors), 1)
    C = np.exp(logC)

    ax.text(0.87, 0.10, f" Pendiente: {p:0.2f}", transform=ax.transAxes, verticalalignment='top')

    ax.text(0.87, 0.05, r"$||e_h||_{h, \infty} \leq$" + f"{C:0.2f}" + r"$h^2$",
            transform=ax.transAxes, verticalalignment='top')
    
    plt.show()
