diff --git a/xgm.py b/xgm.py
index 3f4556ee4b28d4038295be97cab649b6ee999e0c..11294d641519de9f8d436d1e26f1c30ab4a372c3 100644
--- a/xgm.py
+++ b/xgm.py
@@ -667,38 +667,42 @@ def checkTimApdWindow(data, mcp=1, use_apd=True, intstart=None, intstop=None):
         Output:
             Plot    
     '''
+    mcpToChannel={1:'D', 2:'C', 3:'B', 4:'A'}
+    apdChannels={1:3, 2:2, 3:1, 4:0}
     npulses_max = data['npulses_sase3'].max().values
     tid = data['npulses_sase3'].where(data['npulses_sase3'] == npulses_max,
                                       drop=True)[0].trainId.values
     if 'MCP{}raw'.format(mcp) not in data:
         tid, data_from_train = data.attrs['run'].train_from_id(tid)
-        trace = data_from_train['SCS_UTC1_ADQ/ADC/1:network']['digitizers.channel_1_D.raw.samples']
-        print('no raw data for MCP{}. Loading trace from MCP1'.format(mcp))
-        label_trace='MCP1 Voltage [V]'
+        trace = data_from_train['SCS_UTC1_ADQ/ADC/1:network']['digitizers.channel_1_'
+                                +'{}.raw.samples'.format(mcpToChannel[mcp])]
+        print('no raw data for MCP{}. Loading trace from MCP{}'.format(mcp, mcp))
     else:
-        idx = np.argwhere(data['MCP{}raw'.format(mcp)].trainId.values == tid)[0]
-        trace = data['MCP{}raw'.format(mcp)][idx].T
-        label_trace='MCP{} Voltage [V]'.format(mcp)
+        trace = data['MCP{}raw'.format(mcp)].sel(trainId=tid).T
     if use_apd:
         pulseStart = data.attrs['run'].get_array(
-            'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStart.value')[0].values
+            'SCS_UTC1_ADQ/ADC/1', 
+            'board1.apd.channel_{}.pulseStart.value'.format(apdChannels[mcp]))[0].values
         pulseStop = data.attrs['run'].get_array(
-            'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStop.value')[0].values
+            'SCS_UTC1_ADQ/ADC/1', 
+            'board1.apd.channel_{}.pulseStop.value'.format(apdChannels[mcp]))[0].values
         initialDelay = data.attrs['run'].get_array(
-                        'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.initialDelay.value')[0].values
+            'SCS_UTC1_ADQ/ADC/1', 
+            'board1.apd.channel_{}.initialDelay.value'.format(apdChannels[mcp]))[0].values
         upperLimit = data.attrs['run'].get_array(
-                        'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.upperLimit.value')[0].values
-        nsamples = upperLimit - initialDelay
+            'SCS_UTC1_ADQ/ADC/1', 
+            'board1.apd.channel_{}.upperLimit.value'.format(apdChannels[mcp]))[0].values
     else:
         pulseStart = intstart
         pulseStop = intstop
-        if npulses_max > 1:
-            sa3 = data['sase3'].where(data['sase3']>1)
-            step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
-            step = int(step[1] - step[0])
-            nsamples = 440 * step
-        else:
-            nsamples = 0
+    if npulses_max > 1:
+        sa3 = data['sase3'].where(data['sase3']>1)
+        step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
+        step = int(step[1] - step[0])
+        nsamples = 440 * step
+        print(nsamples)
+    else:
+        nsamples = 0
 
     fig, ax = plt.subplots(figsize=(5,3))
     ax.plot(trace[:pulseStop+25], color='C1', label='first pulse')
@@ -707,7 +711,7 @@ def checkTimApdWindow(data, mcp=1, use_apd=True, intstart=None, intstop=None):
     ax.axvline(pulseStop, color='gray', ls='--')
     ax.set_xlim(pulseStart-25, pulseStop+25)
     ax.locator_params(axis='x', nbins=4)
-    ax.set_ylabel(label_trace)
+    ax.set_ylabel('MCP{} Voltage [V]'.format(mcp))
     ax.set_xlabel('First pulse sample #')
     if npulses_max > 1:
         pulseStart = pulseStart + nsamples*(npulses_max-1)