From 8a5704e2b10307a314494419fcfed141a41c9c66 Mon Sep 17 00:00:00 2001
From: Steffen Hauf <haufs@max-exfl073.desy.de>
Date: Thu, 31 May 2018 09:23:24 +0200
Subject: [PATCH] Speed inprovements and only available memory cells

---
 LPD/correct_lpd_batch.py | 144 ++++++++++++++++++++++++++-------------
 1 file changed, 96 insertions(+), 48 deletions(-)

diff --git a/LPD/correct_lpd_batch.py b/LPD/correct_lpd_batch.py
index 1247636c8..2ee1dfaa8 100644
--- a/LPD/correct_lpd_batch.py
+++ b/LPD/correct_lpd_batch.py
@@ -68,7 +68,8 @@ cells = np.arange(max_cells)
 QUADRANTS = 4
 MODULES_PER_QUAD = 4
 DET_FILE_INSET = "LPD"
-
+CHUNK_SIZE = 512
+MAX_PAR = 32
 
 if in_folder[-1] == "/":
     in_folder = in_folder[:-1]
@@ -80,6 +81,7 @@ if not os.path.exists(out_folder):
 elif not overwrite:
     raise AttributeError("Output path exists! Exiting")
 
+max_cells_db = 128
 
 # In[42]:
 
@@ -202,7 +204,7 @@ if True:
         metadata.calibration_constant = offset
 
         # set the operating condition
-        condition = Conditions.Dark.LPD(memory_cells=max_cells, bias_voltage=bias_voltage)
+        condition = Conditions.Dark.LPD(memory_cells=max_cells_db, bias_voltage=bias_voltage)
         metadata.detector_condition = condition
 
         # specify the a version for this constant
@@ -216,9 +218,11 @@ if True:
             try:
                 metadata.retrieve(cal_db_interface)
                 offsets.append(copy.copy(offset.data))
-            except:
+            except Exception as e:
+                print("Could not retrieve offset from db for {}: {}".format(qm, e))
                 offsets.append(np.zeros((256,256,max_cells,3)))
         else:
