From fcb27cf7c4915f7b02ed576d87d1f062832819b6 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Tue, 17 Mar 2020 20:30:23 +0100
Subject: [PATCH] Adds source argument in fastAdc to rename the peak dimension
 accordingly

---
 Load.py          |  6 +++---
 bunch_pattern.py | 21 +++++++++++++++++----
 xgm.py           | 49 +++++++++++++++++++++++++++++++-----------------
 3 files changed, 52 insertions(+), 24 deletions(-)

diff --git a/Load.py b/Load.py
index 7db62c2..70b0c78 100644
--- a/Load.py
+++ b/Load.py
@@ -40,13 +40,13 @@ mnemonics = {
                       'dim':None},
 
     #Bunch Arrival Monitors
-    "BAM6": {'source':'SCS_ILH_LAS/DOOCS/BAM_414_B2:output',
+    "BAM5": {'source':'SCS_ILH_LAS/DOOCS/BAM_414_B2:output',
                       'key':'data.lowChargeArrivalTime',
                       'dim':['BAMbunchId']},
-    "BAM7": {'source':'SCS_ILH_LAS/DOOCS/BAM_1932M_TL:output',
+    "BAM6": {'source':'SCS_ILH_LAS/DOOCS/BAM_1932M_TL:output',
                       'key':'data.lowChargeArrivalTime',
                       'dim':['BAMbunchId']},
-    "BAM8": {'source':'SCS_ILH_LAS/DOOCS/BAM_1932S_TL:output',
+    "BAM7": {'source':'SCS_ILH_LAS/DOOCS/BAM_1932S_TL:output',
                       'key':'data.lowChargeArrivalTime',
                       'dim':['BAMbunchId']},
     
diff --git a/bunch_pattern.py b/bunch_pattern.py
index 42f34a6..1617393 100644
--- a/bunch_pattern.py
+++ b/bunch_pattern.py
@@ -181,22 +181,31 @@ def repRate(data, sase='sase3'):
     f = 1/((sase[0,1] - sase[0,0])*12e-3/54.1666667)
     return f
 
-def sortBAMdata(data, key='sase3'):
+def sortBAMdata(data, key='scs_ppl', sa3Offset=0):
     ''' Extracts beam arrival monitor data from the raw arrays 'BAM6', 'BAM7', etc...
         according to the bunchPatternTable. The BAM arrays contain 7220 values, which
         corresponds to FLASH busrt length of 800 us @ 9 MHz. The bunchPatternTable
         only has 2700 values, corresponding to XFEL 600 us burst length @ 4.5 MHz.
-        Hence, we truncate the BAM arrays to 5400 with a stride of 2 and match them
+        Hence, the BAM arrays are truncated to 5400 with a stride of 2 and matched
         to the bunchPatternTable. If key is one of the sase, the given dimension name
         of the bam arrays is 'sa[sase number]_pId', to match other data (XGM, TIM...).
+        If key is 'scs_ppl', the dimension is named 'ol_pId'
         Inputs:
             data: xarray Dataset containing BAM arrays
             key: str, ['sase1', 'sase2', 'sase3', 'scs_ppl']
-            
+            sa3Offset: int, used if key=='scs_ppl'. Offset in number of pulse_id 
+                between the first OL and FEL pulses. An offset of 40 means that 
+                the first laser pulse comes 40 pulse_id later than the FEL on a 
+                grid of 4.5 MHz. Negative values shift the laser pulse before
+                the FEL one.
         Output:
             ndata: xarray Dataset with same keys as input data (but new bam arrays)
     '''
     a, b, mask = extractBunchPattern(key=key, runDir=data.attrs['run'])
+    if key == 'scs_ppl':
+        a3, b3, mask3 = extractBunchPattern(key='sase3', runDir=data.attrs['run'])
+        firstSa3_pId = a3.where(b3>0, drop=True)[0,0].values.astype(int)
+        mask = mask.roll(pulse_slot=firstSa3_pId+sa3Offset)
     mask = mask.rename({'pulse_slot':'BAMbunchId'})
     ndata = data
     dropList = []
@@ -208,7 +217,11 @@ def sortBAMdata(data, key='sase3'):
             bam = bam.where(mask, drop=True)
             if 'sase' in key:
                 name = f'sa{key[4]}_pId'
-                bam = bam.rename({'BAMbunchId':name})
+            elif key=='scs_ppl':
+                name = 'ol_pId'
+            else:
+                name = 'bam_pId'
+            bam = bam.rename({'BAMbunchId':name})
             mergeList.append(bam)
     mergeList.append(data.drop(dropList))
     ndata = xr.merge(mergeList, join='inner')
diff --git a/xgm.py b/xgm.py
index 7289fc7..192508a 100644
--- a/xgm.py
+++ b/xgm.py
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import xarray as xr
 from scipy.signal import find_peaks
+import ToolBox as tb
 
 # XGM
 def cleanXGMdata(data, npulses=None, sase3First=True):
@@ -815,7 +816,8 @@ def matchXgmTimPulseId(data, use_apd=True, intstart=None, intstop=None,
 
 # Fast ADC
 def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, 
-                 period=None, npulses=None, usePeakValue=False, peakType='pos'):
+                 period=None, npulses=None, source='scs_ppl', 
+                 usePeakValue=False, peakType='pos'):
     ''' Computes peak integration from raw FastADC traces.
     
         Inputs:
@@ -832,8 +834,13 @@ def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop,
                 two bunches @ 4.5 MHz. 
             npulses: number of pulses. If None, takes the maximum number of
                 pulses according to the bunch patter (field 'npulses_sase3')
+            source: str, nature of the pulses: 'sase[1,2 or 3]', or 'scs_ppl',
+                used to give name to the peak Id dimension.
             usePeakValue: bool, if True takes the peak value of the signal, 
                           otherwise integrates over integration region.
+            peakType: str, 'pos' or 'neg'. Used if usePeakValue is True to
+                indicate if min or max value should be extracted.
+                          
             
         Output:
             results: DataArray with dims trainId x max(sase3 pulses) 
@@ -842,20 +849,28 @@ def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop,
     keyraw = 'FastADC{}raw'.format(channel)
     if keyraw not in data:
         raise ValueError("Source not found: {}!".format(keyraw))
-    if npulses is None:
-        npulses = int(data['npulses_sase3'].max().values)
-    if period is None:
-        sa3 = data['sase3'].where(data['sase3']>1)
-        if npulses > 1:
-            #Calculate the number of pulses between two lasing pulses (step)
-            step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
-            step = int(step[1] - step[0])
-            #multiply by elementary pulse length (221.5 ns / 9.23 ns = 24 samples)
-            period = 24 * step
-        else:
-            period = 1
+    if npulses is None or period is None:
+        indices, npulses_bp, mask = tb.extractBunchPattern(runDir=data.attrs['run'], 
+                                                           key=source)
+        if npulses is None:
+            npulses = int(npulses_bp.max().values)
+        if period is None:
+            indices = indices_bp.where(indices_bp>1)
+            if npulses > 1:
+                #Calculate the number of pulses between two lasing pulses (step)
+                step = indices.where(npulses_bp>1, drop=True)[0,:2].values
+                step = int(step[1] - step[0])
+                #multiply by elementary pulse length (221.5 ns / 9.23 ns = 24 samples)
+                period = 24 * step
+            else:
+                period = 1
+    pulseId = 'peakId'
+    if source=='scs_ppl':
+        pulseId = 'ol_pId'
+    if 'sase' in source:
+        pulseId = f'sa{source[4]}_pId'
     results = xr.DataArray(np.empty((data.trainId.shape[0], npulses)), coords=data[keyraw].coords,
-                           dims=['trainId', 'peakId'.format(channel)])
+                           dims=['trainId', pulseId])
     for i in range(npulses):
         a = intstart + period*i
         b = intstop + period*i
@@ -872,8 +887,8 @@ def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop,
         results[:,i] = val
     return results
 
-def autoFindFastAdcPeaks(data, channel=5, window='small', usePeakValue=False, 
-                         display=False, plot=False):
+def autoFindFastAdcPeaks(data, channel=5, window='large', usePeakValue=False, 
+                         source='scs_ppl', display=False, plot=False):
     ''' Automatically finds peaks in channel of Fast ADC trace, a minimum width of 4 
         samples. The find_peaks function and determination of the peak integration 
         region and baseline subtraction is optimized for typical photodiode signals
@@ -929,7 +944,7 @@ def autoFindFastAdcPeaks(data, channel=5, window='small', usePeakValue=False,
               f'rep. rate={1e6/(9.230769*period):.3f} kHz')
     fAdcPeaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop,
                          bkgstart=bkgstart, bkgstop=bkgstop, period=period, npulses=npulses,
-                         usePeakValue=usePeakValue, peakType=posNeg[:3])
+                         source=source, usePeakValue=usePeakValue, peakType=posNeg[:3])
     if plot:
         plt.figure()
         plt.plot(trace_plot, 'o-', ms=3)
-- 
GitLab