diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py index 3cd9472f7b6ce7b5c09b28c047367f4c14d9d4a8..f118fb80e1b5cd49545d3f4b8429becc6aceea4f 100644 --- a/src/toolbox_scs/routines/boz.py +++ b/src/toolbox_scs/routines/boz.py @@ -342,11 +342,11 @@ def find_rois(data_mean, threshold): pY = data_mean.sum(axis=1) # along X - lowX = int(np.argmax(pX[:128] > threshold)) # 1st occurrence returned - highX = int(np.argmax(pX[128:] < threshold) + 128) # 1st occ. returned - midX = (lowX + highX)//2 - leftX = int(np.argmin(pX[(lowX+20):midX]) + lowX + 20) - rightX = int(np.argmin(pX[midX:highX-20]) + midX) + lowX = int(np.argmax(pX[:64] > threshold)) # 1st occurrence returned + highX = int(np.argmax(pX[192:] <= threshold) + 192) # 1st occ. returned + + leftX = int(np.argmin(pX[64:128]) + 64) + rightX = int(np.argmin(pX[128:192]) + 128) # along Y lowY = int(np.argmax(pY[:64] > threshold)) # 1st occurrence returned @@ -795,26 +795,24 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None): fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6)) img_rois = {} + centers = {} + + for k, r in enumerate(['n', '0', 'p']): + roi = rois[r] + centers[r] = np.array([(roi['yl'] + roi['yh'])//2, + (roi['xl'] + roi['xh'])//2]) + d = '0' + roi = rois[d] for k, r in enumerate(['n', '0', 'p']): - img_rois[r] = avg[rois[r]['yl']:rois[r]['yh'], - rois[r]['xl']:rois[r]['xh']] + img_rois[r] = np.roll(avg, tuple(centers[r] - centers[d]))[ + roi['yl']:roi['yh'], roi['xl']:roi['xh']] im = axs[0, k].imshow(img_rois[r], vmin=vmin, vmax=vmax) for k, r in enumerate(['n', '0', 'p']): - if img_rois[r].shape[1] != img_rois['0'].shape[1]: - if k == 0: - n1 = img_rois[r].shape[1] - n = img_rois['0'].shape[1] - v = img_rois[r][:, (n1-n):]/img_rois['0'] - else: - n1 = img_rois[r].shape[1] - n = img_rois['0'].shape[1] - v = img_rois[r][:, :-(n1-n)]/img_rois['0'] - else: - v = img_rois[r]/img_rois['0'] + v = img_rois[r]/img_rois['0'] im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r') cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") @@ -1073,16 +1071,17 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) - # calculated error from transmission of 1.0 - v = d['n'].values.flatten()/d['0'].values.flatten() - err1 = 1e8*np.nanmean((v - 1.0)**2) + v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(), + methods=['weighted']) + err_1 = 1e8*v_1['weighted']['s']**2 - err2 = np.sum((Fmodel-np.arange(2**9))**2) + v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(), + methods=['weighted']) + err_2 = 1e8*v_2['weighted']['s']**2 - # print(f'{err}: {p}') - # logging.info(f'{err}: {p}') + err_a = np.sum((Fmodel-np.arange(2**9))**2) - return (1.0 - alpha)*err1 + alpha*err2 + return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a def nl_fit(params, domain): @@ -1175,44 +1174,68 @@ def inspect_nl_fit(res_fit): return f -def snr(sig, ref, verbose=False): - """ Compute mean, std and SNR with and without weight from transmitted signal sig - and I0 signal ref +def snr(sig, ref, methods=None, verbose=False): + """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref. + + Inputs + ------ + sig: 1D signal samples + ref: 1D reference samples + methods: None by default or list of strings to select which methods to use. + Possible values are 'direct', 'weighted', 'diff'. In case of None, all + methods will be calculated. + verbose: booleand, if True prints calculated values + + Returns + ------- + dictionnary of [methods][value] where value is 'mu' for mean and 's' for + standard deviation. + """ + if methods is None: + methods = ['direct', 'weighted', 'diff'] + w = ref x = sig/ref mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref) - + w = w[mask] sig = sig[mask] ref = ref[mask] x = x[mask] - # direct mean and std - mu = np.mean(x) - s = np.std(x) - if verbose: - print(f'mu: {mu}, s: {s}, snr: {mu/s}') res = {} - res['direct'] = {'mu': mu, 's':s} + + # direct mean and std + if 'direct' in methods: + mu = np.mean(x) + s = np.std(x) + if verbose: + print(f'mu: {mu}, s: {s}, snr: {mu/s}') + + res['direct'] = {'mu': mu, 's':s} # weighted mean and std - wmu = np.sum(sig)/np.sum(ref) - v1 = np.sum(w) - v2 = np.sum(w**2) - ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1)) + if 'weighted' in methods: + wmu = np.sum(sig)/np.sum(ref) + v1 = np.sum(w) + v2 = np.sum(w**2) + ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1)) + + if verbose: + print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}') - if verbose: - print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}') - res['weighted'] = {'mu': wmu, 's':ws} + res['weighted'] = {'mu': wmu, 's':ws} # noise from diff - dmu = np.mean(x) - ds = np.std(np.diff(x))/np.sqrt(2) - if verbose: - print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}') - res['diff'] = {'mu': dmu, 's':ds} + if 'diff' in methods: + dmu = np.mean(x) + ds = np.std(np.diff(x))/np.sqrt(2) + if verbose: + print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}') + + res['diff'] = {'mu': dmu, 's':ds} return res @@ -1256,7 +1279,7 @@ def inspect_correction(params, gain=None): scale = 1e-6 - f, axs = plt.subplots(3, 3, figsize=(6, 6), sharex=True, sharey=True) + f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True) # nbins = np.linspace(0.01, 1.0, 100) @@ -1278,16 +1301,11 @@ def inspect_correction(params, gain=None): snr_v = snr(good_d[n].values.flatten(), good_d[r].values.flatten(), verbose=True) - if k == 0: - m = np.nanmean(good_d[n].values.flatten() - /good_d[r].values.flatten()) - else: - m = 1 - + m = snr_v['direct']['mu'] h, xedges, yedges, img = axs[l, k].hist2d( g*scale*good_d[r].values.flatten(), - good_d[n].values.flatten()/good_d[r].values.flatten()/m, - [photon_scale, np.linspace(0.95, 1.05, 150)], + good_d[n].values.flatten()/good_d[r].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Blues', vmax=200, norm=LogNorm(), @@ -1295,13 +1313,14 @@ def inspect_correction(params, gain=None): ) h, xedges, yedges, img2 = axs[l, k].hist2d( g*scale*sat_d[r].values.flatten(), - sat_d[n].values.flatten()/sat_d[r].values.flatten()/m, - [photon_scale, np.linspace(0.95, 1.05, 150)], + sat_d[n].values.flatten()/sat_d[r].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Reds', vmax=200, norm=LogNorm(), # alpha=0.5 # make the plot looks ugly with lots of white lines ) + v = snr_v['direct']['mu']/snr_v['direct']['s'] axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}', transform = axs[l, k].transAxes) @@ -1312,9 +1331,11 @@ def inspect_correction(params, gain=None): # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + axs[l, k].set_ylim([0.95*m, 1.05*m]) + for k in range(3): - for l in range(3): - axs[l, k].set_ylim([0.95, 1.05]) + #for l in range(3): + # axs[l, k].set_ylim([0.95, 1.05]) if gain: axs[2, k].set_xlabel('#ph (10$^6$)') else: