From 7f02cc4e873411423ffd03cad59c6fdf069675cf Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 9 Oct 2023 16:12:12 +0200 Subject: [PATCH] Automated property definition. --- pes_to_spec/model.py | 52 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index a58e954..02512dc 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -603,6 +603,41 @@ class UncorrelatedDeviation(OutlierMixin, BaseEstimator): """ return accuracy_score(y, self.predict(X), sample_weight=sample_weight) +def get_model_with_resolution(low_res_data: Dict[str, np.ndarray], + high_res_data: np.ndarray, + high_res_photon_energy: np.ndarray, + pulse_energy: np.ndarray, + **kwargs) -> Tuple[Model, float]: + """ + Create a model to obtain the resolution and then use the discovered resolution + to update the model. + + Args: + low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, + where i is a number between 1 and 4 and k is a letter between A and D. + For each dictionary entry, a numpy array is expected with shape + (train_id, ToF channel). + high_res_data: Reference high resolution data with a one-to-one match to the + low resolution data in the train_id dimension. Shape (train_id, ToF channel). + high_res_photon_energy: Photon energy axis for the high-resolution data. + pulse_energy: XGM intensity. + **kwargs: Other arguments sent to the Model constructor. + This includes the `channels` argument with a list of channels, for example. + + Returns: Model and resolution. + """ + #kwargs_no_smear = {k: v for k, v in kwargs.items() + # if k != "high_res_fwhm"} + #model = Model(**kwargs_no_smear) + #model.fit(low_res_data, high_res_data, high_res_photon_energy, pulse_energy=pulse_energy) + #resolution = model.resolution + #model = Model(high_res_fwhm=resolution*0.5, + # **kwargs_no_smear) + model = Model(**kwargs) + model.fit(low_res_data, high_res_data, high_res_photon_energy, pulse_energy=pulse_energy) + resolution = model.resolution + return model, resolution + class Model(TransformerMixin, BaseEstimator): """ Object representing a previous fit of the model to be used to predict high-resolution @@ -656,9 +691,9 @@ class Model(TransformerMixin, BaseEstimator): elif model_type == "bnn_rvm": self.fit_model = BNNModel(n_epochs=n_bnn_epochs, rvm=True) elif model_type == "ridge": - self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) + self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(tol=1e-8, verbose=True), n_jobs=8) elif model_type == "ard": - self.fit_model = MultiOutputGenericWithStd(ARDRegression(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) + self.fit_model = MultiOutputGenericWithStd(ARDRegression(tol=1e-8, verbose=True), n_jobs=8) self.model_type = model_type self.n_obs = 0 @@ -731,7 +766,7 @@ class Model(TransformerMixin, BaseEstimator): """ self.x_select.debug_peak_finding(low_res_data, filename) - def preprocess_high_res(self, high_res_data: np.ndarray) -> np.ndarray: + def preprocess_high_res(self, high_res_data: np.ndarray, resolution: Optional[float]=None) -> np.ndarray: """ Preprocess high-resolution data to remove high requency components. @@ -740,7 +775,10 @@ class Model(TransformerMixin, BaseEstimator): Returns: Smoothened spectrum. """ - return self.y_model['smoothen'].transform(high_res_data) + if resolution is None: + return self.y_model['smoothen'].transform(high_res_data) + s = HighResolutionSmoother(resolution) + return s.fit_transform(high_res_data, energy=self.get_energy_values()) def uniformize(self, intensity: np.ndarray) -> np.ndarray: """ @@ -810,7 +848,7 @@ class Model(TransformerMixin, BaseEstimator): pulse_energy: XGM intensity. ood: Whether to fit out-of-sample detection to test data compatibility later. - Returns: Smoothened high resolution spectrum. + Returns: Input high resolution spectrum. """ print("Checking data quality in high-resolution data.") peaks = self.count_peaks(high_res_data, high_res_photon_energy) @@ -877,7 +915,7 @@ class Model(TransformerMixin, BaseEstimator): # n: noise (uncertainty) # e: energy # true signal (as far as we can get -- it is smoothened, but this is the model target) - y = high_res[inliers & filter_hr] + y = high_res_data[inliers & filter_hr] y_pred, n = self.fit_model.predict(x_t[inliers & filter_hr], return_std=True) y_hat = self.y_model['pca'].inverse_transform(y_pred) @@ -952,7 +990,7 @@ class Model(TransformerMixin, BaseEstimator): print("End of fit.") - return high_res.reshape((B, P, -1)) + return high_res_data.reshape((B, P, -1)) def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]: """ -- GitLab