diff --git a/xgm.py b/xgm.py
index b4a3c6473f600a0fa6085712caea530991aa50ed..ddb9eec7fb7d0759f7603acaa1b28b31d5a655aa 100644
--- a/xgm.py
+++ b/xgm.py
@@ -195,7 +195,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM'):
     return result
 
 
-def calcContribSASE(data, sase='sase1', xgm='SA3_XGM'):
+def saseContribution(data, sase='sase1', xgm='SA3_XGM'):
     ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
         for each train in the run. Supports fresh bunch, dedicated trains
         and pulse on demand modes.
@@ -211,13 +211,11 @@ def calcContribSASE(data, sase='sase1', xgm='SA3_XGM'):
     '''
     xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm)
     xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm)
-    if np.all(xgm_sa1.trainId.isin(xgm_sa3.trainId).values) == False:
-        print('Dedicated mode')
-        r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['SA3_XGM_dim', 'SA1_XGM_dim'])
-        xgm_sa1 = r[0]
-        xgm_sa1.fillna(0)
-        xgm_sa3 = r[1]
-        xgm_sa3.fillna(0)
+    #Fill missing train ids with 0
+    r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['XGMbunchId'])
+    xgm_sa1 = r[0].fillna(0)
+    xgm_sa3 = r[1].fillna(0)
+
     contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1))
     if sase=='sase1':
         return contrib
@@ -300,7 +298,7 @@ def calibrateXGMs(data, rollingWindow=200, plot=False):
         print('calibration factor SA3 XGM: %f'%sa3_calib_factor)
 
     # Calibrate SCS XGM with SASE3-only contribution
-    sa3contrib = calcContribSASE(data, 'sase3', 'SA3_XGM')
+    sa3contrib = saseContribution(data, 'sase3', 'SA3_XGM')
     if not noSCS:
         scs_sase3_fast = selectSASEinXGM(data, 'sase3', 'SCS_XGM').mean(axis=1)
         meanFast = scs_sase3_fast.rolling(trainId=rollingWindow).mean()
@@ -438,8 +436,8 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
     return tim
 
 
-def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None, intstop=None,
-              bkgstart=None, bkgstop=None, t_offset=1760, npulses_apd=None):
+def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None,
+                 intstop=None, bkgstart=None, bkgstop=None, t_offset=1760, npulses_apd=None):
     ''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM
         (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
         The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to
@@ -486,9 +484,8 @@ def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intst
             print('not enough consecutive data points with the largest number of pulses per train')
         start += rollingWindow
         stop = np.min((ntrains, stop+rollingWindow))
-        #print(start, stop)
     filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd)
-    sa3contrib = calcContribSASE(data, 'sase3', 'SA3_XGM')
+    sa3contrib = saseContribution(data, 'sase3', 'SA3_XGM')
     avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean()
     ratio = ((data['npulses_sase3']+data['npulses_sase1']) *
              data['SCS_XGM_SLOW'] * sa3contrib) / (avgFast*data['npulses_sase3'])
@@ -505,9 +502,7 @@ def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intst
         ax.plot(avgFast*F, label='Calibrated TIM rolling avg', color='C2')
         ax.legend(loc='upper left', fontsize=8)
         ax.set_ylabel('Energy [$\mu$J]', size=10)
-        #ax2=ax#.twinx()
         ax.plot(filteredTIM.mean(axis=1)*F, label='Calibrated TIM train avg', alpha=0.2, color='C9')
-        #ax2.set_ylabel('Calibrated TIM (MCP{}) [uJ]'.format(mcp))
         ax.legend(loc='best', fontsize=8, ncol=2)
         plt.xlabel('train in run')
         
@@ -530,10 +525,6 @@ def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intst
         ax.hist(filteredTIM.values.flatten()*F, bins=50, rwidth=0.8)
         ax.set_ylabel('number of pulses', size=10)
         ax.set_xlabel('Pulse energy MCP{} [uJ]'.format(mcp), size=10)
-        #ax2 = ax.twiny()
-        #ax2.set_xlabel('MCP 1 APD')
-        #toRaw = lambda x: x/F
-        #ax2.set_xlim((toRaw(ax.get_xlim()[0]),toRaw(ax.get_xlim()[1])))
         ax.set_yscale('log')
         
         ax = plt.subplot(236)