From cf1a60f70f6d5e334e19d9d846ccd33f41d6fcaf Mon Sep 17 00:00:00 2001
From: Mads Jakobsen <mads.jakobsen@xfel.eu>
Date: Thu, 13 Mar 2025 10:41:39 +0100
Subject: [PATCH] added abstraction layer to handle adding plots to hash

---
 .../src/onlinemid_karabo/processor.py         | 50 ++++++++++++-------
 python_package/src/onlinemid/__init__.py      | 33 ++++++------
 .../src/onlinemid/utils/image_plotting.py     | 30 +++++++++++
 3 files changed, 81 insertions(+), 32 deletions(-)

diff --git a/karabo_devices/shared_mem_processor/src/onlinemid_karabo/processor.py b/karabo_devices/shared_mem_processor/src/onlinemid_karabo/processor.py
index 2f0d8f0..1440650 100644
--- a/karabo_devices/shared_mem_processor/src/onlinemid_karabo/processor.py
+++ b/karabo_devices/shared_mem_processor/src/onlinemid_karabo/processor.py
@@ -39,6 +39,7 @@ from onlinemid.utils.karabo import hash_from_dict
 from onlinemid.utils.agipd import generate_test_data
 from onlinemid.utils.karabo import Periodic_Output
 from onlinemid.dataplotters.agipd_module import AgipdModulePlotter
+from onlinemid.utils.plotting import DataToPlotToHash
 
 from onlinexpcsutils import onlinexpcsutils
 import runningstats
@@ -151,7 +152,7 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         # Periodic_Output
         self.periodic_output = Periodic_Output() # default time at 300 sec
 
-        self.onlineImagePlotter = AgipdModulePlotter()
+        self.onlineImagePlotter = DataToPlotToHash()
 
         # expose the get xpcs settings function
         self.KARABO_SLOT(self.get_OnlineMIDSettings)
@@ -268,7 +269,7 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         print(strUpdateSettings)
         '''
            
-    def get_detector_data_list_from_matched_sources_and_the_gain_stage_data(self, sources):
+    def get_calNG_data(self, sources):
         '''
         assemble matched sources into a list of pairs (module name, module data)
         of types:
@@ -279,26 +280,27 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         
         # variable that will contain the sum of the np arrays that have the pixel count per gain stage
         numPixelPerGainStage = None
+
         for source, (data, timestamp) in sources.items():   
             
             unique_string_id = agipd_module_name_from_mid_agipd_source(source)
 
             self._shmem_handler.dereference_shmem_handles(data) 
-            
             detectorData = data['image']['data']
             
             # ensure 3 dimensional
             if len(detectorData.shape) == 2:
                 (slowDim, fastDim) = detectorData.shape
                 detectorData.reshape(shape=(1,slowDim,fastDim))
-            
+
+            detector_data_list.append((unique_string_id, detectorData))
+
+
             if numPixelPerGainStage is None:
                 numPixelPerGainStage = data['numPixelsPerGainStage']
             else:
                 numPixelPerGainStage += data['numPixelsPerGainStage']
-            
-            detector_data_list.append((unique_string_id, detectorData))
- 
+             
         return detector_data_list, numPixelPerGainStage
                     
     
@@ -344,7 +346,7 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
                 my_detector_data_list.append((agipd_module_name_from_agipd_module_number(15) , generate_test_data(q=q)))
             
             OnlineMIDProcessor.my_generated_train_id += 1
-            self.handle_detector_data_list(OnlineMIDProcessor.my_generated_train_id, my_detector_data_list)
+            self.handle_detector_data_list(OnlineMIDProcessor.my_generated_train_id, my_detector_data_list, numPixelPerGainStage = np.random.randint(0,high=1000, size=150).reshape(50,3))
             time.sleep(0.4) 
 
         self.log.WARN("Generating sample output has ended")
@@ -359,8 +361,16 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         message = f"processing detector data from modules {[pair[0] for pair in detector_data_list]} of shape {[pair[1].shape for pair in detector_data_list]}"
         self.periodic_output.log(self.log.INFO, message)
 
-        print("asking to plot")
-        current_img = self.onlineImagePlotter.plot(train_id, detector_data_list[0][1][0])
+        #print("asking to plot")
+        #current_img = self.onlineImagePlotter.plot(train_id, detector_data_list[0][1][0])
+
+        plot_hash = Hash()
+
+        plot_hash = self.onlineImagePlotter.add_to_hash(plot_hash, 'data.current_det_image', AgipdModulePlotter(), train_id, detector_data_list[0][1][0])
+        
+        #if numPixelPerGainStage is not None:
+        #    plot_hash = self.onlineImagePlotter.add_to_hash(plot_hash, 'my_added_key_pixel_per_gain_stage', PixelperGainStagePlotter(), numPixelPerGainStage)
+
         
 
         # time_to_handle_list_0 = time.perf_counter()
@@ -407,7 +417,7 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         # plot accumulating sum and lit pixel counter for panel data
         # time_to_handle_list_3 = time.perf_counter()
         
-        plot_hash = Hash()
+       
         
         #        npArray_sum = onlinexpcsutils.get_accumulating_sum(self.xpcs_processing_memory, selected_module_key)
         #        if npArray_sum is not None:
