Skip to content
Snippets Groups Projects
DSSC.py 27.6 KiB
Newer Older
from joblib import Parallel, delayed, parallel_backend
from time import strftime
import tempfile
import shutil
from tqdm.auto import tqdm
import os
import warnings
import psutil
import extra_data as ed
from extra_data.read_machinery import find_proposal
from extra_geom import DSSC_1MGeometry
import ToolBox as tb
Loïc Le Guyader's avatar
Loïc Le Guyader committed
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import xarray as xr
import h5py

Loïc Le Guyader's avatar
Loïc Le Guyader committed
from imageio import imread
    def __init__(self, proposal, distance=1):
        """ Create a DSSC object to process DSSC data.
        
            inputs:
                proposal: (int,str) proposal number string
                distance: (float) distance sample to DSSC detector in meter
        if isinstance(proposal,int):
            proposal = 'p{:06d}'.format(proposal)
        runFolder = find_proposal(proposal)
        self.semester = runFolder.split('/')[-2]
        self.proposal = proposal
        self.topic = runFolder.split('/')[-3]
        self.tempdir = None
        self.save_folder = os.path.join(runFolder, 'usr/condensed_runs/')
        self.distance = distance
        self.px_pitch_h = 236 # horizontal pitch in microns
        self.px_pitch_v = 204 # vertical pitch in microns
        self.aspect = self.px_pitch_v/self.px_pitch_h # aspect ratio of the DSSC images
        self.geom = None
        self.mask = None
        self.max_fraction_memory = 0.4
        self.filter_mask = None
        self.Nworker = 16
        
        print('DSSC configuration')
        print(f'Topic: {self.topic}')
        print(f'Semester: {self.semester}')
        print(f'Proposal: {self.proposal}')
        print(f'Default save folder: {self.save_folder}')
        print(f'Sample to DSSC distance: {self.distance} m')
        
        if not os.path.exists(self.save_folder):
            warnings.warn(f'Default save folder does not exist: {self.save_folder}')
        
    def __del__(self):
        # deleting temporay folder
        if self.tempdir:
            shutil.rmtree(self.tempdir)
    
    def open_run(self, run_nr, isDark=False):
        """ Open a run with extra-data and prepare the virtual dataset for multiprocessing
        
            inputs:
                run_nr: the run number
                isDark: a boolean to specify if the run is a dark run or not
        
        """
        
        print('Opening run data with extra-data')
        self.run_nr = run_nr
        self.xgm = None
        self.filter_mask = None
        self.run = ed.open_run(self.proposal, self.run_nr)
        self.isDark = isDark
        self.plot_title = f'{self.proposal} run: {self.run_nr}'
        
        self.fpt = self.run.detector_info('SCS_DET_DSSC1M-1/DET/0CH0:xtdf')['frames_per_train']
        self.nbunches = self.run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value')
        self.nbunches = np.unique(self.nbunches)
        if len(self.nbunches) == 1:
            self.nbunches = self.nbunches[0]
        else:
            warnings.warn('not all trains have same length DSSC data')
            print(f'nbunches: {self.nbunches}')
            self.nbunches = self.nbunches[-1]
            
        print(f'DSSC frames per train: {self.fpt}')
        print(f'SA3 bunches per train: {self.nbunches}')
        
        if self.tempdir is not None:
            shutil.rmtree(self.tempdir)
        
        self.tempdir = tempfile.mkdtemp()
        print(f'Temporary directory: {self.tempdir}')

        print('Creating virtual dataset')
        self.vds_filenames = self.create_virtual_dssc_datasets(self.run, path=self.tempdir)
        
        # create a dummy scan variable for dark run
        # for other type or run, use DSSC.define_run function to overwrite it
        self.scan = xr.DataArray(np.ones_like(self.run.train_ids), dims=['trainId'],
                                 coords={'trainId': self.run.train_ids}).to_dataset(
                        name='scan_variable')
        self.scan_vname = 'dummy'
        
    def define_scan(self, vname, bins):
        """
            Prepare the binning of the DSSC data.
            
            inputs:
                vname: variable name for the scan, can be a mnemonic string from ToolBox
                    or a dictionnary with ['source', 'key'] fields
                bins: step size (or bins_edge but not yet implemented)
        """

        if type(vname) is dict:
            scan = self.run.get_array(vname['source'], vname['key'])
        elif type(vname) is str:
            if vname not in tb.mnemonics:
                raise ValueError(f'{vname} not found in the ToolBox mnemonics table')
            scan = self.run.get_array(tb.mnemonics[vname]['source'], tb.mnemonics[vname]['key'])
        else:
            raise ValueError(f'vname should be a string or a dict. We got {type(vname)}')
            
        if (type(bins) is int) or (type(bins) is float):
            scan = bins * np.round(scan / bins)
        else:
            # TODO: digitize the data
            raise ValueError(f'To be implemented')
        self.scan_vname = vname
       
        self.scan = scan.to_dataset(name='scan_variable')
        self.scan['xgm_pumped'] = self.xgm[:, :self.nbunches:2].mean('dim_0')
        self.scan['xgm_unpumped'] = self.xgm[:, 1:self.nbunches:2].mean('dim_0')

        self.scan_counts = xr.DataArray(np.ones(len(self.scan['scan_variable'])),
                                        dims=['scan_variable'],
                                        coords={'scan_variable': self.scan['scan_variable'].values},
                                        name='counts')
        self.scan_points = self.scan.groupby('scan_variable').mean('trainId').coords['scan_variable'].values
        self.scan_points_counts = self.scan_counts.groupby('scan_variable').sum()
        self.plot_scan()
    def plot_scan(self):
        """ Plot a previously defined scan to see the scan range and the statistics.
        """
        if self.scan:
            fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=[5, 5])
        else:
            fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5])
            
        ax1.plot(self.scan_points, self.scan_points_counts, 'o-', ms=2)
        ax1.set_xlabel(f'{self.scan_vname}')
        ax1.set_ylabel('# trains')
        ax1.set_title(self.plot_title)
        
        if self.scan:
            ax2.plot(self.scan['scan_variable'])
            ax2.set_xlabel('train #')
            ax2.set_ylabel(f'{self.scan_vname}')
        """ Loads pulse resolved dedicated SAS3 data from the SCS XGM.
        
        """
        if self.xgm is None:
            self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'],
                                          tb.mnemonics['SCS_SA3']['key'], roi=ed.by_index[:self.nbunches])
    def plot_xgm_hist(self, nbins=100):
        """ Plots an histogram of the SCS XGM dedicated SAS3 data.
        
            inputs:
                nbins: number of the bins for the histogram.
        """
        if self.xgm is None:
            self.load_xgm()
            
        hist, bins_edges = np.histogram(self.xgm, nbins, density=True)
        width = 1.0 * (bins_edges[1] - bins_edges[0])
        bins_center = 0.5*(bins_edges[:-1] + bins_edges[1:])
        
        plt.figure(figsize=(5,3))
        plt.bar(bins_center, hist, align='center', width=width)
        plt.xlabel(f"{tb.mnemonics['SCS_SA3']['source']}{tb.mnemonics['SCS_SA3']['key']}")
        plt.ylabel('density')
        plt.title(self.plot_title)
        
    def xgm_filter(self, xgm_low=-np.inf, xgm_high=np.inf):
        """ Filters the data by train. If one pulse within a train has an SASE3 SCS XGM value below
            xgm_low or above xgm_high, that train will be dropped from the dataset.
            
            inputs:
                xgm_low: low threshold value
                xgm_high: high threshold value
        """
                   
        if self.xgm is None:
            self.load_xgm()
        
        if self.isDark:
            warnings.warn(f'This run was loaded as dark. Filtering on xgm makes no sense. Aborting')
            return
        
        self.xgm_low = xgm_low
        self.xgm_high = xgm_high
        
        filter_mask = (self.xgm > self.xgm_low) * (self.xgm < self.xgm_high)
                   
        if self.filter_mask:
            self.filter_mask = self.filter_mask*filter_mask
        else:
            self.filter_mask = filter_mask
                   
        valid = filter_mask.prod('dim_0').astype(bool)
        xgm_valid = self.xgm.where(valid)
        xgm_valid = xgm_valid.dropna('trainId')
        self.scan = self.scan.sel({'trainId': xgm_valid.trainId})
        nrejected = len(self.run.train_ids) - len(self.scan)
        print((f'Rejecting {nrejected} out of {len(self.run.train_ids)} trains due to xgm '
               f'thresholds: [{self.xgm_low}, {self.xgm_high}]'))

    def load_geom(self, geopath=None, quad_pos=None):
        """ Loads and return the DSSC geometry.

            inputs:
                geopath: path to the h5 geometry file. If None uses a default file.
                quad_pos: list of quadrants tuple position. If None uses a default position.

            output:
                return the loaded geometry
        if quad_pos is None:
            quad_pos = [(-124.100,    3.112),  # TR
                    (-133.068, -110.604),  # BR
                    (   0.988, -125.236),  # BL
                    (   4.528,   -4.912)   # TL
                ]

        if geopath is None:
            geopath = '/gpfs/exfel/sw/software/git/EXtra-geom/docs/dssc_geo_june19.h5'

        self.geom = DSSC_1MGeometry.from_h5_file_and_quad_positions(geopath, quad_pos)

        return self.geom
