From 1f092c281115d8b363b52b376a876178e0dc6f3a Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Mon, 18 Nov 2019 21:12:56 +0100
Subject: [PATCH] Adds function to automatically find ADC peaks

---
 xgm.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 46 insertions(+)

diff --git a/xgm.py b/xgm.py
index 59f2dac..dbc47a4 100644
--- a/xgm.py
+++ b/xgm.py
@@ -8,6 +8,7 @@
 import matplotlib.pyplot as plt
 import numpy as np
 import xarray as xr
+from scipy.signal import find_peaks
 
 # Machine
 def pulsePatternInfo(data, plot=False):
@@ -967,6 +968,51 @@ def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, period=Non
         results[:,i] = integ
     return results
 
+def autoFindFastAdcPeaks(data, channel=5, threshold=35000, display=False, plot=False):
+    ''' Automatically finds positive peaks in channel of Fast ADC trace, assuming
+        a minimum absolute height of 'threshold' counts and a minimum width of 4 
+        samples. The find_peaks function and determination of the peak region and 
+        baseline subtraction is optimized for typical photodiode signals of the 
+        SCS instrument (ILH, FFT reflectometer, FFT diag stage).
+        Inputs:
+            data: xarray Dataset containing Fast ADC traces
+            key: data key of the array of traces
+            threshold: minimum height of the peaks
+            display: bool, displays info on the pulses found
+            plot: plots regions of integration of the first pulse in the trace
+        Output:
+            peaks: DataArray of the integrated peaks 
+    '''
+    
+    key = f'FastADC{channel}raw'
+    if key not in data:
+        raise ValueError(f'{key} not found in data set')
+    trace = data[key].where(data['npulses_sase3']>0, drop=True).isel(trainId=0).values
+    centers, peaks = find_peaks(trace, height=threshold, width=(4, None))
+    c = centers[0]
+    w = np.average(peaks['widths']).astype(int)
+    period = np.median(np.diff(centers)).astype(int)
+    npulses = centers.shape[0]
+    intstart = int(c - w/4) + 1
+    intstop = int(c + w/4) + 1
+    bkgstop = int(peaks['left_ips'][0])-5
+    bkgstart = bkgstop - 10
+    if display:
+        print(f'Found {npulses} pulses, avg. width={w}, period={period} samples, ' +
+              f'rep. rate={1e6/(9.230769*period):.3f} kHz')
+    fAdcPeaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop,
+                         bkgstart=bkgstart, bkgstop=bkgstop, period=period, npulses=npulses)
+    if plot:
+        plt.figure()
+        plt.plot(trace, 'o-', ms=3)
+        for i in range(npulses):
+            plt.axvline(intstart+i*period, ls='--', color='g')
+            plt.axvline(intstop+i*period, ls='--', color='r')
+            plt.axvline(bkgstart+i*period, ls='--', color='lightgrey')
+            plt.axvline(bkgstop+i*period, ls='--', color='grey')
+        plt.title(f'Fast ADC {channel} trace')
+        plt.xlim(bkgstart-10, intstop + 50)
+    return fAdcPeaks
 
 def mergeFastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, 
                       period=None, npulses=None, dim='lasPulseId'):
-- 
GitLab