Skip to content
Snippets Groups Projects
Commit 73e78ede authored by Egor Sobolev's avatar Egor Sobolev Committed by spbonc
Browse files

Fix according refactoring, add photon rounding and extra masks

parent e8cb2bf9
No related branches found
No related tags found
1 merge request!70Autocorrelation addon
import h5py
import numpy as np
from calngUtils import shmem_utils
from karabo.bound import (
DOUBLE_ELEMENT, INT32_ELEMENT, NDARRAY_ELEMENT, STRING_ELEMENT,
VECTOR_STRING_ELEMENT)
from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT, STRING_ELEMENT
from .. import base_kernel_runner
from .base_addon import BaseCorrectionAddon
from .. import shmem_utils
def block(b):
......@@ -77,6 +80,11 @@ class Autocorrelation(BaseCorrectionAddon):
@staticmethod
def extend_device_schema(schema, prefix):
(
INT32_ELEMENT(schema)
.key(f"{prefix}.moduleNumber")
.readOnly().initialValue(-1)
.commit(),
STRING_ELEMENT(schema)
.key(f"{prefix}.boundary")
.tags("managed")
......@@ -93,6 +101,22 @@ class Autocorrelation(BaseCorrectionAddon):
.defaultValue(9.3)
.reconfigurable()
.commit(),
DOUBLE_ELEMENT(schema)
.key(f"{prefix}.roundingThreshold")
.tags("managed")
.assignmentOptional()
.defaultValue(0.7)
.reconfigurable()
.commit(),
VECTOR_STRING_ELEMENT(schema)
.key(f"{prefix}.maskPaths")
.tags("managed")
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
)
@staticmethod
......@@ -104,35 +128,51 @@ class Autocorrelation(BaseCorrectionAddon):
.commit(),
)
def __init__(self, config):
def __init__(self, device, prefix, config):
super().__init__(device, prefix, config)
self._shape = None
self._shmem_buffer = None
global cupy
import cupy
global fft
kernel_type = base_kernel_runner.KernelRunnerTypes[
device.unsafe_get("kernelType")
]
if kernel_type is base_kernel_runner.KernelRunnerTypes.CPU:
import scipy.fft as fft
else:
import cupyx.scipy.fft as fft
self._intensity_per_photon = config["intenstityPerPhoton"]
self._autocorr = AUTOCORR_FUN[config["boundary"]]
self._rounding_threshold = config["roundingThreshold"]
self._shmem_buffer = None
da = device["constantParameters.karaboDa"]
self._modno = int(da[-2:])
device.set(f"{prefix}.moduleNumber", self._modno)
self._load_mask(config["maskPaths"])
def __del__(self):
del self._shmem_buffer
super().__del__()
def _initialization(self):
global fft
if self._device.kernel_runner._gpu_based:
import cupyx.scipy.fft as fft
else:
import scipy.fft as fft
self._update_buffer()
def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
nf, nx, ny = processed_data.shape
if nf != self._nframe:
self._update_buffer()
def post_correction(self, train_id, processed_data, cell_table, pulse_table, output_hash):
shape = processed_data.shape
if shape != self._shape:
self._update_buffer(shape)
data = np.around(processed_data / self._intensity_per_photon)
data[np.isnan(data) | (data < 0.0)] = 0.0
np.around(
processed_data / self._intensity_per_photon - (self._rounding_threshold - 0.5),
out=processed_data
)
processed_data[processed_data < 0.0] = 0.0
if self._mask is not None:
processed_data[:, self._mask] = np.nan
data = processed_data.copy()
data[np.isnan(data)] = 0.0
autocorr = self._autocorr(data)
......@@ -150,23 +190,18 @@ class Autocorrelation(BaseCorrectionAddon):
else:
output_hash.set("image.autocorr", buffer_array)
delete_paths = output_hash.get("deleteThese", default=[])
output_hash.set("deleteThese",
delete_paths + ["image.autocorr"])
# def post_reshape(self, reshaped_data, cell_table, pulse_table, output_hash):
# pass
def reconfigure(self, changed_config):
if changed_config.has("intenstityPerPhoton"):
self._intensity_per_photon = changed_config["intenstityPerPhoton"]
if changed_config.has("boundary"):
self._autocorr = AUTOCORR_FUN[changed_config["boundary"]]
if changed_config.has("roundingThreshold"):
self._rounding_threshold = changed_config["roundingThreshold"]
if changed_config.has("maskPaths"):
self._load_mask(changed_config["maskPaths"])
def _update_buffer(self):
nf, _, nx, ny = self._device.input_data_shape
self._nframe = nf
shape = (nf, nx, ny)
def _update_buffer(self, shape):
self._shape = shape
if self._shmem_buffer is None:
shmem_buffer_name = self._device.getInstanceId() + ":dataOutput/autocorrelation"
......@@ -181,3 +216,16 @@ class Autocorrelation(BaseCorrectionAddon):
self._shmem_buffer.cuda_pin()
else:
self._shmem_buffer.change_shape(shape)
def _load_mask(self, paths):
self._mask = None
for fn in paths:
try:
with h5py.File(fn, 'r') as f:
mask = f["entry_1/data_1/mask"][self._modno] != 0
if self._mask is None:
self._mask = mask
else:
self._mask = self._mask | mask
except Exception as e:
self._device.logger.warn(f"Aurocorrelation addon: {e}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment