Skip to content
Snippets Groups Projects

Handle multi-pulse data

Merged Danilo Enoque Ferreira de Lima requested to merge pulses into main
1 file
+ 11
8
Compare changes
  • Side-by-side
  • Inline
+ 11
8
@@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
@@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
self.mean = dict()
self.mean = dict()
self.std = dict()
self.std = dict()
def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False) -> np.ndarray:
def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: List[int]=[0]) -> np.ndarray:
"""
"""
Get a dictionary with the channel names for the inut low resolution data and output
Get a dictionary with the channel names for the inut low resolution data and output
only the relevant input data in an array.
only the relevant input data in an array.
@@ -487,8 +487,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
@@ -487,8 +487,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
X: Dictionary with keys named channel_{i}_{k},
X: Dictionary with keys named channel_{i}_{k},
where i is a number between 1 and 4 and k is a letter between A and D.
where i is a number between 1 and 4 and k is a letter between A and D.
keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary.
keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary.
 
pulse_spacing: Distances between pulses in multi-pulse data. If there is only one pulse, set it to a list containing only the element zero.
Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
Returns: Concatenated and pre-processed low-resolution data of shape (train_id, pulse_id, features).
"""
"""
if self.tof_start is None:
if self.tof_start is None:
raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
@@ -496,12 +497,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
@@ -496,12 +497,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
if self.delta_tof is not None:
if self.delta_tof is not None:
first = max(0, self.tof_start - self.delta_tof)
first = max(0, self.tof_start - self.delta_tof)
last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
y = {channel: item[:, first:last] for channel, item in X.items()}
y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing], axis=1)
 
for channel, item in X.items()}
if not keep_dictionary_structure:
if not keep_dictionary_structure:
selected = list(y.values())
selected = list(y.values())
if self.poly:
if self.poly:
selected += [np.sqrt(np.fabs(v)) for v in y.values()]
selected += [np.sqrt(np.fabs(v)) for v in y.values()]
return np.concatenate(selected, axis=1)
return np.concatenate(selected, axis=-1)
return y
return y
def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
@@ -992,7 +994,7 @@ class Model(TransformerMixin, BaseEstimator):
@@ -992,7 +994,7 @@ class Model(TransformerMixin, BaseEstimator):
#fig.savefig("tmp.png")
#fig.savefig("tmp.png")
#plt.close(fig)
#plt.close(fig)
Hmod = np.real(np.absolute(H))
#Hmod = np.real(np.absolute(H))
Gdir = np.fft.fftshift(np.fft.ifft(G))
Gdir = np.fft.fftshift(np.fft.ifft(G))
self.wiener_filter = Gdir
self.wiener_filter = Gdir
self.wiener_filter_ft = G
self.wiener_filter_ft = G
@@ -1005,7 +1007,7 @@ class Model(TransformerMixin, BaseEstimator):
@@ -1005,7 +1007,7 @@ class Model(TransformerMixin, BaseEstimator):
energy_mu = np.sum(e_axis*hmod)/np.sum(hmod)
energy_mu = np.sum(e_axis*hmod)/np.sum(hmod)
energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod)
energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod)
self.resolution = np.sqrt(energy_var)
self.resolution = np.sqrt(energy_var)
print("Resolution:", self.resolution)
#print("Resolution:", self.resolution)
# get intensity effect
# get intensity effect
intensity = np.sum(z, axis=1)
intensity = np.sum(z, axis=1)
@@ -1069,13 +1071,14 @@ class Model(TransformerMixin, BaseEstimator):
@@ -1069,13 +1071,14 @@ class Model(TransformerMixin, BaseEstimator):
"""Get KDE for the predicted intensity."""
"""Get KDE for the predicted intensity."""
return self.kde_intensity
return self.kde_intensity
def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> Dict[str, np.ndarray]:
"""
"""
Predict a high-resolution spectrum from a low resolution given one.
Predict a high-resolution spectrum from a low resolution given one.
The output includes the uncertainty in its second and third entries of the first dimension.
The output includes the uncertainty in its second and third entries of the first dimension.
Args:
Args:
low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
 
pulse_spacing: Where each pulse starts, relative to the first pulse (which is at 0). For single pulse data, set this to [0].
Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
the expected prediction in key "expected", the stat. uncertainty in key "unc" and
the expected prediction in key "expected", the stat. uncertainty in key "unc" and
@@ -1086,7 +1089,7 @@ class Model(TransformerMixin, BaseEstimator):
@@ -1086,7 +1089,7 @@ class Model(TransformerMixin, BaseEstimator):
#t += [time_ns()*1e-9]
#t += [time_ns()*1e-9]
#n += ["Initial"]
#n += ["Initial"]
low_res_pre = self.x_select.transform(low_res_data)
low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
#t += [time_ns()*1e-9]
#t += [time_ns()*1e-9]
#n += ["Select"]
#n += ["Select"]
Loading