diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 695dc689a0de450137035c752e2fd207dff01b73..ac27237030cdae02c01efae1c7bf3cc82383c5bc 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): self.mean = dict() self.std = dict() - def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False) -> np.ndarray: + def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: List[int]=[0]) -> np.ndarray: """ Get a dictionary with the channel names for the inut low resolution data and output only the relevant input data in an array. @@ -487,8 +487,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): X: Dictionary with keys named channel_{i}_{k}, where i is a number between 1 and 4 and k is a letter between A and D. keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary. + pulse_spacing: Distances between pulses in multi-pulse data. If there is only one pulse, set it to a list containing only the element zero. - Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features). + Returns: Concatenated and pre-processed low-resolution data of shape (train_id, pulse_id, features). """ if self.tof_start is None: raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.") @@ -496,12 +497,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): if self.delta_tof is not None: first = max(0, self.tof_start - self.delta_tof) last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof) - y = {channel: item[:, first:last] for channel, item in X.items()} + y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing], axis=1) + for channel, item in X.items()} if not keep_dictionary_structure: selected = list(y.values()) if self.poly: selected += [np.sqrt(np.fabs(v)) for v in y.values()] - return np.concatenate(selected, axis=1) + return np.concatenate(selected, axis=-1) return y def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int: @@ -992,7 +994,7 @@ class Model(TransformerMixin, BaseEstimator): #fig.savefig("tmp.png") #plt.close(fig) - Hmod = np.real(np.absolute(H)) + #Hmod = np.real(np.absolute(H)) Gdir = np.fft.fftshift(np.fft.ifft(G)) self.wiener_filter = Gdir self.wiener_filter_ft = G @@ -1005,7 +1007,7 @@ class Model(TransformerMixin, BaseEstimator): energy_mu = np.sum(e_axis*hmod)/np.sum(hmod) energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod) self.resolution = np.sqrt(energy_var) - print("Resolution:", self.resolution) + #print("Resolution:", self.resolution) # get intensity effect intensity = np.sum(z, axis=1) @@ -1069,13 +1071,14 @@ class Model(TransformerMixin, BaseEstimator): """Get KDE for the predicted intensity.""" return self.kde_intensity - def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> Dict[str, np.ndarray]: """ Predict a high-resolution spectrum from a low resolution given one. The output includes the uncertainty in its second and third entries of the first dimension. Args: low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + pulse_spacing: Where each pulse starts, relative to the first pulse (which is at 0). For single pulse data, set this to [0]. Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing the expected prediction in key "expected", the stat. uncertainty in key "unc" and @@ -1086,7 +1089,7 @@ class Model(TransformerMixin, BaseEstimator): #t += [time_ns()*1e-9] #n += ["Initial"] - low_res_pre = self.x_select.transform(low_res_data) + low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing) #t += [time_ns()*1e-9] #n += ["Select"]