# DSSC Offline Correction #

Author: European XFEL Detector Group, Version: 1.0

Offline Calibration for the DSSC Detector

In [None]:
cluster_profile = "noDB" # The ipcluster profile to use
in_folder = "/gpfs/exfel/exp/SCS/202031/p900170/raw" # path to input data, required
out_folder = "/gpfs/exfel/data/scratch/samartse/test/DSSC" # path to output to, required
sequences = [-1] # sequence files to evaluate.
modules = [-1] # modules to correct, set to -1 for all, range allowed
run = 229 #runs to process, required

karabo_id = "SCS_DET_DSSC1M-1" # karabo karabo_id
karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators
receiver_id = "{}CH0" # inset for receiver devices
path_template = 'RAW-R{:04d}-{}-S{:05d}.h5' # the template to use to access data
h5path = 'INSTRUMENT/{}/DET/{}:xtdf/image' # path in the HDF5 file to images
h5path_idx = '/INDEX/{}/DET/{}:xtdf/image' # path in the HDF5 file to images
slow_data_pattern = 'RAW-R{}-DA{}-S00000.h5'

use_dir_creation_date = True # use the creation data of the input dir for database queries
cal_db_interface = "tcp://max-exfl016:8020#8025" # the database interface to use
cal_db_timeout = 300000 # in milli seconds

mem_cells = 0 # number of memory cells used, set to 0 to automatically infer
overwrite = True # set to True if existing data should be overwritten
max_pulses = 800 # maximum number of pulses per train
bias_voltage = 100 # detector bias voltage
sequences_per_node = 1 # number of sequence files per cluster node if run as slurm job, set to 0 to not run SLURM parallel
chunk_size_idim = 1  # chunking size of imaging dimension, adjust if user software is sensitive to this.
mask_noisy_asic = 0.25 # set to a value other than 0 and below 1 to mask entire ADC if fraction of noisy pixels is above
mask_cold_asic = 0.25 # mask cold ASICS if number of pixels with negligable standard deviation is larger than this fraction
noisy_pix_threshold = 1. # threshold above which ap pixel is considered noisy.
geo_file = "/gpfs/exfel/data/scratch/xcal/dssc_geo_june19.h5" # detector geometry file
dinstance = "DSSC1M1"
slow_data_aggregators = [1,2,3,4] #quadrant/aggregator

def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da):
    from xfel_calibrate.calibrate import balance_sequences as bs
    return bs(in_folder, run, sequences, sequences_per_node, karabo_da)
    

In [None]:
# make sure a cluster is running with ipcluster start --n=32, give it a while to start
import os
import sys
from collections import OrderedDict

import h5py
import matplotlib
import numpy as np

matplotlib.use("agg")
import matplotlib.pyplot as plt
from ipyparallel import Client
from IPython.display import Latex, Markdown, display

print(f"Connecting to profile {cluster_profile}")
view = Client(profile=cluster_profile)[:]
view.use_dill()

from datetime import timedelta

from cal_tools.dssclib import get_dssc_ctrl_data, get_pulseid_checksum
from cal_tools.tools import (
    get_constant_from_db,
    get_dir_creation_date,
    get_notebook_name,
    map_modules_from_folder,
    parse_runs,
    run_prop_seq_from_path,
)
from dateutil import parser
from iCalibrationDB import Conditions, ConstantMetaData, Constants, Detectors, Versions

In [None]:
creation_time = None
if use_dir_creation_date:
    creation_time = get_dir_creation_date(in_folder, run)
    print(f"Using {creation_time} as creation time")

if sequences[0] == -1:
    sequences = None
    
h5path = h5path.format(karabo_id, receiver_id)
h5path_idx = h5path_idx.format(karabo_id, receiver_id)


if karabo_da[0] == '-1':
    if modules[0] == -1:
        modules = list(range(16))
    karabo_da = ["DSSC{:02d}".format(i) for i in modules]
else:
    modules = [int(x[-2:]) for x in karabo_da]
