# -*- coding: utf-8 -*-
#
# This file is part of the famewoks project
#
# Copyright (c) Mauro Rovezzi, CNRS
# Distributed under the GNU GPLv3. See LICENSE for more info.
"""Data Reduction Workflow from BLISS to Larch.
Currently applied to BLISS data collected at ESRF/BM16 beamline
"""
from pathlib import Path
import glob
import time
import numpy as np
from typing import Any
from copy import deepcopy
from datetime import datetime
from larch.io.specfile_reader import DataSourceSpecH5, _str2rng
from larch.io.specfile_reader import __version__ as ds_version
from larch.io.mergegroups import merge_groups
from larch.math.deglitch import remove_spikes_medfilt1d
from larch import Group
from larch.io import AthenaProject
from larch.xafs import pre_edge
from larch.xafs.rebin_xafs import rebin_xafs
from xraydb import xray_edge
# from larch.xafs.pre_edge import energy_align
from famewoks.energy import energy_align
#
from famewoks import __version__ as wkfl_version
from famewoks.models import ExpSession, ExpSample, ExpDataset, XasScanGroup
from famewoks import _logger
_logger.debug(f"DataSourceSpecH5 version: {ds_version}")
_logger.debug(f"workflow version: {wkfl_version}")
[docs]
def render_tree(level: int, message: str) -> str:
indentation = "|" * level
return f"{indentation}|- {message}"
[docs]
def safe_pre_edge(session: ExpSession, group: Group, verbose: bool = True) -> Group:
"""Apply `pre_edge` to a group, safely catching errors and checking for correct E0."""
infolab = group.filename
glabel = group.label if not None else ""
try:
e0ref = xray_edge(session.elem, session.edge).energy
_logger.debug(
f"{infolab}_{glabel} -> safe_pre_ede() -> E0 for {session.elem} {session.edge} edge: {e0ref:.2f} eV"
)
except Exception as err:
_logger.error(
f"{infolab}_{glabel} -> cannot get edge energy from session -> check element and edge"
)
_logger.debug(f"--- {type(err).__name__} ---> {err}")
e0ref = 0
try:
pre_edge(group)
except Exception as err:
_logger.warning(f"{infolab}_{glabel} -> `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
e0thr = e0ref * 0.01
if abs(group.e0 - e0ref) > e0thr:
_logger.debug(
f"{infolab}_{glabel} -> found E0 is out of 1% threshold from theory ({e0ref:.2f} eV) -> forced from {group.e0:.2f} to {e0ref:.2f} eV"
)
group.e0 = e0ref
return group
[docs]
def search_samples(
session: ExpSession,
verbose: bool = True,
ignore_names: list[str] = ["rack", "mount", "align", "bl_", "sample", "holder"],
):
"""Search for sample names in the `{session.datadir}/{session.raw_data_name}` directory.
Parameters
----------
session : ExpSession
The experimental session containing metadata and directory information.
see the documentation :class:`famewoks.datamodel.ExpSession`.
verbose : bool, optional [True]
If True, logs detailed information about the found samples (default is
True).
ignore_names : List[str], optional
Ignores files with names containing a list of strings. Default is:
- 'rack'
- 'mount'
- 'align'
- 'bl_',
- '.h5'. (always is ignored, to avoid top level HDF5 files)
Returns
-------
list of ExpSample
A list of `ExpSample` objects representing the found samples.
Notes
-----
- Searches for files in the `{session.datadir}/{session.raw_data_name}` directory.
- Found samples are sorted by their creation time.
- Updates the `session.samples` attribute with the found samples.
- If `session.bad_samples` is not None, it applies
:func:`famewoks.bliss2larch.set_bad_samples` filter bad samples.
"""
samples = []
search_dir = Path(session.datadir) / session.raw_data_name
if not search_dir.exists():
_logger.error(f"Cannot access: {search_dir}")
return samples
fnames = sorted(search_dir.glob("*"), key=lambda x: x.stat().st_ctime)
_logger.debug(f"Found {len(fnames)} files in {search_dir}")
samp_names = [samp.name for samp in session.samples]
_logger.debug(
f"{len(samp_names)} samples already in session.samples -> replacing with new search"
)
isamp = 0
for fname in fnames:
samp = fname.name
_logger.debug(f"sample: {samp} -> ({fname})")
if ".h5" in samp.lower():
continue
if any(ignore_name in samp.lower() for ignore_name in ignore_names):
flag = 0
else:
flag = 1
samples.append(ExpSample(name=samp, flag=flag))
isamp += 1
session.samples = samples
_logger.debug(f"{len(samples)} samples now in session.samples")
if verbose:
show_samples_info(session)
return samples
[docs]
def show_samples_info(
session: ExpSession, all: bool = False, show_datasets: bool = False
):
"""Show the list of samples in a given session.
Parameters
----------
session : ExpSession
The session containing the samples to display information about.
all : bool, optional [False]
If True, displays information about all samples in the session.
If False, only displays information about samples with a flag of 1
Returns
-------
None
"""
samples = session.samples
outinfos = [f"{len(samples)} samples stored in session:"]
samp_sel_label = "all samples" if all else "good *only* samples"
outinfos.append(f"#- index: name ({samp_sel_label})")
for isamp, samp in enumerate(samples):
if (samp.flag == 0) and (not all):
continue
outinfos.append(f"- {isamp}: {samp.name}")
if show_datasets:
dsets_infos = show_datasets_info(samp, verbose=False)
dsets_infos = [f"|{dsetinf}" for dsetinf in dsets_infos]
outinfos.extend(dsets_infos)
_logger.info("\n".join(outinfos))
[docs]
def show_datasets_info(sample: ExpSample, verbose: bool = True):
"""Show datasets info."""
outlist = []
for idataset, dataset in enumerate(sample.datasets):
outlist.append(f"- {idataset}: {dataset.name}")
if verbose:
_logger.info("\n".join(outlist))
else:
return outlist
[docs]
def search_datasets(
session: ExpSession, sample: int | str | None = None, verbose: bool = True
) -> list[ExpDataset] | ExpSession | None:
"""Search for HDF5 files and energy scans, grouped by datasets.
Parameters
----------
session : ExpSession
Experimental session with metadata ->
:class:`famewoks.datamodel.ExpSession`.
sample : Union[int, str, None], optional
The specific sample to search for datasets. If `None`, searches all
samples in the session. Can be an index (int) or a sample name (str)
(default is None).
verbose : bool, optional
If True, logs detailed information about the found datasets (default is
True).
Returns
-------
Union[List[ExpDataset], ExpSession]
If a specific sample is searched, returns a list of `ExpDataset` objects
representing the found datasets. If all samples are searched, returns
the updated `ExpSession`.
Notes
-----
- The function searches for HDF5 files within the
`{session.datadir}/{session.raw_data_name}/{sample_name}/**/*.h5` directory.
- It identifies scans containing `session.search_scanno` (e.g. .1) in their
scan number and `session.search_scantitle` (e.g. 'trigscan') in their title.
- Updates the `session.samples` attribute with the found datasets.
- Detailed information about the found datasets is logged if `verbose` is
True.
- If no samples are found in the session, it logs an error and returns
without making changes.
"""
if len(session.samples) == 0:
_ = search_samples(session, verbose=False)
_logger.debug("no samples found in session -> performed `search_samples()`")
samps_names = [samp.name for samp in session.samples]
if sample is None:
samps = session.samples
verbose = False
elif isinstance(sample, str):
try:
isamp = samps_names.index(sample)
except ValueError:
_ = search_samples(session, verbose=False)
_logger.warning(
f"sample `{sample}` not found -> performed `search_samples()`"
)
samps_names = [samp.name for samp in session.samples]
try:
isamp = samps_names.index(sample)
except ValueError:
_logger.error(f"sample `{sample}` not found -> giving up!")
return None
samps = [session.samples[isamp]]
else:
try:
samps = [session.samples[sample]]
except IndexError:
_ = search_samples(session, verbose=False)
_logger.warning(
f"sample `{sample}` not found -> performed `search_samples()`"
)
try:
samps = [session.samples[sample]]
except IndexError:
_logger.error(f"sample `{sample}` not found -> giving up!")
return None
datasets = []
for samp in samps:
isamp = samps_names.index(samp.name)
if samp.flag == 0:
continue
_logger.debug(f"{isamp}: {samp.name}")
search_str = f"{session.datadir}/{session.raw_data_name}/{samp.name}/**/*.h5"
fnames = glob.glob(search_str)
fnames.sort(key=lambda x: Path(x).stat().st_ctime) ##sort by creation time
outinfo = (
["#- index: dataset name [nscans : ntypes]"] if sample is not None else []
)
idataset = 0
nscans_list = []
datasets = []
for fname in fnames:
scans = []
scans_names = []
scans_emin = []
scans_emax = []
scans_eshifts = []
fnroot = Path(fname).name.split(".h5")[0]
dat = DataSourceSpecH5(fname, verbose=False)
dat._logger.setLevel("ERROR")
nscans = 0
for scanno, scantitle, scantstamp in dat.get_scans():
if (session.search_scanno in scanno) and (
session.search_scantitle in scantitle
):
try:
dat.set_scan(scanno)
except Exception as err:
_logger.error(f"--- {type(err).__name__} ---> {err}")
continue
scaninfo = dat.get_scan_info_from_title()
scanint = int(scanno.split(".")[0])
scans.append(
XasScanGroup(
flag=1,
fname=fname,
sample=samp.name,
dataset=fnroot,
scanno=scanno,
scanint=scanint,
title=scantitle,
time=scantstamp,
comment="",
fluo=None,
fluos=None,
ref=None,
trans=None,
)
)
scans_names.append(scanint)
scans_emin.append(float(scaninfo["scan_start"]))
scans_emax.append(float(scaninfo["scan_end"]))
scans_eshifts.append(float(0))
nscans += 1
dat.close()
if nscans == 0:
continue
scans_types = get_scans_types(scans_emin, scans_emax)
nscans_list.append(nscans)
outinfo.append(
f"- {idataset:2}: {fnroot:30} [{nscans:3} : {len(scans_types)}]"
)
datasets.append(
ExpDataset(
name=fnroot,
flag=1,
scans=scans,
scans_names=scans_names,
scans_emin=np.array(scans_emin) * 1000,
scans_emax=np.array(scans_emax) * 1000,
scans_eshifts=np.array(scans_eshifts),
scans_types=scans_types,
energy=None,
trans=None,
fluo=None,
ref=None,
rebin_pars=session.rebin_pars,
)
)
idataset += 1
samp.datasets = datasets
_logger.info(
f"- {isamp:2}: {samp.name:25} -> {nscans_list}"
) # TODO: # -> types: {scans_types}")
if verbose:
_logger.info("\n".join(outinfo))
if len(samps) == 1:
return datasets
else:
return session
[docs]
def get_scans(
dataset: ExpDataset,
verbose: bool = False,
ignore_flag: bool = False,
) -> list[list[int | str | XasScanGroup]]:
"""Get a list of enabled scans from a dataset.
Parameters
----------
dataset: ExpDataset
dataset
verbose: bool
show selected scans infos
Returns
-------
outscans: list of list
[
[index, ScanXASGroup, title]
]
"""
outscans = []
for iscan, scan in enumerate(dataset.scans):
if scan.flag == 0 and (not ignore_flag):
_logger.debug(
f"[get_scans()] {dataset.name}/{scan.scanno} -> skipped (flag=0)"
)
continue
if verbose:
_logger.info(f"{iscan} -> {dataset.name}/{scan.scanno} [{scan.title}]")
outscans.append([iscan, scan, scan.title])
return outscans
[docs]
def get_scan(
dataset: ExpDataset, scanint: int, ignore_flag: bool = False
) -> XasScanGroup:
"""Get scan object from the scans in the given dataset.
Parameters
----------
dataset : ExpDataset
The dataset containing the list of scans. see the documentation of
:class:`famewoks.datamodel.ExpDataset`.
scanint : int
The integer identifier of the scan to retrieve.
Returns
-------
ExpScan
The `XasScanGroup` object corresponding to the provided scan integer
identifier
"""
scans = get_scans(dataset, verbose=False, ignore_flag=ignore_flag)
for [iscan, scan, stitle] in scans:
if scan.scanint == scanint:
_logger.debug(
f"- {iscan}) scan {scan.scanint}: {dataset.name}/{scan.scanno} [{stitle}]"
)
return scan
[docs]
def get_groups(scan: XasScanGroup, data: str) -> list[Group]:
"""Get a list of enabled data groups from a scan.
Parameters
----------
scan: XasScanGroup
scan object
data: str
Returns
-------
outgroups: list of Group
"""
outgroups = []
groups = getattr(scan, data)
for grp in groups:
if grp.flag == 0:
continue
outgroups.append(grp)
return outgroups
[docs]
def get_group(
datain: ExpSession | ExpSample | ExpDataset,
scanint: int,
data: str,
sample: ExpSample | int | str | None = None,
dataset: ExpDataset | int | str | None = None,
) -> Group:
"""Get a data group from dataset and scan number."""
dset = None
if isinstance(datain, ExpSession):
if sample is None:
raise ValueError("`sample` not given")
sample = sample if isinstance(sample, ExpSample) else datain.get_sample(sample)
if dataset is None:
raise ValueError("`dataset` not given")
dset = (
dataset if isinstance(dataset, ExpDataset) else sample.get_dataset(dataset)
)
if isinstance(datain, ExpSample):
if dataset is None:
raise ValueError("`dataset` not given")
dset = (
dataset if isinstance(dataset, ExpDataset) else datain.get_dataset(dataset)
)
if isinstance(datain, ExpDataset):
dset = datain
scan = get_scan(dset, scanint)
groups = get_groups(scan, data)
return groups[0]
[docs]
def load_data(
session: ExpSession,
sample: int | str | ExpSample,
dataset: int | str | ExpDataset,
use_fluo_corr: bool = False,
filter_spikes: bool = False,
skip_scans: str | list[int] | None = None,
iskip: int | None = None,
istrip: int | None = None,
calc_eshift: bool = False,
merge: bool = False,
**kws: dict[str, Any],
):
"""Load data from disk into the data model (=memory).
Parameters
----------
session:
the session object :class:`famewoks.models.ExpSession`
sample:
sample identification, int (=index), str (=sample name) or
:class:`famewoks.models.ExpSample`
dataset:
dataset identification, int (=index), str (=dataset name) or
:class:`famewoks.models.ExpDataset`
use_fluo_corr : bool [False]
when True, use the dead-time corrected fluorescence channels otherwise
the uncorrected `roi` is used
filter_spikes: bool [False]
if True, remove spikes via median filter (with default parameters)
skip_scans : List[int], str or None
the scans to exclude. if a string is given is parsed by :func:`_str2rng`
iskip: int [None]
skip first iskip points
istrip: int [None]
strip last istrip points
calc_eshift: bool [False]
to calculate the energy shift
Returns
-------
None
"""
sample = session.get_sample(sample)
dataset = sample.get_dataset(dataset)
cnts = session.counters
cnts_fluo = cnts.fluo_corr if use_fluo_corr else cnts.fluo_roi
i0sig = cnts.ix[0]
i1sig = cnts.ix[1]
i2sig = cnts.ix[2]
good_scans = []
ngood = 0
set_bad_scans(session, sample, dataset, scans=skip_scans)
scans = get_scans(dataset)
ntot = len(scans)
e0ref_grp = None
datas = []
if istrip is not None:
istrip = -1 * istrip
for idx, (iscan, scan, _) in enumerate(scans):
infolab = f"{dataset.name}/scan{scan.scanint}"
ds = DataSourceSpecH5(scan.fname)
ds.set_scan(scan.scanno)
_logger.debug(render_tree(0, f"{infolab}: start loading data"))
#: load energy + Ix channels
try:
ene = ds.get_array(cnts.ene)[iskip:istrip] * cnts.ene_to_ev # in eV
i0 = ds.get_array(i0sig)[iskip:istrip]
i1 = ds.get_array(i1sig)[iskip:istrip]
i2 = ds.get_array(i2sig)[iskip:istrip]
stime = ds.get_array(cnts.time)[iskip:istrip]
except Exception as err:
scan.flag = 0
_logger.warning(
f"{infolab}: not loaded and flagged as bad scan -> probably something wrong with this scan?"
)
_logger.debug(err)
continue
#: check array points dicrepancies between energy and ix
try:
assert ene.size == i0.size == stime.size, (
f"{infolab}: array shape mismatch ene/ix/time"
)
except AssertionError as err:
scan.flag = 0
_logger.error(f"{err} -> flagged as bad scan and skipped")
_logger.debug(
f"--- {type(err).__name__} ---> array [size]: {cnts.ene} [{ene.size}], {cnts.ix[0]} [{i0.size}], {cnts.time} [{stime.size}], {cnts_fluo[0]} [{fluo0.size}]"
)
continue
#: load fluorescence channel
def __get_fluo(ichan: int):
try:
cnt_fluo = cnts_fluo[ichan]
except IndexError:
return np.zeros_like(ene), False
try:
fluo = ds.get_array(cnt_fluo)[iskip:istrip]
has_fluo = True
except Exception as err:
_logger.warning(
f"{infolab}: cannot read fluorescence channel `{cnt_fluo}`"
)
_logger.debug(err)
fluo = np.zeros_like(ene)
has_fluo = False
return fluo, has_fluo
yfluos = []
has_fluos = []
for ich, _ in enumerate(cnts_fluo):
yfluo, has_fluo = __get_fluo(ich)
yfluos.append(yfluo)
if has_fluo:
has_fluos.append(ich)
if len(has_fluos) > 0:
#: check arrays points dicrepancies with fluorescence channels
fluo0 = yfluos[has_fluos[0]]
ptsdiff = ene.size - fluo0.size
if ptsdiff > 0:
_logger.debug(
f"{infolab}: ene.size - fluo.size = {ptsdiff} -> stripping last points (as PyMca)"
)
ene = ene[: fluo0.size]
i0 = i0[: fluo0.size]
i1 = i1[: fluo0.size]
i2 = i2[: fluo0.size]
stime = stime[: fluo0.size]
if ptsdiff < 0:
scan.flag = 0
_logger.error(
f"{infolab}: ene.size - fluo.size = {ptsdiff} -> cannot load data"
)
continue
#: -> load fluos (= fluorescence channels) data
_logger.debug(render_tree(1, f"start loading fluorescence channels data [= fluos]"))
gfluos = []
ysum = np.zeros_like(ene)
isum = 0
for idet, (sig, etime) in enumerate(zip(cnts_fluo, cnts.fluo_time)):
if idet not in has_fluos:
_logger.debug(f"{infolab}: skipped {sig}")
continue
glabel = f"fluo{idet}_{sig}"
ysig = (
ds.get_array(sig)[iskip:istrip] / ds.get_array(etime)[iskip:istrip]
) * stime
ylab = f"{sig}"
if filter_spikes:
ysig = remove_spikes_medfilt1d(ds.get_array(sig)[iskip:istrip])
ylab += "_despiked"
#: normalize to monitor
ysig = (ysig / i0) * np.average(i0) #: to keep number of counts
ylab += f"/{cnts.ix[0]}"
g = Group(
filename=f"{dataset.name}_scan{scan.scanint}_{glabel}",
label=glabel,
signal=ylab,
energy=ene,
i0=i0,
mu=ysig,
datatype="xas",
flag=1,
)
try:
pre_edge(g)
except Exception as err:
_logger.warning(f"{infolab}_{glabel} -> `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
gfluos.append(g)
_logger.debug(render_tree(2, f"[fluos: {glabel}] -> loaded {ylab}"))
ysum = ysum + ysig
isum += 1
scan.fluos = gfluos
#: -> add sum of fluorescence channels
_logger.debug(render_tree(1, f"add sum of fluorescence channels [fluo]"))
glabel = f"fluo_sum{isum}"
gsum = Group(
filename=f"{dataset.name}_scan{scan.scanint}_{glabel}",
label=glabel,
signal=glabel,
energy=ene,
i0=i0,
mu=ysum,
datatype="xas",
flag=1,
)
try:
pre_edge(gsum)
except Exception as err:
_logger.warning(f"{infolab}_{glabel} -> `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
scan.fluo = [gsum]
datas.append("fluo")
#: -> load transmission data
glabel = "mutrans"
ylab = f"log({i0sig}/{i1sig})"
_logger.debug(render_tree(1, f"load transmission data [= trans] ({glabel})"))
try:
mutrans = np.log(i0 / i1)
if np.any(np.isnan(mutrans)):
raise ValueError(f"NaN values found in {glabel}={ylab}")
mutrans_flag = 1
except Exception as err:
_logger.warning(f"{infolab} -> cannot calculate mutrans")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
mutrans = np.zeros_like(ene)
mutrans_flag = 0
gtrans = Group(
filename=f"{dataset.name}_scan{scan.scanint}_{glabel}",
label=glabel,
signal=ylab,
energy=ene,
i0=i0,
mu=mutrans,
datatype="xas",
flag=mutrans_flag,
)
try:
pre_edge(gtrans)
except Exception as err:
_logger.warning(f"{infolab}_{glabel} -> `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
scan.trans = [gtrans]
if mutrans_flag:
datas.append("trans")
#: -> load reference data
glabel = "muref"
ylab = f"{i2sig}/{i1sig}" if cnts.ref_fluo else f"log({i1sig}/{i2sig})"
_logger.debug(render_tree(1,f"load reference data [= ref] ({glabel})"))
try:
muref = i2 / i1 if cnts.ref_fluo else np.log(i1 / i2)
if np.any(np.isnan(muref)):
raise ValueError(f"NaN values found in {glabel}={ylab}")
muref_flag = 1
except Exception as err:
_logger.warning(f"{infolab} -> cannot calculate muref")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
muref = np.zeros_like(ene)
muref_flag = 1
gref = Group(
filename=f"{dataset.name}_scan{scan.scanint}_{glabel}",
label=glabel,
signal=ylab,
energy=ene,
i0=i1,
mu=muref,
datatype="xas",
flag=muref_flag,
)
try:
pre_edge(gref)
except Exception as err:
_logger.warning(f"{infolab}_{glabel} -> `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
scan.ref = [gref]
if muref_flag:
datas.append("ref")
good_scans.append(scan.scanint)
ngood += 1
#: -> calculate energy_shifts load scans_eshifts
if calc_eshift:
if idx == 0:
if session.enealign is None:
_logger.warning(
f"{infolab}_{glabel} -> a reference spectrum for the energy calibration is not set -> using current scan"
)
set_enealign(session, gref)
e0ref_grp = session.enealign
eshift = get_eshift(
gref,
e0ref_grp,
emin=e0ref_grp.pars_energy_align["emin"],
emax=e0ref_grp.pars_energy_align["emax"],
)
dataset.scans_eshifts[iscan] = float(eshift)
if calc_eshift and len(dataset.scans_types) > 1:
p1 = dataset.scans_eshifts[0]
p2 = dataset.scans_eshifts[len(scans) - 1]
_logger.warning(
"Not the same type of scans detected, we'll do an interpolation between the first and the last eshift"
)
for i in range(len(scans)):
dataset.scans_eshifts[i] = (
np.linspace(p1, p2, len(scans)).tolist()[i]
if dataset.scans_eshifts[i] != 0
else 0
)
ds.close()
_logger.info(f"{dataset.name} -> loaded {ngood}/{ntot} scans -> {good_scans}")
if merge:
merge_data(dataset)
[docs]
def get_eshift(
scan: Group,
scan_ref: Group,
emin: float = -15,
emax: float = 35,
**kws: dict[str, Any],
):
"""Get the energy shift between two Groups"""
eshift = energy_align(dat=scan, ref=scan_ref, emin=emin, emax=emax)
_logger.info(f"{scan.filename} [{scan_ref.filename}] energy shift: {eshift}")
return eshift
[docs]
def apply_eshift(dset: ExpDataset, data: str):
assert data in [
"fluo",
"ref",
"trans",
], "data must be one of ['fluo', 'ref', 'trans']"
scans = get_scans(dset)
for iscan, scn, _ in scans:
eshift = dset.scans_eshifts[iscan]
getattr(dset.scans[iscan], data)[0].energy += eshift
_logger.info(f"{dset.name}/{scn.scanint}/{data} -> energy shifted by {eshift}")
merge_data(dset)
[docs]
def set_bad_samples(session: ExpSession, samples: str | int | list[int | str] | None):
"""Set bad samples."""
flag, flag_str = 0, "BAD"
allsamps_name = [samp.name for samp in session.samples]
#: parse input
if samples is None:
samples = allsamps_name
flag, flag_str = 1, "GOOD"
elif isinstance(samples, str):
samples = [samples] if samples in allsamps_name else _str2rng(samples)
elif isinstance(samples, int):
samples = [samples]
#: flag samples
bad_samples = []
for bad_samp in samples: # type: ignore
try:
samp = session.get_sample(bad_samp)
except ValueError:
continue
samp.flag = flag
bad_samples.append(samp.name)
_logger.debug(f"{samp.name} flagged as {flag_str}")
#: store info to file
if len(bad_samples) > 0:
_logger.info(f"flagged {len(bad_samples)} samples as {flag_str}: {bad_samples}")
outfile = "SAMPLES.log"
outheader = "#YYYY-MM-DD HH:MM:SS, flag (0=BAD, 1=GOOD), samples"
outstr = f"{flag}, {bad_samples}"
session._write_logfile_str(outfile, outstr, outheader)
[docs]
def set_bad_scans(
session: ExpSession,
sample: int | str | ExpSample,
dataset: int | str | ExpDataset,
scans: str | int | list[int] | None,
):
"""Set flag=0 to bad scans -> not included in merge."""
sample = session.get_sample(sample)
dataset = sample.get_dataset(dataset)
flag, flag_str = 0, "BAD"
allscans = [scn.scanint for scn in dataset.scans]
if scans is None:
scans = allscans
flag, flag_str = 1, "GOOD"
if isinstance(scans, str):
scans = _str2rng(scans)
if isinstance(scans, int):
scans = [scans]
bad_scans = []
for scanint in scans: # type: ignore
if scanint in allscans:
scan = get_scan(dataset, scanint, ignore_flag=True) # type: ignore
scan.flag = flag
bad_scans.append(scan.scanint)
else:
_logger.debug(f"cannot find scan {scanint} in: {allscans}")
if len(bad_scans) > 0:
_logger.info(f"flagged {len(bad_scans)} scans as *{flag_str}*: {bad_scans}")
#: store info to file
outfile = "SCANS.log"
outheader = "#YYYY-MM-DD HH:MM:SS, sample, dataset, flag (0=BAD, 1=GOOD), scans"
outstr = f"{sample.name}, {dataset.name}, {flag}, {bad_scans}"
session._write_logfile_str(outfile, outstr, outheader)
merge_data(dataset)
[docs]
def set_bad_fluo_channels(
session: ExpSession,
sample: int | str | ExpSample,
dataset: int | str | ExpDataset,
channels: str | int | list[int] | None,
scans: str | int | list[int] | None = None,
):
"""Set bad fluo channels.
Parameters
----------
session:
ExpSession object
dataset:
int, str or ExpDataset object
sample:
int, str or ExpSession object
channels:
str (parsed by str2rng) or list of int if None: all channels are
enabled back!
scan:
int, or list of ints, if None: set for all scans in the current dataset
"""
sample = session.get_sample(sample)
dataset = sample.get_dataset(dataset)
allchannels = [ich for ich, _ in enumerate(session.counters.fluo_roi)]
#: parse scans input
scans_objs = []
allscans_int = [scn.scanint for scn in dataset.scans]
if scans is None:
scans = allscans_int
elif isinstance(scans, str):
scans = _str2rng(scans)
elif isinstance(scans, int):
scans = [scans]
else:
pass
#: get scans objects
for scn in scans: # type: ignore
i = 0
try:
while dataset.scans[i].scanint != scn:
i = i + 1
except Exception as err:
_logger.error(f"{sample.name}/{dataset.name}/{scn} -> scan not found")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
continue
scans_objs.append(dataset.scans[i])
#: check channels input
flag, flag_str = 0, "BAD"
if isinstance(channels, str):
try:
channels = _str2rng(channels) # type: ignore
except Exception:
raise ValueError("given a wrong `channels` string")
if channels is None:
flag, flag_str = 1, "GOOD"
channels = allchannels # type: ignore
if isinstance(channels, int):
channels = [channels]
for iscn, scn in enumerate(scans_objs):
if scn.flag == 0:
_logger.debug(
f"scan {scn.scanint} is disabled: cannot set {flag_str} fluo channels"
)
_ = scans_objs.pop(iscn)
continue
try:
for ichannel, gfluo in enumerate(scn.fluos):
if ichannel in channels: # type: ignore
gfluo.flag = flag
else:
continue
except Exception as err:
_logger.error("`set_bad_fluo_channels()` failed -> do `load_data()` first")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
#: reporcess data
sum_fluo_channels(dataset, scn)
if len(scans_objs) > 0 and len(channels) > 0: # type: ignore
_logger.info(
f"flagged {len(channels)} fluo channels as {flag_str} for {len(scans_objs)} scans"
)
_logger.debug(
f"{flag_str} channels: {channels} for scans: {[scn.name for scn in scans_objs]}"
)
#: store information to file
outfile = "CHANNELS.log"
outheader = "sample, dataset, scans, flag (0=BAD, 1=GOOD), fluo_channels"
outstr = f"{sample.name}, {dataset.name}, {scans}, {flag}, {channels}"
session._write_logfile_str(outfile, outstr, outheader)
[docs]
def sum_fluo_channels(dataset: ExpDataset, scan: XasScanGroup):
"""Merge fluorescence channels for a given scan.
Parameters
----------
scan : Scan
The scan object containing the fluorescence channels to be merged. The
`Scan` object should have a `fluo` attribute (list of data groups) and
`fluos` attribute (list of fluorescence channels), each with a `flag`
attribute indicating if the channel should be processed.
Notes
-----
- Only fluorescence channels with the `flag` attribute set to `True` are
processed.
- The filename of the merged data group is updated to reflect the number of
channels summed.
- The function applies a pre-edge correction to the merged data group.
"""
gsum = scan.fluo[0]
ysum = np.zeros_like(gsum.energy)
nsum = 0
for gfluo in scan.fluos:
if gfluo.flag == 0:
continue
ysum = ysum + gfluo.mu
nsum += 1
gsum.mu = ysum
glabel = f"fluo_sum{nsum}"
infolab = f"{dataset.name}_scan{scan.scanint}_{glabel}"
gsum.filename = infolab
gsum.label = glabel
if np.all(ysum == 0):
_logger.warning(f"{infolab}: zero sum of fluorescence channels")
try:
pre_edge(gsum)
except Exception as err:
_logger.warning(f"{infolab}: `pre_edge()` failed")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
[docs]
def merge_data(dset: ExpDataset):
"""Merge scans in ExpDataset for datas ("fluo", "ref", "trans").
This function merges data from the scans of an ExpDataset object based on
the specified method. It processes the fluorescence ('fluo'), reference
('ref'), and transmission ('trans') data from the scans, merges the data
groups, and applies pre-edge correction and rebins the dataset.
Parameters
----------
dset : ExpDataset
The dataset containing scans to be merged. see the documentation of
:class:`famewoks.datamodel.ExpDataset`.
method : str, optional
The method to use for merging data. Currently, only "sum" is supported.
(default is "sum")
Notes
-----
- Only scans with the `flag` attribute set to `True` are processed.
- The function applies a pre-edge correction to the merged data group before
saving it back to the dataset.
- The merged data group's filename is updated to reflect the number of scans
merged.
- The function logs the name of the merged data group and rebins the dataset
after merging.
"""
scans = get_scans(dset)
mrgname = ""
for data in ["fluo", "ref", "trans"]:
groups_to_merge = []
for _, scan, _ in scans:
data_group = getattr(scan, data)
if data_group is not None and hasattr(data_group, "__getitem__"):
data_group = data_group[0]
else:
continue
if isinstance(data_group, Group):
groups_to_merge.append(data_group)
glabel = f"mrg{len(groups_to_merge)}_{data}"
mrgname = f"{dset.name}_{glabel}"
if len(groups_to_merge) == 1:
grp_mrg = deepcopy(groups_to_merge[0])
else:
try:
grp_mrg = merge_groups(
groups_to_merge,
master=groups_to_merge[0],
xarray="energy",
yarray="mu",
kind="cubic",
trim=True,
calc_yerr=True,
)
_logger.info(f"{mrgname} -> merged {len(groups_to_merge)} groups")
except Exception as err:
_logger.warning(f"{mrgname} -> `merge_groups()` failed to merge groups")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
continue
try:
assert grp_mrg.energy.size == grp_mrg.mu.size, (
f"{mrgname} -> array shape mismatch energy/mu"
)
except AssertionError as err:
_logger.warning(err)
try:
pre_edge(grp_mrg)
except Exception as err:
_logger.warning(f"{mrgname} -> `pre_edge()` failed on merged group")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
grp_mrg.filename = mrgname
grp_mrg.label = glabel
grp_mrg.flag = 1
setattr(dset, data, grp_mrg)
rebin_data(dset)
[docs]
def rebin_data(dset: ExpDataset, rebin_pars: dict[str, Any] | None = None):
"""Rebin merged data groups in an ExpDataset object.
This function rebins the fluorescence ('fluo'), reference ('ref'), and
transmission ('trans') data groups in the dataset after they have been
merged. It applies rebinning and pre-edge correction to the rebinned data
groups.
Parameters
----------
dset : ExpDataset
The dataset containing the merged data groups to be rebinned. The
`ExpDataset` object should have `fluo`, `ref`, and `trans` attributes
which are the data groups to be rebinned.
Notes
-----
- This function assumes that the data groups have already been merged using
the `merge_data` function.
- If the data groups are not found in the dataset, an error is logged
indicating that `merge_data` should be run first.
- The function logs any errors encountered during the rebinning process.
"""
if rebin_pars is None:
rebin_pars = dset.rebin_pars
assert isinstance(rebin_pars, dict), "`rebin_pars` should be a dictionary"
for dat in ["fluo", "ref", "trans"]:
infolab = f"{dset.name}/{dat}"
try:
data_group = getattr(dset, dat)
except Exception as err:
_logger.error(
f"{infolab}: cannot get merged data -> run `merge_data()` first"
)
_logger.debug(f"--- {type(err).__name__} ---> {err}")
continue
try:
rebin_group(data_group, **rebin_pars)
except Exception as err:
_logger.debug(f"--- {type(err).__name__} ---> {err}")
continue
[docs]
def rebin_group(grp: Group, **kws: dict[str, Any]):
"""Rebin a single Larch group.
Parameters
----------
grp : Group
The group to be rebinned
kws : dict
Keyword arguments to be passed to the `rebin_xafs()` function
*Default parameters for Larch `rebin_xafs()` function* :
- `pre1`: pre_step*int((min(energy)-e0)/pre_step)
- `pre2` : -30
- `pre_step`: 2
- `exafs1` :-15
- `exafs2` : max(energy)-e0
- `xanes_step` : e0/25000 , round down to 0.05
- `method` : centroid
"""
try:
rebin_xafs(grp, **kws)
grp_rebinned = grp.rebinned
pre_edge(grp_rebinned)
grp_rebinned.filename = f"{grp.filename}_rebinned"
grp_rebinned.label = f"{grp.label}_rebinned"
_logger.info(f"{grp.filename} -> rebinned")
except Exception as err:
_logger.error(f"{grp.filename} -> cannot rebin")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
[docs]
def save_data(
dset: ExpDataset,
data: list[str] | str = "all",
datadir: str | None = None,
save_rebinned: bool = False,
suffix: str | None = None,
):
"""Save all data to an Athena project file.
Parameters
----------
dset : ExpDataset
The :class:`famewoks.datamodel.ExpDataset` object containing the data
to be saved.
data : list of str or str, optional (default: 'all')
list of the data channel to be saved within ["fluo", "trans", "ref"]
datadir : str or None, optional
The directory where the data will be saved. If None, defaults to
"/tmp".
save_rebinned : bool, optional
Whether to save rebinned data. Defaults to False.
suffix : str or None, optional
Add an additional string to the output filename.
Notes
-----
This function saves all data from the provided `ExpDataset` object to an
Athena project file. It creates a directory named
'{session.processed_data_name} within the specified `{session.datadir}`, if
it does not already exist, and saves the project file with a filename based
on the dataset name and current timestamp. If `save_rebinned` is True,
rebinned data will be saved with "_rebin" appended to the filename.
"""
import logging
sx_logger = logging.getLogger("silx")
sx_logger.setLevel("ERROR")
tstamp = time.strftime("%y%m%d%_H%M") #: 251118_1702
data_all = ["fluo", "trans", "ref"]
if data == "all":
data = data_all
data_check = [dat in data_all for dat in data]
try:
assert all(data_check), f"wrong data type, possible choices: {data_all}"
except AssertionError as err:
_logger.debug(f"--- {type(err).__name__} ---> {err}")
return
if datadir is None:
import tempfile
datadir = tempfile.gettempdir()
dirout = (
Path(datadir) / "PROCESSED_DATA" / "famewoks"
) #: TODO: session.processed_data_name
if not dirout.exists():
dirout.mkdir(parents=True)
suffix = "" if suffix is None else f"_{suffix}"
for dat in data:
fnameout = str(dirout / f"{dset.name}_{tstamp}_{dat}{suffix}.prj")
apj = AthenaProject(fnameout)
# apj.info = {}
# apj.info["scans"] = dset.scans_names
outgrps = ["saved groups:"]
# _logger.debug(f"{dat}: saving individual scans")
for _, scan in zip(dset.scans_names, dset.scans):
if scan.flag == 0:
continue
grp = getattr(scan, dat)[0]
if grp is None:
_logger.warning(f"{dset.name}/{dat} -> cannot get the scan group")
continue
try:
apj.add_group(grp)
outgrps.append(f"{grp.filename}")
_logger.debug(f"{grp.filename} -> added to the Athena project")
except Exception as err:
_logger.warning(f"{grp.filename} -> cannot add to the Athena project")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
# _logger.debug(f"{dat}: saving merged groups")
gmrg = getattr(dset, dat)
if gmrg is None:
_logger.debug(f"{dset.name}/{dat} -> cannot get the merged group")
else:
try:
apj.add_group(gmrg)
outgrps.append(gmrg.filename)
_logger.debug(f"{gmrg.filename} -> added to the Athena project")
except Exception as err:
_logger.warning(f"{gmrg.filename} -> cannot add to the Athena project")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
if save_rebinned:
try:
grb = gmrg.rebinned
except Exception as err:
_logger.warning(f"{gmrg.filename} -> cannot get the rebinned group")
_logger.debug(f"--- {type(err).__name__} ---> {err}")
continue
try:
apj.add_group(grb)
outgrps.append(grb.filename)
_logger.debug(f"{grb.filename} -> added to the Athena project")
except Exception as err:
_logger.warning(
f"{grb.filename} -> cannot add to the Athena project"
)
_logger.debug(f"--- {type(err).__name__} ---> {err}")
apj.save()
_logger.info(f"data saved in {fnameout}")
_logger.info("\n".join(outgrps))
apj = None #: close the file
[docs]
def get_scans_types(
scans_emin: list[float], scans_emax: list[float], energy_window: float = 10
) -> list[tuple[float, float, list[int]]]:
"""Cluster scans based on their energy ranges.
Parameters
----------
scans_emin : list
A list of minimum energy values for each scan.
scans_emax : list
A list of maximum energy values for each scan.
energy_window : int, optional
The energy window used for clustering (default is 10).
Returns
-------
list
list of scan types with the following format: (emin, emax, [iscans]) where iscans is a list of scan indices
"""
scan_types = []
for iscan, (emin, emax) in enumerate(zip(scans_emin, scans_emax)):
if iscan == 0:
scan_types.append([emin, emax, [iscan]])
_logger.debug(f"{iscan} -> {emin} - {emax} [case 0: initial type]")
continue
for ityp, (smin, smax, iscans) in enumerate(scan_types):
if abs(emin - smin) <= energy_window and abs(emax - smax) <= energy_window:
iscans.append(iscan)
_logger.debug(
f"{iscan} -> {emin} - {emax} [case 1: in energy window -> add to type {ityp}]"
)
break
else:
scan_types.append([emin, emax, [iscan]])
_logger.debug(
f"{iscan} -> {emin} - {emax} [case 2: out of energy window -> new type ({len(scan_types)})]"
)
return scan_types
[docs]
def set_enealign(session: ExpSession, gref: Group, data: str = "ref"):
"""Set the energy alignment reference group.
Parameters
----------
session (ExpSession): The session object for which the energy calibration is set.
gref (Group): The reference group used for energy calibration.
data (str, optional): The type of data used for energy calibration. Defaults to "ref".
Returns
-------
None (sets `session.enealign` and `.pars_energy_align` attribute)
"""
if not hasattr(gref, "e0"):
pre_edge(gref)
erelmin = gref.energy.min() - gref.e0
erelmax = gref.energy.max() - gref.e0
emin = -15 if erelmin < -25 else float(erelmin)
emax = 35 if erelmax > 45 else float(erelmax)
if not hasattr(gref, "rebinned"):
rebin_group(gref, **session.rebin_pars)
session.enealign = deepcopy(gref.rebinned)
setattr(session.enealign, "pars_energy_align", {"emin": emin, "emax": emax})
_logger.info(
f"{session.enealign.filename} -> setted as reference to calculate the energy shifts"
)
[docs]
def set_eshift(dset: ExpDataset, scan: int, eshift: float):
i = 0
while dset.scans[i].scanint != scan:
i = i + 1
dset.scans_eshifts[i] = eshift
_logger.info(f"scan {scan} from dataset {dset.name} shift is now set : {eshift}")
if __name__ == "__main__":
pass