+            print("Could not retrieve offset from db for {}".format(qm))
             offsets.append(np.zeros((256,256,max_cells,3)))
         """ 
         metadata = ConstantMetaData()
@@ -362,8 +366,9 @@ def map_modules_from_files(filelist):
                 
     return module_files, mod_ids
 
-dirlist = os.listdir(in_folder)
+dirlist = sorted(os.listdir(in_folder))
 file_list = []
+
 for entry in dirlist:
     #only h5 file
     abs_entry = "{}/{}".format(in_folder, entry)
@@ -378,13 +383,13 @@ for entry in dirlist:
                     
 mapped_files, mod_ids = map_modules_from_files(file_list)
 
-
+print(file_list)
 # In[45]:
 
 
 import copy
 from functools import partial
-def correct_module(cells, do_ff, index_v, inp):
+def correct_module(max_cells, do_ff, index_v, CHUNK_SIZE, inp):
     import numpy as np
     import copy
     import h5py
@@ -419,10 +424,16 @@ def correct_module(cells, do_ff, index_v, inp):
             last = np.squeeze(infile["/INDEX/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/last".format(channel)])
             last_index = int(last[status != 0][-1])
             first_index = int(last[status != 0][0])
-        im = np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)][first_index:last_index, ...])
-        cells = np.squeeze(np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/cellId".format(channel)][first_index:last_index, ...]))
-
-        dont_copy = ["data",]
+        allcells = np.squeeze(np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/cellId".format(channel)][first_index:last_index, ...]))
+        single_image = np.array(np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)][first_index, ...]))
+        can_calibrate = allcells < max_cells
+        if np.count_nonzero(can_calibrate) == 0:
+            return
+        allcells = allcells[can_calibrate]
+        firange = np.arange(first_index, last_index)
+        firange = firange[can_calibrate]
+
+        dont_copy = ["data", "cellId", "trainId", "pulseId", "status", "length"]
         dont_copy = ["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/{}".format(channel, do)
                     for do in dont_copy]
 
@@ -438,52 +449,86 @@ def correct_module(cells, do_ff, index_v, inp):
 
         infile.visititems(visitor)
         outfile.flush()
-
-
-        outfile.flush()
-        infile.close()
-
-        im, gain = splitOffGainLPD(im[:,0,...])
         
-        im = im.astype(np.float32)
-        im[gain > 2] = np.nan
-        gain[gain > 2] = 0
+        oshape = (firange.size, single_image.shape[2], single_image.shape[1])
         
-        im = np.rollaxis(im, 2)
-        im = np.rollaxis(im, 2, 1)
-
-        gain = np.rollaxis(gain, 2)
-        gain = np.rollaxis(gain, 2, 1)
-
-        om = offset[...,cells,:]
-        rc = rel_gain[...,cells,:]
-        rbc = rel_gain_b[...,cells,:]
-        og = np.choose(gain, (om[...,0], om[...,1], om[...,2]))
-        rg = np.choose(gain, (rc[...,0], rc[...,1], rc[...,2]))        
-        rgb = np.choose(gain, (rbc[...,0], rbc[...,1], rbc[...,2]))
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)] = np.zeros(oshape, np.float32)
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/gain".format(channel)] = np.zeros(oshape, np.uint8)
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/mask".format(channel)] = np.zeros(oshape, np.uint32)
+
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/cellId".format(channel)] = np.zeros(firange.size, np.uint16)
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/trainId".format(channel)] = np.zeros(firange.size, np.uint64)
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/pulseId".format(channel)] = np.zeros(firange.size, np.uint64)
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/status".format(channel)] = np.zeros(firange.size, np.uint16)
         
-        mskg = mask[...,cells,:]
-        msk = np.choose(gain, (mskg[...,0], mskg[...,1], mskg[...,2]))
-        im -= og
+        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/length".format(channel)] = np.zeros(firange.size, np.uint32)
         
-        im = (im-rgb)/rg
-        if do_ff:
-            im /= flatfield[:,:,None]
+        #
+        cidx = 0
+        for irange in np.array_split(firange, firange.size//CHUNK_SIZE):
+
+            im = np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)][irange, ...])
+            trainId = np.squeeze(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/trainId".format(channel)][irange, ...])
+            pulseId = np.squeeze(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/pulseId".format(channel)][irange, ...])
+            status = np.squeeze(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/status".format(channel)][irange, ...])
 
-        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)] = np.rollaxis(np.rollaxis(im,1), 2)
-        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/gain".format(channel)] = np.rollaxis(np.rollaxis(gain,1), 2)
-        outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/mask".format(channel)] = np.rollaxis(np.rollaxis(msk,1), 2)
+            cells = np.squeeze(np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/cellId".format(channel)][irange, ...]))
+            
+            
+            length = np.squeeze(np.array(infile["/INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/length".format(channel)][irange, ...]))
+
+            
+
+            im, gain = splitOffGainLPD(im[:,0,...])
+
+            im = im.astype(np.float32)
+            im[gain > 2] = np.nan
+            gain[gain > 2] = 0
+
+            im = np.rollaxis(im, 2)
+            im = np.rollaxis(im, 2, 1)
+
+            gain = np.rollaxis(gain, 2)
+            gain = np.rollaxis(gain, 2, 1)
+
+            om = offset[...,cells,:]
+            rc = rel_gain[...,cells,:]
+            rbc = rel_gain_b[...,cells,:]
+            og = np.choose(gain, (om[...,0], om[...,1], om[...,2]))
+            rg = np.choose(gain, (rc[...,0], rc[...,1], rc[...,2]))        
+            rgb = np.choose(gain, (rbc[...,0], rbc[...,1], rbc[...,2]))
+
+            mskg = mask[...,cells,:]
+            msk = np.choose(gain, (mskg[...,0], mskg[...,1], mskg[...,2]))
+            im -= og
+
+            im = (im-rgb)/rg
+            if do_ff:
+                im /= flatfield[:,:,None]
+            nidx = int(cidx+irange.size)
+
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/data".format(channel)][cidx:nidx,...] = np.rollaxis(np.rollaxis(im,1), 2)
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/gain".format(channel)][cidx:nidx,...] = np.rollaxis(np.rollaxis(gain,1), 2)
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/mask".format(channel)][cidx:nidx,...] = np.rollaxis(np.rollaxis(msk,1), 2)
+
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/cellId".format(channel)][cidx:nidx] = cells
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/trainId".format(channel)][cidx:nidx] = trainId
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/pulseId".format(channel)][cidx:nidx] = pulseId
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/status".format(channel)][cidx:nidx] = status
+            outfile["INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/image/length".format(channel)][cidx:nidx] = length
+            cidx = nidx
         
         outfile.close()
-
+        infile.close()
     except Exception as e:
-       print(e)
-       pass    
+        print(e)
+        pass    
 
 done = False
 first_files = []
+inp = []
 while not done:
-    inp = []
+    
     dones = []
     first = True
     for i in range(16):
@@ -502,13 +547,16 @@ while not done:
                     rel_gains[i][...,:max_cells,:], bad_pixels[i][...,:max_cells,:],
                     flat_fields[i], rel_gains_b[i][...,:max_cells,:]))
     first = False
-    p = partial(correct_module, max_cells, do_ff, index_v)
+    if len(inp) > MAX_PAR:
+        print("Running {} tasks parallel".format(len(inp)))
+        p = partial(correct_module, max_cells, do_ff, index_v, CHUNK_SIZE)
+        r = view.map_sync(p, inp)
+        inp = []
+
     
-    r = view.map_sync(p, inp)
     #r = list(map(p, inp))
     done = all(dones)
-
-
+#r.wait()
 
 # In[46]:
 
-- 
GitLab