Forked from
SCS / ToolBox
678 commits behind the upstream repository.
-
Loïc Le Guyader authoredLoïc Le Guyader authored
DSSC1module.py 16.82 KiB
import multiprocessing
from time import strftime
from tqdm.auto import tqdm
import os
import warnings
import psutil
import extra_data as ed
from extra_data.read_machinery import find_proposal
import ToolBox as tb
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.patches as patches
import numpy as np
import xarray as xr
import h5py
from glob import glob
from imageio import imread
class DSSC1module:
def __init__(self, module, proposal):
""" Create a DSSC object to process 1 module of DSSC data.
inputs:
module: module number to process
proposal: (int,str) proposal number string
"""
self.module = module
if isinstance(proposal,int):
proposal = 'p{:06d}'.format(proposal)
self.runFolder = find_proposal(proposal)
self.semester = self.runFolder.split('/')[-2]
self.proposal = proposal
self.topic = self.runFolder.split('/')[-3]
self.save_folder = os.path.join(self.runFolder, 'usr/condensed_runs/')
self.px_pitch_h = 236 # horizontal pitch in microns
self.px_pitch_v = 204 # vertical pitch in microns
self.aspect = self.px_pitch_v/self.px_pitch_h # aspect ratio of the DSSC images
print('DSSC configuration')
print(f'DSSC module: {self.module}')
print(f'Topic: {self.topic}')
print(f'Semester: {self.semester}')
print(f'Proposal: {self.proposal}')
print(f'Default save folder: {self.save_folder}')
if not os.path.exists(self.save_folder):
warnings.warn(f'Default save folder does not exist: {self.save_folder}')
self.dark_data = 0
self.max_fraction_memory = 0.8
self.Nworker = 10
self.rois = None
self.maxSaturatedPixel = 1
def open_run(self, run_nr, t0=0.0):
""" Open a run with extra-data and prepare the virtual dataset for multiprocessing
inputs:
run_nr: the run number
t0: optional t0 in mm
"""
print('Opening run data with extra-data')
self.run_nr = run_nr
self.xgm = None
self.run = ed.open_run(self.proposal, self.run_nr)
self.plot_title = f'{self.proposal} run: {self.run_nr}'
self.fpt = self.run.detector_info(f'SCS_DET_DSSC1M-1/DET/{self.module}CH0:xtdf')['frames_per_train']
self.nbunches = self.run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value')
self.nbunches = np.unique(self.nbunches)
if len(self.nbunches) == 1:
self.nbunches = self.nbunches[0]
else:
warnings.warn('not all trains have same length DSSC data')
print(f'nbunches: {self.nbunches}')
self.nbunches = self.nbunches[-1]
print(f'DSSC frames per train: {self.fpt}')
print(f'SA3 bunches per train: {self.nbunches}')
print('Collecting DSSC module files')
self.collect_dssc_module_file()
print(f'Loading XGM data')
self.xgm = self.run.get_array(tb.mnemonics['SCS_SA3']['source'],
tb.mnemonics['SCS_SA3']['key'],
roi=ed.by_index[:self.nbunches])
self.xgm = self.xgm.rename({'dim_0':'pulseId'})
self.xgm['pulseId'] = np.arange(0, 2*self.nbunches, 2)
print(f'Loading mono nrj data')
self.nrj = self.run.get_array(tb.mnemonics['nrj']['source'],
tb.mnemonics['nrj']['key'])
print(f'Loading delay line data')
try:
self.delay_mm = self.run.get_array(tb.mnemonics['PP800_DelayLine']['source'],
tb.mnemonics['PP800_DelayLine']['key'])
except:
self.delay_mm = 0*self.nrj
self.t0 = t0
self.delay_ps = tb.positionToDelay(self.delay_mm, origin=self.t0, invert=True)
def collect_dssc_module_file(self):
""" Collect the raw DSSC module h5 files.
"""
pattern = self.runFolder + f'/raw/r{self.run_nr:04d}/RAW-R{self.run_nr:04d}-DSSC{self.module:02d}-S*.h5'
self.h5list = glob(pattern)
def process(self, dark_pass=None):
""" Process DSSC data from one module using multiprocessing
dark_pass: if None, process data, if 'mean', compute the mean, if 'std', compute the std
"""
# get available memory in GB, we will try to use 80 % of it
max_GB = psutil.virtual_memory().available/1024**3
print(f'max available memory: {max_GB} GB')
# max_GB / (8byte * Nworker * 128px * 512px * N_pulses)
self.chunksize = int(self.max_fraction_memory*max_GB * 1024**3 // (8 * self.Nworker * 128 * 512 * self.fpt))
print('processing', self.chunksize, 'trains per chunk')
if dark_pass == 'mean':
rois = None
dark = 0
mask = 1
elif dark_pass == 'std':
dark = self.dark_data['dark_mean']
rois = None
mask = 1
elif dark_pass is None:
dark = self.dark_data['dark_mean']
rois = self.rois
mask = self.dark_data['mask']
else:
raise ValueError(f"dark_pass should be either None or 'mean' or 'std' but not {dark_pass}")
jobs = []
for m,h5fname in enumerate(self.h5list):
jobs.append(dict(
fpt=self.fpt,
module=self.module,
h5fname=h5fname,
chunksize=self.chunksize,
nbunches=self.nbunches,
workerId=m,
Nworker=self.Nworker,
dark_data=dark,
rois=rois,
mask=mask,
maxSaturatedPixel=self.maxSaturatedPixel
))
timestamp = strftime('%X')
print(f'start time: {timestamp}')
with multiprocessing.Pool(self.Nworker) as pool:
res = pool.map(process_one_module, jobs)
print('finished:', strftime('%X'))
# rearange the multiprocessed data
# this is to get rid of the worker dimension, there is no sum over worker really involved
self.module_data = xr.concat(res, dim='worker').sum(dim='worker')
# reorder the dimension
if 'trainId' in self.module_data.dims:
self.module_data = self.module_data.transpose('trainId', 'pulseId', 'x', 'y')
else:
self.module_data = self.module_data.transpose('pulseId', 'x', 'y')
# fix some computation now that we have everything
self.module_data['std_data'] = np.sqrt(self.module_data['std_data']/(self.module_data['counts'] - 1))
self.module_data['dark_corrected_data'] = self.module_data['dark_corrected_data']/self.module_data['counts']
self.module_data['run'] = self.run_nr
if dark_pass == 'mean':
self.dark_data = self.module_data['dark_corrected_data'].to_dataset('dark_mean')
self.dark_data['run'] = self.run_nr
elif dark_pass == 'std':
self.dark_data['dark_std'] = self.module_data['std_data']
assert self.dark_data['run'] == self.run_nr, "noise map computed from different darks"
else:
self.module_data['xgm'] = self.xgm
self.module_data['nrj'] = self.nrj
self.module_data['delay_mm'] = self.delay_mm
self.module_data['delay_ps'] = self.delay_ps
self.module_data['t0'] = self.t0
self.plot_title = f"{self.proposal} run: {self.module_data['run'].values} dark: {self.dark_data['run'].values}"
self.module_data.attrs['plot_title'] = self.plot_title
def compute_mask(self, low=0.01, high=0.8):
""" Compute a DSSC module mask from the noise map of a dark run.
"""
if self.dark_data['dark_std'] is None:
raise ValueError('Cannot compute from from a missing dark noise map')
self.dark_data['mask_low'] = low
self.dark_data['mask_high'] = high
m_std = self.dark_data['dark_std'].mean('pulseId')
self.dark_data['mask'] = 1 - ((m_std > self.dark_data['mask_high']) + (m_std < self.dark_data['mask_low']))
def plot_module(self, plot_dark=False, low=1, high=98, vmin=None, vmax=None):
""" Plot a module.
inputs:
plot_dark: if true, plot dark instead of run data.
low: low percentile fraction of the display scale
high: high percentile fraction of the display scale
vmin: low value of the display scale, overwrites vmin computed from low
vmax: max value of the display scale, overwrites vmax computed from high
"""
if plot_dark:
mean = self.dark_data['dark_mean'].mean('pulseId')
std = self.dark_data['dark_std']
title = f"{self.proposal} dark: {self.dark_data['run'].values}"
else:
mean = self.module_data['dark_corrected_data'].mean('pulseId')
std = self.module_data['std_data']
title = self.plot_title
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, figsize=[5, 4*2.5])
_vmin, _vmax = np.percentile((mean.values[~self.dark_data['mask']]).flatten(), [low, high])
if vmin is None:
vmin = _vmin
if vmax is None:
vmax = _vmax
im = ax1.imshow(mean, vmin=vmin, vmax=vmax)
fig.colorbar(im, ax=ax1)
ax1.set_title('mean')
fig.suptitle(title)
im = ax2.imshow(std.mean('pulseId'), vmin=0, vmax=2)
fig.colorbar(im, ax=ax2)
ax2.set_title('std')
ax3.hist(std.values.flatten(), bins=200, range=[0, 2], density=True)
ax3.axvline(self.dark_data['mask_low'], ls='--', c='k')
ax3.axvline(self.dark_data['mask_high'], ls='--', c='k')
ax3.set_yscale('log')
ax3.set_ylabel('density')
ax3.set_xlabel('std values')
im = ax4.imshow(self.dark_data['mask'])
fig.colorbar(im, ax=ax4)
def save(self, save_folder=None, overwrite=False, isDark=False):
""" Save the crunched data.
inputs:
save_folder: string of the fodler where to save the data.
overwrite: boolean whether or not to overwrite existing files.
isDark: save the dark or the process data
"""
if save_folder is None:
save_folder = self.save_folder
if isDark:
fname = f'run{self.run_nr}_dark.h5' # no scan
data = self.dark_data
else:
fname = f'run{self.run_nr}.h5' # run with delay scan (change for other scan types!)
data = self.module_data
save_path = os.path.join(save_folder, fname)
file_exists = os.path.isfile(save_path)
if not file_exists or (file_exists and overwrite):
if file_exists:
warnings.warn(f'Overwriting file: {save_path}')
os.remove(save_path)
data.to_netcdf(save_path, group='data')
data.close()
os.chmod(save_path, 0o664)
print('saving: ', save_path)
else:
print('file', save_path, 'exists and overwrite is False')
def load_dark(self, dark_runNB, save_folder=None):
""" Load dark data.
inputs:
save_folder: string of the folder where the data were saved.
"""
if save_folder is None:
save_folder = self.save_folder
self.run_nr = dark_runNB
self.dark_data = xr.load_dataset(os.path.join(save_folder, f'run{dark_runNB}_dark.h5'), group='data')
self.plot_title = f"{self.proposal} dark: {self.dark_data['run'].values}"
def show_rois(self):
fig, ax1 = plt.subplots(nrows=1, figsize=[5, 2.5])
try:
ax1.imshow(self.module_data['dark_corrected_data'].mean('pulseId') * self.dark_data['mask'])
except:
ax1.imshow(self.dark_data['dark_mean'].mean('pulseId') * self.dark_data['mask'])
for r,v in self.rois.items():
rect = patches.Rectangle((v['y'][0], v['x'][0]),
v['y'][1] - v['y'][0],
v['x'][1] - v['x'][0],
linewidth=1, edgecolor='r', facecolor='none')
ax1.add_patch(rect)
fig.suptitle(self.plot_title)
# since 'self' is not pickable, this function has to be outside the DSSC class so that it can be used
# by the multiprocessing pool.map function
def process_one_module(job):
chunksize = job['chunksize']
Nworker = job['Nworker']
workerId = job['workerId']
dark_data = job['dark_data']
fpt = job['fpt']
module = job['module']
rois = job['rois']
mask = job['mask']
h5fname = job['h5fname']
maxSaturatedPixel = job['maxSaturatedPixel']
image_path = f"INSTRUMENT/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/data"
npulse_path = f"INDEX/SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf/image/count"
with h5py.File(h5fname, 'r') as m:
all_trainIds = m['INDEX/trainId'][()]
n_trains = len(all_trainIds)
n_chunk = np.ceil(n_trains/chunksize) + 1
chunks = np.linspace(0, n_trains, n_chunk, endpoint=True, dtype=int)
# create empty dataset to add actual data to
module_data = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
dims=['pulseId', 'x', 'y'],
coords={'pulseId':np.arange(fpt)}).to_dataset(name='dark_corrected_data')
module_data['std_data'] = xr.DataArray(np.zeros([fpt, 128, 512], dtype=np.float64),
dims=['pulseId', 'x', 'y'])
if rois is not None:
for k in rois.keys():
module_data[k] = xr.DataArray(np.empty([n_trains], dtype=np.float64),
dims=['trainId'], coords = {'trainId': all_trainIds})
module_data['counts'] = 0
# crunching
with h5py.File(h5fname, 'r') as m:
#chunk_start = np.arange(len(all_trainIds), step=job['chunksize'], dtype=int)
trains_start = 0
# This line is the strange hack from https://github.com/tqdm/tqdm/issues/485
print(' ', end='', flush=True)
for k,v in enumerate(tqdm(chunks[:-1], desc=f"pool.map#{workerId:02d}")):
chunk_dssc = np.s_[int(chunks[k] * fpt):int(chunks[k+1] * fpt)] # for dssc data
data = m[image_path][chunk_dssc].squeeze()
trains = m['INDEX/trainId'][np.s_[int(chunks[k]):int(chunks[k+1])]]
n_trains = len(trains)
data = data.astype(np.float64)
data = xr.DataArray(np.reshape(data, [n_trains, fpt, 128, 512]),
dims=['trainId', 'pulseId', 'x', 'y'],
coords={'trainId': trains})
temp = data - dark_data
if rois is not None:
temp2 = temp.where(mask)
for k,v in rois.items():
bkg = dark_data.isel({'x':slice(v['x'][0], v['x'][1]),
'y':slice(v['y'][0], v['y'][1])})
im = data.isel({'x':slice(v['x'][0], v['x'][1]),
'y':slice(v['y'][0], v['y'][1])})
smask = mask.isel({'x':slice(v['x'][0], v['x'][1]),
'y':slice(v['y'][0], v['y'][1])})
im = im.where(smask)
cim = im - bkg.where(smask)
val = cim.sum(dim=['x','y'])
tokeep = (im>254).sum(dim=['x', 'y']) < maxSaturatedPixel
tokeep = tokeep.assign_coords(pulseId = val.coords['pulseId'].values)
todrop = 1-tokeep
Ndropped = todrop.sum().values
percent_dropped = 100*Ndropped/(n_trains * fpt)
print(f'Dropped: {Ndropped}, i.e. {percent_dropped:.2f}%')
gval = val.where(tokeep, np.nan)
module_data[k] = gval
module_data['dark_corrected_data'] += temp.sum(dim='trainId')
module_data['std_data'] += (temp**2).sum(dim='trainId')
module_data['counts'] += n_trains
return module_data