diff --git a/xgm.py b/xgm.py
index 97d8518369b4e5ca64fb7aa5e8acd2d172f06417..b17fbf4ff6ad7bc16ed7c61b66986f6f9629a3a1 100644
--- a/xgm.py
+++ b/xgm.py
@@ -141,17 +141,15 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM'):
     #get a list of indices where a change occured:
     idx_change_sa3 = np.argwhere(np.isin(npulses_sa3.trainId.values,
                                      diff.trainId.values, assume_unique=True))[:,0]
-    #add index 0 to get the initial pulse number per train:
-    idx_change_sa3 = np.insert(idx_change_sa3, 0, 0)
 
     #Same for SASE 1:
     diff = npulses_sa1.diff(dim='trainId')
     diff = diff.where(diff !=0, drop=True)
     idx_change_sa1 = np.argwhere(np.isin(npulses_sa1.trainId.values,
                                      diff.trainId.values, assume_unique=True))[:,0]
-    idx_change_sa1 = np.insert(idx_change_sa1, 0, 0)
 
     #create index that locates all changes of pulse number in both SASE1 and 3:
+    #add index 0 to get the initial pulse number per train:
     idx_change = np.unique(np.concatenate(([0], idx_change_sa3, idx_change_sa1))).astype(int)
     if sase=='sase1':
         npulses = npulses_sa1
@@ -229,24 +227,21 @@ def filterOnTrains(data, key='sase3'):
             filtered xarray Dataset
     '''
     key = 'npulses_' + key
-    res = {}
-    for d in data:
-        res[d] = data[d].where(data[key]>0, drop=True)
-    return xr.Dataset(res)
-
+    res = data.where(data[key]>0, drop=True)
+    return res
 
 def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset=1760, mcp=1, npulses=None):
     ''' Computes peak integration from raw MCP traces.
     
         Inputs:
             data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
-            npulses: number of pulses
             intstart: trace index of integration start
             intstop: trace index of integration stop
             bkgstart: trace index of background start
             bkgstop: trace index of background stop
             t_offset: index separation between two pulses
             mcp: MCP channel number
+            npulses: number of pulses
             
         Output:
             results: DataArray with dims trainId x max(sase3 pulses)*1MHz/intra-train rep.rate 
@@ -257,8 +252,8 @@ def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset=1760, mcp=1, n
         raise ValueError("Source not found: {}!".format(keyraw))
     if npulses is None:
         npulses = int((data['sase3'].max().values + 1)/4)
-        sa3 = data['sase3'].where(data['sase3']>1)/4
-        sa3 -= sa3[:,0]
+    sa3 = data['sase3'].where(data['sase3']>1)/4
+    sa3 -= sa3[:,0]
     results = xr.DataArray(np.empty((sa3.shape[0], npulses)), coords=sa3.coords,
                            dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
     for i in range(npulses):
@@ -328,3 +323,59 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
             tim = xr.concat([tim, temp], dim='trainId')
     return tim
     
+def calibrateTIM(data, rollingWindow=200, mcp=1, use_apd=True, intstart=None, intstop=None,
+              bkgstart=None, bkgstop=None, t_offset=1760, npulses_apd=None):
+    ''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM
+        (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
+        The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to
+        match the SASE3-only average TIM pulse peak per train (TIM_avg) to the slow XGM 
+        signal E_slow.
+        Since E_slow is the average energy per pulse over all SASE1 and SASE3 
+        pulses (N1 and N3), we first extract the relative contribution C of the SASE3 pulses
+        by looking at the pulse-resolved signals of the SA3_XGM in the tunnel.
+        There, the signal of SASE1 is usually strong enough to be above noise level.
+        Let TIM_avg be the average of the TIM pulses (SASE3 only).
+        The calibration factor is then defined as: F = E_slow * C * (N1+N3) / ( N3 * TIM_avg ).
+        If N3 changes during the run, we locate the indices for which N3 is maximum and define
+        a window where to apply calibration (indices start/stop).
+        
+        Warning: the calibration does not include the transmission by the KB mirrors!
+        
+        Inputs:
+            data: xarray Dataset
+            rolling window: number of trains to perform a running average on to match
+                            TIM-avg and E_slow
+            mcp: MCP channel
+            use_apd: boolean. If False, the TIM pulse peaks are extract from raw traces using
+                     getTIMapd
+            intstart: trace index of integration start
+            intstop: trace index of integration stop
+            bkgstart: trace index of background start
+            bkgstop: trace index of background stop
+            t_offset: index separation between two pulses
+            mcp: MCP channel number
+            npulses_apd: number of pulses
+            
+        Output:
+            F: float, TIM calibration factor.
+        
+    '''
+    start = 0
+    stop = None
+    npulses = data['npulses_sase3']
+    if not np.all(npulses == npulses[0]):
+        start = np.argmax(npulses.values)
+        stop = ntrains + np.argmax(npulses.values[::-1]) - 1
+        if stop - start < rollingWindow:
+            print('not enough consecutive data points with the largest number of pulses per train')
+        start += rollingWindow
+        stop = np.min((ntrains, stop+rollingWindow))
+        #print(start, stop)
+    filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd)
+    sa3contrib = calcContribSASE(data, 'sase3', 'SA3_XGM')
+    avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean()
+    ratio = ((data['npulses_sase3']+data['npulses_sase1']) *
+             data['SCS_XGM_SLOW'] * sa3contrib) / (avgFast*data['npulses_sase3'])
+    F = float(ratio[start:stop].median().values)
+    return F
+