Skip to content
Snippets Groups Projects
Commit e73a5bc9 authored by Egor Sobolev's avatar Egor Sobolev
Browse files

Merge branch 'fix/selection-based-on-cellid'

parents fc570222 f6993f23
No related branches found
No related tags found
1 merge request!814[AGIPD][CORRECT] Select frames using cellId instead of position
...@@ -101,7 +101,7 @@ install_requires = [ ...@@ -101,7 +101,7 @@ install_requires = [
"tabulate==0.8.6", "tabulate==0.8.6",
"traitlets==4.3.3", "traitlets==4.3.3",
"xarray==2022.3.0", "xarray==2022.3.0",
"EXtra-redu==0.0.7", "EXtra-redu==0.0.8",
"rich==12.6.0", "rich==12.6.0",
"httpx==0.23.0", "httpx==0.23.0",
] ]
......
...@@ -316,17 +316,21 @@ class CellSelection: ...@@ -316,17 +316,21 @@ class CellSelection:
raise NotImplementedError raise NotImplementedError
def get_cells_on_trains( def get_cells_on_trains(
self, train_sel: np.ndarray, nfrm: np.ndarray, cm: int = 0 self, train_sel: np.ndarray, nfrm: np.ndarray,
cellid: np.ndarray, cm: int = 0
) -> np.array: ) -> np.array:
"""Returns mask of cells selected for processing """Returns mask of cells selected for processing
:param train_sel: list of a train ids selected for processing :param train_sel: list of a train ids selected for processing
:param nfrm: the number of frames expected for every train in :param nfrm: the number of frames expected for every train in
the list `train_sel` the list `train_sel`
:param cellid: array of cell IDs in the same sequence as images to
filter
:param cm: flag indicates the final selection or interim selection :param cm: flag indicates the final selection or interim selection
for common-mode correction for common-mode correction
:returns:
:return: boolean array with flags indicating images for processing - boolean array with flags indicating images for processing
- integer array with number of selected frames in trains
""" """
raise NotImplementedError raise NotImplementedError
...@@ -337,17 +341,6 @@ class CellSelection: ...@@ -337,17 +341,6 @@ class CellSelection:
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod
def _sel_for_cm(flag, flag_cm, cm):
if cm == CellSelection.CM_NONE:
return flag
elif cm == CellSelection.CM_PRESEL:
return flag_cm
elif cm == CellSelection.CM_FINSEL:
return flag[flag_cm]
else:
raise ValueError("param 'cm' takes only 0,1,2")
class AgipdCorrections: class AgipdCorrections:
...@@ -535,7 +528,6 @@ class AgipdCorrections: ...@@ -535,7 +528,6 @@ class AgipdCorrections:
n_valid_trains = len(valid_train_ids) n_valid_trains = len(valid_train_ids)
data_dict["n_valid_trains"][0] = n_valid_trains data_dict["n_valid_trains"][0] = n_valid_trains
data_dict["valid_trains"][:n_valid_trains] = valid_train_ids data_dict["valid_trains"][:n_valid_trains] = valid_train_ids
data_dict["nimg_in_trains"][:n_valid_trains] = nimg_in_trains
if "AGIPD500K" in agipd_base: if "AGIPD500K" in agipd_base:
agipd_comp = components.AGIPD500K(im_dc) agipd_comp = components.AGIPD500K(im_dc)
...@@ -550,8 +542,11 @@ class AgipdCorrections: ...@@ -550,8 +542,11 @@ class AgipdCorrections:
cm = (self.cell_sel.CM_NONE if apply_sel_pulses cm = (self.cell_sel.CM_NONE if apply_sel_pulses
else self.cell_sel.CM_PRESEL) else self.cell_sel.CM_PRESEL)
img_selected = self.cell_sel.get_cells_on_trains( cellid = np.squeeze(im_dc[agipd_base, "image.cellId"].ndarray())
np.array(valid_train_ids), nimg_in_trains, cm=cm)
img_selected, nimg_in_trains = self.cell_sel.get_cells_on_trains(
np.array(valid_train_ids), nimg_in_trains, cellid, cm=cm)
data_dict["nimg_in_trains"][:n_valid_trains] = nimg_in_trains
frm_ix = np.flatnonzero(img_selected) frm_ix = np.flatnonzero(img_selected)
data_dict["cm_presel"][0] = (cm == self.cell_sel.CM_PRESEL) data_dict["cm_presel"][0] = (cm == self.cell_sel.CM_PRESEL)
...@@ -1022,11 +1017,13 @@ class AgipdCorrections: ...@@ -1022,11 +1017,13 @@ class AgipdCorrections:
ntrains = data_dict["n_valid_trains"][0] ntrains = data_dict["n_valid_trains"][0]
train_ids = data_dict["valid_trains"][:ntrains] train_ids = data_dict["valid_trains"][:ntrains]
nimg_in_trains = data_dict["nimg_in_trains"][:ntrains] nimg_in_trains = data_dict["nimg_in_trains"][:ntrains]
cellid = data_dict["cellId"][:n_img]
# Initializing can_calibrate array # Initializing can_calibrate array
can_calibrate = self.cell_sel.get_cells_on_trains( can_calibrate, nimg_in_trains = self.cell_sel.get_cells_on_trains(
train_ids, nimg_in_trains, cm=self.cell_sel.CM_FINSEL train_ids, nimg_in_trains, cellid, cm=self.cell_sel.CM_FINSEL
) )
data_dict["nimg_in_trains"][:ntrains] = nimg_in_trains
if np.all(can_calibrate): if np.all(can_calibrate):
return n_img return n_img
...@@ -1624,6 +1621,7 @@ class CellRange(CellSelection): ...@@ -1624,6 +1621,7 @@ class CellRange(CellSelection):
self.flag_cm[:self.max_cells] = self.flag self.flag_cm[:self.max_cells] = self.flag
self.flag_cm = (self.flag_cm.reshape(-1, self.row_size).any(1) self.flag_cm = (self.flag_cm.reshape(-1, self.row_size).any(1)
.repeat(self.row_size)[:self.max_cells]) .repeat(self.row_size)[:self.max_cells])
self.sel_type = [self.flag, self.flag_cm, self.flag]
def msg(self): def msg(self):
return ( return (
...@@ -1633,10 +1631,24 @@ class CellRange(CellSelection): ...@@ -1633,10 +1631,24 @@ class CellRange(CellSelection):
) )
def get_cells_on_trains( def get_cells_on_trains(
self, train_sel: np.ndarray, nfrm: np.ndarray, cm: int = 0 self, train_sel: np.ndarray, nfrm: np.ndarray,
cellid: np.ndarray, cm: int = 0
) -> np.array: ) -> np.array:
return np.tile(self._sel_for_cm(self.flag, self.flag_cm, cm), if cm < 0 or cm > 2:
len(train_sel)) raise ValueError("param 'cm' takes only 0,1,2")
flag = self.sel_type[cm]
sel = np.zeros(np.sum(nfrm), bool)
counts = np.zeros(len(nfrm), int)
i0 = 0
for i, nfrm_i in enumerate(nfrm):
iN = i0 + nfrm_i
f = flag[cellid[i0:iN]]
sel[i0:iN] = f
counts[i] = np.sum(f)
i0 = iN
return sel, counts
def filter_trains(self, train_sel: np.ndarray): def filter_trains(self, train_sel: np.ndarray):
return train_sel return train_sel
...@@ -1670,14 +1682,11 @@ class LitFrameSelection(CellSelection): ...@@ -1670,14 +1682,11 @@ class LitFrameSelection(CellSelection):
self.use_super_selection = use_super_selection self.use_super_selection = use_super_selection
if use_super_selection == 'off': if use_super_selection == 'off':
self.cm_sel_type = SelType.ROW self.sel_type = [SelType.CELL, SelType.ROW, SelType.CELL]
self.final_sel_type = SelType.CELL
elif use_super_selection == 'cm': elif use_super_selection == 'cm':
self.cm_sel_type = SelType.SUPER_ROW self.sel_type = [SelType.CELL, SelType.SUPER_ROW, SelType.CELL]
self.final_sel_type = SelType.CELL
elif use_super_selection == 'final': elif use_super_selection == 'final':
self.cm_sel_type = SelType.SUPER_ROW self.sel_type = [SelType.SUPER_CELL, SelType.SUPER_ROW, SelType.SUPER_CELL]
self.final_sel_type = SelType.SUPER_CELL
else: else:
raise ValueError("param 'use_super_selection' takes only " raise ValueError("param 'use_super_selection' takes only "
"'off', 'cm' or 'final'") "'off', 'cm' or 'final'")
...@@ -1713,12 +1722,16 @@ class LitFrameSelection(CellSelection): ...@@ -1713,12 +1722,16 @@ class LitFrameSelection(CellSelection):
) )
def get_cells_on_trains( def get_cells_on_trains(
self, train_sel: np.ndarray, nfrm: np.ndarray, cm: int = 0 self, train_sel: np.ndarray, nfrm: np.ndarray,
cellid: np.ndarray, cm: int = 0
) -> np.array: ) -> np.array:
if cm < 0 or cm > 2:
raise ValueError("param 'cm' takes only 0,1,2")
(sel, counts), = self._sel.litframes_on_trains(
train_sel, nfrm, cellid, [self.sel_type[cm]])
cell_flags, cm_flags = self._sel.litframes_on_trains( return sel, counts
train_sel, nfrm, [self.final_sel_type, self.cm_sel_type])
return self._sel_for_cm(cell_flags, cm_flags, cm)
def filter_trains(self, train_sel: np.ndarray): def filter_trains(self, train_sel: np.ndarray):
return self._sel.filter_trains(train_sel, drop_empty=True) return self._sel.filter_trains(train_sel, drop_empty=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment