From dabb2b01766150bd8e06270b22507f2dd2698d17 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Mon, 28 Oct 2019 15:03:03 +0100
Subject: [PATCH] simplified and generalized code

---
 DSSC1module.py | 374 ++++++++++++++++++-------------------------------
 1 file changed, 138 insertions(+), 236 deletions(-)

diff --git a/DSSC1module.py b/DSSC1module.py
index e0ca15f..81f6c91 100644
--- a/DSSC1module.py
+++ b/DSSC1module.py
@@ -12,6 +12,7 @@ from karabo_data.read_machinery import find_proposal
 import ToolBox as tb
 import matplotlib.pyplot as plt
 from mpl_toolkits.axes_grid1 import ImageGrid
+import matplotlib.patches as patches
 import numpy as np
 import xarray as xr
 import h5py
@@ -52,22 +53,22 @@ class DSSC1module:
         if not os.path.exists(self.save_folder):
             warnings.warn(f'Default save folder does not exist: {self.save_folder}')
             
-        self.dark_data = None
+        self.dark_data = 0
         self.max_fraction_memory = 0.8
         self.Nworker = 16
+        self.rois = None
         
     def __del__(self):
         # deleting temporay folder
         if self.tempdir:
             shutil.rmtree(self.tempdir)
     
-    def open_run(self, run_nr, isDark=False):
+    def open_run(self, run_nr, t0=0.0):
         """ Open a run with karabo-data and prepare the virtual dataset for multiprocessing
         
             inputs:
                 run_nr: the run number
-                isDark: a boolean to specify if the run is a dark run or not
-        
+                t0: optional t0 in mm
         """
         
         print('Opening run data with karabo-data')
@@ -75,7 +76,6 @@ class DSSC1module:
         self.xgm = None
         
         self.run = kd.open_run(self.proposal, self.run_nr)
-        self.isDark = isDark
         self.plot_title = f'{self.proposal} run: {self.run_nr}'
         
         self.fpt = self.run.detector_info('SCS_DET_DSSC1M-1/DET/0CH0:xtdf')['frames_per_train']
@@ -99,7 +99,23 @@ class DSSC1module:
 
         print('Creating virtual dataset')
         self.vdslist = self.create_virtual_dssc_datasets(self.run, path=self.tempdir)
-                
+        
+        
+        print(f'Loading XGM data')
+        self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'],
+                                      tb.mnemonics['SCS_SA3']['key'],
+                                      roi=kd.by_index[:self.nbunches])
+        self.xgm = self.xgm.rename({'dim_0':'pulseId'})
+        self.xgm['pulseId'] = np.arange(0, 2*self.nbunches, 2)
+        
+        print(f'Loading mono nrj data')
+        self.nrj = self.run.get_array(tb.mnemonics['nrj']['source'],
+                                      tb.mnemonics['nrj']['key'])
+        print(f'Loading daly line data')
+        self.delay_mm = self.run.get_array(tb.mnemonics['PP800_DelayLine']['source'],
+                                      tb.mnemonics['PP800_DelayLine']['key'])
+        self.t0 = t0
+        self.delay_ps = tb.positionToDelay(self.delay_mm, origin=self.t0)
                     
     def create_virtual_dssc_datasets(self, run, path=''):
         """ Create virtual datasets for each the DSSC module used for the multiprocessing.
@@ -120,55 +136,12 @@ class DSSC1module:
         vds_list.append([vds_filename, module_vds])
         
         return vds_list
-    
-    def process_dark(self):
-        
-        if not self.isDark:
-            raise ValueError('The run was not loaded as a dark')
-            
-        # get available memory in GB, we will try to use 80 % of it
-        max_GB = psutil.virtual_memory().available/1024**3
-        print(f'max available memory: {max_GB} GB')
-        
-        # max_GB / (8byte * Nworker * 128px * 512px * N_pulses)
-        self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt))
-        
-        print('processing', self.chunksize, 'trains per chunk')
-                   
-        jobs = []
-        for m in range(self.Nworker):
-            jobs.append(dict(
-            fpt=self.fpt,
-            module=self.module,
-            vdf_module=os.path.join(self.tempdir, f'dssc{self.module}_vds.h5'),
-            chunksize=self.chunksize,
-            nbunches=self.nbunches,
-            workerId=m,
-            Nworker=self.Nworker
-            ))
-            
-        timestamp = strftime('%X')
-        print(f'start time: {timestamp}')
-
-        #with multiprocessing.Pool(self.Nworker) as pool:
-        #    res = pool.map(sum_all_one_module, jobs)
-        
-        res = []
-        for w in range(self.Nworker):
-            res.append(sum_all_one_module(jobs[w]))
-    
-        print('finished:', strftime('%X'))
-    
-        # rearange the multiprocessed data
-        module_data = xr.concat(res, dim='worker').transpose('worker', 'pulse', 'x', 'y').sum(dim='worker')
-        self.dark_data = module_data['data']/module_data['counts']
-        self.dark_data['run'] = self.run_nr
-        
-        self.module_data = self.dark_data
-    
-    def process_run(self):
+       
+    def process(self, dark_pass=None):
         """ Process DSSC data from one module using multiprocessing
         
