From 83ae436764d761dfc2107270ac832c4b914c3c2b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Thu, 24 Oct 2019 18:16:29 +0200
Subject: [PATCH] Biining by intra-train pulse id

---
 DSSC.py | 78 +++++++++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 62 insertions(+), 16 deletions(-)

diff --git a/DSSC.py b/DSSC.py
index 8f05bb7..7108427 100644
--- a/DSSC.py
+++ b/DSSC.py
@@ -44,6 +44,7 @@ class DSSC:
         self.geom = None
         self.mask = None
         self.max_fraction_memory = 0.8
+        self.filter_mask = None
         
         print('DSSC configuration')
         print(f'Topic: {self.topic}')
@@ -72,6 +73,7 @@ class DSSC:
         print('Opening run data with karabo-data')
         self.run_nr = run_nr
         self.xgm = None
+        self.filter_mask = None
         self.scan_vname = None
         
         self.run = kd.open_run(self.proposal, self.run_nr)
@@ -214,7 +216,14 @@ class DSSC:
         self.xgm_low = xgm_low
         self.xgm_high = xgm_high
         
-        valid = ((self.xgm > self.xgm_low) * (self.xgm < self.xgm_high)).prod('dim_0').astype(bool)
+        filter_mask = (self.xgm > self.xgm_low) * (self.xgm < self.xgm_high)
+                   
+        if self.filter_mask:
+            self.filter_mask = self.filter_mask*filter_mask
+        else:
+            self.filter_mask = filter_mask
+                   
+        valid = filter_mask.prod('dim_0').astype(bool)
         xgm_valid = self.xgm.where(valid)
         xgm_valid = xgm_valid.dropna('trainId')
         self.scan = self.scan.sel({'trainId': xgm_valid.trainId})
@@ -271,7 +280,7 @@ class DSSC:
             vds_list.append([vds_filename, module_vds])
         return vds_list
     
-    def binning(self):
+    def binning(self, do_pulse_mean=True):
         """ Bin the DSSC data by the predifined scan type (DSSC.define()) using multiprocessing
         
         """
@@ -303,6 +312,7 @@ class DSSC:
             vdf_scan=self.vds_scan,
             nbunches=self.nbunches,
             run_nr=self.run_nr,
+            do_pulse_mean=do_pulse_mean
             ))
             
         timestamp = strftime('%X')
@@ -318,11 +328,15 @@ class DSSC:
         self.module_data['run'] = self.run_nr
         self.module_data = self.module_data.transpose('scan_variable', 'module', 'x', 'y')
                    
-        self.module_data = xr.merge([self.module_data, self.scan.groupby('scan_variable').mean('trainId')])
+        if do_pulse_mean:
+            self.module_data = xr.merge([self.module_data, self.scan.groupby('scan_variable').mean('trainId')])
         self.module_data = self.module_data.squeeze()
         
-        self.module_data.attrs['scan_variable'] = self.scan_vname
-
+        if do_pulse_mean:
+            self.module_data.attrs['scan_variable'] = self.scan_vname
+        else:
+            self.module_data.attrs['scan_variable'] = 'pulse id'
+                   
     def save(self, save_folder=None, overwrite=False):
         """ Save the crunched data.
         
@@ -561,6 +575,7 @@ def process_one_module(job):
     scan_vdf = job['vdf_scan']
     chunksize = job['chunksize']
     nbunches = job['nbunches']
+    do_pulse_mean = job['do_pulse_mean']
 
     image_path = f'INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data'
     npulse_path = f'INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count'
@@ -574,14 +589,25 @@ def process_one_module(job):
     scan.name = 'scan'
     len_scan = len(scan.groupby(scan))
 
-    # create empty dataset to add actual data to
-    module_data = xr.DataArray(np.empty([len_scan, 128, 512], dtype=np.float64),
-                               dims=['scan_variable', 'x', 'y'],
-                               coords={'scan_variable': np.unique(scan)})
-    module_data = module_data.to_dataset(name='pumped')
-    module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
-    module_data['sum_count'] = xr.DataArray(np.zeros_like(np.unique(scan)), dims=['scan_variable'])
-    module_data['module'] = module
+    if do_pulse_mean:
+        # create empty dataset to add actual data to
+        module_data = xr.DataArray(np.empty([len_scan, 128, 512], dtype=np.float64),
+                                   dims=['scan_variable', 'x', 'y'],
+                                   coords={'scan_variable': np.unique(scan)})
+        module_data = module_data.to_dataset(name='pumped')
+        module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
+        module_data['sum_count'] = xr.DataArray(np.zeros_like(np.unique(scan)), dims=['scan_variable'])
+        module_data['module'] = module
+    else:
+        scan = xr.full_like(scan, 1)
+        len_scan = len(scan.groupby(scan))                   
+        module_data = xr.DataArray(np.empty([len_scan, int(nbunches/2), 128, 512], dtype=np.float64),
+                                   dims=['scan_variable', 'pulse', 'x', 'y'],
+                                   coords={'scan_variable': np.unique(scan)})
+        module_data = module_data.to_dataset(name='pumped')
+        module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
+        module_data['sum_count'] = xr.full_like(module_data['pumped'][..., 0, 0], 0)
+        module_data['module'] = module
 
     # crunching
     with h5py.File(data_vdf, 'r') as m:
@@ -608,19 +634,39 @@ def process_one_module(job):
             coords = {'trainId': trainIds_chunk}
             data = np.reshape(data, [n_trains_actual, fpt, 128, 512])[:, :int(2 * nbunches)]
             data = xr.DataArray(data, dims=['trainId', 'pulse', 'x', 'y'], coords=coords)
-            data_pumped = (data[:, ::4]).mean('pulse')
-            data_unpumped = (data[:, 2::4]).mean('pulse')
+            
+            if do_pulse_mean:
+                data_pumped = (data[:, ::4]).mean('pulse')
+                data_unpumped = (data[:, 2::4]).mean('pulse')
+            else:
+                data_pumped = (data[:, ::4])
+                data_unpumped = (data[:, 2::4])
+                   
             data = data_pumped.to_dataset(name='pumped')
             data['unpumped'] = data_unpumped
-            data['sum_count'] = xr.DataArray(np.ones(n_trains_actual), dims=['trainId'], coords=coords)
+            data['sum_count'] = xr.full_like(data['unpumped'][..., 0, 0], fill_value=1)
+                   
             # grouping and summing
             data['scan_variable'] = scan  # this only adds scan data for matching trainIds
             data = data.dropna('trainId')
             data = data.groupby('scan_variable').sum('trainId')
             where = {'scan_variable': data.scan_variable}
+
             for var in ['pumped', 'unpumped', 'sum_count']:
                 module_data[var].loc[where] = module_data[var].loc[where] + data[var]
+            
+            if not do_pulse_mean:
+                break
+
     for var in ['pumped', 'unpumped']:
         module_data[var] = module_data[var] / module_data.sum_count
     #module_data = module_data.drop('sum_count')
+    
+    if not do_pulse_mean:
+        #print(f'#{module}: {module_data}')
+        module_data = module_data.sum('scan_variable')
+        #print(f'#{module}: {module_data}')
+        module_data = module_data.rename({'pulse':'scan_variable'})
+        #print(f'#{module}: {module_data}')        
+                   
     return module_data
-- 
GitLab