print("Process modules: ", 
      ', '.join([f"Q{x // 4 + 1}M{x % 4 + 1}" for x in modules]))

CHUNK_SIZE = 512
MAX_PAR = 32

if in_folder[-1] == "/":
    in_folder = in_folder[:-1]
print(f"Outputting to {out_folder}")

if not os.path.exists(out_folder):
    os.makedirs(out_folder)
elif not overwrite:
    raise AttributeError("Output path exists! Exiting")

import warnings

warnings.filterwarnings('ignore')

print(f"Detector in use is {karabo_id}")

In [None]:
# set everything up filewise
mmf = map_modules_from_folder(in_folder, run, path_template, karabo_da, sequences)
mapped_files, mod_ids, total_sequences, sequences_qm, file_size = mmf
MAX_PAR = min(MAX_PAR, total_sequences)

## Processed Files ##

In [None]:
import copy

import tabulate
from IPython.display import HTML, Latex, Markdown, display

print(f"Processing a total of {total_sequences} sequence files in chunks of {MAX_PAR}")
table = []
mfc = copy.copy(mapped_files)
ti = 0
for k, files in mfc.items():
    i = 0
    while not files.empty():
        f = files.get()
        if i == 0:
            table.append((ti, k, i, f))
        else:
            table.append((ti, "", i,  f))
        i += 1
        ti += 1
if len(table):
    md = display(Latex(tabulate.tabulate(table, tablefmt='latex', headers=["#", "module", "# module", "file"])))      
# restore the queue
mmf = map_modules_from_folder(in_folder, run, path_template, karabo_da, sequences)
mapped_files, mod_ids, total_sequences, sequences_qm, file_size = mmf

In [None]:
import copy
from functools import partial