+            dark_pass: if None, process data, if 'mean', compute the mean, if 'std', compute the std
+        
         """
 
         # get available memory in GB, we will try to use 80 % of it
@@ -179,6 +152,18 @@ class DSSC1module:
         self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt))
                 
         print('processing', self.chunksize, 'trains per chunk')
+        
+        if dark_pass == 'mean':
+            rois = None
+            dark = 0
+        elif dark_pass == 'std':
+            dark = self.dark_data['dark_mean']
+            rois = None
+        elif dark_pass is None:
+            dark = self.dark_data['dark_mean']
+            rois = self.rois
+        else:
+            raise ValueError(f"dark_pass should be either None or 'mean' or 'std' but not {dark_pass}")
                    
         jobs = []
         for m in range(self.Nworker):
@@ -190,7 +175,8 @@ class DSSC1module:
             nbunches=self.nbunches,
             workerId=m,
             Nworker=self.Nworker,
-            dark_data=self.dark_data
+            dark_data=dark,
+            rois=rois
             ))
             
         timestamp = strftime('%X')
@@ -198,6 +184,7 @@ class DSSC1module:
 
         #with multiprocessing.Pool(self.Nworker) as pool:
         #    res = pool.map(sum_all_one_module, jobs)
+       
         
         res = []
         for w in range(self.Nworker):
@@ -206,54 +193,83 @@ class DSSC1module:
         print('finished:', strftime('%X'))
     
         # rearange the multiprocessed data
-        module_data = xr.concat(res, dim='worker').transpose('worker', 'pulse', 'x', 'y').sum(dim='worker')
-        self.std_data = np.sqrt(module_data['std_data']/(module_data['counts'] - 1))
-        self.dark_corrected_data = module_data['dark_corrected_data']/module_data['counts']
-        self.std_data['run'] = self.run_nr
-        self.dark_corrected_data['run'] = self.run_nr
+        # this is to get rid of the worker dimension, there is no sum over worker really involved
+        self.module_data = xr.concat(res, dim='worker').sum(dim='worker')
         
-        self.plot_title = f"{self.proposal} run: {self.std_data['run'].values} dark: {self.dark_data['run'].values}"
+        # reorder the dimension
+        if 'trainId' in self.module_data.dims:
+            self.module_data = self.module_data.transpose('trainId', 'pulseId', 'x', 'y')
+        else:
+            self.module_data = self.module_data.transpose('pulseId', 'x', 'y')
+        
+        # fix some computation now that we have everything
+        self.module_data['std_data'] = np.sqrt(self.module_data['std_data']/(self.module_data['counts'] - 1))
+        self.module_data['dark_corrected_data'] = self.module_data['dark_corrected_data']/self.module_data['counts']
+             
+        self.module_data['run'] = self.run_nr
+      
+        if dark_pass == 'mean':
+            self.dark_data = self.module_data['dark_corrected_data'].to_dataset('dark_mean')
+            self.dark_data['run'] = self.run_nr
+        elif dark_pass == 'std':
+            self.dark_data['dark_std'] = self.module_data['std_data']
+            assert self.dark_data['run'] == self.run_nr, "noise map computed from different darks"
+        else:
+            self.module_data['xgm'] = self.xgm
+            self.module_data['nrj'] = self.nrj
+            self.module_data['delay_mm'] = self.delay_mm
+            self.module_data['delay_ps'] = self.delay_ps
+            self.module_data['t0'] = self.t0
+            
+            
+        self.plot_title = f"{self.proposal} run: {self.module_data['run'].values} dark: {self.dark_data['run'].values}"            
         
-    def comput_mask(self, low=0.01, high=0.8):
+    def compute_mask(self, low=0.01, high=0.8):
         """ Compute a DSSC module mask from the noise map of a dark run.
         """
         
-        if self.std_data is None:
-            raise ValueError('Cannot compute from from a missing noise map')
+        if self.dark_data['dark_std'] is None:
+            raise ValueError('Cannot compute from from a missing dark noise map')
         
         fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, figsize=[5, 4*2.5])
-        im = ax1.imshow(self.dark_data.mean('pulse'), vmin=0, vmax=90)
+        im = ax1.imshow(self.dark_data['dark_mean'].mean('pulseId'), vmin=0, vmax=90)
         fig.colorbar(im, ax=ax1)
         ax1.set_title('mean')
         fig.suptitle(self.plot_title)
         
-        im = ax2.imshow(self.std_data.mean('pulse'), vmin=0, vmax=2)
+        im = ax2.imshow(self.dark_data['dark_std'].mean('pulseId'), vmin=0, vmax=2)
         fig.colorbar(im, ax=ax2)
         ax2.set_title('std')
         
-        ax3.hist(self.std_data.values.flatten(), bins=200, range=[0, 2], density=True)
+        ax3.hist(self.dark_data['dark_std'].values.flatten(), bins=200, range=[0, 2], density=True)
+        ax3.axvline(low, ls='--', c='k')
+        ax3.axvline(high, ls='--', c='k')
         ax3.set_yscale('log')
         ax3.set_ylabel('density')
         ax3.set_xlabel('std values')
         
-        self.mask = 1 - (1.0*(self.std_data.min('pulse') > high) + (1.0*(self.std_data.max('pulse') < low)))
+        self.mask = 1 - (1.0*(self.dark_data['dark_std'].min('pulseId') > high)
+                         + (1.0*(self.dark_data['dark_std'].max('pulseId') < low)))
         im = ax4.imshow(self.mask)
         fig.colorbar(im, ax=ax4)
 
-    def save(self, save_folder=None, overwrite=False):
+    def save(self, save_folder=None, overwrite=False, isDark=False):
         """ Save the crunched data.
         
             inputs:
                 save_folder: string of the fodler where to save the data.
                 overwrite: boolean whether or not to overwrite existing files.
