diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..880ed1d8ccc13b7d46b30014c3626a51f68a0b25 --- /dev/null +++ b/src/calng/calcat_utils.py @@ -0,0 +1,442 @@ +import copy +import enum +import functools +import pathlib +import threading +import typing + +import calibration_client +import h5py +import numpy as np +from calibration_client.modules import ( + Calibration, + CalibrationConstant, + CalibrationConstantVersion, + Detector, + DetectorType, + Parameter, + PhysicalDetectorUnit, +) +from karabo.bound import FLOAT_ELEMENT, NODE_ELEMENT, STRING_ELEMENT, UINT32_ELEMENT + +from . import utils + + +class ConditionNotFound(Exception): + pass + + +class DetectorNotFound(Exception): + pass + + +class ModuleNotFound(Exception): + pass + + +class CalibrationNotFound(Exception): + pass + + +class CalibrationClientConfigError(Exception): + pass + + +def _check_resp(resp, exception=Exception): + # TODO: probably verify using "info" that exception is the right one + if not resp["success"]: + # TODO: probably more types of app_info errors? + if resp["app_info"]: + if "not found" in resp["info"]: + # this was likely the exception exception + raise exception(resp["info"]) + else: + # but could also be authorization or similar issue + raise CalibrationClientConfigError(resp["app_info"]) + raise exception(resp["info"]) + + +class DetectorStandin(typing.NamedTuple): + detector_name: str + modno_to_source: dict + frames_per_train: int + module_shape: tuple + + +class OperatingConditions(dict): + # TODO: support deviation? + def encode(self): + return { + "parameters": [ + { + "parameter_name": key, + "lower_deviation_value": 0.0, + "upper_deviation_value": 0.0, + "flg_available": False, + "value": value, + } + for (key, value) in self.items() + ] + } + + def __hash__(self): + # this takes me back to pre-screening interview time... + return hash(tuple(sorted(self.items()))) + + +class BaseCalcatFriend: + _constant_enum_class = None # subclass should set + _constants_need_conditions = None # subclass should set + + @staticmethod + def add_schema(schema, prefix="constantParameters"): + """Add elements needed by this object to device's schema (expectedSchema) + + All elements added to schema go under prefix which should end with name of + node which does not exist yet. To change default values and add more fields, + extend this method in subclass. + + The values set on the device schema can handily be queried via _get. This + should only be done in helper methods generating "conditions"; see for example + AgipdCalcatFriend.dark_condition. + """ + + # Settings needed to use calibration client and find constants + ( + NODE_ELEMENT(schema).key(prefix).commit(), + NODE_ELEMENT(schema).key(f"{prefix}.calCat").commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.calCat.baseUrl") + .assignmentMandatory() + .commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.calCat.clientId") + .assignmentMandatory() + .commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.calCat.clientSecret") + .assignmentMandatory() + .commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.calCat.userEmail") + .assignmentOptional() + .defaultValue("") + .commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.calCat.caldbStore") + .displayedName("Location of caldb_store") + .assignmentOptional() + .defaultValue("/gpfs/exfel/d/cal/caldb_store") + .options("/gpfs/exfel/d/cal/caldb_store,/common/cal/caldb_store") + .commit(), + ) + + # Parameters which any detector would probably have (extend this in subclass) + # TODO: probably switch to floating point for everything, including mem cells + ( + STRING_ELEMENT(schema) + .key(f"{prefix}.detectorType") + .assignmentMandatory() + .commit(), + STRING_ELEMENT(schema) + .key(f"{prefix}.detectorName") + .assignmentMandatory() + .commit(), + UINT32_ELEMENT(schema) + .key(f"{prefix}.memoryCells") + .assignmentMandatory() + .commit(), + UINT32_ELEMENT(schema) + .key(f"{prefix}.pixelsX") + .assignmentOptional() + .defaultValue(512) + .commit(), + UINT32_ELEMENT(schema) + .key(f"{prefix}.pixelsY") + .assignmentOptional() + .defaultValue(512) + .commit(), + UINT32_ELEMENT(schema) + .key(f"{prefix}.biasVoltage") + .assignmentOptional() + .defaultValue(300) + .commit(), + ) + + def __init__( + self, + device, + prefix="constantParameters", + ): + self.device = device + self.prefix = prefix + # TODO: can constants be accessed from ONC? + self.caldb_store = pathlib.Path(self._get("calCat.caldbStore")) + if not self.caldb_store.is_dir(): + raise ValueError(f"caldb_store location '{self.caldb_store}' is not dir") + + # TODO: secret / token management + base_url = self._get("calCat.baseUrl") + self.client = calibration_client.CalibrationClient( + client_id=self._get("calCat.clientId"), + client_secret=self._get("calCat.clientSecret"), + user_email=self._get("calCat.userEmail"), + base_api_url=f"{base_url}/api/", + token_url=f"{base_url}/oauth/token", + refresh_url=f"{base_url}/oauth/token", + auth_url=f"{base_url}/oauth/authorize", + scope="public", + session_token=None, + ) + + def _get(self, key): + """Helper to get value from attached device schema""" + return self.device.get(f"{self.prefix}.{key}") + + @functools.cached_property + def detector_id(self): + resp = Detector.get_by_identifier(self.client, self._get("detectorName")) + _check_resp(resp, DetectorNotFound) + return resp["data"]["id"] + + @functools.cached_property + def detector_type_id(self): + resp = DetectorType.get_by_name(self.client, self._get("detectorType")) + _check_resp(resp, DetectorNotFound) + return resp["data"]["id"] + + # TODO: support updating mapping maybe? + # (just means del self.pdus and the properties depending on it) + @functools.cached_property + def pdus(self): + resp = PhysicalDetectorUnit.get_all_by_detector_id( + self.client, self.detector_id + ) + _check_resp(resp) + for irrelevant_key in ("detector", "detector_type", "flg_available"): + for pdu in resp["data"]: + del pdu[irrelevant_key] + return resp["data"] + + @functools.cached_property + def _karabo_da_to_float_uuid(self): + return {pdu["karabo_da"]: pdu["float_uuid"] for pdu in self.pdus} + + @functools.cached_property + def _karabo_da_to_id(self): + return {pdu["karabo_da"]: pdu["id"] for pdu in self.pdus} + + @utils.threadsafe_cache + def calibration_id(self, calibration_name: str): + resp = Calibration.get_by_name(self.client, calibration_name) + # TODO: dump name in exception + _check_resp(resp, CalibrationNotFound) + return resp["data"]["id"] + + @utils.threadsafe_cache + def condition_id(self, pdu, condition): + # modifying condition parameter messes with cache + condition_with_detector = copy.copy(condition) + condition_with_detector["Detector UUID"] = pdu + resp = self.client.search_possible_conditions_from_dict( + "", condition_with_detector.encode() + ) + _check_resp(resp, ConditionNotFound) + return resp["data"][0]["id"] + + @utils.threadsafe_cache + def constant_id(self, calibration_id, condition_id): + resp = CalibrationConstant.get_by_uk( + self.client, + calibration_id=calibration_id, + detector_type_id=self.detector_type_id, + condition_id=condition_id, + ) + _check_resp(resp) + return resp["data"]["id"] + + def get_constant_version(self, karabo_da, constant, snapshot_at=None): + # TODO: support snapshot + # TODO: support creation time + # TODO: move karabo_da into device config (no need to be able to get constants for other modules) + self.device.log.DEBUG(f"Going looking for {constant} for {karabo_da}") + if karabo_da not in self._karabo_da_to_float_uuid: + raise ModuleNotFound(f"Module map did not include {karabo_da}") + if isinstance(constant, str): + constant = self._constant_enum_class[constant] + calibration_id = self.calibration_id(constant.name) + condition = self._constants_need_conditions[constant]() + condition_id = self.condition_id( + self._karabo_da_to_float_uuid[karabo_da], condition + ) + constant_id = self.constant_id( + calibration_id=calibration_id, condition_id=condition_id + ) + + resp = CalibrationConstantVersion.get_by_uk( + self.client, + calibration_constant_id=constant_id, + physical_detector_unit_id=self._karabo_da_to_id[karabo_da], + event_at=None, + snapshot_at=None, + ) + _check_resp(resp) + timestamp = resp["data"][ + "begin_at" + ] # TODO: check which key we like (also has begin_validity_at) + file_path = ( + self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] + ) + # TODO: return time stamps and such, those are useful + # TODO: handle FileNotFoundError if we are led astray + with h5py.File(file_path, "r") as fd: + ary = np.array(fd[resp["data"]["data_set_name"]]["data"]) + return timestamp, ary + + def get_constant_version_and_call_me_back( + self, karabo_da, constant, callback, snapshot_at=None + ): + """WIP. Callback function will get constant name, constant data, and hopefully + soon also some metadata or whatever. + """ + # TODO: thread safe caching throughout this class + # TODO: do we want to use asyncio / "modern" async? + def aux(): + data = self.get_constant_version(karabo_da, constant, snapshot_at) + callback(constant, data) + + thread = threading.Thread(target=aux) + thread.start() + return thread + + +class AgipdConstants(enum.Enum): + SlopesFF = enum.auto() + ThresholdsDark = enum.auto() + Offset = enum.auto() + SlopesPC = enum.auto() + BadPixelsDark = enum.auto() + BadPixelsPC = enum.auto() + BadPixelsFF = enum.auto() + + +class AgipdCalcatFriend(BaseCalcatFriend): + _constant_enum_class = AgipdConstants + + def __init__( + self, + device, + prefix="constantParameters", + ): + super().__init__(device, prefix) + self._constants_need_conditions = { + AgipdConstants.ThresholdsDark: self.dark_condition, + AgipdConstants.Offset: self.dark_condition, + AgipdConstants.SlopesPC: self.dark_condition, + AgipdConstants.SlopesFF: self.illuminated_condition, + AgipdConstants.BadPixelsDark: self.dark_condition, + AgipdConstants.BadPixelsPC: self.dark_condition, + AgipdConstants.BadPixelsFF: self.illuminated_condition, + } + + @staticmethod + def add_schema(schema, prefix="constantParameters"): + super(AgipdCalcatFriend, AgipdCalcatFriend).add_schema(schema, prefix) + + ( + FLOAT_ELEMENT(schema) + .key(f"{prefix}.acquisitionRate") + .assignmentOptional() + .defaultValue(1.1) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.gainSetting") + .assignmentOptional() + .defaultValue(0) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.photonEnergy") + .assignmentOptional() + .defaultValue(9.3) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.gainMode") + .assignmentOptional() + .defaultValue(0) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.integrationtime") + .assignmentOptional() + .defaultValue(12) + .commit(), + ) + + def dark_condition(self): + res = OperatingConditions() + res["Pixels X"] = self._get("pixelsX") + res["Pixels Y"] = self._get("pixelsY") + res["Memory cells"] = self._get("memoryCells") + res["Acquisition rate"] = self._get("acquisitionRate") + res["Sensor Bias Voltage"] = self._get("biasVoltage") + return res + + def illuminated_condition(self): + res = OperatingConditions() + res["Memory cells"] = self._get("memoryCells") + res["Sensor Bias Voltage"] = self._get("biasVoltage") + res["Pixels X"] = self._get("pixelsX") + res["Pixels Y"] = self._get("pixelsY") + res["Source Energy"] = self._get("photonEnergy") + res["Acquisition rate"] = self._get("acquisitionRate") + res["Gain Setting"] = self._get("gainSetting") + return res + + +class DsscConstants(enum.Enum): + Offset = enum.auto() + + +class DsscCalcatFriend(BaseCalcatFriend): + _constant_enum_class = DsscConstants + + def __init__( + self, + device, + prefix="constantParameters", + ): + super().__init__(device, prefix) + self._constants_need_conditions = { + DsscConstants.Offset: self.dark_condition, + } + + @staticmethod + def add_schema(schema, prefix="constantParameters"): + super(DsscCalcatFriend, DsscCalcatFriend).add_schema(schema, prefix) + ( + FLOAT_ELEMENT(schema) + .key(f"{prefix}.pulseIdChecksum") + .assignmentOptional() + .defaultValue(2.8866323107820637e-36) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.acquisitionRate") + .assignmentOptional() + .defaultValue(4.5) + .commit(), + FLOAT_ELEMENT(schema) + .key(f"{prefix}.encodedGain") + .assignmentOptional() + .defaultValue(67328) + .commit(), + ) + + def dark_condition(self): + res = OperatingConditions() + res["Memory cells"] = self._get("memoryCells") + res["Sensor Bias Voltage"] = self._get("biasVoltage") + res["Pixels X"] = self._get("pixelsX") + res["Pixels Y"] = self._get("pixelsY") + # res["Pulse id checksum"] = self._get("pulseIdChecksum") + # res["Acquisition rate"] = self._get("acquisitionRate") + # res["Encoded gain"] = self._get("encodedGain") + return res diff --git a/src/calng/utils.py b/src/calng/utils.py index 9b9a3786acbc4d620a93f511ae7acd7f0e8dcbbf..d28139e715140e15da007ac3eb9ac100049ed054 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -1,9 +1,43 @@ +import functools +import inspect import threading import time import timeit import numpy as np + +def threadsafe_cache(fun): + """This decorator imitates functools.cache, but threadsafer + + With multiple threads hitting a function cached by functools.cache, it is possible + to trigger recomputation. This decorator adds granular locking: each key in the + cache (derived from arguments) has its own lock. + """ + + locks = {} + results = {} + fun_sig = inspect.signature(fun) + + @functools.wraps(fun) + def aux(*args, **kwargs): + bound_args = fun_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + key = bound_args.args + tuple(bound_args.kwargs.items()) + if key in results: + return results[key] + with locks.setdefault(key, threading.Lock()): + if key in results: + # someone else did this - may still be processing + return results[key] + else: + res = fun(*args, **kwargs) + results[key] = res + return res + + return aux + + _np_typechar_to_c_typestring = { "?": "bool", "B": "unsigned char", diff --git a/src/tests/test_calcat_utils.py b/src/tests/test_calcat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7aaa0fb5ea0503d1b945cc64831f78a7c6d67e69 --- /dev/null +++ b/src/tests/test_calcat_utils.py @@ -0,0 +1,169 @@ +import pathlib +import timeit + +from calng import calcat_utils +from karabo.bound import Hash, Schema + +# TODO: secrets management +_test_dir = pathlib.Path(__file__).absolute().parent +with (_test_dir / "calibration-client-secrets.txt").open("r") as fd: + base_url, client_id, client_secret, user_email = fd.read().splitlines() + + +class DummyLogger: + DEBUG = print + INFO = print + + +class Stopwatch: + def __enter__(self): + self.start_ts = timeit.default_timer() + self.running = True + return self + + def __exit__(self, t, v, tb): + self.stop_ts = timeit.default_timer() + self.running = False + + @property + def elapsed(self): + if self.running: + return timeit.default_timer() - self.start_ts + else: + return self.stop_ts - self.start_ts + + +# TODO: consider testing by attaching to real karabo.bound.PythonDevice +class DummyAgipdDevice: + device_class_schema = Schema() + + @staticmethod + def expectedParameters(expected): + calcat_utils.AgipdCalcatFriend.add_schema(expected, "constantParameters") + + def __init__(self, config): + self.schema = config + self.calibration_constant_manager = calcat_utils.AgipdCalcatFriend( + self, + "constantParameters", + ) + self.log = DummyLogger() + + def get(self, key): + return self.schema.get(key) + + +DummyAgipdDevice.expectedParameters(DummyAgipdDevice.device_class_schema) + + +class DummyDsscDevice: + device_class_schema = Schema() + + @staticmethod + def expectedParameters(expected): + # super(DummyDsscDevice, DummyDsscDevice).expectedParameters(expected) + calcat_utils.DsscCalcatFriend.add_schema(expected, "constantParameters") + + def __init__(self, config): + # TODO: check config against schema (as Karabo would) + self.schema = config + self.calibration_constant_manager = calcat_utils.DsscCalcatFriend( + self, + "constantParameters", + ) + self.log = DummyLogger() + + def get(self, key): + return self.schema.get(key) + + +DummyDsscDevice.expectedParameters(DummyDsscDevice.device_class_schema) + + +def test_agipd_constants_and_caching_and_async(): + # def test_agipd_constants(): + conf = Hash() + conf["constantParameters.calCat.baseUrl"] = base_url + conf["constantParameters.calCat.clientId"] = client_id + conf["constantParameters.calCat.clientSecret"] = client_secret + conf["constantParameters.calCat.userEmail"] = user_email + conf["constantParameters.calCat.caldbStore"] = "/gpfs/exfel/d/cal/caldb_store" + + conf["constantParameters.detectorType"] = "AGIPD-Type" + conf["constantParameters.detectorName"] = "SPB_DET_AGIPD1M-1" + conf["constantParameters.pixelsX"] = 512 + conf["constantParameters.pixelsY"] = 128 + conf["constantParameters.memoryCells"] = 352 + conf["constantParameters.acquisitionRate"] = 1.1 + conf["constantParameters.biasVoltage"] = 300 + conf["constantParameters.gainSetting"] = 0 + conf["constantParameters.photonEnergy"] = 9.2 + device = DummyAgipdDevice(conf) + + def backcall(constant_name, metadata_and_data): + # TODO: think of something reasonable to check + timestamp, data = metadata_and_data + assert data.nbytes > 1000 + + with Stopwatch() as timer_async_cold: + # TODO: put this sort of thing in BaseCalcatFriend + threads = [] + for constant in calcat_utils.AgipdConstants: + thread = device.calibration_constant_manager.get_constant_version_and_call_me_back( + "AGIPD00", constant, backcall + ) + threads.append(thread) + for thread in threads: + thread.join() + + with Stopwatch() as timer_async_warm: + threads = [] + for constant in calcat_utils.AgipdConstants: + thread = device.calibration_constant_manager.get_constant_version_and_call_me_back( + "AGIPD00", constant, backcall + ) + threads.append(thread) + for thread in threads: + thread.join() + + with Stopwatch() as timer_sync_warm: + for constant in calcat_utils.AgipdConstants: + ts, ary = device.calibration_constant_manager.get_constant_version( + "AGIPD00", + constant, + ) + assert ts is not None, "Some constants should be found" + + print(f"Cold async took {timer_async_cold.elapsed} s") + print(f"Warm async took {timer_async_warm.elapsed} s") + print(f"Warm sync took {timer_sync_warm.elapsed} s") + assert ( + timer_async_cold.elapsed > timer_async_warm.elapsed + ), "Caching should make second go faster" + assert timer_sync_warm.elapsed > timer_async_warm.elapsed, "Async should be faster" + + +def test_dssc_constants(): + conf = Hash() + + conf["constantParameters.calCat.baseUrl"] = base_url + conf["constantParameters.calCat.clientId"] = client_id + conf["constantParameters.calCat.clientSecret"] = client_secret + conf["constantParameters.calCat.userEmail"] = user_email + conf["constantParameters.calCat.caldbStore"] = "/gpfs/exfel/d/cal/caldb_store" + + conf["constantParameters.detectorType"] = "DSSC-Type" + conf["constantParameters.detectorName"] = "SCS_DET_DSSC1M-1" + conf["constantParameters.memoryCells"] = 400 + conf["constantParameters.biasVoltage"] = 100 + conf["constantParameters.pixelsX"] = 512 + conf["constantParameters.pixelsY"] = 128 + # conf["constantParameters.pulseIdChecksum"] = 2.8866323107820637e-36 + # conf["constantParameters.acquisitionRate"] = 4.5 + # conf["constantParameters.encodedGain"] = 67328 + device = DummyDsscDevice(conf) + ts, offset_map = device.calibration_constant_manager.get_constant_version( + "DSSC00", "Offset" + ) + + assert ts is not None diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py index d8e8c6ce407e70e1fb9e9386e2a44a3868051c76..91b1f280f63bb61f54e0325d94216ecfe5d90a44 100644 --- a/src/tests/test_utils.py +++ b/src/tests/test_utils.py @@ -1,7 +1,12 @@ -import numpy as np +import random +import threading +import time +import timeit +import numpy as np from calng import utils + def test_get_c_type(): assert utils.np_dtype_to_c_type(np.float16) == "half" assert utils.np_dtype_to_c_type(np.float32) == "float" @@ -16,3 +21,82 @@ def test_get_c_type(): assert utils.np_dtype_to_c_type(np.int16) == "short" assert utils.np_dtype_to_c_type(np.int32) == "int" assert utils.np_dtype_to_c_type(np.int64) == "long" + + +class TestThreadsafeCache: + def test_arg_key_wrap(self): + calls = [] + + @utils.threadsafe_cache + def fun(a, b, c=1, d=2, *args, **kwargs): + calls.append((a, b, c, d, args, kwargs)) + + # reordering kwargs /does/ matter because dicts are ordered now + # (note: functools.lru_cache doesn't sort, claims because of speed) + fun(1, 2, 3, 4, 5, six=6, seven=7) + fun(1, 2, 3, 4, 5, seven=7, six=6) + assert len(calls) == 2, "kwargs order matters" + calls.clear() + + # reordering kw-style positional args does not matter + fun(1, 2, 1, 2) + fun(a=1, c=1, b=2, d=2) + assert len(calls) == 1, "reordering regular args as kws doesn't matter" + # and omitting default values does not matter + fun(b=2, a=1) + fun(1, 2) + assert len(calls) == 1, "omitting default args doesn't matter" + + def test_threadsafeness(self): + # wow, synchronization (presumably) makes this take forever *without* the decorator... + from_was_called = [] + + base_sleep = 1 + random_sleep = 0.1 + + @utils.threadsafe_cache + def was_called(x): + time.sleep(random.random() * random_sleep + base_sleep) + from_was_called.append(x) + + threads = [] + num_threads = 1000 + letters = "abcd" + start_ts = timeit.default_timer() + for i in range(num_threads): + for l in letters: + thread = threading.Thread(target=was_called, args=(l,)) + thread.start() + threads.append(thread) + submitted_ts = timeit.default_timer() + print(f"Right after: {len(from_was_called)}") + for thread in threads: + thread.join() + stop_ts = timeit.default_timer() + total_time = stop_ts - start_ts + print(f"After join: {len(from_was_called)}") + print(f"Time to submit: {submitted_ts - start_ts}") + print(f"Wait for join: {stop_ts - submitted_ts}") + print(f"Total: {total_time}") + + # check that function was only called with each letter once + # this is where the decorator from functools will fail + assert len(from_was_called) == len( + letters + ), "Caching prevents recomputation due to threading" + + # check that the function was not locked too broadly (should run faster than sequential lower bound) + reasonable_time_to_spawn_thread = 0.45 / 1000 + cutoff = ( + len(letters) * base_sleep + reasonable_time_to_spawn_thread * num_threads + ) + print(f"Cutoff (sequential lower bound): {cutoff}") + assert ( + total_time < cutoff + ), "Locking should not be so broad as to make sequential" + print( + f"Each thread would have slept [{base_sleep}, {base_sleep + random_sleep})" + ) + + # check that time doesn't go backwards suddenly + assert total_time >= base_sleep, "These tests should measure time correctly"