def correct_module(total_sequences, sequences_qm, karabo_id, dinstance, mask_noisy_asic, 
                   mask_cold_asic, noisy_pix_threshold, chunksize, mem_cells, bias_voltage,
                   cal_db_timeout, creation_time, cal_db_interface, h5path, h5path_idx, inp):
   
    import binascii
    import copy
    import struct
    from hashlib import blake2b

    import h5py
    import numpy as np
    from cal_tools.dssclib import get_dssc_ctrl_data, get_pulseid_checksum
    from cal_tools.enums import BadPixels
    from cal_tools.tools import get_constant_from_db_and_time
    from iCalibrationDB import (
        Conditions,
        ConstantMetaData,
        Constants,
        Detectors,
        Versions,
    )
    
    filename, filename_out, channel, karabo_da, qm, conditions = inp
    
    # DSSC correction requires path without the leading "/"
    if h5path[0] == '/':
        h5path = h5path[1:]
    if h5path_idx[0] == '/':
        h5path_idx = h5path_idx[1:]

    h5path = h5path.format(channel)
    h5path_idx = h5path_idx.format(channel)
    
    low_edges = None
    hists_signal_low = None
    high_edges = None
    hists_signal_high = None
    pulse_edges = None
    err = None
    offset_not_found = False
    def get_num_cells(fname, h5path):
        with h5py.File(fname, "r") as f:

            cells = f[f"{h5path}/cellId"][()]
            maxcell = np.max(cells)
            options = [100, 200, 400, 500, 600, 700, 800]
            dists = np.array([(o-maxcell) for o in options])
            dists[dists<0] = 10000 # assure to always go higher
            return options[np.argmin(dists)]
        
    if mem_cells == 0:
        mem_cells = get_num_cells(filename, h5path)
        
    pulseid_checksum = get_pulseid_checksum(filename, h5path, h5path_idx)
        
    print(f"Memcells: {mem_cells}")
    
    condition =  Conditions.Dark.DSSC(bias_voltage=bias_voltage, memory_cells=mem_cells,\
                                      pulseid_checksum=pulseid_checksum,\
                                      acquisition_rate=conditions['acquisition_rate'],\
                                      target_gain=conditions['target_gain'],\
                                      encoded_gain=conditions['encoded_gain'])
    
    detinst = getattr(Detectors, dinstance)
    device = getattr(detinst, qm)
    with h5py.File(filename, "r", driver="core") as infile:
        y = infile[f"{h5path}/data"].shape[2]
        x = infile[f"{h5path}/data"].shape[3]
    offset, when = get_constant_from_db_and_time(karabo_id, karabo_da,
                                                 Constants.DSSC.Offset(),
                                                 condition,
                                                 None,
                                                 cal_db_interface,
                                                 creation_time=creation_time,
                                                 timeout=cal_db_timeout)
    if offset is not None:
        offset = np.moveaxis(np.moveaxis(offset[...], 2, 0), 2, 1)
    else:
        offset_not_found = True
        print("No offset found in the database")
    
    def copy_and_sanitize_non_cal_data(infile, outfile):
        # these are touched in the correct function, do not copy them here
        dont_copy = ["data"]
        dont_copy = [h5path + "/{}".format(do)
                     for do in dont_copy]

        # a visitor to copy everything else
        def visitor(k, item):
            if k not in dont_copy:

                if isinstance(item, h5py.Group):
                    outfile.create_group(k)
                elif isinstance(item, h5py.Dataset):
                    group = str(k).split("/")
                    group = "/".join(group[:-1])
                    infile.copy(k, outfile[group])

        infile.visititems(visitor)

    try:
        with h5py.File(filename, "r", driver="core") as infile:
            with h5py.File(filename_out, "w") as outfile:
                copy_and_sanitize_non_cal_data(infile, outfile)
                # get indices of last images in each train
                first_arr = np.squeeze(infile[f"{h5path_idx}/first"]).astype(np.int)
                last_arr = np.concatenate((first_arr[1:], np.array([-1,]))).astype(np.int)
                assert first_arr.size == last_arr.size
                oshape = list(infile[f"{h5path}/data"].shape)
                if len(oshape) == 4:
                    oshape = [oshape[0],]+oshape[2:]
                chunks = (chunksize, oshape[1], oshape[2])
                ddset = outfile.create_dataset(f"{h5path}/data",
                                               oshape, chunks=chunks,
                                               dtype=np.float32,
                                               fletcher32=True)

                mdset = outfile.create_dataset(f"{h5path}/mask",
                                               oshape, chunks=chunks,
                                               dtype=np.uint32,
                                               compression="gzip",
                                               compression_opts=1,
                                               shuffle=True,
                                               fletcher32=True)

                for train in range(first_arr.size):
                    first = first_arr[train]
                    last = last_arr[train]
                    if first == last:
                        continue
                    data = np.squeeze(infile[f"{h5path}/data"][first:last, ...].astype(np.float32))
                    cellId = np.squeeze(infile[f"{h5path}/cellId"][first:last, ...])
                    pulseId = np.squeeze(infile[f"{h5path}/pulseId"][first:last, ...])                   
                    if not offset_not_found:
                        data[...] -= offset[cellId,...]
                        
                    if hists_signal_low is None:
                        pulseId = np.repeat(pulseId[:, None], data.shape[1], axis=1)
                        pulseId = np.repeat(pulseId[:,:,None], data.shape[2], axis=2)
                        bins = (55, int(pulseId.max()))
                        rnge = [[-5, 50], [0, int(pulseId.max())]]
                        hists_signal_low, low_edges, pulse_edges = np.histogram2d(data.flatten(),
                                                                                  pulseId.flatten(),
                                                                                  bins=bins,
                                                                                  range=rnge)
                        rnge = [[-5, 300], [0, pulseId.max()]]
                        hists_signal_high, high_edges, _ = np.histogram2d(data.flatten(),
                                                                          pulseId.flatten(),
                                                                          bins=bins,
                                                                          range=rnge)
                    ddset[first:last, ...] = data
                
                # find static and noisy values in dark images
                data = infile[f"{h5path}/data"][last, ...].astype(np.float32)
                bpix = np.zeros(oshape[1:], np.uint32)
                dark_std = np.std(data, axis=0)
                bpix[dark_std > noisy_pix_threshold] = BadPixels.NOISE_OUT_OF_THRESHOLD.value

                for i in range(8):
                    for j in range(2):
                        count_noise = np.count_nonzero(bpix[i*64:(i+1)*64, j*64:(j+1)*64])
                        asic_std = np.std(data[:, i*64:(i+1)*64, j*64:(j+1)*64])
                        if mask_noisy_asic:
                            if count_noise/(64*64) > mask_noisy_asic:
                                bpix[i*64:(i+1)*64, j*64:(j+1)*64] = BadPixels.NOISY_ADC.value
                    
                        if mask_cold_asic:
                            count_cold = np.count_nonzero(asic_std < 0.5)
                            if count_cold/(64*64) > mask_cold_asic:
                                bpix[i*64:(i+1)*64, j*64:(j+1)*64] = BadPixels.ASIC_STD_BELOW_NOISE.value

    except Exception as e:
        print(e)
        success = False
        reason = "Error"
        err = e
   
    if err is None and offset_not_found:
        err = "Offset not found in database!. No offset correction applied."
        
    return (hists_signal_low, hists_signal_high, low_edges, high_edges, pulse_edges, when, qm, err)
    
