diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py index ceba1f2415119d516ff04a0499ee3e6b497cb456..628b0822462d768dc38e88cd2b86046646c6b507 100644 --- a/src/calng/CalibrationManager.py +++ b/src/calng/CalibrationManager.py @@ -6,7 +6,10 @@ from asyncio import gather, wait_for, TimeoutError as AsyncTimeoutError from collections import defaultdict +from collections.abc import Hashable from datetime import datetime +from inspect import ismethod +from itertools import chain, repeat from traceback import format_exc from urllib.parse import urlparse import json @@ -17,11 +20,14 @@ from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future from pkg_resources import parse_version from karabo.middlelayer import ( - KaraboError, Device, DeviceClientBase, Hash, Configurable, Slot, Node, + KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable, + Slot, Node, Type, AccessMode, AccessLevel, Assignment, DaqPolicy, State, Unit, - UInt16, Int32, UInt32, Bool, Double, String, VectorString, VectorHash, - background, callNoWait, setNoWait, sleep, instantiate, slot, coslot, - getDevice, getTopology, getConfigurationFromPast) + UInt16, UInt32, Bool, Double, String, VectorString, VectorHash, + background, call, callNoWait, setNoWait, sleep, instantiate, slot, coslot, + getDevice, getTopology, getConfiguration, getConfigurationFromPast, + get_property) +from karabo.middlelayer_api.proxy import ProxyFactory from karabo import version as karaboVersion from ._version import version as deviceVersion @@ -37,6 +43,20 @@ Device states: ''' +# Copied from karabo MDL source (location depending on version) +# Will be part of MDL's public API in 2.12 +def get_instance_parent(instance): + """Find the parent of the instance""" + parent = instance + while True: + try: + parent = next(iter(parent._parents)) + except StopIteration: + break + + return parent + + class ClassIdsNode(Configurable): correctionClass = String( displayedName='Correction class', @@ -182,11 +202,6 @@ class DeviceServerRow(Configurable): displayedName='Webserver host') -class OperatingConditionRow(Configurable): - name = String() - value = String() - - class WebserverApiNode(Configurable): statePollInterval = Double( displayedName='Status poll interval', @@ -252,75 +267,27 @@ class InstantiationOptionsNode(Configurable): accessMode=AccessMode.RECONFIGURABLE) - def _device(self): - # Pretty sure this is not supposed to work like this. - return next(iter(self._parents.keys())) +class ManagedKeysNode(Configurable): + # Keys managed on detector DAQ devices. + DAQ_KEYS = {'DataDispatcher.trainStride': 'daqTrainStride'} @UInt32( displayedName='DAQ train stride', unitSymbol=Unit.COUNT, defaultValue=5, + allowedStates=[State.ACTIVE], minInc=1) - async def trainStride(self, value): - self.trainStride = value - background(self._device()._set_on_daq( + async def daqTrainStride(self, value): + self.daqTrainStride = value + background(get_instance_parent(self)._set_on_daq( 'DataDispatcher.trainStride', value)) - @Bool( - displayedName='Corrections enabled') - async def correctionsEnabled(self, value): - self.correctionsEnabled = value - background(self._device()._set_on_corrections( - 'applyCorrection', value)) - - @UInt32( - displayedName='Preview modulo', - defaultValue=10, - minInc=1) - async def previewModulo(self, value): - self.previewModulo = value - background(self._device()._set_on_corrections( - 'preview.trainIdModulo', value)) - - @Int32( - displayedName='Preview pulse index', - defaultValue=0, - minInc=-4, - maxExc=1000) - async def previewPulseIndex(self, value): - self.previewPulseIndex = value - background(self._device()._set_on_corrections( - 'preview.pulse', value)) - - # TODO: rework after revamping constant retrieval - @VectorHash( - displayedName='Constant parameters', - rows=OperatingConditionRow) - async def constantParameters(self, value): - self.constantParameters = value - to_set = [] - for (k, v) in self.constantParameters.value: - h = Hash() - h.set('name', k) - h.set('value', v) - to_set.append(h) - background(self._device()._set_on_corrections, - 'constants.Offset.detector_parameters', to_set) - - @String( - displayedName='Pulse filter', - defaultValue='') - async def pulseFilter(self, value): - self.pulseFilter = value - background(self._device()._set_on_corrections( - 'pulseFilter', value)) - @Slot( - displayedName='Load most recent constants', - allowedStates=[State.ACTIVE]) - async def loadMostRecentConstants(self): - background(self._device()._call_on_corrections( - 'loadMostRecentConstantsWrap')) +class ManagedKeysCloneFactory(ProxyFactory): + Proxy = ManagedKeysNode + SubProxy = Configurable + ProxyNode = Node + node_factories = dict(Slot=Slot) class CalibrationManager(DeviceClientBase, Device): @@ -414,13 +381,6 @@ class CalibrationManager(DeviceClientBase, Device): 'set by local secrets file', accessMode=AccessMode.READONLY) - caldbremoteAddress = String( - displayedName='calibrationDB ZMQ address', - description='ZMQ address for calibrationDBRemote instance to connect ' - 'to, to be removed in a future release.', - accessMode=AccessMode.INITONLY, - assignment=Assignment.MANDATORY) - webserverApi = Node( WebserverApiNode, displayedName='Webserver API', @@ -453,20 +413,11 @@ class CalibrationManager(DeviceClientBase, Device): self.state = State.CHANGING background(self._instantiate_pipeline()) - outputAxisOrder = String( - displayedName='Output axis order', - description='Axes of main data output can be reordered after correction. ' - 'Choose between "pixels-fast" (memory_cell, x, y), ' - '"memorycells-fast" (x, y, memory_cell), and "no-reshape" ' - '(memory_cell, y, x)', - options=('pixels-fast','memorycells-fast','no-reshape'), - defaultValue='pixels-fast', - accessMode=AccessMode.RECONFIGURABLE) - # TODO: Inject at runtime by scanning correction device schema. - runtimeParameters = Node( - RuntimeParametersNode, - displayedName='Runtime parameters') + managed = Node( + ManagedKeysNode, + displayedName='Managed keys', + description='Properties and slots managed on devices in the pipeline.') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -531,6 +482,12 @@ class CalibrationManager(DeviceClientBase, Device): self._http_client = AsyncHTTPClient() + # Check device servers and initialize webserver access. + await self._check_servers() + + # Inject schema for configuration of managed devices. + await self._inject_managed_keys() + if self.state == State.INIT: self._set_status('Calibration manager ready') self.state = State.ACTIVE @@ -572,6 +529,178 @@ class CalibrationManager(DeviceClientBase, Device): for device_id in devices: self._check_new_device(device_id, devices[device_id, 'classId']) + async def _get_shared_keys(self, device_ids, keys): + """Find the most common property values on devices.""" + + key_values = defaultdict(list) + + for device_id in device_ids: + try: + config = await wait_for(getConfiguration(device_id), + timeout=2.0) + except AsyncTimeoutError: + # Ignore this device if the configuration can no longer + # be obtained. + continue + + for key in keys: + value = config[key] + + if isinstance(value, Hashable): + # Value must be hashable to determine the most + # common one below. + key_values[key].append(value) + + return {key: max(set(values), key=values.count) for key, values + in key_values.items()} + + async def _inject_managed_keys(self): + """Attempt to retrieve the correction device's schema and insert + part of it as managed keys. + """ + + correction_device_servers = [ + server for _, server, _, _, _, _ in self.moduleGroups.value] + + up_corr_servers = await self._get_servers_in_state( + 'up', servers=correction_device_servers) + + if up_corr_servers: + # At least one correction server is up. + corr_server = next(iter(up_corr_servers)) + else: + # None of the correction servers is up, try to start the + # first one. + corr_server = correction_device_servers[0] + + try: + await wait_for(self._ensure_server_state(corr_server, 'up'), + timeout=self.webserverApi.upTimeout.value) + except AsyncTimeoutError: + self._set_fatal(f'Could not bring up correction device server ' + f'`{corr_server}` for schema retrieval') + return + + # Obtain the device schema from a correction device server. + managed_schema, _, _ = await call(corr_server, 'slotGetClassSchema', + self._correction_class_id) + + if managed_schema.name != self._correction_class_id: + self._set_fatal( + f'Correction class ID `{self._correction_class_id}` not known ' + f'or loadable by device server `{corr_server}`') + return + + # Collect the keys to be managed and build a nested hash + # expressing its hierarchy, leafs are set to None. + managed_keys = set(managed_schema.hash['managedKeys', 'defaultValue']) + managed_tree = Hash(*chain.from_iterable( + zip(managed_keys, repeat(None, len(managed_keys))))) + managed_paths = set(managed_tree.paths()) + + # Reduce the correction schema to the managed paths. + managed_hash = managed_schema.hash + for path in managed_hash.paths(): + if path not in managed_paths: + del managed_hash[path] + + # Check for current values of managed keys on DAQ devices and + # update schema correspondingly. + if self._daq_device_ids: + daq_vals = await self._get_shared_keys( + self._daq_device_ids, ManagedKeysNode.DAQ_KEYS.keys()) + + for orig_key, managed_key in ManagedKeysNode.DAQ_KEYS.items(): + managed_hash[managed_key, 'defaultValue'] = daq_vals[orig_key] + + # Check for current values of managed keys on correction + # devices and update schema correspondingly. + corr_vals = await self._get_shared_keys( + self._correction_device_ids, managed_keys) + + for key, value in corr_vals.items(): + managed_hash[key, 'defaultValue'] = value + + # Retrieve the attributes on the current managed node. The + # original implementation of toSchemaAndAttrs in the Node's + # superclass Descriptor is used to avoid Node-specific + # attributes in the attributes that are not valid in the + # property definition. + # The value are then obtained from the Node object again since + # enums are converted to their values by toSchemaAndAttrs, which + # in turn is not valid for property definition. + _, attrs = Descriptor.toSchemaAndAttrs(self.__class__.managed, + None, None) + managed_node_attrs = {key: getattr(self.__class__.managed, key) + for key in attrs.keys()} + + # Build a proxy from the managed schema, and create a new node + # based on it with the original node attributes. This code is + # heavily inspired by deviceClone. + managed_node = Node( + ManagedKeysCloneFactory.createProxy(managed_schema), + **managed_node_attrs) + + # Walk the managed tree to and sanitize all descriptors to our + # specifications. + def _sanitize_node(parent, tree, prefix=''): + for key, value in tree.items(): + # Fetch the descriptor class, not its instance! + descr = getattr(parent.cls, key) + + full_key = f'{prefix}.{key}' if prefix else key + + if isinstance(descr, Node): + _sanitize_node(descr, value, full_key) + + elif isinstance(descr, Slot): + async def _managed_slot_called(parent, fk=full_key): + background(self._call_on_corrections(fk)) + + _managed_slot_called.__name__ = f'managed.{full_key}' + descr.__call__(_managed_slot_called) + + # Managed slots can only be called in the ACTIVE + # state. + descr.allowedStates = [State.ACTIVE] + + elif isinstance(descr, Type): + # Regular property. + + if descr.accessMode == AccessMode.RECONFIGURABLE: + # Add a callback only if the original descriptor + # is reconfigurable. + + async def _managed_prop_changed(parent, v, k=key, + fk=full_key): + setattr(parent, k, v) + + if self.state != State.INIT: + # Do not propagate updates during injection. + background(self._set_on_corrections(fk, v)) + + descr.__call__(_managed_prop_changed) + + # Managed properties are always optional, + # reconfigurable and may only be changed in the + # ACTIVE state. + descr.assignment = Assignment.OPTIONAL + descr.accessMode = AccessMode.RECONFIGURABLE + descr.allowedStates = [State.ACTIVE] + + else: + self.logger.warn(f'Encountered unknown descriptor type ' + f'{type(descr)}') + + _sanitize_node(managed_node, managed_tree) + + # Inject the newly prepared node for managed keys. + self.__class__.managed = managed_node + await self.publishInjectedParameters() + self._managed_keys = managed_keys + + self.logger.debug('Managed schema injected') + def _set_status(self, text, level=logging.INFO): """Add and log a status message. @@ -635,7 +764,7 @@ class CalibrationManager(DeviceClientBase, Device): return body['servers'][0] - async def _get_servers_in_state(self, state): + async def _get_servers_in_state(self, state, servers=None): """List all servers in a particular state. Args: @@ -646,7 +775,8 @@ class CalibrationManager(DeviceClientBase, Device): """ servers = await gather(*[self._get_server_info(name) - for name in self._server_hosts.keys()]) + for name in self._server_hosts.keys() + if servers is None or name in servers]) return {server['karabo_name'] for server in servers if server['status'].startswith(state)} @@ -871,11 +1001,6 @@ class CalibrationManager(DeviceClientBase, Device): return self._set_error('Request unexpectedly failed while ' 'checking device server state', e) - self._daq_device_ids.clear() - self._domain_device_ids.clear() - self._correction_device_ids.clear() - await self._check_topology() - # Class and device ID templates per role. class_ids = {} device_id_templates = {} @@ -929,17 +1054,20 @@ class CalibrationManager(DeviceClientBase, Device): input_source_by_module[vname] = input_source config = Hash() + + # Legacy keys for calibrationBase. config['det_type'] = self.detectorType config['det_identifier'] = self.detectorIdentifier config['da_name'] = aggregator - config['caldb_zmq_interface'] = self.caldbremoteAddress config['dataInput.connectedOutputChannels'] = [input_channel] config['fastSources'] = [input_source] - config['dataFormat.outputImageDtype'] = 'float16' - config['dataFormat.outputAxisOrder'] = self.outputAxisOrder - config['dataFormat.pixelsX'] = 512 - config['dataFormat.pixelsY'] = 128 - config['dataFormat.memoryCells'] = 400 + + # Add managed keys. + for key in self._managed_keys: + value = get_property(self, f'managed.{key}') + + if not ismethod(value): + config[key] = value if not await self._instantiate_device( server_by_group[group], class_ids['correction'], device_id,