+                isDark: save the dark or the process data
         """
         if save_folder is None:
-            save_folder = this.save_folder
+            save_folder = self.save_folder
 
-        if self.isDark:
+        if isDark:
             fname = f'run{self.run_nr}_dark.h5'  # no scan
+            data = self.dark_data
         else:
             fname = f'run{self.run_nr}.h5'  # run with delay scan (change for other scan types!)
+            data = self.module_data
 
 
         save_path = os.path.join(save_folder, fname)
@@ -263,43 +279,38 @@ class DSSC1module:
             if file_exists:
                 warnings.warn(f'Overwriting file: {save_path}')
                 os.remove(save_path)
-            self.module_data.to_netcdf(save_path, group='data')
+            data.to_netcdf(save_path, group='data')
             os.chmod(save_path, 0o664)
             print('saving: ', save_path)
         else:
             print('file', save_path, 'exists and overwrite is False')
-                   
-    def load_binned(self, runNB, dark_runNB, xgm_norm = True, save_folder=None):
-        """ load previously binned (crunched) DSSC data by DSSC.crunch() and DSSC.save()
+            
+    def load_dark(self, dark_runNB, save_folder=None):
+        """ Load dark data.
         
             inputs:
-                runNB: run number to load
-                dark_runNB: run number of the corresponding dark
-                xgm_norm: normlize by XGM data if True
-                save_folder: path string  where the crunched data are saved
+                save_folder: string of the folder where the data were saved.
         """
 
         if save_folder is None:
             save_folder = self.save_folder
 
-        self.plot_title = f'{self.proposal} run: {runNB} dark: {dark_runNB}'
-                   
-        dark = xr.open_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.h5'), group='data')
-        binned = xr.open_dataset(os.path.join(save_folder, f'run{runNB}.h5'), group='data')
-
-        binned['pumped'] = (binned['pumped'] - dark['pumped'].values)
-        binned['unpumped'] = (binned['unpumped'] - dark['unpumped'].values)
+        fname = f'run{self.run_nr}_dark.h5'  # no scan
+        self.dark_data = xr.open_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.h5'), group='data')
 
-        if xgm_norm:
-            binned['pumped'] = binned['pumped'] / binned['xgm_pumped']
-            binned['unpumped'] = binned['unpumped'] / binned['xgm_unpumped']
-        
-        self.scan_points = binned['scan_variable']
-        self.scan_points_counts = binned['sum_count'][:, 0]
-        self.scan_vname = binned.attrs['scan_variable']
-        self.scan = None
+    def show_rois(self):
+        fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5])
+        ax1.imshow(self.module_data['dark_corrected_data'].mean('pulseId') * self.mask)
+        for r,v in self.rois.items():
+            rect = patches.Rectangle((v['y'][0], v['x'][0]),
+                                     v['y'][1] - v['y'][0],
+                                     v['x'][1] - v['x'][0],
+                                     linewidth=1, edgecolor='r', facecolor='none')
 