done = False
first_files = {}
inp = []
left = total_sequences

hists_signal_low = 0
hists_signal_high = 0 

low_edges, high_edges, pulse_edges = None, None, None

tGain, encodedGain, operatingFreq = get_dssc_ctrl_data(in_folder\
                                + "/r{:04d}/".format(run),\
                                slow_data_pattern,slow_data_aggregators, run)

whens = []
qms = []
Errors = []
while not done:
    dones = []
    for i, k_da in zip(modules, karabo_da):
        qm = "Q{}M{}".format(i//4 +1, i % 4 + 1)

        if qm in mapped_files:
            if not mapped_files[qm].empty():
                fname_in = str(mapped_files[qm].get())
                dones.append(mapped_files[qm].empty())
            else:
                print(f"{qm} file is missing")
                continue
        else:
            print(f"Skipping {qm}")
            continue
        fout = os.path.abspath("{}/{}".format(out_folder, (os.path.split(fname_in)[-1]).replace("RAW", "CORR")))
        
        first_files[i] = (fname_in, fout)
        conditions = {}
        conditions['acquisition_rate'] = operatingFreq[qm]
        conditions['target_gain'] = tGain[qm]
        conditions['encoded_gain'] = encodedGain[qm]
        inp.append((fname_in, fout, i, k_da, qm, conditions))
        
    if len(inp) >= min(MAX_PAR, left):
        print(f"Running {len(inp)} tasks parallel")
        p = partial(correct_module, total_sequences, sequences_qm,
                    karabo_id, dinstance, mask_noisy_asic, mask_cold_asic,
                    noisy_pix_threshold, chunk_size_idim, mem_cells,
                    bias_voltage, cal_db_timeout, creation_time, cal_db_interface,
                    h5path, h5path_idx)

        r = view.map_sync(p, inp)
        #r = list(map(p, inp))

        inp = []
        left -= MAX_PAR
        
        for rr in r:
            if rr is not None:
                hl, hh, low_edges, high_edges, pulse_edges, when, qm, err = rr
                whens.append(when)
                qms.append(qm)
                Errors.append(err)
                if hl is not None:  # any one being None will also make the others None
                    hists_signal_low += hl.astype(np.float64)
                    hists_signal_high += hh.astype(np.float64)                
    
    done = all(dones)

whens = [x for _,x in sorted(zip(qms,whens))]
qms = sorted(qms)
for i, qm in enumerate(qms):
    try:
        when = whens[i].isoformat()
    except:
        when = whens[i]
    if Errors[i] is not None:

        # Avoid writing wrong injection date if cons. not found.
        if "not found" in str(Errors[i]):
            print(f"ERROR! {qm}: {Errors[i]}")
        else:
            print(f"Offset for {qm} was injected on {when}, ERROR!: {Errors[i]}")
    else:
        print(f"Offset for {qm} was injected on {when}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from matplotlib.ticker import FormatStrFormatter, LinearLocator
from mpl_toolkits.mplot3d import Axes3D

%matplotlib inline
def do_3d_plot(data, edges, x_axis, y_axis):
    fig = plt.figure(figsize=(10,10))
    ax = fig.gca(projection='3d')

    # Make data.
    X = edges[0][:-1]
    Y = edges[1][:-1]
    X, Y = np.meshgrid(X, Y)

    Z = data.T

    # Plot the surface.
    surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)
    ax.set_xlabel(x_axis)
    ax.set_ylabel(y_axis)
    ax.set_zlabel("Counts")

In [None]:
def do_2d_plot(data, edges, y_axis, x_axis):
    from matplotlib.colors import LogNorm
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    extent = [np.min(edges[1]), np.max(edges[1]),np.min(edges[0]), np.max(edges[0])]
    im = ax.imshow(data[::-1,:], extent=extent, aspect="auto", norm=LogNorm(vmin=1, vmax=np.max(data)))
    ax.set_xlabel(x_axis)
    ax.set_ylabel(y_axis)
    cb = fig.colorbar(im)
    cb.set_label("Counts")
    
    

## Mean Intensity per Pulse ##

The following plots show the mean signal for each pulse in a detailed and expanded intensity region.

In [None]:
do_3d_plot(hists_signal_low, [low_edges, pulse_edges], "Signal (ADU)", "Pulse id")
do_2d_plot(hists_signal_low, [low_edges, pulse_edges], "Signal (ADU)", "Pulse id")
do_3d_plot(hists_signal_high, [high_edges, pulse_edges], "Signal (ADU)", "Pulse id")
do_2d_plot(hists_signal_high, [high_edges, pulse_edges], "Signal (ADU)", "Pulse id")

In [None]:
corrected = []
raw = []
mask = []
pulse_ids = []
train_ids = [] 
for channel, ff in first_files.items():
    try:
        raw_file, corr_file = ff
        data_path = h5path.format(channel)
        index_path = h5path_idx.format(channel)
        try:
            infile = h5py.File(raw_file, "r")
            first_idx = int(np.array(infile[f"{index_path}/first"])[0])
            
            raw_d = np.array(infile[f"{data_path}/data"])
            # Use first 128 images for plotting
            if raw_d.shape[0] >= 128:
                # random number for plotting
                plt_im = 128 
            else:
                plt_im = d.shape[0]
            last_idx = first_idx + plt_im
            raw.append((channel,raw_d[first_idx:last_idx,0,...]))
        finally:
            infile.close()
        
        infile = h5py.File(corr_file, "r")
        try:
            corrected.append((channel, np.array(infile[f"{data_path}/data"][first_idx:last_idx,...])))
            mask.append((channel, np.array(infile[f"{data_path}/mask"][first_idx:last_idx,...])))
            pulse_ids.append((channel, np.squeeze(infile[f"{data_path}/pulseId"][first_idx:last_idx,...])))
            train_ids.append((channel, np.squeeze(infile[f"{data_path}/trainId"][first_idx:last_idx,...])))
        finally:
            infile.close()
        
    except Exception as e:
        print(e)

In [None]:
def combine_stack(d, sdim):
    combined = np.zeros((sdim, 1300,1300), np.float32)
    combined[...] = 0
    
    dy = 0
    quad_pos = [
        (0, 145),
        (130, 140),
        (125, 15),
        (0, 15),
        
    ]
    
    px = 0.236
    py = 0.204
    with h5py.File(geo_file, "r") as gf:
        # TODO: refactor to -> for ch, f in d:
        for i in range(len(d)):
            
            ch = d[i][0]
          
            mi = 3-(ch%4)
            mp = gf["Q{}/M{}/Position".format(ch//4+1, mi%4+1)][()]
            t1 = gf["Q{}/M{}/T01/Position".format(ch//4+1, ch%4+1)][()]
            t2 = gf["Q{}/M{}/T02/Position".format(ch//4+1, ch%4+1)][()]
            if ch//4 < 2:
                t1, t2 = t2, t1
            
            if ch // 4 == 0 or ch // 4 == 1:
                td = d[i][1][:,::-1,:]
            else:
                td = d[i][1][:,:,::-1]
            
            t1d = td[:,:,:256]
            t2d = td[:,:,256:]
            
            x0t1 = int((t1[0]+mp[0])/px)
            y0t1 = int((t1[1]+mp[1])/py)
            x0t2 = int((t2[0]+mp[0])/px)
            y0t2 = int((t2[1]+mp[1])/py)
            
            x0t1 += int(quad_pos[i//4][1]/px)
            x0t2 += int(quad_pos[i//4][1]/px)
            y0t1 += int(quad_pos[i//4][0]/py)+combined.shape[1]//16
            y0t2 += int(quad_pos[i//4][0]/py)+combined.shape[1]//16
            combined[:,y0t1:y0t1+128,x0t1:x0t1+256] = t1d
            combined[:,y0t2:y0t2+128,x0t2:x0t2+256] = t2d

    return combined

In [None]:
combined = combine_stack(corrected, last_idx-first_idx)
combined_raw = combine_stack(raw, last_idx-first_idx)
combined_mask = combine_stack(mask, last_idx-first_idx)

### Mean RAW Preview ###



In [None]:
display(Markdown("The per pixel mean of the first {} images of the RAW data".format(plt_im)))

In [None]:
%matplotlib inline
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
im = ax.imshow(np.mean(combined_raw[:,...],axis=0),
               vmin=min(0.75*np.median(combined_raw[combined_raw > 0]), -5),
               vmax=max(1.5*np.median(combined_raw[combined_raw > 0]), 50), cmap="jet")
cb = fig.colorbar(im, ax=ax)

### Single Shot Preview ###

A single shot image from cell 2 of the first train

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
dim = combined[2,...]

im = ax.imshow(dim, vmin=-0, vmax=max(1.5*np.median(dim[dim > 0]), 50), cmap="jet", interpolation="nearest")
cb = fig.colorbar(im, ax=ax)

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
h = ax.hist(dim.flatten(), bins=100, range=(0, 100))

### Mean CORRECTED Preview ###

In [None]:
display(Markdown("The per pixel mean of the first {} images of the CORRECTED data".format(plt_im)))

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
im = ax.imshow(np.mean(combined[:,...], axis=0), vmin=0,
               vmax=max(1.5*np.median(combined[combined > 0]), 10), cmap="jet", interpolation="nearest")
cb = fig.colorbar(im, ax=ax)

### Max CORRECTED Preview ###

The per pixel maximum of the first 128 images of the CORRECTED data

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
im = ax.imshow(np.max(combined[:,...], axis=0), vmin=0,
               vmax=max(100*np.median(combined[combined > 0]), 20), cmap="jet", interpolation="nearest")
cb = fig.colorbar(im, ax=ax)

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
combined[combined <= 0] = 0
h = ax.hist(combined.flatten(), bins=100, range=(-5, 100), log=True)

## Bad Pixels ##
The mask contains dedicated entries for all pixels and memory cells as well as all three gains stages. Each mask entry is encoded in 32 bits as:

In [None]:
import tabulate
from cal_tools.enums import BadPixels
from IPython.display import HTML, Latex, Markdown, display

table = []
for item in BadPixels:
    table.append((item.name, "{:016b}".format(item.value)))
md = display(Latex(tabulate.tabulate(table, tablefmt='latex', headers=["Bad pixel type", "Bit mask"])))

### Full Train Bad Pixels ###

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
im = ax.imshow(np.log2(np.max(combined_mask[:,...], axis=0)), vmin=0,
               vmax=32, cmap="jet")
cb = fig.colorbar(im, ax=ax)

### Full Train Bad Pixels - Only Dark Char. Related ###

In [None]:
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
im = ax.imshow(np.max((combined_mask.astype(np.uint32)[:,...] & BadPixels.NOISY_ADC.value) != 0, axis=0), vmin=0,
               vmax=1, cmap="jet")
cb = fig.colorbar(im, ax=ax)