# -*- coding: utf-8 -*-
#
# This file is part of the famewoks project
"""Collection of custom plots."""
from itertools import cycle
import numpy as np
import palettable
from pandas import DataFrame
# from IPython.display import display
# from IPython import embed
import plotly.graph_objects as pgo
from plotly.subplots import make_subplots
# from larch.plot.plotly_xafsplots import plotlabels # PlotlyFigure
from famewoks.models import ExpDataset, ExpSession, DATA_NAMES
from famewoks.bliss2larch import get_scans, get_groups
from famewoks.utils import get_y
from famewoks import _logger
# colors
# CFLUO = palettable.colorbrewer.sequential.Blues_8_r.hex_colors[:4]
# CFLUO.extend(palettable.colorbrewer.sequential.BuGn_8_r.hex_colors[:4])
# CFLUO.extend(palettable.colorbrewer.sequential.OrRd_8_r.hex_colors[:4])
# CFLUO.extend(palettable.colorbrewer.sequential.YlOrBr_8_r.hex_colors[:4])
CFLUO = palettable.colorbrewer.diverging.RdBu_11.hex_colors[:4]
CFLUO.extend(palettable.colorbrewer.diverging.RdBu_11_r.hex_colors[:4])
CFLUO.extend(palettable.colorbrewer.diverging.PRGn_11.hex_colors[:4])
CFLUO.extend(palettable.colorbrewer.diverging.PRGn_11_r.hex_colors[:4])
# CFLUO = palettable.tableau.GreenOrange_12.hex_colors
# CFLUO.extend(palettable.tableau.PurpleGray_12.hex_colors)
CTABGO12 = palettable.tableau.GreenOrange_12.hex_colors
CTABPG12 = palettable.tableau.PurpleGray_12.hex_colors
CTAB = CTABGO12 + CTABPG12
[docs]
def plot_data(
dset: ExpDataset,
data: str | None = "fluos",
ynorm: str | None = None,
show_i0: bool = False,
show_e0: bool = False,
show_deriv: bool = False,
show_slide: bool = True,
show_merge: bool | str = False,
# skip_scans: Union[str, List[int], None] = None,
):
"""Plot all scans in a dataset with or without a slider.
Parameters
----------
dset : ExpDataset
The dataset containing the scans to be plotted.
see the documentation of :class:`famewoks.datamodel.ExpDataset`.
data : str, optional
The name of the data to visualize (default is "fluos").
ynorm : str or None, optional
Normalization method for the y-axis data. Options are: None
View raw data.
flat
Apply flat normalization.
area
Normalize by area under the curve.
show_i0 : bool, optional
If True, plot the i0 data (default is False).
show_e0 : bool, optional
If True, plot the e0 data (default is False).
show_deriv : bool, optional
If True, plot the derivative of the data (default is False).
show_slide : bool, optional
If True, plot the data scan by scan with a slider (default is True).
show_merge : bool, optional
If True, plot the merged data (default is False).
maxplots : int, optional
Maximum number of plots to display (default is 1000).
skip_scans :
list of the scans to skip
Returns
-------
fig : plotly.graph_objects.Figure
The Plotly figure object containing the plot.
Notes
-----
The function performs the following steps: - Checks if the data is in
the available plot data names. - Uses a helper function `get_y` to apply
the necessary transformations to the y data. - Iterates over the scans
in the dataset, creating Plotly traces for each scan and applying the
necessary visibility and annotations. - If `show_merge` is True, plots
the merged data. - If `show_i0` is True, plots the i0 data on a
secondary y-axis. - If `show_e0` is True, adds vertical lines for e0
values. - Adds a slider for navigating through scans if `show_slide` is
True.
Here's a figure showing more details :
.. image:: images/plot_data_explained.png
:alt: Explanation of plot data
"""
if data is None:
show_i0 = False
yattr = "i0"
data = "trans"
show_merge = False
show_e0 = False
datalabel = False
else:
yattr = "mu"
datalabel = True
yaxis_label = f"{data}"
assert data in DATA_NAMES, f"available plot data: {DATA_NAMES}"
assert not (data == "fluos" and show_merge == "rebin"), (
"cannot plot `fluos` rebinned"
)
assert ynorm in [None, "area", "flat", True], "invalid normalizatiom method"
assert not (show_i0 and show_deriv), "cannot plot both I0 and derivative"
def get_the_window_max(dset: ExpDataset, data: str, show_merge: bool | str, y):
p = y.max()
if show_merge:
if show_merge == "rebin":
gmrg = getattr(getattr(dset, data), "rebinned")
else:
gmrg = scn.fluos[0] if data == "fluos" else getattr(dset, data)
ymrg, _ = get_y(gmrg, yattr, ynorm, "")
p = max(ymrg.max(), y.max())
return p
colors = cycle(CFLUO)
trace_visible = not show_slide
fig = (
make_subplots(specs=[[{"secondary_y": True}]])
if show_i0 or show_deriv # or show_merge
else pgo.Figure()
)
ntraces = 0
iframes = []
mrg_plotted = False
e0_plotted = True
scans = get_scans(dset)
xmin, xmax = dset.scans_emin.min(), dset.scans_emax.max()
for id, (iscn, scn, _) in enumerate(scans):
groups = get_groups(scn, data) #: list of Groups
if show_slide or data == "fluos":
e0_plotted = True
ndata = len(groups)
itraces = []
for ig, g in enumerate(groups):
x = g.energy
y = getattr(g, yattr)
label = g.label if datalabel else "I0"
glabel = label if show_slide else f"scan{scn.scanint}_{label}"
color = CFLUO[ig] if show_slide else next(colors)
yaxis_label = f"{yattr}"
y, yaxis_label = get_y(g, yattr, ynorm, yaxis_label)
if show_e0 and e0_plotted:
if data == "fluos":
e0_plotted = False
glabel = f"{glabel} (e0: {g.e0:.2f})"
else:
glabel = (
f"{glabel} (e0: {g.e0:.2f})"
if dset.scans_eshifts[iscn] == 0
else f"{glabel} (eshift: {dset.scans_eshifts[iscn]:.2f})"
)
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=[g.e0, g.e0],
y=[
y.min(),
get_the_window_max(
dset=dset, data=data, show_merge=show_merge, y=y
),
],
name="",
marker=None,
showlegend=False,
line={"color": color, "width": 1, "dash": "dash"},
)
)
itraces.append(ntraces)
ntraces += 1
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x,
y=y,
name=glabel,
marker=None,
line={"color": color, "width": 1},
),
)
itraces.append(ntraces)
ntraces += 1
if show_merge and ig == ndata - 1 and not mrg_plotted:
if show_merge == "rebin":
gmrg = getattr(getattr(dset, data), "rebinned")
else:
gmrg = scn.fluo[0] if data == "fluos" else getattr(dset, data)
if gmrg is None:
_logger.warning(f"no merged data for scan {scn.scanint}, skipped")
continue
x = gmrg.energy
y, _ = get_y(gmrg, ynorm=ynorm, ylabel=yaxis_label)
glabel = (
gmrg.label
if show_slide or data != "fluos"
else f"scan{scn.scanint}_{gmrg.label}"
)
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x,
y=y,
name=glabel,
marker=None,
line={"color": "blue", "width": 4},
),
)
itraces.append(ntraces)
ntraces += 1
if data != "fluos":
mrg_plotted = not show_slide
if show_i0 and ig == ndata - 1:
x = g.energy
y = g.i0
color = "gray" if show_slide else next(colors)
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x,
y=y,
name=f"scan{scn.scanint}_I0",
marker=None,
line={"color": color, "width": 1},
),
secondary_y=True,
)
itraces.append(ntraces)
ntraces += 1
if show_deriv and ig == ndata - 1:
x = g.energy
y = g.dmude
color = "gray" if show_slide else next(colors)
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x,
y=y,
name=f"scan{scn.scanint}_dmudE",
marker=None,
line={"color": color, "width": 1},
),
secondary_y=True,
)
itraces.append(ntraces)
ntraces += 1
_logger.debug(f"scan {iscn} -> itraces: {itraces}")
iframes.append([scn.scanint, itraces])
_logger.debug(f"loaded {ntraces} traces == len(fig.data) {len(fig.data)}")
iscn0, itraces0 = iframes[0]
for itr in itraces0:
trace = fig.data[itr]
trace.visible = True
_logger.debug(f"scan {iscn0} -> set visible itrace {itr}")
# Create and add slider stepping at each scan
steps = []
for iscn, itraces in iframes:
step = {
"method": "update",
"args": [
{"visible": [False] * ntraces},
{"title": f"dataset: {dset.name} | scan {iscn}"},
],
"label": f"{iscn}",
}
# Make traces for iscan visible
for itr in itraces:
step["args"][0]["visible"][itr] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [
{
"active": 0,
"currentvalue": {"prefix": "Scan: "},
"pad": {"t": 20},
"steps": steps,
}
]
fig.update_layout(
height=600,
width=1000,
showlegend=True,
sliders=sliders if show_slide else None,
title_text=f"dataset: {dset.name}",
xaxis={"range": [xmin, xmax]},
xaxis_title="energy, eV",
yaxis_title=yaxis_label,
)
if show_i0:
fig.update_yaxes(showgrid=False, secondary_y=True, title_text="i0")
elif show_deriv:
fig.update_yaxes(showgrid=False, secondary_y=True, title_text="dmu/dE")
else:
fig.update_yaxes(showgrid=True)
fig.show()
return fig
[docs]
def plot_eshift(
session: ExpSession,
dset: ExpDataset,
array: str = "dmude",
show_e0: bool = True,
):
scans = get_scans(dset)
fig = pgo.Figure()
ntraces = 0
iframes = []
erefgrp = session.enealign
xmin = erefgrp.e0 + erefgrp.pars_energy_align["emin"] - 15
xmax = erefgrp.e0 + erefgrp.pars_energy_align["emax"] + 15
trace_visible = False
for iscn, scn, _ in scans:
itraces = []
grps = get_groups(scn, data="ref")
for grp in grps:
glabel = grp.label
x = grp.energy
y = getattr(grp, array)
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x,
y=y,
name=glabel,
marker=None,
line={"color": "blue", "width": 1},
),
)
itraces.append(ntraces)
ntraces += 1
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=x + dset.scans_eshifts[iscn],
y=y,
name=f"{glabel} shifted: {dset.scans_eshifts[iscn]:.3f} eV",
marker=None,
line={"color": "green", "width": 1},
),
)
itraces.append(ntraces)
ntraces += 1
fig.add_trace(
pgo.Scatter(
visible=trace_visible,
x=erefgrp.energy,
y=getattr(erefgrp, array),
name=f"{erefgrp.filename} (ref)",
marker=None,
line={"color": "red", "width": 1},
),
)
itraces.append(ntraces)
ntraces += 1
if show_e0:
fig.add_trace(
pgo.Scatter(
x=[erefgrp.e0, erefgrp.e0],
y=[0, erefgrp.dmude.max()],
line_width=3,
line_dash="dash",
line_color="red",
visible=False,
showlegend=False,
),
)
itraces.append(ntraces)
ntraces += 1
fig.add_trace(
pgo.Scatter(
x=[
grp.e0 + dset.scans_eshifts[iscn],
grp.e0 + dset.scans_eshifts[iscn],
],
y=[0, grp.dmude.max()],
line_width=3,
line_dash="dash",
line_color="green",
visible=False,
showlegend=False,
),
)
itraces.append(ntraces)
ntraces += 1
fig.add_trace(
pgo.Scatter(
x=[grp.e0, grp.e0],
y=[0, grp.dmude.max()],
line_width=3,
line_dash="dash",
line_color="blue",
visible=False,
showlegend=False,
),
)
itraces.append(ntraces)
ntraces += 1
iframes.append([scn.scanint, itraces])
iscn0, itraces0 = iframes[0]
for itr in itraces0:
trace = fig.data[itr]
trace.visible = True
_logger.debug(f"scan {iscn0} -> set visible itrace {itr}")
steps = []
for iscn, itraces in iframes:
step = {
"method": "update",
"args": [
{"visible": [False] * ntraces},
{"title": f"dataset: {dset.name} | scan {iscn}"},
],
"label": f"{iscn}",
}
# Make traces for iscan visible
for itr in itraces:
step["args"][0]["visible"][itr] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [
{
"active": 0,
"currentvalue": {"prefix": "Scan: "},
"pad": {"t": 20},
"steps": steps,
}
]
fig.update_layout(
height=600,
width=1000,
showlegend=True,
sliders=sliders,
title_text=f"dataset: {dset.name}",
xaxis={"range": [xmin, xmax]},
xaxis_title="energy, eV",
)
fig.show()
return fig
[docs]
def plot_monitoring(
df: DataFrame, pltcnts: list[list[int, str, bool, float]], range_slider: bool = True
):
"""Build an interactive figure with Plotly.
Parameters
----------
df : pandas.DataFrame
data container
pltcnts : list of list
counters settings for the plot [[flag, counter_name, is_y2, yscale]]
range_slider : boolean [True]
show the range slider
"""
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "xy", "secondary_y": True}]])
for pc in pltcnts:
flag, name, y2, yscale = pc
fig.add_trace(
{
"type": "scatter",
"x": df["time"],
"y": df[name] * yscale,
"name": f"{name} (x{yscale})",
},
secondary_y=y2,
)
# Add range slider
if range_slider:
fig.update_layout(
xaxis={
"rangeselector": {
"buttons": [
{
"count": 1,
"label": "1d",
"step": "day",
"stepmode": "backward",
},
{
"count": 7,
"label": "1w",
"step": "day",
"stepmode": "backward",
},
{"step": "all"},
]
},
"rangeslider": {"visible": True},
"type": "date",
}
)
fig.show()
return fig
### DEPRECATED / TO REMOVE ###
[docs]
def plot_curves(
dset: ExpDataset,
scan: int | None = None,
yoffset: int = 0,
ynorm: str | None = None,
):
"""Plot curves."""
cfluos = cycle(CFLUO)
if scan is None:
scans = dset.scans
else:
assert scan in dset.scans_names, f"available scans: {dset.scans_names}"
iscan = dset.scans_names.index(scan)
scans = [dset.scans[iscan]]
for scn in scans:
if scn.flag == 0:
continue
fig = make_subplots(rows=1, cols=1)
x = scn.energy
yshift = 0
for idx, (label, flag) in enumerate(zip(scn.fluos_names, scn.fluos_flags)):
if flag == 0:
continue
y = scn.fluos[idx, :]
if ynorm == "area":
y = y / np.trapezoid(y)
fig.add_trace(
pgo.Scatter(
x=x, y=y, name=label, marker=None, line={"color": next(cfluos)}
),
row=1,
col=1,
)
yshift += yoffset
fig.update_layout(
height=600,
width=1000,
title_text=f"dataset: {dset.name} | scan: {scn.scanno}",
showlegend=True,
)
fig.show()
[docs]
def plot_groups(dset, data="fluo"):
"""Plot merged Larch groups."""
fig = make_subplots(rows=1, cols=1)
x = getattr(getattr(dset, data), "energy")
y = getattr(getattr(dset, data), "mu")
glabel = f"{dset.name}_{data}"
fig.add_trace(
pgo.Scatter(
x=x,
y=y,
name=glabel,
marker=None,
line={"color": "blue"},
)
)
fig.update_layout(showlegend=True)
fig.show()