-        self.binned = binned
+            ax1.add_patch(rect)
+            
+        fig.suptitle(self.plot_title)
+        
                    
     def plot_DSSC(self, use_mask = True, p_low = 1, p_high = 98, vmin = None, vmax = None):
         """ Plot pumped and unpumped DSSC images.
@@ -361,65 +372,6 @@ class DSSC1module:
  
 # since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used
 # by the multiprocessing pool.map function
-def sum_all_one_module(job):
-    
-    chunksize = job['chunksize']
-    Nworker = job['Nworker']
-    workerId = job['workerId']
-    fpt = job['fpt']
-    module = job['module']
-    
-    image_path = f"INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data"
-    npulse_path = f"INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count"
-        
-    with h5py.File(job['vdf_module'], 'r') as m:
-        all_trainIds = m['INDEX/trainId'][()]
-    n_trains = len(all_trainIds)
-    
-    # chunks distributed between workers
-    worker_chunks = np.linspace(0, n_trains, Nworker+1, dtype=int)
-    
-    # this worker's chunck is evenly split into maximum chunksize by allowed distributed memory
-    n_chunks = np.ceil((worker_chunks[workerId+1] - worker_chunks[workerId])/chunksize)
-    if workerId == 0:
-        print(n_chunks)
-    chunks = np.linspace(worker_chunks[workerId], worker_chunks[workerId+1], n_chunks+1, dtype=int)
-    if workerId == 0:
-        print(chunks)
-    
-    # create empty dataset to add actual data to
-    module_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
-                               dims=['pulse', 'x', 'y'])
-    counts = 0
-    
-    # crunching
-    with h5py.File(job['vdf_module'], 'r') as m:
-
-        #chunk_start = np.arange(len(all_trainIds), step=job['chunksize'], dtype=int)
-        trains_start = 0
-                   
-        # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485
-        print(' ', end='', flush=True)
-         
-        for k,v in enumerate(tqdm(chunks, desc=f"pool.map#{workerId:02d}")):
-            if k == chunks.shape[0] - 1:
-                continue
-                
-            chunk_dssc = np.s_[int(chunks[k] * fpt):int(chunks[k+1] * fpt)]  # for dssc data
-            data = m[image_path][chunk_dssc].squeeze()
-            data = data.astype(np.float64)
-            n_trains = int(data.shape[0] // fpt)
-            data = np.reshape(data, [n_trains, fpt, 128, 512])
-            
-            data = data.sum(axis=0)
-                       
-            module_data += data
-            counts += chunks[k+1] - chunks[k]
-    
-    module_data = module_data.to_dataset(name='data')
-    module_data['counts'] = counts
-    return module_data
-
 def process_one_module(job):
     
     chunksize = job['chunksize']
@@ -428,6 +380,7 @@ def process_one_module(job):
     dark_data = job['dark_data']
     fpt = job['fpt']
     module = job['module']
+    rois = job['rois']
     
     image_path = f"INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data"
     npulse_path = f"INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count"
@@ -438,6 +391,7 @@ def process_one_module(job):
     
     # chunks distributed between workers
     worker_chunks = np.linspace(0, n_trains, Nworker+1, dtype=int)
+    trainIds_chunk = all_trainIds[worker_chunks[workerId]:worker_chunks[workerId+1]]
     
     # this worker's chunck is evenly split into maximum chunksize by allowed distributed memory
     n_chunks = np.ceil((worker_chunks[workerId+1] - worker_chunks[workerId])/chunksize)
@@ -448,11 +402,17 @@ def process_one_module(job):
         print(chunks)
     
     # create empty dataset to add actual data to
-    dark_corrected_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
-                               dims=['pulse', 'x', 'y'])
-    std_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
-                               dims=['pulse', 'x', 'y'])
-    counts = 0
+    module_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
+                               dims=['pulseId', 'x', 'y'],
+                               coords={'pulseId':np.arange(fpt)}).to_dataset(name='dark_corrected_data')
+    module_data['std_data'] = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
+                               dims=['pulseId', 'x', 'y'])
+
+    if rois is not None:
+        for k in rois.keys():
+            module_data[k] = xr.DataArray(np.empty([len(trainIds_chunk)], dtype=np.float64),
+                                   dims=['trainId'], coords = {'trainId': trainIds_chunk})
+    module_data['counts'] = 0
     
     # crunching
     with h5py.File(job['vdf_module'], 'r') as m:
@@ -469,87 +429,29 @@ def process_one_module(job):
                 
             chunk_dssc = np.s_[int(chunks[k] * fpt):int(chunks[k+1] * fpt)]  # for dssc data
             data = m[image_path][chunk_dssc].squeeze()
+            
+            trains = m['INDEX/trainId'][np.s_[int(chunks[k]):int(chunks[k+1])]]
+            n_trains = len(trains)                   
+                    
             data = data.astype(np.float64)
-            n_trains = int(data.shape[0] // fpt)
             data = xr.DataArray(np.reshape(data, [n_trains, fpt, 128, 512]),
-                               dims=['trains', 'pulse', 'x', 'y'])
+                                dims=['trainId', 'pulseId', 'x', 'y'],
+                                coords={'trainId': trains})
             
             temp = data - dark_data
-            dark_corrected_data_ = temp.sum(axis=0)
-            std_data_ = (temp**2).sum(axis=0)
-                       
-            dark_corrected_data += dark_corrected_data_
-            std_data += std_data_
             
-            counts += chunks[k+1] - chunks[k]
-    
-    module_data = std_data.to_dataset(name='std_data')
-    module_data['dark_corrected_data'] = dark_corrected_data
-    module_data['counts'] = counts
-    return module_data
-
-def process_one_module_old(job):
-    module = job['module']
-    fpt = job['fpt']
-    data_vdf = job['vdf_module']
-    scan_vdf = job['vdf_scan']
-    chunksize = job['chunksize']
-    nbunches = job['nbunches']
-
-
-    # load scan variable
-    scan = xr.open_dataset(scan_vdf, group='data')['scan_variable']
-    scan.name = 'scan'
-    len_scan = len(scan.groupby(scan))
-
-    # create empty dataset to add actual data to
-    module_data = xr.DataArray(np.zeros([len_scan, 128, 512], dtype=np.float64),
-                               dims=['scan_variable', 'x', 'y'],
-                               coords={'scan_variable': np.unique(scan)})
-    module_data = module_data.to_dataset(name='pumped')
-    module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
-    module_data['sum_count'] = xr.DataArray(np.zeros_like(np.unique(scan)), dims=['scan_variable'])
-    module_data['module'] = module
-
-    # crunching
-    with h5py.File(data_vdf, 'r') as m:
-        #fpt_calc = int(len(m[image_path]) / n_trains)
-        #assert fpt_calc == fpt, f'data length does not match expected value (module {module})'
-        all_trainIds = m['INDEX/trainId'][()]
-        frames_per_train = m[npulse_path][()]
-        trains_with_data = all_trainIds[frames_per_train == fpt]
-        #print(np.unique(pulses_per_train), '/', fpt)
-        #print(len(trains_with_data))
-        chunk_start = np.arange(len(all_trainIds), step=chunksize, dtype=int)
-        trains_start = 0
-                   
-        # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485
-        print(' ', end='', flush=True)
-        for c0 in tqdm(chunk_start, desc=f'pool.map#{module:02d}', position=module):
-            chunk_dssc = np.s_[int(c0 * fpt):int((c0 + chunksize) * fpt)]  # for dssc data
-            data = m[image_path][chunk_dssc].squeeze()
-            data = data.astype(np.float64)
-            n_trains = int(data.shape[0] // fpt)
-            trainIds_chunk = np.unique(trains_with_data[trains_start:trains_start + n_trains])
-            trains_start += n_trains
-            n_trains_actual = len(trainIds_chunk)
-            coords = {'trainId': trainIds_chunk}
-            data = np.reshape(data, [n_trains_actual, fpt, 128, 512])[:, :int(2 * nbunches)]
-            data = xr.DataArray(data, dims=['trainId', 'pulse', 'x', 'y'], coords=coords)
-            data_pumped = (data[:, ::4]).mean('pulse')
-            data_unpumped = (data[:, 2::4]).mean('pulse')
-            data = data_pumped.to_dataset(name='pumped')
-            data['unpumped'] = data_unpumped
-            data['sum_count'] = xr.DataArray(np.ones(n_trains_actual), dims=['trainId'], coords=coords)
-            # grouping and summing
-            data['scan_variable'] = scan  # this only adds scan data for matching trainIds
-            data = data.dropna('trainId')
-            data = data.groupby('scan_variable').sum('trainId')
-            where = {'scan_variable': data.scan_variable}
-            for var in ['pumped', 'unpumped', 'sum_count']:
-                module_data[var].loc[where] = module_data[var].loc[where] + data[var]
-    for var in ['pumped', 'unpumped']:
-        module_data[var] = module_data[var] / module_data.sum_count
-    #module_data = module_data.drop('sum_count')
-    return module_data
+            if rois is not None:
+                for k,v in rois.items():
+                    val = temp.isel({'x':slice(v['x'][0], v['x'][1]),
+                                     'y':slice(v['y'][0], v['y'][1])}).sum(dim=['x','y'])
+                    #if workerId == 0:
+                    #    print(k, val)
+                    module_data[k] = val
 
+            module_data['dark_corrected_data'] += temp.sum(dim='trainId')
+            module_data['std_data'] += (temp**2).sum(dim='trainId')
+            module_data['counts'] += n_trains
+    
+    if workerId == 0:
+        print(module_data)    
+    return module_data
\ No newline at end of file
-- 
GitLab