from joblib import Parallel, delayed, parallel_backend from time import strftime import tempfile import shutil from tqdm.auto import tqdm import os import warnings import psutil import extra_data as ed from extra_data.read_machinery import find_proposal from extra_geom import DSSC_1MGeometry import ToolBox as tb import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid import numpy as np import xarray as xr import h5py from imageio import imread class DSSC: def __init__(self, proposal, distance=1): """ Create a DSSC object to process DSSC data. inputs: proposal: (int,str) proposal number string distance: (float) distance sample to DSSC detector in meter """ if isinstance(proposal,int): proposal = 'p{:06d}'.format(proposal) runFolder = find_proposal(proposal) self.semester = runFolder.split('/')[-2] self.proposal = proposal self.topic = runFolder.split('/')[-3] self.tempdir = None self.save_folder = os.path.join(runFolder, 'usr/condensed_runs/') self.distance = distance self.px_pitch_h = 236 # horizontal pitch in microns self.px_pitch_v = 204 # vertical pitch in microns self.aspect = self.px_pitch_v/self.px_pitch_h # aspect ratio of the DSSC images self.geom = None self.mask = None self.max_fraction_memory = 0.4 self.filter_mask = None self.Nworker = 16 print('DSSC configuration') print(f'Topic: {self.topic}') print(f'Semester: {self.semester}') print(f'Proposal: {self.proposal}') print(f'Default save folder: {self.save_folder}') print(f'Sample to DSSC distance: {self.distance} m') if not os.path.exists(self.save_folder): warnings.warn(f'Default save folder does not exist: {self.save_folder}') def __del__(self): # deleting temporay folder if self.tempdir: shutil.rmtree(self.tempdir) def open_run(self, run_nr, isDark=False): """ Open a run with extra-data and prepare the virtual dataset for multiprocessing inputs: run_nr: the run number isDark: a boolean to specify if the run is a dark run or not """ print('Opening run data with extra-data') self.run_nr = run_nr self.xgm = None self.filter_mask = None self.run = ed.open_run(self.proposal, self.run_nr) self.isDark = isDark self.plot_title = f'{self.proposal} run: {self.run_nr}' self.fpt = self.run.detector_info('SCS_DET_DSSC1M-1/DET/0CH0:xtdf')['frames_per_train'] self.nbunches = self.run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value') self.nbunches = np.unique(self.nbunches) if len(self.nbunches) == 1: self.nbunches = self.nbunches[0] else: warnings.warn('not all trains have same length DSSC data') print(f'nbunches: {self.nbunches}') self.nbunches = self.nbunches[-1] print(f'DSSC frames per train: {self.fpt}') print(f'SA3 bunches per train: {self.nbunches}') if self.tempdir is not None: shutil.rmtree(self.tempdir) self.tempdir = tempfile.mkdtemp() print(f'Temporary directory: {self.tempdir}') print('Creating virtual dataset') self.vds_filenames = self.create_virtual_dssc_datasets(self.run, path=self.tempdir) # create a dummy scan variable for dark run # for other type or run, use DSSC.define_run function to overwrite it self.scan = xr.DataArray(np.ones_like(self.run.train_ids), dims=['trainId'], coords={'trainId': self.run.train_ids}).to_dataset( name='scan_variable') self.scan_vname = 'dummy' def define_scan(self, vname, bins): """ Prepare the binning of the DSSC data. inputs: vname: variable name for the scan, can be a mnemonic string from ToolBox or a dictionnary with ['source', 'key'] fields bins: step size (or bins_edge but not yet implemented) """ if type(vname) is dict: scan = self.run.get_array(vname['source'], vname['key']) elif type(vname) is str: if vname not in tb.mnemonics: raise ValueError(f'{vname} not found in the ToolBox mnemonics table') scan = self.run.get_array(tb.mnemonics[vname]['source'], tb.mnemonics[vname]['key']) else: raise ValueError(f'vname should be a string or a dict. We got {type(vname)}') if (type(bins) is int) or (type(bins) is float): scan = bins * np.round(scan / bins) else: # TODO: digitize the data raise ValueError(f'To be implemented') self.scan_vname = vname self.scan = scan.to_dataset(name='scan_variable') self.scan['xgm_pumped'] = self.xgm[:, :self.nbunches:2].mean('dim_0') self.scan['xgm_unpumped'] = self.xgm[:, 1:self.nbunches:2].mean('dim_0') self.scan_counts = xr.DataArray(np.ones(len(self.scan['scan_variable'])), dims=['scan_variable'], coords={'scan_variable': self.scan['scan_variable'].values}, name='counts') self.scan_points = self.scan.groupby('scan_variable').mean('trainId').coords['scan_variable'].values self.scan_points_counts = self.scan_counts.groupby('scan_variable').sum() self.plot_scan() def plot_scan(self): """ Plot a previously defined scan to see the scan range and the statistics. """ if self.scan: fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=[5, 5]) else: fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5]) ax1.plot(self.scan_points, self.scan_points_counts, 'o-', ms=2) ax1.set_xlabel(f'{self.scan_vname}') ax1.set_ylabel('# trains') ax1.set_title(self.plot_title) if self.scan: ax2.plot(self.scan['scan_variable']) ax2.set_xlabel('train #') ax2.set_ylabel(f'{self.scan_vname}') def load_xgm(self): """ Loads pulse resolved dedicated SAS3 data from the SCS XGM. """ if self.xgm is None: self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'], tb.mnemonics['SCS_SA3']['key'], roi=ed.by_index[:self.nbunches]) def plot_xgm_hist(self, nbins=100): """ Plots an histogram of the SCS XGM dedicated SAS3 data. inputs: nbins: number of the bins for the histogram. """ if self.xgm is None: self.load_xgm() hist, bins_edges = np.histogram(self.xgm, nbins, density=True) width = 1.0 * (bins_edges[1] - bins_edges[0]) bins_center = 0.5*(bins_edges[:-1] + bins_edges[1:]) plt.figure(figsize=(5,3)) plt.bar(bins_center, hist, align='center', width=width) plt.xlabel(f"{tb.mnemonics['SCS_SA3']['source']}{tb.mnemonics['SCS_SA3']['key']}") plt.ylabel('density') plt.title(self.plot_title) def xgm_filter(self, xgm_low=-np.inf, xgm_high=np.inf): """ Filters the data by train. If one pulse within a train has an SASE3 SCS XGM value below xgm_low or above xgm_high, that train will be dropped from the dataset. inputs: xgm_low: low threshold value xgm_high: high threshold value """ if self.xgm is None: self.load_xgm() if self.isDark: warnings.warn(f'This run was loaded as dark. Filtering on xgm makes no sense. Aborting') return self.xgm_low = xgm_low self.xgm_high = xgm_high filter_mask = (self.xgm > self.xgm_low) * (self.xgm < self.xgm_high) if self.filter_mask: self.filter_mask = self.filter_mask*filter_mask else: self.filter_mask = filter_mask valid = filter_mask.prod('dim_0').astype(bool) xgm_valid = self.xgm.where(valid) xgm_valid = xgm_valid.dropna('trainId') self.scan = self.scan.sel({'trainId': xgm_valid.trainId}) nrejected = len(self.run.train_ids) - len(self.scan) print((f'Rejecting {nrejected} out of {len(self.run.train_ids)} trains due to xgm ' f'thresholds: [{self.xgm_low}, {self.xgm_high}]')) def load_geom(self, geopath=None, quad_pos=None): """ Loads and return the DSSC geometry. inputs: geopath: path to the h5 geometry file. If None uses a default file. quad_pos: list of quadrants tuple position. If None uses a default position. output: return the loaded geometry """ if quad_pos is None: quad_pos = [(-124.100, 3.112), # TR (-133.068, -110.604), # BR ( 0.988, -125.236), # BL ( 4.528, -4.912) # TL ] if geopath is None: geopath = '/gpfs/exfel/sw/software/git/EXtra-geom/docs/dssc_geo_june19.h5' self.geom = DSSC_1MGeometry.from_h5_file_and_quad_positions(geopath, quad_pos) return self.geom def load_mask(self, fname, plot=True): """ Load a DSSC mask file. input: fname: string of the filename of the mask file plot: if True, the loaded mask is plotted """ dssc_mask = imread(fname) dssc_mask = dssc_mask.astype(float)[..., 0] // 255 dssc_mask[dssc_mask==0] = np.nan self.mask = dssc_mask if plot: plt.figure() plt.imshow(self.mask) def create_virtual_dssc_datasets(self, run, path=''): """ Create virtual datasets for each 16 DSSC modules used for the multiprocessing. input: run: extra-data run path: string where the virtual files are created output: dictionnary of key:module, value:virtual dataset filename """ vds_filenames = {} for module in tqdm(range(16)): fname = os.path.join(path, f'dssc{module}_vds.h5') if os.path.isfile(fname): os.remove(fname) vds = run.get_virtual_dataset(f'SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf', 'image.data', filename=fname) vds.file.close() # keep h5 file closed outside 'with' context vds_filenames[module] = fname return vds_filenames def binning(self, do_pulse_mean=True): """ Bin the DSSC data by the predifined scan type (DSSC.define()) using multiprocessing """ # get available memory in GB, we will try to use 80 % of it max_GB = psutil.virtual_memory().available/1024**3 print(f'max available memory: {max_GB} GB') # max_GB / (8byte * Nworker * 128px * 512px * N_pulses) self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt)) print('processing', self.chunksize, 'trains per chunk') jobs = [] for m in range(16): jobs.append(dict( module=m, fpt=self.fpt, vds=self.vds_filenames[m], chunksize=self.chunksize, scan=self.scan['scan_variable'], nbunches=self.nbunches, run_nr=self.run_nr, do_pulse_mean=do_pulse_mean )) if self.Nworker != 16: with warnings.catch_warnings(): warnings.simplefilter("default") warnings.warn(('Nworker other than 16 known to cause issue' + '(https://in.xfel.eu/gitlab/SCS/ToolBox/merge_requests/76)'), RuntimeWarning) timestamp = strftime('%X') print(f'start time: {timestamp}') with parallel_backend('loky', n_jobs=self.Nworker): module_data = Parallel(verbose=20)( delayed(process_one_module)(job) for job in tqdm(jobs) ) print('finished:', strftime('%X')) # rearange the multiprocessed data self.module_data = xr.concat(module_data, dim='module') self.module_data['run'] = self.run_nr self.module_data = self.module_data.transpose('scan_variable', 'module', 'x', 'y') if do_pulse_mean: self.module_data = xr.merge([self.module_data, self.scan.groupby('scan_variable').mean('trainId')]) elif self.xgm is not None: xgm_pumped = self.xgm[:, :self.nbunches:2].mean('trainId').to_dataset(name='xgm_pumped').rename({'dim_0':'scan_variable'}) xgm_unpumped = self.xgm[:, 1:self.nbunches:2].mean('trainId').to_dataset(name='xgm_unpumped').rename({'dim_0':'scan_variable'}) self.module_data = xr.merge([self.module_data, xgm_pumped, xgm_unpumped]) self.module_data = self.module_data.squeeze() if do_pulse_mean: self.module_data.attrs['scan_variable'] = self.scan_vname else: self.module_data.attrs['scan_variable'] = 'pulse id' def save(self, save_folder=None, overwrite=False): """ Save the crunched data. inputs: save_folder: string of the fodler where to save the data. overwrite: boolean whether or not to overwrite existing files. """ if save_folder is None: save_folder = self.save_folder if self.isDark: fname = f'run{self.run_nr}_dark.nc' # no scan else: fname = f'run{self.run_nr}.nc' # run with delay scan (change for other scan types!) save_path = os.path.join(save_folder, fname) file_exists = os.path.isfile(save_path) if not file_exists or (file_exists and overwrite): if file_exists: warnings.warn(f'Overwriting file: {save_path}') os.remove(save_path) self.module_data.to_netcdf(save_path, group='data') self.module_data.close() os.chmod(save_path, 0o664) print('saving: ', save_path) else: print('file', save_path, 'exists and overwrite is False') def load_binned(self, runNB, dark_runNB, xgm_norm = True, save_folder=None): """ load previously binned (crunched) DSSC data by DSSC.crunch() and DSSC.save() inputs: runNB: run number to load dark_runNB: run number of the corresponding dark xgm_norm: normlize by XGM data if True save_folder: path string where the crunched data are saved """ if save_folder is None: save_folder = self.save_folder self.plot_title = f'{self.proposal} run: {runNB} dark: {dark_runNB}' dark = xr.load_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.nc'), group='data', engine='netcdf4') binned = xr.load_dataset(os.path.join(save_folder, f'run{runNB}.nc'), group='data', engine='netcdf4') binned['pumped'] = (binned['pumped'] - dark['pumped'].values) binned['unpumped'] = (binned['unpumped'] - dark['unpumped'].values) if xgm_norm: binned['pumped'] = binned['pumped'] / binned['xgm_pumped'] binned['unpumped'] = binned['unpumped'] / binned['xgm_unpumped'] self.scan_points = binned['scan_variable'] self.scan_points_counts = binned['sum_count'][:, 0] self.scan_vname = binned.attrs['scan_variable'] self.scan = None self.binned = binned def plot_DSSC(self, use_mask = True, p_low = 1, p_high = 98, vmin = None, vmax = None): """ Plot pumped and unpumped DSSC images. inputs: use_mask: if True, a mask is applied on the DSSC. p_low: low percentile value to adjust the contrast scale on the unpumped and pumped image p_high: high percentile value to adjust the contrast scale on the unpumped and pumped image vmin: low value of the image scale vmax: high value of the image scale """ if use_mask: if self.mask is None: raise ValueError('No mask was loaded !') mask = self.mask mask_txt = ' masked' else: mask = 1 mask_txt = '' if self.geom is None: self.load_geom() im_pump_mean, _ = self.geom.position_modules_fast(self.binned['pumped'].mean('scan_variable')) im_unpump_mean, _ = self.geom.position_modules_fast(self.binned['unpumped'].mean('scan_variable')) self.im_pump_mean = mask*im_pump_mean self.im_unpump_mean = mask*im_unpump_mean fig = plt.figure(figsize=(9, 4)) grid = ImageGrid(fig, 111, nrows_ncols=(1,2), axes_pad=0.15, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad=0.15, ) _vmin, _vmax = np.percentile(self.im_pump_mean[~np.isnan(self.im_pump_mean)], [p_low, p_high]) if vmin is None: vmin = _vmin if vmax is None: vmax = _vmax im = grid[0].imshow(self.im_pump_mean, vmin=vmin, vmax=vmax, aspect=self.aspect) grid[0].set_title('pumped' + mask_txt) im = grid[1].imshow(self.im_unpump_mean, vmin=vmin, vmax=vmax, aspect=self.aspect) grid[1].set_title('unpumped' + mask_txt) grid[-1].cax.colorbar(im) grid[-1].cax.toggle_label(True) fig.suptitle(self.plot_title) def azimuthal_int(self, wl, center=None, angle_range=[0, 180-1e-6], dr=1, use_mask=True): """ Perform azimuthal integration of 1D binned DSSC run. inputs: wl: photon wavelength center: center of integration angle_range: angles of integration dr: dr use_mask: if True, use the loaded mask """ if self.geom is None: self.load_geom() if use_mask: if self.mask is None: raise ValueError('No mask was loaded !') mask = self.mask mask_txt = ' masked' else: mask = 1 mask_txt = '' im_pumped_arranged, c_geom = self.geom.position_modules_fast(self.binned['pumped'].values) im_unpumped_arranged, c_geom = self.geom.position_modules_fast(self.binned['unpumped'].values) im_pumped_arranged *= mask im_unpumped_arranged *= mask im_pumped_mean = im_pumped_arranged.mean(axis=0) im_unpumped_mean = im_unpumped_arranged.mean(axis=0) if center is None: center = c_geom ai = tb.azimuthal_integrator(im_pumped_mean.shape, center, angle_range, dr=dr) norm = ai(~np.isnan(im_pumped_mean)) az_pump = [] az_unpump = [] for i in tqdm(range(len(self.binned['scan_variable']))): az_pump.append(ai(im_pumped_arranged[i]) / norm) az_unpump.append(ai(im_unpumped_arranged[i]) / norm) az_pump = np.stack(az_pump) az_unpump = np.stack(az_unpump) coords = {'scan_variable': self.binned['scan_variable'], 'distance': ai.distance} azimuthal = xr.DataArray(az_pump, dims=['scan_variable', 'distance'], coords=coords) azimuthal = azimuthal.to_dataset(name='pumped') azimuthal['unpumped'] = xr.DataArray(az_unpump, dims=['scan_variable', 'distance'], coords=coords) azimuthal = azimuthal.transpose('distance', 'scan_variable') #t0 = 225.5 #azimuthal['delay'] = (t0 - azimuthal.delay)*6.6 #azimuthal['delay'] = azimuthal.delay azimuthal['delta_q (1/nm)'] = 2e-9 * np.pi * np.sin( np.arctan(azimuthal.distance * self.px_pitch_v*1e-6 / self.distance)) / wl azimuthal.attrs = self.binned.attrs self.azimuthal = azimuthal.swap_dims({'distance': 'delta_q (1/nm)'}) def plot_azimuthal_int(self, kind='difference', lim=None): """ Plot a computed azimuthal integration. inputs: kind: (str) either 'difference' or 'relative' to change the type of plot. """ fig, [ax1, ax2, ax3] = plt.subplots(nrows=3, sharex=True, sharey=True) xr.plot.imshow(self.azimuthal.pumped, ax=ax1, vmin=0, robust=True) ax1.set_title('pumped') xr.plot.imshow(self.azimuthal.unpumped, ax=ax2, vmin=0, robust=True) ax2.set_title('unpumped') if kind == 'difference': val = self.azimuthal.pumped - self.azimuthal.unpumped ax3.set_title('pumped - unpumped') elif kind == 'relative': val = (self.azimuthal.pumped - self.azimuthal.unpumped)/self.azimuthal.unpumped ax3.set_title('(pumped - unpumped)/unpumped') else: raise ValueError('kind should be either difference or relative') if lim is None: xr.plot.imshow(val, ax=ax3, robust=True) else: xr.plot.imshow(val, ax=ax3, vmin=lim[0], vmax=lim[1]) ax3.set_xlabel(self.scan_vname) fig.suptitle(f'{self.plot_title}') def plot_azimuthal_line_cut(self, data, qranges, qwidths): """ Plot line scans on top of the data. inputs: data: an azimuthal integrated xarray DataArray with 'delta_q (1/nm)' as one of its dimension. qranges: a list of q-range qwidth: a list of q-width, same length as qranges """ fig, [ax1, ax2] = plt.subplots(nrows=2, sharex=True, figsize=[8, 7]) xr.plot.imshow(data, ax=ax1, robust=True) # attributes are not propagated during xarray mathematical operation https://github.com/pydata/xarray/issues/988 # so we might not have in data the scan vaiable name anymore ax1.set_xlabel(self.scan_vname) fig.suptitle(f'{self.plot_title}') for i, (qr, qw) in enumerate(zip(qranges, qwidths)): sel = (data['delta_q (1/nm)'] > (qr - qw/2)) * (data['delta_q (1/nm)'] < (qr + qw/2)) val = data.where(sel).mean('delta_q (1/nm)') ax2.plot(data.scan_variable, val, c=f'C{i}', label=f'q = {qr:.2f}') ax1.axhline(qr - qw/2, c=f'C{i}', lw=1) ax1.axhline(qr + qw/2, c=f'C{i}', lw=1) ax2.legend() ax2.set_xlabel(self.scan_vname) # since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used # by the multiprocessing pool.map function def process_one_module(job): module = job['module'] fpt = job['fpt'] vds = job['vds'] scan = job['scan'] chunksize = job['chunksize'] nbunches = job['nbunches'] do_pulse_mean = job['do_pulse_mean'] image_path = f'INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data' npulse_path = f'INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count' with h5py.File(vds, 'r') as m: all_trainIds = m['INDEX/trainId'][()] frames_per_train = m[npulse_path][()] trains_with_data = all_trainIds[frames_per_train == fpt] len_scan = len(scan.groupby(scan)) if do_pulse_mean: # create empty dataset to add actual data to module_data = xr.DataArray(np.empty([len_scan, 128, 512], dtype=np.float64), dims=['scan_variable', 'x', 'y'], coords={'scan_variable': np.unique(scan)}) module_data = module_data.to_dataset(name='pumped') module_data['unpumped'] = xr.full_like(module_data['pumped'], 0) module_data['sum_count'] = xr.DataArray(np.zeros_like(np.unique(scan)), dims=['scan_variable']) module_data['module'] = module else: scan = xr.full_like(scan, 1) len_scan = len(scan.groupby(scan)) module_data = xr.DataArray(np.empty([len_scan, int(nbunches/2), 128, 512], dtype=np.float64), dims=['scan_variable', 'pulse', 'x', 'y'], coords={'scan_variable': np.unique(scan)}) module_data = module_data.to_dataset(name='pumped') module_data['unpumped'] = xr.full_like(module_data['pumped'], 0) module_data['sum_count'] = xr.full_like(module_data['pumped'][..., 0, 0], 0) module_data['module'] = module # crunching with h5py.File(vds, 'r') as m: chunk_start = np.arange(len(all_trainIds), step=chunksize, dtype=int) trains_start = 0 # This line is the strange hack from https://github.com/tqdm/tqdm/issues/485 print(' ', end='', flush=True) for c0 in tqdm(chunk_start, desc=f'pool.map#{module:02d}', position=module): chunk_dssc = np.s_[int(c0 * fpt):int((c0 + chunksize) * fpt)] # for dssc data data = m[image_path][chunk_dssc].squeeze() data = data.astype(np.float64) n_trains = int(data.shape[0] // fpt) trainIds_chunk = np.unique(trains_with_data[trains_start:trains_start + n_trains]) trains_start += n_trains n_trains_actual = len(trainIds_chunk) coords = {'trainId': trainIds_chunk} data = np.reshape(data, [n_trains_actual, fpt, 128, 512])[:, :int(2 * nbunches)] data = xr.DataArray(data, dims=['trainId', 'pulse', 'x', 'y'], coords=coords) if do_pulse_mean: data_pumped = (data[:, ::4]).mean('pulse') data_unpumped = (data[:, 2::4]).mean('pulse') else: data_pumped = (data[:, ::4]) data_unpumped = (data[:, 2::4]) data = data_pumped.to_dataset(name='pumped') data['unpumped'] = data_unpumped data['sum_count'] = xr.full_like(data['unpumped'][..., 0, 0], fill_value=1) # grouping and summing data['scan_variable'] = scan # this only adds scan data for matching trainIds data = data.dropna('trainId') data = data.groupby('scan_variable').sum('trainId') where = {'scan_variable': data.scan_variable} for var in ['pumped', 'unpumped', 'sum_count']: module_data[var].loc[where] = module_data[var].loc[where] + data[var] for var in ['pumped', 'unpumped']: module_data[var] = module_data[var] / module_data.sum_count #module_data = module_data.drop('sum_count') if not do_pulse_mean: module_data = module_data.sum('scan_variable') module_data = module_data.rename({'pulse':'scan_variable'}) return module_data