Loïc Le Guyader's avatar
Loïc Le Guyader committed
               
    def load_mask(self, fname, plot=True):
Loïc Le Guyader's avatar
Loïc Le Guyader committed
        """ Load a DSSC mask file.
            
            input:
                fname: string of the filename of the mask file
                plot: if True, the loaded mask is plotted
Loïc Le Guyader's avatar
Loïc Le Guyader committed
        """
                   

        dssc_mask = imread(fname)
        dssc_mask = dssc_mask.astype(float)[..., 0] // 255
        dssc_mask[dssc_mask==0] = np.nan
        self.mask = dssc_mask
        
        if plot:
            plt.figure()
            plt.imshow(self.mask)

    def create_virtual_dssc_datasets(self, run, path=''):
        """ Create virtual datasets for each 16 DSSC modules used for the multiprocessing.
            
            input:
                path: string where the virtual files are created
            output:
                dictionnary of key:module, value:virtual dataset filename
        vds_filenames = {}

        for module in tqdm(range(16)):
            fname = os.path.join(path, f'dssc{module}_vds.h5')
            if os.path.isfile(fname):
                os.remove(fname)

            vds = run.get_virtual_dataset(f'SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf',
                                                 'image.data', filename=fname)

            vds.file.close() # keep h5 file closed outside 'with' context

            vds_filenames[module] = fname

        return vds_filenames
    def binning(self, do_pulse_mean=True):
        """ Bin the DSSC data by the predifined scan type (DSSC.define()) using multiprocessing
        # get available memory in GB, we will try to use 80 % of it
        max_GB = psutil.virtual_memory().available/1024**3
        print(f'max available memory: {max_GB} GB')
        
        # max_GB / (8byte * Nworker * 128px * 512px * N_pulses)
        self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt))
        
        print('processing', self.chunksize, 'trains per chunk')
                   
        jobs = []
        for m in range(16):
            jobs.append(dict(
            module=m,
            fpt=self.fpt,
            chunksize=self.chunksize,
            scan=self.scan['scan_variable'],
            nbunches=self.nbunches,
            run_nr=self.run_nr,
            do_pulse_mean=do_pulse_mean
        if self.Nworker != 16:
            with warnings.catch_warnings():
                warnings.simplefilter("default")
                warnings.warn(('Nworker other than 16 known to cause issue' +
                    '(https://in.xfel.eu/gitlab/SCS/ToolBox/merge_requests/76)'),
                    RuntimeWarning)

        timestamp = strftime('%X')
        print(f'start time: {timestamp}')

        with parallel_backend('loky', n_jobs=self.Nworker):
            module_data = Parallel(verbose=20)(
                delayed(process_one_module)(job) for job in tqdm(jobs)
            )
        print('finished:', strftime('%X'))
    
        # rearange the multiprocessed data
        self.module_data = xr.concat(module_data, dim='module')
        self.module_data['run'] = self.run_nr
        self.module_data = self.module_data.transpose('scan_variable', 'module', 'x', 'y')
                   
        if do_pulse_mean:
            self.module_data = xr.merge([self.module_data, self.scan.groupby('scan_variable').mean('trainId')])
        elif self.xgm is not None:
            xgm_pumped = self.xgm[:, :self.nbunches:2].mean('trainId').to_dataset(name='xgm_pumped').rename({'dim_0':'scan_variable'})
            xgm_unpumped = self.xgm[:, 1:self.nbunches:2].mean('trainId').to_dataset(name='xgm_unpumped').rename({'dim_0':'scan_variable'})
            self.module_data = xr.merge([self.module_data, xgm_pumped, xgm_unpumped])
        self.module_data = self.module_data.squeeze()
        if do_pulse_mean:
            self.module_data.attrs['scan_variable'] = self.scan_vname
        else:
            self.module_data.attrs['scan_variable'] = 'pulse id'
                   
    def save(self, save_folder=None, overwrite=False):
        """ Save the crunched data.
        
            inputs:
                save_folder: string of the fodler where to save the data.
                overwrite: boolean whether or not to overwrite existing files.
        """
        if save_folder is None:
Loïc Le Guyader's avatar
Loïc Le Guyader committed
            save_folder = self.save_folder
            fname = f'run{self.run_nr}_dark.nc'  # no scan
            fname = f'run{self.run_nr}.nc'  # run with delay scan (change for other scan types!)


        save_path = os.path.join(save_folder, fname)
        file_exists = os.path.isfile(save_path)

        if not file_exists or (file_exists and overwrite):
            if file_exists:
                warnings.warn(f'Overwriting file: {save_path}')
                os.remove(save_path)
            self.module_data.to_netcdf(save_path, group='data')
            os.chmod(save_path, 0o664)
            print('saving: ', save_path)
        else:
            print('file', save_path, 'exists and overwrite is False')
                   
    def load_binned(self, runNB, dark_runNB, xgm_norm = True, save_folder=None):
        """ load previously binned (crunched) DSSC data by DSSC.crunch() and DSSC.save()
        
            inputs:
                runNB: run number to load
                dark_runNB: run number of the corresponding dark
                xgm_norm: normlize by XGM data if True
                save_folder: path string  where the crunched data are saved
        """

        if save_folder is None:
            save_folder = self.save_folder

        self.plot_title = f'{self.proposal} run: {runNB} dark: {dark_runNB}'
                   
        dark = xr.load_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.nc'), group='data',
            engine='netcdf4')
        binned = xr.load_dataset(os.path.join(save_folder, f'run{runNB}.nc'), group='data',
            engine='netcdf4')

        binned['pumped'] = (binned['pumped'] - dark['pumped'].values)
        binned['unpumped'] = (binned['unpumped'] - dark['unpumped'].values)

        if xgm_norm:
            binned['pumped'] = binned['pumped'] / binned['xgm_pumped']
            binned['unpumped'] = binned['unpumped'] / binned['xgm_unpumped']
        
        self.scan_points = binned['scan_variable']
        self.scan_points_counts = binned['sum_count'][:, 0]
        self.scan_vname = binned.attrs['scan_variable']
        self.scan = None

        self.binned = binned
                   
    def plot_DSSC(self, use_mask = True, p_low = 1, p_high = 98, vmin = None, vmax = None):
        """ Plot pumped and unpumped DSSC images.
        
            inputs:
                use_mask: if True, a mask is applied on the DSSC.
                p_low: low percentile value to adjust the contrast scale on the unpumped and pumped image
                p_high: high percentile value to adjust the contrast scale on the unpumped and pumped image
                vmin: low value of the image scale
                vmax: high value of the image scale
        """
                   
        if use_mask:
            if self.mask is None:
                raise ValueError('No mask was loaded !')
                   
            mask = self.mask
            mask_txt = ' masked'
        else:
            mask = 1
            mask_txt = ''
        
        if self.geom is None:
            self.load_geom()
                   
        im_pump_mean, _ = self.geom.position_modules_fast(self.binned['pumped'].mean('scan_variable'))
        im_unpump_mean, _ = self.geom.position_modules_fast(self.binned['unpumped'].mean('scan_variable'))
        
        self.im_pump_mean = mask*im_pump_mean
        self.im_unpump_mean = mask*im_unpump_mean
                           
        fig = plt.figure(figsize=(9, 4))
        grid = ImageGrid(fig, 111,
                 nrows_ncols=(1,2),
                 axes_pad=0.15,
                 share_all=True,
                 cbar_location="right",
                 cbar_mode="single",
                 cbar_size="7%",
                 cbar_pad=0.15,
                 )

        _vmin, _vmax = np.percentile(self.im_pump_mean[~np.isnan(self.im_pump_mean)], [p_low, p_high])
        if vmin is None:
            vmin = _vmin
        if vmax is None:
            vmax = _vmax
                         
        im = grid[0].imshow(self.im_pump_mean, vmin=vmin, vmax=vmax, aspect=self.aspect)
        grid[0].set_title('pumped' + mask_txt)

        im = grid[1].imshow(self.im_unpump_mean, vmin=vmin, vmax=vmax, aspect=self.aspect)
        grid[1].set_title('unpumped' + mask_txt)
                   
        grid[-1].cax.colorbar(im)
        grid[-1].cax.toggle_label(True)
        
        fig.suptitle(self.plot_title)
                   
                   
    def azimuthal_int(self, wl, center=None, angle_range=[0, 180-1e-6], dr=1, use_mask=True):
        """ Perform azimuthal integration of 1D binned DSSC run.
        
            inputs:
                wl: photon wavelength
                center: center of integration
                angle_range: angles of integration
                dr: dr
                use_mask: if True, use the loaded mask
        """

        if self.geom is None:
            self.load_geom()

        if use_mask:
            if self.mask is None:
                raise ValueError('No mask was loaded !')

            mask = self.mask
            mask_txt = ' masked'
        else:
            mask = 1
            mask_txt = ''

        im_pumped_arranged, c_geom = self.geom.position_modules_fast(self.binned['pumped'].values)
        im_unpumped_arranged, c_geom = self.geom.position_modules_fast(self.binned['unpumped'].values)

        im_pumped_arranged *= mask
        im_unpumped_arranged *= mask

        im_pumped_mean = im_pumped_arranged.mean(axis=0)
        im_unpumped_mean = im_unpumped_arranged.mean(axis=0)

        if center is None:
            center = c_geom

        ai = tb.azimuthal_integrator(im_pumped_mean.shape, center, angle_range, dr=dr)
        norm = ai(~np.isnan(im_pumped_mean))

        az_pump = []
        az_unpump = []

        for i in tqdm(range(len(self.binned['scan_variable']))):
            az_pump.append(ai(im_pumped_arranged[i]) / norm)
            az_unpump.append(ai(im_unpumped_arranged[i]) / norm)

        az_pump = np.stack(az_pump)
        az_unpump = np.stack(az_unpump)

        coords = {'scan_variable': self.binned['scan_variable'], 'distance': ai.distance}
        azimuthal = xr.DataArray(az_pump, dims=['scan_variable', 'distance'], coords=coords)
        azimuthal = azimuthal.to_dataset(name='pumped')
        azimuthal['unpumped'] = xr.DataArray(az_unpump, dims=['scan_variable', 'distance'], coords=coords)
        azimuthal = azimuthal.transpose('distance', 'scan_variable')

        #t0 = 225.5
        #azimuthal['delay'] = (t0 - azimuthal.delay)*6.6
        #azimuthal['delay'] = azimuthal.delay

        azimuthal['delta_q (1/nm)'] = 2e-9 * np.pi * np.sin(
            np.arctan(azimuthal.distance *  self.px_pitch_v*1e-6 / self.distance)) / wl
Loïc Le Guyader's avatar
Loïc Le Guyader committed
        
        azimuthal.attrs = self.binned.attrs

        self.azimuthal = azimuthal.swap_dims({'distance': 'delta_q (1/nm)'})
                   
    def plot_azimuthal_int(self, kind='difference', lim=None):
        """ Plot a computed azimuthal integration.
        
            inputs:
                kind: (str) either 'difference' or 'relative' to change the type of plot.
        fig, [ax1, ax2, ax3] = plt.subplots(nrows=3, sharex=True, sharey=True)
        xr.plot.imshow(self.azimuthal.pumped, ax=ax1, vmin=0, robust=True)
Loïc Le Guyader's avatar
Loïc Le Guyader committed
        ax1.set_title('pumped')

        xr.plot.imshow(self.azimuthal.unpumped, ax=ax2, vmin=0, robust=True)
        ax2.set_title('unpumped')
                   
        if kind == 'difference':
            val = self.azimuthal.pumped - self.azimuthal.unpumped
            ax3.set_title('pumped - unpumped')
        elif kind == 'relative':
            val = (self.azimuthal.pumped - self.azimuthal.unpumped)/self.azimuthal.unpumped
            ax3.set_title('(pumped - unpumped)/unpumped')
        else:
            raise ValueError('kind should be either difference or relative')
            xr.plot.imshow(val, ax=ax3, robust=True)
            xr.plot.imshow(val, ax=ax3, vmin=lim[0], vmax=lim[1])
        ax3.set_xlabel(self.scan_vname)
        fig.suptitle(f'{self.plot_title}')
Loïc Le Guyader's avatar
Loïc Le Guyader committed
    def plot_azimuthal_line_cut(self, data, qranges, qwidths):
        """ Plot line scans on top of the data.
        
            inputs:
                data: an azimuthal integrated xarray DataArray with 'delta_q (1/nm)' as one of its dimension.
                qranges: a list of q-range
                qwidth: a list of q-width, same length as qranges
        """
                   
        fig, [ax1, ax2] = plt.subplots(nrows=2, sharex=True, figsize=[8, 7])

        xr.plot.imshow(data, ax=ax1, robust=True)

        # attributes are not propagated during xarray mathematical operation https://github.com/pydata/xarray/issues/988
        # so we might not have in data the scan vaiable name anymore
        ax1.set_xlabel(self.scan_vname) 
        fig.suptitle(f'{self.plot_title}')
    
        for i, (qr, qw) in enumerate(zip(qranges, qwidths)):
            sel = (data['delta_q (1/nm)'] > (qr - qw/2)) * (data['delta_q (1/nm)'] < (qr + qw/2))
            val = data.where(sel).mean('delta_q (1/nm)')
            ax2.plot(data.scan_variable, val, c=f'C{i}', label=f'q = {qr:.2f}')
        
            ax1.axhline(qr - qw/2, c=f'C{i}', lw=1)
            ax1.axhline(qr + qw/2, c=f'C{i}', lw=1)
        ax2.legend()
        ax2.set_xlabel(self.scan_vname)
        
                   
# since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used
# by the multiprocessing pool.map function
def process_one_module(job):
    module = job['module']
    fpt = job['fpt']
    chunksize = job['chunksize']
    nbunches = job['nbunches']
    do_pulse_mean = job['do_pulse_mean']

    image_path = f'INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data'
    npulse_path = f'INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count'
    with h5py.File(vds, 'r') as m:
        all_trainIds = m['INDEX/trainId'][()]
        frames_per_train = m[npulse_path][()]
    trains_with_data = all_trainIds[frames_per_train == fpt]

    len_scan = len(scan.groupby(scan))

    if do_pulse_mean:
        # create empty dataset to add actual data to
        module_data = xr.DataArray(np.empty([len_scan, 128, 512], dtype=np.float64),
                                   dims=['scan_variable', 'x', 'y'],
                                   coords={'scan_variable': np.unique(scan)})
        module_data = module_data.to_dataset(name='pumped')
        module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
        module_data['sum_count'] = xr.DataArray(np.zeros_like(np.unique(scan)), dims=['scan_variable'])
        module_data['module'] = module
    else:
        scan = xr.full_like(scan, 1)
        len_scan = len(scan.groupby(scan))                   
        module_data = xr.DataArray(np.empty([len_scan, int(nbunches/2), 128, 512], dtype=np.float64),
                                   dims=['scan_variable', 'pulse', 'x', 'y'],
                                   coords={'scan_variable': np.unique(scan)})
        module_data = module_data.to_dataset(name='pumped')
        module_data['unpumped'] = xr.full_like(module_data['pumped'], 0)
        module_data['sum_count'] = xr.full_like(module_data['pumped'][..., 0, 0], 0)
        module_data['module'] = module
    with h5py.File(vds, 'r') as m:
        chunk_start = np.arange(len(all_trainIds), step=chunksize, dtype=int)
        trains_start = 0
                   
        # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485
        print(' ', end='', flush=True)
        for c0 in tqdm(chunk_start, desc=f'pool.map#{module:02d}', position=module):
            chunk_dssc = np.s_[int(c0 * fpt):int((c0 + chunksize) * fpt)]  # for dssc data
            data = m[image_path][chunk_dssc].squeeze()
            data = data.astype(np.float64)
            n_trains = int(data.shape[0] // fpt)
            trainIds_chunk = np.unique(trains_with_data[trains_start:trains_start + n_trains])
            trains_start += n_trains
            n_trains_actual = len(trainIds_chunk)
            coords = {'trainId': trainIds_chunk}
            data = np.reshape(data, [n_trains_actual, fpt, 128, 512])[:, :int(2 * nbunches)]
            data = xr.DataArray(data, dims=['trainId', 'pulse', 'x', 'y'], coords=coords)
            
            if do_pulse_mean:
                data_pumped = (data[:, ::4]).mean('pulse')
                data_unpumped = (data[:, 2::4]).mean('pulse')
            else:
                data_pumped = (data[:, ::4])
                data_unpumped = (data[:, 2::4])
                   
            data = data_pumped.to_dataset(name='pumped')
            data['unpumped'] = data_unpumped
            data['sum_count'] = xr.full_like(data['unpumped'][..., 0, 0], fill_value=1)
                   
            # grouping and summing
            data['scan_variable'] = scan  # this only adds scan data for matching trainIds
            data = data.dropna('trainId')
            data = data.groupby('scan_variable').sum('trainId')
            where = {'scan_variable': data.scan_variable}
            for var in ['pumped', 'unpumped', 'sum_count']:
                module_data[var].loc[where] = module_data[var].loc[where] + data[var]
    for var in ['pumped', 'unpumped']:
        module_data[var] = module_data[var] / module_data.sum_count
    #module_data = module_data.drop('sum_count')
    
    if not do_pulse_mean:
        module_data = module_data.sum('scan_variable')
Loïc Le Guyader's avatar
Loïc Le Guyader committed
        module_data = module_data.rename({'pulse':'scan_variable'})  
    return module_data