Skip to content
Snippets Groups Projects
Forked from SCS / ToolBox
678 commits behind the upstream repository.
DSSC1module.py 16.82 KiB
import multiprocessing
from time import strftime
from tqdm.auto import tqdm
import os
import warnings
import psutil

import extra_data as ed
from extra_data.read_machinery import find_proposal
import ToolBox as tb
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.patches as patches
import numpy as np
import xarray as xr
import h5py
from glob import glob

from imageio import imread

class DSSC1module:
        
    def __init__(self, module, proposal):
        """ Create a DSSC object to process 1 module of DSSC data.
        
            inputs:
                module: module number to process
                proposal: (int,str) proposal number string
            
        """
        self.module = module
        
        if isinstance(proposal,int):
            proposal = 'p{:06d}'.format(proposal)
        self.runFolder = find_proposal(proposal)
        self.semester = self.runFolder.split('/')[-2]
        self.proposal = proposal
        self.topic = self.runFolder.split('/')[-3]
        self.save_folder = os.path.join(self.runFolder, 'usr/condensed_runs/')
        self.px_pitch_h = 236 # horizontal pitch in microns
        self.px_pitch_v = 204 # vertical pitch in microns
        self.aspect = self.px_pitch_v/self.px_pitch_h # aspect ratio of the DSSC images
        
        print('DSSC configuration')
        print(f'DSSC module: {self.module}')
        print(f'Topic: {self.topic}')
        print(f'Semester: {self.semester}')
        print(f'Proposal: {self.proposal}')
        print(f'Default save folder: {self.save_folder}')
        
        if not os.path.exists(self.save_folder):
            warnings.warn(f'Default save folder does not exist: {self.save_folder}')
            
        self.dark_data = 0
        self.max_fraction_memory = 0.8
        self.Nworker = 10
        self.rois = None
        self.maxSaturatedPixel = 1
           
    def open_run(self, run_nr, t0=0.0):
        """ Open a run with extra-data and prepare the virtual dataset for multiprocessing
        
            inputs:
                run_nr: the run number
                t0: optional t0 in mm
        """
        
        print('Opening run data with extra-data')
        self.run_nr = run_nr
        self.xgm = None
        
        self.run = ed.open_run(self.proposal, self.run_nr)
        self.plot_title = f'{self.proposal} run: {self.run_nr}'
        
        self.fpt = self.run.detector_info(f'SCS_DET_DSSC1M-1/DET/{self.module}CH0:xtdf')['frames_per_train']
        self.nbunches = self.run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value')
        self.nbunches = np.unique(self.nbunches)
        if len(self.nbunches) == 1:
            self.nbunches = self.nbunches[0]
        else:
            warnings.warn('not all trains have same length DSSC data')
            print(f'nbunches: {self.nbunches}')
            self.nbunches = self.nbunches[-1]
            
        print(f'DSSC frames per train: {self.fpt}')
        print(f'SA3 bunches per train: {self.nbunches}')
        
        print('Collecting DSSC module files')
        self.collect_dssc_module_file()
        
        
        print(f'Loading XGM data')
        self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'],
                                      tb.mnemonics['SCS_SA3']['key'],
                                      roi=ed.by_index[:self.nbunches])
        self.xgm = self.xgm.rename({'dim_0':'pulseId'})
        self.xgm['pulseId'] = np.arange(0, 2*self.nbunches, 2)
        
        print(f'Loading mono nrj data')
        self.nrj = self.run.get_array(tb.mnemonics['nrj']['source'],
                                      tb.mnemonics['nrj']['key'])
        print(f'Loading delay line data')
        try:
            self.delay_mm = self.run.get_array(tb.mnemonics['PP800_DelayLine']['source'],
                                               tb.mnemonics['PP800_DelayLine']['key'])
        except:
            self.delay_mm = 0*self.nrj
        self.t0 = t0
        self.delay_ps = tb.positionToDelay(self.delay_mm, origin=self.t0, invert=True)
                    
    def collect_dssc_module_file(self):
        """ Collect the raw DSSC module h5 files.
        """
        
        pattern = self.runFolder + f'/raw/r{self.run_nr:04d}/RAW-R{self.run_nr:04d}-DSSC{self.module:02d}-S*.h5'
        self.h5list = glob(pattern)
       
    def process(self, dark_pass=None):
        """ Process DSSC data from one module using multiprocessing
        
            dark_pass: if None, process data, if 'mean', compute the mean, if 'std', compute the std
        
        """

        # get available memory in GB, we will try to use 80 % of it
        max_GB = psutil.virtual_memory().available/1024**3
        print(f'max available memory: {max_GB} GB')
        
        # max_GB / (8byte * Nworker * 128px * 512px * N_pulses)
        self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt))
                
        print('processing', self.chunksize, 'trains per chunk')
        
        if dark_pass == 'mean':
            rois = None
            dark = 0
            mask = 1
        elif dark_pass == 'std':
            dark = self.dark_data['dark_mean']
            rois = None
            mask = 1
        elif dark_pass is None:
            dark = self.dark_data['dark_mean']
            rois = self.rois
            mask = self.dark_data['mask']
        else:
            raise ValueError(f"dark_pass should be either None or 'mean' or 'std' but not {dark_pass}")
                   
        jobs = []
        for m,h5fname in enumerate(self.h5list):
            jobs.append(dict(
            fpt=self.fpt,
            module=self.module,
            h5fname=h5fname,
            chunksize=self.chunksize,
            nbunches=self.nbunches,
            workerId=m,
            Nworker=self.Nworker,
            dark_data=dark,
            rois=rois,
            mask=mask,
            maxSaturatedPixel=self.maxSaturatedPixel
            ))
            
        timestamp = strftime('%X')
        print(f'start time: {timestamp}')

        with multiprocessing.Pool(self.Nworker) as pool:
            res = pool.map(process_one_module, jobs)
        
        print('finished:', strftime('%X'))
        
        # rearange the multiprocessed data
        # this is to get rid of the worker dimension, there is no sum over worker really involved
        self.module_data = xr.concat(res, dim='worker').sum(dim='worker')
        
        # reorder the dimension
        if 'trainId' in self.module_data.dims:
            self.module_data = self.module_data.transpose('trainId', 'pulseId', 'x', 'y')
        else:
            self.module_data = self.module_data.transpose('pulseId', 'x', 'y')
        
        # fix some computation now that we have everything
        self.module_data['std_data'] = np.sqrt(self.module_data['std_data']/(self.module_data['counts'] - 1))
        self.module_data['dark_corrected_data'] = self.module_data['dark_corrected_data']/self.module_data['counts']
             
        self.module_data['run'] = self.run_nr
      
        if dark_pass == 'mean':
            self.dark_data = self.module_data['dark_corrected_data'].to_dataset('dark_mean')
            self.dark_data['run'] = self.run_nr
        elif dark_pass == 'std':
            self.dark_data['dark_std'] = self.module_data['std_data']
            assert self.dark_data['run'] == self.run_nr, "noise map computed from different darks"
        else:
            self.module_data['xgm'] = self.xgm
            self.module_data['nrj'] = self.nrj
            self.module_data['delay_mm'] = self.delay_mm
            self.module_data['delay_ps'] = self.delay_ps
            self.module_data['t0'] = self.t0
            
            
        self.plot_title = f"{self.proposal} run: {self.module_data['run'].values} dark: {self.dark_data['run'].values}"
        self.module_data.attrs['plot_title'] = self.plot_title
        
    def compute_mask(self, low=0.01, high=0.8):
        """ Compute a DSSC module mask from the noise map of a dark run.
        """
        
        if self.dark_data['dark_std'] is None:
            raise ValueError('Cannot compute from from a missing dark noise map')
            
        self.dark_data['mask_low'] = low
        self.dark_data['mask_high'] = high
        
        m_std = self.dark_data['dark_std'].mean('pulseId')
        
        self.dark_data['mask'] = 1 - ((m_std > self.dark_data['mask_high']) + (m_std < self.dark_data['mask_low']))

    def plot_module(self, plot_dark=False, low=1, high=98, vmin=None, vmax=None):
        """ Plot a module.
        
            inputs:
                plot_dark: if true, plot dark instead of run data.
                low: low percentile fraction of the display scale
                high: high percentile fraction of the display scale
                vmin: low value of the display scale, overwrites vmin computed from low
                vmax: max value of the display scale, overwrites vmax computed from high
                
        """
        
        if plot_dark:
            mean = self.dark_data['dark_mean'].mean('pulseId')
            std = self.dark_data['dark_std']
            title = f"{self.proposal} dark: {self.dark_data['run'].values}"
        else:
            mean = self.module_data['dark_corrected_data'].mean('pulseId')
            std = self.module_data['std_data']
            title = self.plot_title
            
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, figsize=[5, 4*2.5])
        _vmin, _vmax = np.percentile((mean.values[~self.dark_data['mask']]).flatten(), [low, high])
        if vmin is None:
            vmin = _vmin
        if vmax is None:
            vmax = _vmax
        im = ax1.imshow(mean, vmin=vmin, vmax=vmax)
        fig.colorbar(im, ax=ax1)
        ax1.set_title('mean')
        fig.suptitle(title)
        
        im = ax2.imshow(std.mean('pulseId'), vmin=0, vmax=2)
        fig.colorbar(im, ax=ax2)
        ax2.set_title('std')
        
        ax3.hist(std.values.flatten(), bins=200, range=[0, 2], density=True)
        ax3.axvline(self.dark_data['mask_low'], ls='--', c='k')
        ax3.axvline(self.dark_data['mask_high'], ls='--', c='k')
        ax3.set_yscale('log')
        ax3.set_ylabel('density')
        ax3.set_xlabel('std values')
        
        im = ax4.imshow(self.dark_data['mask'])
        fig.colorbar(im, ax=ax4)

    def save(self, save_folder=None, overwrite=False, isDark=False):
        """ Save the crunched data.
        
            inputs:
                save_folder: string of the fodler where to save the data.
                overwrite: boolean whether or not to overwrite existing files.
                isDark: save the dark or the process data
        """
        if save_folder is None:
            save_folder = self.save_folder

        if isDark:
            fname = f'run{self.run_nr}_dark.h5'  # no scan
            data = self.dark_data
        else:
            fname = f'run{self.run_nr}.h5'  # run with delay scan (change for other scan types!)
            data = self.module_data


        save_path = os.path.join(save_folder, fname)
        file_exists = os.path.isfile(save_path)

        if not file_exists or (file_exists and overwrite):
            if file_exists:
                warnings.warn(f'Overwriting file: {save_path}')
                os.remove(save_path)
            data.to_netcdf(save_path, group='data')
            data.close()
            os.chmod(save_path, 0o664)
            print('saving: ', save_path)
        else:
            print('file', save_path, 'exists and overwrite is False')
            
    def load_dark(self, dark_runNB, save_folder=None):
        """ Load dark data.
        
            inputs:
                save_folder: string of the folder where the data were saved.
        """

        if save_folder is None:
            save_folder = self.save_folder
            
        self.run_nr = dark_runNB
        self.dark_data = xr.load_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.h5'), group='data')
        self.plot_title = f"{self.proposal} dark: {self.dark_data['run'].values}"

    def show_rois(self):
        fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5])
        try:
            ax1.imshow(self.module_data['dark_corrected_data'].mean('pulseId') * self.dark_data['mask'])
        except:
            ax1.imshow(self.dark_data['dark_mean'].mean('pulseId') * self.dark_data['mask'])            
        for r,v in self.rois.items():
            rect = patches.Rectangle((v['y'][0], v['x'][0]),
                                     v['y'][1] - v['y'][0],
                                     v['x'][1] - v['x'][0],
                                     linewidth=1, edgecolor='r', facecolor='none')

            ax1.add_patch(rect)
            
        fig.suptitle(self.plot_title)

 
# since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used
# by the multiprocessing pool.map function
def process_one_module(job):
    
    chunksize = job['chunksize']
    Nworker = job['Nworker']
    workerId = job['workerId']
    dark_data = job['dark_data']
    fpt = job['fpt']
    module = job['module']
    rois = job['rois']
    mask = job['mask']
    h5fname = job['h5fname']
    maxSaturatedPixel = job['maxSaturatedPixel']
    
    image_path = f"INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data"
    npulse_path = f"INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count"
        
    with h5py.File(h5fname, 'r') as m:
        all_trainIds = m['INDEX/trainId'][()]
    n_trains = len(all_trainIds)
    
    n_chunk = np.ceil(n_trains/chunksize) + 1
    
    chunks = np.linspace(0, n_trains, n_chunk, endpoint=True, dtype=int)
        
    # create empty dataset to add actual data to
    module_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
                               dims=['pulseId', 'x', 'y'],
                               coords={'pulseId':np.arange(fpt)}).to_dataset(name='dark_corrected_data')
    module_data['std_data'] = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
                               dims=['pulseId', 'x', 'y'])

    if rois is not None:
        for k in rois.keys():
            module_data[k] = xr.DataArray(np.empty([n_trains], dtype=np.float64),
                                   dims=['trainId'], coords = {'trainId': all_trainIds})
    module_data['counts'] = 0
    
    # crunching
    with h5py.File(h5fname, 'r') as m:

        #chunk_start = np.arange(len(all_trainIds), step=job['chunksize'], dtype=int)
        trains_start = 0
                   
        # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485
        print(' ', end='', flush=True)
         
        for k,v in enumerate(tqdm(chunks[:-1], desc=f"pool.map#{workerId:02d}")):              
            chunk_dssc = np.s_[int(chunks[k] * fpt):int(chunks[k+1] * fpt)]  # for dssc data
            data = m[image_path][chunk_dssc].squeeze()
            
            trains = m['INDEX/trainId'][np.s_[int(chunks[k]):int(chunks[k+1])]]
            n_trains = len(trains)                   
                    
            data = data.astype(np.float64)
            data = xr.DataArray(np.reshape(data, [n_trains, fpt, 128, 512]),
                                dims=['trainId', 'pulseId', 'x', 'y'],
                                coords={'trainId': trains})
            
            temp = data - dark_data
            
            if rois is not None:
                temp2 = temp.where(mask)
                for k,v in rois.items():
                    bkg = dark_data.isel({'x':slice(v['x'][0], v['x'][1]),
                                     'y':slice(v['y'][0], v['y'][1])})
                    
                    im = data.isel({'x':slice(v['x'][0], v['x'][1]),
                                     'y':slice(v['y'][0], v['y'][1])})
                    
                    smask = mask.isel({'x':slice(v['x'][0], v['x'][1]),
                                     'y':slice(v['y'][0], v['y'][1])})
                    
                    im = im.where(smask)
                                        

                    cim = im - bkg.where(smask)
                    
                    val = cim.sum(dim=['x','y'])
                    
                    tokeep = (im>254).sum(dim=['x', 'y']) < maxSaturatedPixel
                    tokeep = tokeep.assign_coords(pulseId = val.coords['pulseId'].values)
                    todrop = 1-tokeep
                    
                    Ndropped = todrop.sum().values
                    percent_dropped = 100*Ndropped/(n_trains * fpt)
                    
                    print(f'Dropped: {Ndropped}, i.e. {percent_dropped:.2f}%')
                    gval = val.where(tokeep, np.nan) 
                    module_data[k] = gval

            module_data['dark_corrected_data'] += temp.sum(dim='trainId')
            module_data['std_data'] += (temp**2).sum(dim='trainId')
            module_data['counts'] += n_trains
       
    return module_data