Skip to content
Snippets Groups Projects
Commit 1954f34f authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Allow for different pulse spacings per channel.

parent 4db25a16
No related branches found
No related tags found
1 merge request!10Handle multi-pulse data
......@@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
self.mean = dict()
self.std = dict()
def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: List[int]=[0]) -> np.ndarray:
def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
"""
Get a dictionary with the channel names for the inut low resolution data and output
only the relevant input data in an array.
......@@ -493,11 +493,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
"""
if self.tof_start is None:
raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
if pulse_spacing is None:
pulse_spacing = {ch: [0] for ch in X.keys()}
y = X
if self.delta_tof is not None:
first = max(0, self.tof_start - self.delta_tof)
last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing], axis=1)
y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]], axis=1)
for channel, item in X.items()}
if not keep_dictionary_structure:
selected = list(y.values())
......@@ -1032,7 +1034,7 @@ class Model(TransformerMixin, BaseEstimator):
return high_res.reshape((B, P, -1))
def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> Dict[str, np.ndarray]:
def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]:
"""
Check if a new low-resolution data source is compatible with the one used in training, by
comparing the effect of the trained PCA model on it, but do it per channel.
......@@ -1055,7 +1057,7 @@ class Model(TransformerMixin, BaseEstimator):
result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels}
return result
def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> np.ndarray:
def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
"""
Check if a new low-resolution data source is compatible with the one used in training, by
using a robust covariance matrix estimate of the data
......@@ -1069,7 +1071,7 @@ class Model(TransformerMixin, BaseEstimator):
low_res = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
B, P, _ = low_res.shape
pca_model = self.x_model
low_pca = pca_model.transform(low_res.reshape((B, P, -1)))
low_pca = pca_model.transform(low_res.reshape((B*P, -1)))
return self.ood['full'].predict(low_pca).reshape((B, P))
def xgm_profile(self) -> gaussian_kde:
......@@ -1080,7 +1082,7 @@ class Model(TransformerMixin, BaseEstimator):
"""Get KDE for the predicted intensity."""
return self.kde_intensity
def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> Dict[str, np.ndarray]:
def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]:
"""
Predict a high-resolution spectrum from a low resolution given one.
The output includes the uncertainty in its second and third entries of the first dimension.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment