From 34145d7a49e00fea68ca094633ebbdf9a2b4e828 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Thu, 27 Oct 2022 10:53:50 +0200
Subject: [PATCH] Add conversion to delay, improve plotting

---
 src/toolbox_scs/routines/Reflectivity.py | 66 ++++++++++++++----------
 1 file changed, 38 insertions(+), 28 deletions(-)

diff --git a/src/toolbox_scs/routines/Reflectivity.py b/src/toolbox_scs/routines/Reflectivity.py
index 6813233..c6bba0e 100644
--- a/src/toolbox_scs/routines/Reflectivity.py
+++ b/src/toolbox_scs/routines/Reflectivity.py
@@ -9,11 +9,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import xarray as xr
 import re
-from functools import partial
 from toolbox_scs.misc.laser_utils import positionToDelay as pTd
-from toolbox_scs.misc.laser_utils import delayToPosition as dTp
-from toolbox_scs.base.knife_edge import (prepare_arrays,
-                                         function_fit)
 from toolbox_scs.routines.XAS import xas
 
 __all__ = [
@@ -59,9 +55,10 @@ def prepare_reflectivity_ds(ds, Iokey, Irkey, alternateTrains,
 
 
 def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
-                 delaykey='PP800_DelayLine', binWidth=0.05, 
+                 delaykey='PP800_DelayLine', binWidth=0.05,
+                 positionToDelay=True, origin=None, invert=False,
                  pumpedOnly=False, alternateTrains=False, pumpOnEven=True,
-                 Ioweights=False, plot=True, plotErrors=True, origin=None,
+                 Ioweights=False, plot=True, plotErrors=True, units='mm'
                  ):
     """
     Computes the reflectivity R = 100*(Ir/Io[pumped] / Ir/Io[unpumped] - 1)
@@ -85,6 +82,9 @@ def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
         optical delay in ps)
     binWidth: float
         width of bin in units of delay variable
+    positionToDelay: bool
+    origin: float
+    invert: bool
     pumpedOnly: bool
         Assumes that all trains and pulses are pumped. In this case,
         Delta R is defined as Ir/Io.
@@ -102,8 +102,6 @@ def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
         If True, plots the results.
     plotErrors: bool
         If True, plots the 95% confidence interval.
-    origin: float
-        Used for plotting a vertical line at the position.
 
     Output
     ------
@@ -114,18 +112,18 @@ def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
     # select relevant variables from dataset
     variables = [Iokey, Irkey, delaykey]
     ds = data[variables]
-
     # prepare dataset according to pulse pattern
     ds = prepare_reflectivity_ds(ds, Iokey, Irkey, alternateTrains,
                                  pumpOnEven, pumpedOnly)
-    
+
     if (len(ds[delaykey].dims) > 1) and (ds[delaykey].dims !=
                                          ds[Iokey].dims):
         raise ValueError("Dimensions mismatch: delay variable has dims "
                          f"{ds[delaykey].dims} but (It, Io) variables have "
-                         f"dims {ds['deltaR'].dims}.")
+                         f"dims {ds[Iokey].dims}.")
 
     bin_delays = binWidth * np.round(ds[delaykey] / binWidth)
+    ds[delaykey+'_binned'] = bin_delays
     counts = xr.ones_like(ds[Iokey]).groupby(bin_delays).sum(...)
     if Ioweights is False:
         ds['deltaR'] = ds[Irkey]/ds[Iokey]
@@ -164,42 +162,54 @@ def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
                               name='counts',
                               coords={delaykey: xas_pumped['nrj']})
         binned = xr.merge([deltaR, stddev, stderr, counts])
-    
+
     # copy attributes
     for key, val in data.attrs.items():
         binned.attrs[key] = val
 
+    binned = binned.rename({delaykey: 'delay'})
+
     if plot:
-        plot_reflectivity(binned, delaykey, origin, plotErrors)
+        plot_reflectivity(binned, delaykey, positionToDelay,
+                          origin, invert, plotErrors, units)
 
     return binned
 
 
-def plot_reflectivity(data, delaykey, origin, plotErrors):
-    fig, ax = plt.subplots(figsize=(6,3.5), constrained_layout=True)
-    ax.plot(data[delaykey], data['deltaR'], 'o-', color='C0')  
-    ax.set_xlabel(delaykey)
-    ax.set_ylabel(r'$\Delta R$ [%]', color='C0')
-    ax.grid()
+def plot_reflectivity(data, delaykey, positionToDelay, origin,
+                      invert, plotErrors, units):
+    fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
+    ax.plot(data['delay'], data['deltaR'], 'o-', color='C0')
+    xlabel = delaykey + f' [{units}]'
     if plotErrors:
-        ax.fill_between(data[delaykey],
+        ax.fill_between(data['delay'],
                         data['deltaR'] - 1.96*data['deltaR_stderr'],
                         data['deltaR'] + 1.96*data['deltaR_stderr'],
                         color='C0', alpha=0.2)
         ax2 = ax.twinx()
-        ax2.bar(data[delaykey], data['counts'],
-            width=0.80*(data[delaykey][1]-data[delaykey][0]),
-            color='C1', alpha=0.2)
-        ax2.set_ylabel('counts', color='C1')
+        ax2.bar(data['delay'], data['counts'],
+                width=0.80*(data['delay'][1]-data['delay'][0]),
+                color='C1', alpha=0.2)
+        ax2.set_ylabel('counts', color='C1', fontsize=13)
         ax2.set_ylim(0, data['counts'].max()*3)
     if origin is not None:
         ax.axvline(origin, color='grey', ls='--')
+        if positionToDelay:
+            ax3 = ax.twiny()
+            xmin, xmax = ax.get_xlim()
+            ax3.set_xlim(pTd(xmin, origin, invert),
+                         pTd(xmax, origin, invert),)
+            ax3.set_xlabel('delay [ps]', fontsize=13)
     try:
         proposalNB = int(re.findall(r'p(\d{6})',
                                     data.attrs['runFolder'])[0])
         runNB = int(re.findall(r'r(\d{4})', data.attrs['runFolder'])[0])
-        ax.set_title(f'run {runNB} p{proposalNB}')
-    except Exception as err:
+        ax.set_title(f'run {runNB} p{proposalNB}', fontsize=14)
+    except Exception:
         if 'plot_title' in data.attrs:
-            fig.suptitle(data.attrs['plot_title'])
-        print('error', err)
+            ax.set_title(data.attrs['plot_title'])
+    ax.set_xlabel(xlabel, fontsize=13)
+    ax.set_ylabel(r'$\Delta R$ [%]', color='C0', fontsize=13)
+    ax.grid()
+
+    return fig, ax
-- 
GitLab