@@ -418,9 +428,9 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         #            plot_hash['data.accumulating_lit_bunch'] = ImageData(npArray_lit_pixel_counter)
 
         #npArray_currentImage = onlinexpcsutils.get_current_Image(self.online_processing_memory, selected_module_key)
-        if current_img is not None:
-            print("Current imnage is not none:")
-            plot_hash['data.current_det_image'] = ImageData(current_img)
+        #if current_img is not None:
+        #    print("Current imnage is not none:")
+        #    plot_hash['data.current_det_image'] = ImageData(current_img)
 
 
 
@@ -431,8 +441,11 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         # print("time_to_handle_list 3", time.perf_counter() - time_to_handle_list_3)
         # send hash
 
+
+
         # time_to_handle_list_4 = time.perf_counter()
 
+        ## send plot hash
         my_timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))
         myname = self.get('deviceId')
         if self.assembled_output is not None:
@@ -440,6 +453,8 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
                 plot_hash, ChannelMetaData(f'{myname}:plots', my_timestamp), copyAllData=False
                 )
         self.assembled_output.update(safeNDArray=True)
+        ##
+
 
 
         # print("time_to_handle_list _ 4", time.perf_counter() - time_to_handle_list_4)        
@@ -479,9 +494,10 @@ class OnlineMIDProcessor(TrainMatcher.TrainMatcher):
         # device processing timer start
         ts_start = default_timer()
 
-        ########
-        ## XPCS ANALYSIS START
-        detector_data_list, numPixelPerGainStage = self.get_detector_data_list_from_matched_sources_and_the_gain_stage_data(sources)
+        ## Get data from calNG
+        # the detector data in the form of a list 
+        # with tuples of the form (str: unique_module_name, np array: detector data of shape 352x512x128) 
+        detector_data_list, numPixelPerGainStage = self.get_calNG_data(sources)
 
 
         self.handle_detector_data_list(train_id, detector_data_list, numPixelPerGainStage=numPixelPerGainStage)
diff --git a/python_package/src/onlinemid/__init__.py b/python_package/src/onlinemid/__init__.py
index 7ac05c9..e48b23d 100644
--- a/python_package/src/onlinemid/__init__.py
+++ b/python_package/src/onlinemid/__init__.py
@@ -1,18 +1,21 @@
-#from .dataplotters.agipd_module import AgipdModulePlotter
-#
-#from .utils.agipd import agipd_module_number_from_mid_agipd_source
-#from .utils.agipd import mid_agipd_source_from_agipd_module_number
-#from .utils.agipd import agipd_module_name_from_mid_agipd_source
-#from .utils.agipd import mid_agipd_source_from_agipd_module_name
-#from .utils.agipd import agipd_module_name_from_agipd_module_number
-#from .utils.agipd import agipd_module_number_from_agipd_module_name
-#from .utils.agipd import generate_test_data
-#
-#from utils.karabo import dict_from_hash
-#from utils.karabo import hash_from_dict
-#from utils.karabo import Periodic_Output
-#
-#from utils.online_data_structures import NpVec
+from .dataplotters.agipd_module import AgipdModulePlotter
 
+from .utils.image_plotting import DataToPlotToHash
 #from .utils.image_plotting import OnlineImagePlotter
 
+from .utils.agipd import agipd_module_number_from_mid_agipd_source
+from .utils.agipd import mid_agipd_source_from_agipd_module_number
+from .utils.agipd import agipd_module_name_from_mid_agipd_source
+from .utils.agipd import mid_agipd_source_from_agipd_module_name
+from .utils.agipd import agipd_module_name_from_agipd_module_number
+from .utils.agipd import agipd_module_number_from_agipd_module_name
+from .utils.agipd import generate_test_data
+#
+from utils.karabo import dict_from_hash
+from utils.karabo import hash_from_dict
+from utils.karabo import Periodic_Output
+
+from utils.online_data_structures import NpVec
+
+
+
diff --git a/python_package/src/onlinemid/utils/image_plotting.py b/python_package/src/onlinemid/utils/image_plotting.py
index 704b945..0f71e4a 100644
--- a/python_package/src/onlinemid/utils/image_plotting.py
+++ b/python_package/src/onlinemid/utils/image_plotting.py
@@ -92,3 +92,33 @@ class OnlineImagePlotter():
             pass
 
         return current_image
+
+
+class DataToPlotToHash():
+
+    def __init__(self):
+
+        self.hash_keys_and_plotter = dict()
+
+    def add_to_hash(self, hash, hash_key, Plotter, *args):
+        
+        if hash_key not in self.hash_keys_and_plotter:
+            print(f"key {hash_key} does not exist, init plotter")
+            self.hash_keys_and_plotter[hash_key] = Plotter()
+        else:
+            print(f"self.hash_keys_and_plotter[{hash_key}] is of type", type(Plotter))
+            if not isinstance(self.hash_keys_and_plotter[hash_key], Plotter):
+                print("wrong type, initializing!")
+                self.hash_keys_and_plotter[hash_key] = Plotter()
+            print(f"self.hash_keys_and_plotter[{hash_key}] is of type", type(Plotter))
+        
+        # plotter is initialized, give plotter data
+        current_img = self.hash_keys_and_plotter[hash_key].plot(args)
+
+
+        from karabo.bound import Hash, Schema, ImageData
+
+        if current_img is not None:
+            hash[hash_key] = ImageData(current_img)
+        
+        return hash
-- 
GitLab