import XEFI
import numpy as np
import matplotlib.patches as patches

energy = 8050.92 # eV, corresponding to a wavelength of 1.54 Å
angles = np.linspace(0.1, 0.4, 3000) # Angles of Incidence in degrees
z = [0, -800, -1340]  # Z-coordinates for the multilayer interface
layer_names = ["Air", "PS", "P3HT", "Si"]

# Calculated at 8050.92 eV using kkcalc2, but could be any complex
# refractive index or callable that returns a complex refractive index.
refractive_indices = [
    1.0        + 0j,         # Air/Vacuum
    0.99999637 + 4.96e-09j,  # Polystyrene (C8H8)
    0.99999536 + 3.31e-08j,  # Poly(3-hexylthiophene) (P3HT, C10H14S)
    0.99999243 + 1.72e-07j,  # Silicon (Si)
 ]

result: XEFI.BasicResult = XEFI.XEF_Basic(
    energies=energy,
    angles=angles,
    z=z,
    refractive_indices=refractive_indices,
    layer_names=layer_names,
)

intensity_full = result.summed_intensity(np.linspace(0, -800, 1000))
top_depth = 100
intensity_top = result.summed_intensity(
    np.linspace(0, -800, 1000), bounds=(0, -top_depth)
)
intensity_bot = result.summed_intensity(
    np.linspace(0, -800, 1000), bounds=(-top_depth, -800)
)

fig,ax = plt.subplots(figsize=(10,6))
ax.plot(angles, intensity_full, label="Total Intensity")
ax.plot(angles, intensity_top, label=f"Top {top_depth} Å Intensity")
ax.plot(angles, intensity_bot, label=f"Bottom {800 - top_depth} Å Intensity")
ax.set_xlabel("Angle of Incidence (degrees)")
ax.set_ylabel("Intensity (a.u.)")
ax.set_yscale("log")

result._add_crit_angles(ax=ax)

# Add white rectangle behind axis
patch_alpha = 0.8
rect = patches.Rectangle(
    (0, 0), 1, -0.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect2 = patches.Rectangle(
    (-0.1, -0.1), 0.1, 1.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect3 = patches.Rectangle(
    (1, -0.1), 0.05, 1.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect4 = patches.Rectangle(
    (-0.1, 1), 1.15, 0.05, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect.set_clip_on(False)
rect2.set_clip_on(False)
rect3.set_clip_on(False)
rect4.set_clip_on(False)
ax.add_patch(rect)
ax.add_patch(rect2)
ax.add_patch(rect3)
ax.add_patch(rect4)

leg = ax.legend()
fig.tight_layout()
fig.patch.set_facecolor("white")
fig.patch.set_alpha(0.05)
ax.patch.set_facecolor('None')
ax.patch.set_alpha(0)