import multiprocessing from time import strftime from tqdm.auto import tqdm import os import warnings import psutil import extra_data as ed from extra_data.read_machinery import find_proposal import ToolBox as tb import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid import matplotlib.patches as patches import numpy as np import xarray as xr import h5py from glob import glob from imageio import imread class DSSC1module: def __init__(self, module, proposal): """ Create a DSSC object to process 1 module of DSSC data. inputs: module: module number to process proposal: (int,str) proposal number string """ self.module = module if isinstance(proposal,int): proposal = 'p{:06d}'.format(proposal) self.runFolder = find_proposal(proposal) self.semester = self.runFolder.split('/')[-2] self.proposal = proposal self.topic = self.runFolder.split('/')[-3] self.save_folder = os.path.join(self.runFolder, 'usr/condensed_runs/') self.px_pitch_h = 236 # horizontal pitch in microns self.px_pitch_v = 204 # vertical pitch in microns self.aspect = self.px_pitch_v/self.px_pitch_h # aspect ratio of the DSSC images print('DSSC configuration') print(f'DSSC module: {self.module}') print(f'Topic: {self.topic}') print(f'Semester: {self.semester}') print(f'Proposal: {self.proposal}') print(f'Default save folder: {self.save_folder}') if not os.path.exists(self.save_folder): warnings.warn(f'Default save folder does not exist: {self.save_folder}') self.dark_data = 0 self.max_fraction_memory = 0.8 self.Nworker = 10 self.rois = None self.maxSaturatedPixel = 1 def open_run(self, run_nr, t0=0.0): """ Open a run with extra-data and prepare the virtual dataset for multiprocessing inputs: run_nr: the run number t0: optional t0 in mm """ print('Opening run data with extra-data') self.run_nr = run_nr self.xgm = None self.run = ed.open_run(self.proposal, self.run_nr) self.plot_title = f'{self.proposal} run: {self.run_nr}' self.fpt = self.run.detector_info(f'SCS_DET_DSSC1M-1/DET/{self.module}CH0:xtdf')['frames_per_train'] self.nbunches = self.run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value') self.nbunches = np.unique(self.nbunches) if len(self.nbunches) == 1: self.nbunches = self.nbunches[0] else: warnings.warn('not all trains have same length DSSC data') print(f'nbunches: {self.nbunches}') self.nbunches = self.nbunches[-1] print(f'DSSC frames per train: {self.fpt}') print(f'SA3 bunches per train: {self.nbunches}') print('Collecting DSSC module files') self.collect_dssc_module_file() print(f'Loading XGM data') self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'], tb.mnemonics['SCS_SA3']['key'], roi=ed.by_index[:self.nbunches]) self.xgm = self.xgm.rename({'dim_0':'pulseId'}) self.xgm['pulseId'] = np.arange(0, 2*self.nbunches, 2) print(f'Loading mono nrj data') self.nrj = self.run.get_array(tb.mnemonics['nrj']['source'], tb.mnemonics['nrj']['key']) print(f'Loading delay line data') try: self.delay_mm = self.run.get_array(tb.mnemonics['PP800_DelayLine']['source'], tb.mnemonics['PP800_DelayLine']['key']) except: self.delay_mm = 0*self.nrj self.t0 = t0 self.delay_ps = tb.positionToDelay(self.delay_mm, origin=self.t0, invert=True) def collect_dssc_module_file(self): """ Collect the raw DSSC module h5 files. """ pattern = self.runFolder + f'/raw/r{self.run_nr:04d}/RAW-R{self.run_nr:04d}-DSSC{self.module:02d}-S*.h5' self.h5list = glob(pattern) def process(self, dark_pass=None): """ Process DSSC data from one module using multiprocessing dark_pass: if None, process data, if 'mean', compute the mean, if 'std', compute the std """ # get available memory in GB, we will try to use 80 % of it max_GB = psutil.virtual_memory().available/1024**3 print(f'max available memory: {max_GB} GB') # max_GB / (8byte * Nworker * 128px * 512px * N_pulses) self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt)) print('processing', self.chunksize, 'trains per chunk') if dark_pass == 'mean': rois = None dark = 0 mask = 1 elif dark_pass == 'std': dark = self.dark_data['dark_mean'] rois = None mask = 1 elif dark_pass is None: dark = self.dark_data['dark_mean'] rois = self.rois mask = self.dark_data['mask'] else: raise ValueError(f"dark_pass should be either None or 'mean' or 'std' but not {dark_pass}") jobs = [] for m,h5fname in enumerate(self.h5list): jobs.append(dict( fpt=self.fpt, module=self.module, h5fname=h5fname, chunksize=self.chunksize, nbunches=self.nbunches, workerId=m, Nworker=self.Nworker, dark_data=dark, rois=rois, mask=mask, maxSaturatedPixel=self.maxSaturatedPixel )) timestamp = strftime('%X') print(f'start time: {timestamp}') with multiprocessing.Pool(self.Nworker) as pool: res = pool.map(process_one_module, jobs) print('finished:', strftime('%X')) # rearange the multiprocessed data # this is to get rid of the worker dimension, there is no sum over worker really involved self.module_data = xr.concat(res, dim='worker').sum(dim='worker') # reorder the dimension if 'trainId' in self.module_data.dims: self.module_data = self.module_data.transpose('trainId', 'pulseId', 'x', 'y') else: self.module_data = self.module_data.transpose('pulseId', 'x', 'y') # fix some computation now that we have everything self.module_data['std_data'] = np.sqrt(self.module_data['std_data']/(self.module_data['counts'] - 1)) self.module_data['dark_corrected_data'] = self.module_data['dark_corrected_data']/self.module_data['counts'] self.module_data['run'] = self.run_nr if dark_pass == 'mean': self.dark_data = self.module_data['dark_corrected_data'].to_dataset('dark_mean') self.dark_data['run'] = self.run_nr elif dark_pass == 'std': self.dark_data['dark_std'] = self.module_data['std_data'] assert self.dark_data['run'] == self.run_nr, "noise map computed from different darks" else: self.module_data['xgm'] = self.xgm self.module_data['nrj'] = self.nrj self.module_data['delay_mm'] = self.delay_mm self.module_data['delay_ps'] = self.delay_ps self.module_data['t0'] = self.t0 self.plot_title = f"{self.proposal} run: {self.module_data['run'].values} dark: {self.dark_data['run'].values}" self.module_data.attrs['plot_title'] = self.plot_title def compute_mask(self, low=0.01, high=0.8): """ Compute a DSSC module mask from the noise map of a dark run. """ if self.dark_data['dark_std'] is None: raise ValueError('Cannot compute from from a missing dark noise map') self.dark_data['mask_low'] = low self.dark_data['mask_high'] = high m_std = self.dark_data['dark_std'].mean('pulseId') self.dark_data['mask'] = 1 - ((m_std > self.dark_data['mask_high']) + (m_std < self.dark_data['mask_low'])) def plot_module(self, plot_dark=False, low=1, high=98, vmin=None, vmax=None): """ Plot a module. inputs: plot_dark: if true, plot dark instead of run data. low: low percentile fraction of the display scale high: high percentile fraction of the display scale vmin: low value of the display scale, overwrites vmin computed from low vmax: max value of the display scale, overwrites vmax computed from high """ if plot_dark: mean = self.dark_data['dark_mean'].mean('pulseId') std = self.dark_data['dark_std'] title = f"{self.proposal} dark: {self.dark_data['run'].values}" else: mean = self.module_data['dark_corrected_data'].mean('pulseId') std = self.module_data['std_data'] title = self.plot_title fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, figsize=[5, 4*2.5]) _vmin, _vmax = np.percentile((mean.values[~self.dark_data['mask']]).flatten(), [low, high]) if vmin is None: vmin = _vmin if vmax is None: vmax = _vmax im = ax1.imshow(mean, vmin=vmin, vmax=vmax) fig.colorbar(im, ax=ax1) ax1.set_title('mean') fig.suptitle(title) im = ax2.imshow(std.mean('pulseId'), vmin=0, vmax=2) fig.colorbar(im, ax=ax2) ax2.set_title('std') ax3.hist(std.values.flatten(), bins=200, range=[0, 2], density=True) ax3.axvline(self.dark_data['mask_low'], ls='--', c='k') ax3.axvline(self.dark_data['mask_high'], ls='--', c='k') ax3.set_yscale('log') ax3.set_ylabel('density') ax3.set_xlabel('std values') im = ax4.imshow(self.dark_data['mask']) fig.colorbar(im, ax=ax4) def save(self, save_folder=None, overwrite=False, isDark=False): """ Save the crunched data. inputs: save_folder: string of the fodler where to save the data. overwrite: boolean whether or not to overwrite existing files. isDark: save the dark or the process data """ if save_folder is None: save_folder = self.save_folder if isDark: fname = f'run{self.run_nr}_dark.h5' # no scan data = self.dark_data else: fname = f'run{self.run_nr}.h5' # run with delay scan (change for other scan types!) data = self.module_data save_path = os.path.join(save_folder, fname) file_exists = os.path.isfile(save_path) if not file_exists or (file_exists and overwrite): if file_exists: warnings.warn(f'Overwriting file: {save_path}') os.remove(save_path) data.to_netcdf(save_path, group='data') data.close() os.chmod(save_path, 0o664) print('saving: ', save_path) else: print('file', save_path, 'exists and overwrite is False') def load_dark(self, dark_runNB, save_folder=None): """ Load dark data. inputs: save_folder: string of the folder where the data were saved. """ if save_folder is None: save_folder = self.save_folder self.run_nr = dark_runNB self.dark_data = xr.load_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.h5'), group='data') self.plot_title = f"{self.proposal} dark: {self.dark_data['run'].values}" def show_rois(self): fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5]) try: ax1.imshow(self.module_data['dark_corrected_data'].mean('pulseId') * self.dark_data['mask']) except: ax1.imshow(self.dark_data['dark_mean'].mean('pulseId') * self.dark_data['mask']) for r,v in self.rois.items(): rect = patches.Rectangle((v['y'][0], v['x'][0]), v['y'][1] - v['y'][0], v['x'][1] - v['x'][0], linewidth=1, edgecolor='r', facecolor='none') ax1.add_patch(rect) fig.suptitle(self.plot_title) # since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used # by the multiprocessing pool.map function def process_one_module(job): chunksize = job['chunksize'] Nworker = job['Nworker'] workerId = job['workerId'] dark_data = job['dark_data'] fpt = job['fpt'] module = job['module'] rois = job['rois'] mask = job['mask'] h5fname = job['h5fname'] maxSaturatedPixel = job['maxSaturatedPixel'] image_path = f"INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data" npulse_path = f"INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count" with h5py.File(h5fname, 'r') as m: all_trainIds = m['INDEX/trainId'][()] n_trains = len(all_trainIds) n_chunk = np.ceil(n_trains/chunksize) + 1 chunks = np.linspace(0, n_trains, n_chunk, endpoint=True, dtype=int) # create empty dataset to add actual data to module_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64), dims=['pulseId', 'x', 'y'], coords={'pulseId':np.arange(fpt)}).to_dataset(name='dark_corrected_data') module_data['std_data'] = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64), dims=['pulseId', 'x', 'y']) if rois is not None: for k in rois.keys(): module_data[k] = xr.DataArray(np.empty([n_trains], dtype=np.float64), dims=['trainId'], coords = {'trainId': all_trainIds}) module_data['counts'] = 0 # crunching with h5py.File(h5fname, 'r') as m: #chunk_start = np.arange(len(all_trainIds), step=job['chunksize'], dtype=int) trains_start = 0 # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485 print(' ', end='', flush=True) for k,v in enumerate(tqdm(chunks[:-1], desc=f"pool.map#{workerId:02d}")): chunk_dssc = np.s_[int(chunks[k] * fpt):int(chunks[k+1] * fpt)] # for dssc data data = m[image_path][chunk_dssc].squeeze() trains = m['INDEX/trainId'][np.s_[int(chunks[k]):int(chunks[k+1])]] n_trains = len(trains) data = data.astype(np.float64) data = xr.DataArray(np.reshape(data, [n_trains, fpt, 128, 512]), dims=['trainId', 'pulseId', 'x', 'y'], coords={'trainId': trains}) temp = data - dark_data if rois is not None: temp2 = temp.where(mask) for k,v in rois.items(): bkg = dark_data.isel({'x':slice(v['x'][0], v['x'][1]), 'y':slice(v['y'][0], v['y'][1])}) im = data.isel({'x':slice(v['x'][0], v['x'][1]), 'y':slice(v['y'][0], v['y'][1])}) smask = mask.isel({'x':slice(v['x'][0], v['x'][1]), 'y':slice(v['y'][0], v['y'][1])}) im = im.where(smask) cim = im - bkg.where(smask) val = cim.sum(dim=['x','y']) tokeep = (im>254).sum(dim=['x', 'y']) < maxSaturatedPixel tokeep = tokeep.assign_coords(pulseId = val.coords['pulseId'].values) todrop = 1-tokeep Ndropped = todrop.sum().values percent_dropped = 100*Ndropped/(n_trains * fpt) print(f'Dropped: {Ndropped}, i.e. {percent_dropped:.2f}%') gval = val.where(tokeep, np.nan) module_data[k] = gval module_data['dark_corrected_data'] += temp.sum(dim='trainId') module_data['std_data'] += (temp**2).sum(dim='trainId') module_data['counts'] += n_trains return module_data