From 1bd27bbe30b82b99f48cd27fcae1d5133b18ecc4 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Tue, 10 Mar 2020 22:14:44 +0100
Subject: [PATCH] Improved autoFindFastAdcPeaks()

---
 xgm.py | 47 +++++++++++++++++++++++++++++++++--------------
 1 file changed, 33 insertions(+), 14 deletions(-)

diff --git a/xgm.py b/xgm.py
index 40d61ea..24f2250 100644
--- a/xgm.py
+++ b/xgm.py
@@ -863,18 +863,18 @@ def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, period=Non
         results[:,i] = integ
     return results
 
-def autoFindFastAdcPeaks(data, channel=5, threshold=35000, display=False, plot=False):
-    ''' Automatically finds positive peaks in channel of Fast ADC trace, assuming
-        a minimum absolute height of 'threshold' counts and a minimum width of 4 
-        samples. The find_peaks function and determination of the peak region and 
-        baseline subtraction is optimized for typical photodiode signals of the 
-        SCS instrument (ILH, FFT reflectometer, FFT diag stage).
+def autoFindFastAdcPeaks(data, channel=5, window='small', display=False, plot=False):
+    ''' Automatically finds peaks in channel of Fast ADC trace, a minimum width of 4 
+        samples. The find_peaks function and determination of the peak integration 
+        region and baseline subtraction is optimized for typical photodiode signals
+        of the SCS instrument (ILH, FFT reflectometer, FFT diag stage).
         Inputs:
             data: xarray Dataset containing Fast ADC traces
             key: data key of the array of traces
-            threshold: minimum height of the peaks
+            window: 'small' or 'large': defines the width of the integration region
+                centered on the peak.
             display: bool, displays info on the pulses found
-            plot: plots regions of integration of the first pulse in the trace
+            plot: bool, plots regions of integration of the first pulse in the trace
         Output:
             peaks: DataArray of the integrated peaks 
     '''
@@ -882,25 +882,44 @@ def autoFindFastAdcPeaks(data, channel=5, threshold=35000, display=False, plot=F
     key = f'FastADC{channel}raw'
     if key not in data:
         raise ValueError(f'{key} not found in data set')
-    tid = data[key].where(data[key]>threshold, drop=True).trainId[0]
-    trace = data[key].sel(trainId=tid)
+    #average over the 100 first traces to get at least one train with signal
+    trace = data[key].isel(trainId=slice(0,100)).mean(dim='trainId').values
+    if plot:
+        trace_plot = np.copy(trace)
+    #subtract baseline and check if peaks are positive or negative
+    bl = np.median(trace)
+    trace_no_bl = trace - bl
+    if np.max(trace_no_bl) >= np.abs(np.min(trace_no_bl)):
+        posNeg = 'positive'
+    else:
+        posNeg = 'negative'
+        trace_no_bl *= -1
+        trace = bl + trace_no_bl
+    threshold = bl + np.max(trace_no_bl) / 2
+    #find peaks
     centers, peaks = find_peaks(trace, height=threshold, width=(4, None))
     c = centers[0]
     w = np.average(peaks['widths']).astype(int)
     period = np.median(np.diff(centers)).astype(int)
     npulses = centers.shape[0]
-    intstart = int(c - w/4) + 1
-    intstop = int(c + w/4) + 1
+    if window not in ['small', 'large']:
+        raise ValueError(f"'window argument should be either 'small' or 'large', not {window}")
+    if window=='small':
+        intstart = int(c - w/4) + 1
+        intstop = int(c + w/4) + 1
+    if window=='large':
+        intstart = int(peaks['left_ips'][0])
+        intstop = int(peaks['right_ips'][0]) + w
     bkgstop = int(peaks['left_ips'][0])-5
     bkgstart = bkgstop - 10
     if display:
-        print(f'Found {npulses} pulses, avg. width={w}, period={period} samples, ' +
+        print(f'Found {npulses} {posNeg} pulses, avg. width={w}, period={period} samples, ' +
               f'rep. rate={1e6/(9.230769*period):.3f} kHz')
     fAdcPeaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop,
                          bkgstart=bkgstart, bkgstop=bkgstop, period=period, npulses=npulses)
     if plot:
         plt.figure()
-        plt.plot(trace, 'o-', ms=3)
+        plt.plot(trace_plot, 'o-', ms=3)
         for i in range(npulses):
             plt.axvline(intstart+i*period, ls='--', color='g')
             plt.axvline(intstop+i*period, ls='--', color='r')
-- 
GitLab