import XEFI
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt

energy = 8050.92 # eV, corresponding to a wavelength of 1.54 Å
angles = np.linspace(0.1, 0.4, 3000) # Angles of Incidence in degrees
z = [0, -800, -1340]  # Z-coordinates for the multilayer interface
layer_names = ["Air", "PS", "P3HT", "Si"]

# Calculated at 8050.92 eV using kkcalc2, but could be any complex
# refractive index or callable that returns a complex refractive index.
refractive_indices = [
    1.0        + 0j,         # Air/Vacuum
    0.99999637 + 4.96e-09j,  # Polystyrene (C8H8)
    0.99999536 + 3.31e-08j,  # Poly(3-hexylthiophene) (P3HT, C10H14S)
    0.99999243 + 1.72e-07j,  # Silicon (Si)
 ]

fig, ax = plt.subplots(figsize=(10,6))
for roughness in [1, 4, 9, 16, 25, 36, 49, 64, 81, 100]:
    result_rough: XEFI.SlicedResult = XEFI.XEF_Sliced(
        energies=energy,
        angles=angles,
        z=z,
        refractive_indices=refractive_indices,
        z_roughness=[roughness, roughness, roughness],
        slice_thickness=1.0,
        sigmas=4.0,
        layer_names=layer_names,
    )
    intensity_rough = result_rough.summed_intensity(np.linspace(0, -800, 1000))
    ax.plot(angles, intensity_rough, label=f"{roughness} Å")
result_rough._add_crit_angles(ax=ax)
ax.set_xlabel("Angle of Incidence (degrees)")
ax.set_ylabel("Intensity (a.u.)")
ax.set_yscale("log")
ax.legend(loc="upper right", title="Roughness", ncol=2)

fig.tight_layout()
fig.patch.set_facecolor("white")
fig.patch.set_alpha(0.05)
ax.patch.set_facecolor('None')
ax.patch.set_alpha(0)

# Add white rectangle behind axis
patch_alpha = 0.8
rect = patches.Rectangle(
    (0, 0), 1, -0.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect2 = patches.Rectangle(
    (-0.1, -0.1), 0.1, 1.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect3 = patches.Rectangle(
    (1, -0.1), 0.05, 1.1, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect4 = patches.Rectangle(
    (-0.1, 1), 1.15, 0.05, transform=ax.transAxes, facecolor='white', alpha=patch_alpha, zorder=-1
)
rect.set_clip_on(False)
rect2.set_clip_on(False)
rect3.set_clip_on(False)
rect4.set_clip_on(False)
ax.add_patch(rect)
ax.add_patch(rect2)
ax.add_patch(rect3)
ax.add_patch(rect4)