# -*- coding: utf-8 -*-
""" TZPG simple calculator for SCS.

    Interactive widget to calculate beam sizes and position at the sample and
    detector planes for the SCS instrument.

    Copyright (2019-2021) SCS Team.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import PathPatch, Rectangle, Polygon
from matplotlib.colors import hsv_to_rgb
from matplotlib.path import Path

import ipywidgets as widgets
from ipywidgets import HBox, VBox
from IPython.display import display

from TZPGcalc.GeoBeams import GeoBeams


# number of membrane to show
SampleN = 7

# Database of existing zone plates parameters
TZPG_db = {
    'Custom': {
        'design_nrj': 860,
        'TZPGwH': 1,
        'TZPGwV': 1,
        'TZPGoffaxis': 0.75,
        'grating': 3.8,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'O': {
        'design_nrj': 530,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'Fe': {
        'design_nrj': 715,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'Co': {
        'design_nrj': 785,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'Ni': {
        'design_nrj': 860,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'Cu': {
        'design_nrj': 927,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    },
    'Gd': {
        'design_nrj': 1210,
        'TZPGwH': 0.8,
        'TZPGwV': 0.8,
        'TZPGoffaxis': 0.55,
        'grating': 3.1,
        'F_x': 0.25,
        'F_y': 0.25,
        '3beams': True
    }
}


class TZPGcalc():
    def __init__(self):
        self.geo_beams = GeoBeams()
        self.initWidgets()
        self.initFig()
        self.init_beam_transport()

        # spot sizes of all beams
        self.SpotSizes = {}
        for d in ['det', 'sam']:
            self.SpotSizes[d] = {}
            for k in ['L', 'H']:
                self.SpotSizes[d][k] = np.zeros((6, 2))

        # spot center of all beams
        self.SpotCenters = {}
        for d in ['det', 'sam']:
            self.SpotCenters[d] = {}
            for k in ['L', 'H']:
                self.SpotCenters[d][k] = np.zeros((6, 2))

        self.UpdateFig()
        display(self.control)

    def init_beam_transport(self):
        temp = GeoBeams()
        # set default value for beam transport
        for v in ['fVFM', 'fHFM', 'EXw', 'IHFw']:
            if v in ['fVFM', 'fHFM']:
                scale = 1.0
            else:
                scale = 1e6
            self.widgets[v].value = scale*temp.elems[v]

    def initFig(self):
        "Creates a figure for the sample plane and detector plane images."

        plt.close('TZPGcalc')
        fig, (self.ax_sam, self.ax_det) = plt.subplots(
            1, 2, num='TZPGcalc', figsize=(6, 3))

        # display scale
        self.scale = 1e3  # displayed distances in [mm]

        self.ax_sam.set_title('Sample plane')
        self.ax_det.set_title('Detector plane')

        self.ax_sam.set_aspect('equal')
        self.ax_det.set_aspect('equal')
        self.ax_sam.set_xlim([-2, 2])
        self.ax_sam.set_ylim([-2, 2])
        self.ax_det.set_xlim([-35, 35])
        self.ax_det.set_ylim([-20, 50])

        # red and blue shifted color of the beams
        c_rr = hsv_to_rgb([10/360, 100/100, 70/100])
        c_rb = hsv_to_rgb([220/360, 100/100, 70/100])
        c_gr = hsv_to_rgb([10/360, 100/100, 100/100])
        c_gb = hsv_to_rgb([220/360, 100/100, 100/100])

        self.samBeamsL = {
            'F0G0': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G-1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F1G0': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_rr, alpha=0.7, lw=None)),
            'F1G1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_gr, alpha=0.7, lw=None)),
            'F1G-1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_gr, alpha=0.7, lw=None))
            }

        self.detBeamsL = {
            'F0G0': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G-1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F1G0': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_rr, alpha=0.7, lw=None)),
            'F1G1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_gr, alpha=0.7, lw=None)),
            'F1G-1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_gr, alpha=0.7, lw=None))
            }

        self.samBeamsH = {
            'F0G0': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G-1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F1G0': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_rb, alpha=0.7, lw=None)),
            'F1G1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_gb, alpha=0.7, lw=None)),
            'F1G-1': self.ax_sam.add_patch(
                Polygon([(0, 0)], facecolor=c_gb, alpha=0.7, lw=None))
            }

        self.detBeamsH = {
            'F0G0': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F0G-1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor="black", alpha=0.4, lw=None)),
            'F1G0': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_rb, alpha=0.7, lw=None)),
            'F1G1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_gb, alpha=0.7, lw=None)),
            'F1G-1': self.ax_det.add_patch(
                Polygon([(0, 0)], facecolor=c_gb, alpha=0.7, lw=None))
            }

        self.detLines = {
            'module': self.ax_det.add_patch(
                Rectangle((0, 0), 1, 1, fill=False, facecolor='k')),
            'Vfilter': self.ax_det.add_patch(
                Rectangle((0, 0), 1, 1, facecolor="green", alpha=0.4)),
            'Hfilter': self.ax_det.add_patch(
                Rectangle((0, 0), 1, 1, facecolor="green", alpha=0.4)),
            'diamond': self.ax_det.add_patch(
                Rectangle((-8, -8), 16, 16, facecolor="green", alpha=0.4,
                          angle=45))
               }

        # SampleNxSampleN membranes
        self.sampleLines = {}
        self.etchLines = {}
        for k in range(SampleN*SampleN):
            self.sampleLines[k] = self.ax_sam.add_patch(
                Rectangle((0, 0), 1, 1, fill=False, facecolor='k'))
            self.etchLines[k] = self.ax_sam.add_patch(
                Rectangle((0, 0), 1, 1, fill=False, facecolor='k',
                          alpha=0.4, ls='--'))

        # Flat Liquid Jet
        self.FLJ_lines = {}
        self.FlatLiquidJet()

    def FlatLiquidJet(self):
        """Draw a Flat Liquid jet.
        """
        sw = self.widgets
        HW = 0.5*sw['FLJ_W'].value  # [mm]
        L = sw['FLJ_L'].value  # [mm]
        mf = sw['FLJ_mf'].value  # [mm]
        incidence = np.deg2rad(sw['samIncidence'].value)  # [rad]
        ox = sw['samX'].value  # [mm]
        oy = -L/2 + sw['samY'].value  # [mm]

        # incidence angle squeezes sample
        ci = np.cos(incidence)

        verts = [
            [(0*ci, L + oy),   # P0
             (HW/2*ci, L + oy),  # P1
             (HW*ci, 0.5*(1+mf)*L + oy),  # P2
             (HW*ci, mf*L + oy)],  # P3
            [(HW*ci, mf*L + oy),  # P0
             (HW*ci, 0.5*mf*L + oy),  # P1
             (0, 0 + oy),  # P2
             (0, 0 + oy)]  # P3
        ]

        # second half image of liquid jet
        mirror = []
        for v in verts:
            mirror.append([(-p[0], p[1]) for p in v])

        verts += mirror

        # apply offsets
        fverts = []
        for v in verts:
            fverts.append([(p[0] + ox, p[1]) for p in v])

        codes = [
            Path.MOVETO,  # P0
            Path.CURVE4,  # P1
            Path.CURVE4,  # P2
            Path.CURVE4  # P3
        ]

        for k, v in enumerate(fverts):
            path = Path(v, codes)
            if k in self.FLJ_lines:
                self.FLJ_lines[k].remove()

            self.FLJ_lines[k] = self.ax_sam.add_patch(
                PathPatch(path, alpha=0.4, facecolor='none', lw=2)
                )

    def RectUpdate(self, rect, xLeft, yBottom, xRight, yTop):
        """Updates the position and size of the given Rectangle.

        Inputs
        ------
            rect: Rectangle to update
            xLeft: x position of the left corner
            yBottom: y position of the bottom corner
            xRight: x position of the right corner
            yTop: y position of the top corner
        """

        xw = np.abs(xLeft - xRight)
        yw = np.abs(yTop - yBottom)

        rect.set_xy((self.scale*xLeft, self.scale*yBottom))
        rect.set_height(self.scale*yw)
        rect.set_width(self.scale*xw)

    def PolyUpdate(self, poly, xLeft, yBottom, xRight, yTop):
        """Updates the corner position of a Polygon.

        Inputs
        ------
            poly: regular Polygon to update
            xLeft: x position of the left corner
            yBottom: y position of the bottom corner
            xRight: x position of the right corner
            yTop: y position of the top corner
        """

        xy = self.scale*np.array([
            [xLeft, yBottom],
            [xLeft, yTop],
            [xRight, yTop],
            [xRight, yBottom]])

        poly.set_xy(xy)

    def UpdateBeams(self, Beams, img, conf):
        """Updates the position and size of the beams.

        Inputs
        ------
            Beams: dictionary of f'F{f}G{g}' Polygon for f = 0 and 1 zone
                plate order and g = +1, 0 and -1 grating order
            img: dictionary of parameters for the imaging plane
            conf: dictionary of parameters for optics
        """
        # shortcut
        sge = self.geo_beams.elems

        sge['fBOZ_x'] = conf['F_x']
        sge['fBOZ_y'] = conf['F_y']
        sge['theta_grating'] = conf['theta_grating']

        # imaging plane
        n = np.array([np.sin(img['incidence']),
                      0,
                      np.cos(img['incidence'])])
        p0 = np.array([0, 0, img['z']])

        res = {}
        for z, (f, g) in enumerate([(0, -1), (0, 0), (0, 1),
                                    (1, -1), (1, 0), (1, 1)]):
            beam = f'F{f}G{g}'
            res[beam] = (self.geo_beams.plane_image(p0, n, f, g)
                         - np.array([0, sge['offaxis']]))  # 0th order offset

            corners = self.scale*res[beam]

            Beams[beam].set_xy(corners)
            for k in [0, 1]:
                vs = corners[:, k]
                self.SpotSizes[img['type']][conf['Energy']][z, k] = (
                    1e3*(np.max(vs) - np.min(vs)))
                self.SpotCenters[img['type']][conf['Energy']][z, k] = (
                    1e3*0.5*(np.max(vs) + np.min(vs)))

        # 3 beams configuration
        b = self.widgets['3beams'].value
        Beams['F0G0'].set_visible(b)
        Beams['F1G0'].set_visible(b)

    def DetectorUpdate(self, Xoff, Yoff):
        """Draws DSSC detector module with filter mask.

        Inputs
        ------
            Xoff: x offset
            Yoff: y offset
        """
        # x module axis is vertical, y module axis is horizontal
        # the module 15 is +0.91 mm vertical from the beam and
        # 4.233 mm horizontal from the beam
        offset_h = 4.233e-3  # [m]
        offset_v = 0.91e-3  # [m]

        moduleHw = 256*0.236e-3  # [m]
        moduleVw = 128*0.204e-3  # [m]

        filterW = 7e-3  # [m]
        filterL = 160e-3  # [m]
        diamondW = 16e-3  # [m]

        self.RectUpdate(self.detLines['module'],
                        -moduleHw - offset_h + Xoff, offset_v + Yoff,
                        -offset_h + Xoff, moduleVw + offset_v + Yoff)
        self.RectUpdate(self.detLines['Vfilter'],
                        -filterW/2 + Xoff, -filterL/2 + Yoff,
                        filterW/2 + Xoff, filterL/2 + Yoff)
        self.RectUpdate(self.detLines['Hfilter'],
                        -filterL/2 + Xoff, -filterW/2 + Yoff,
                        filterL/2 + Xoff, filterW/2 + Yoff)

        # moving rotated rectangles is a pain in matplotlib
        self.detLines['diamond'].set_xy((
            self.scale*Xoff, self.scale*(Yoff - diamondW/2*np.sqrt(2))))

    def SampleUpdate(self):
        if self.widgets['SampleType'].value == 'Membranes Array':
            b = True
        elif self.widgets['SampleType'].value == 'Flat Liquid Jet':
            b = False
        else:
            raise ValueError('Sample type must be either "Membranes Array" or'
                             '"Flat Liquid Jet"')

        for k in range(SampleN*SampleN):
            self.sampleLines[k].set_visible(b)
            self.etchLines[k].set_visible(b)

        for k in range(4):
            self.FLJ_lines[k].set_visible(not(b))

        if b:
            self.MembraneSampleUpdate()
        else:
            self.FlatLiquidJet()

    def MembraneSampleUpdate(self):
        """Draws the sample.

        Inputs
        ------
            w: membrane width
            p: membrane pitch
            Xoff: sample x offset
            Yoff: sample y offset
            thickness: sample thickness used to calculate the etched facets
            incidence: incidence angle in rad
            etch_angle: etching angle from surface in rad
        """
        sw = self.widgets
        w = sw['samw'].value*1e-3  # [m]
        p = sw['samp'].value*1e-3  # [m]
        Xoff = sw['samX'].value*1e-3  # [m]
        Yoff = sw['samY'].value*1e-3  # [m]
        thickness = sw['samthickness'].value*1e-6  # [m]
        etch_angle = np.deg2rad(sw['samEtchAngle'].value)  # [rad]
        incidence = np.deg2rad(sw['samIncidence'].value)  # [rad]

        # Si etching angle
        wp = w + 2*thickness/np.tan(etch_angle)

        # incidence angle squeezes sample and etch lines
        # and induces an apparent shift off the etch lines
        ci = np.cos(incidence)
        thsi = thickness*np.sin(incidence)

        j = 0
        for r in range(-(SampleN-1)//2, (SampleN-1)//2+1):
            for c in range(-(SampleN-1)//2, (SampleN-1)//2+1):
                self.RectUpdate(
                    self.sampleLines[j],
                    ci*(r*p - w/2 + Xoff), c*p - w/2 - Yoff,
                    ci*(r*p + w/2 + Xoff), c*p + w/2 - Yoff)
                self.RectUpdate(
                    self.etchLines[j],
                    ci*(r*p - wp/2 + Xoff)+thsi, c*p - wp/2 - Yoff,
                    ci*(r*p + wp/2 + Xoff)+thsi, c*p + wp/2 - Yoff)
                j += 1

    def UpdateFig(self):
        """Updates the figure with the current slider values.

        """
        # shortcuts
        sw = self.widgets
        sge = self.geo_beams.elems

        # we calculate the optics for the central wavelength
        nrjL = sw['nrjL'].value  # [eV]
        nrjH = sw['nrjH'].value  # [eV]
        nrjD = sw['nrjD'].value  # [eV]
        wlL = 1240/nrjL*1e-9
        wlH = 1240/nrjH*1e-9
        wlD = 1240/nrjD*1e-9

        F_x = sw['F_x'].value  # [m] Nominal horiz. BOZ focal length
        F_y = sw['F_y'].value  # [m] Nominal vert. BOZ focal length
        theta_grating = sw['grating'].value*1e-3  # [rad]
        sampleZ = sw['SAMZ'].value*1e-3  # [m]
        samIncidence = np.deg2rad(sw['samIncidence'].value)  # [rad]
        detectorZ = sw['detZ'].value*1e-3  # [m]

        sge['WH'] = sw['TZPGwH'].value*1e-3  # [m]
        sge['WV'] = sw['TZPGwV'].value*1e-3  # [m]
        sge['offaxis'] = -sw['TZPGoffaxis'].value*1e-3  # [m]
        sge['EXw'] = sw['EXw'].value*1e-6  # [m]
        sge['IHFw'] = sw['IHFw'].value*1e-6  # [m]
        sge['fVFM'] = sw['fVFM'].value  # [m]
        sge['fHFM'] = sw['fHFM'].value  # [m]

        d_nominal = wlD/np.sin(theta_grating)
        sw['d_label'].value = (
            f'Grating Pitch:{int(np.round(d_nominal*1e9))} nm')

        # zone plate radius at the point further away from the optical axis
        rn = np.sqrt((sge['WH']/2.0)**2 +
                     (sge['WV']/2.0 + np.abs(sge['offaxis']))**2)
        dr_nominal_x = wlD * F_x / (2*rn)
        dr_nominal_y = wlD * F_y / (2*rn)
        sw['dr_label_x'].value = (
            f'Outer Zone Plate width dr for Horiz. '
            f'focus:{int(np.round(dr_nominal_x*1e9))} nm')
        sw['dr_label_y'].value = (
            f'for Vert. focus:{int(np.round(dr_nominal_y*1e9))} nm')

        # Optics properties (focal length and grating angle) for the
        # low energy and high energy photon
        F_L_x = 2*rn*dr_nominal_x/wlL
        F_L_y = 2*rn*dr_nominal_y/wlL
        G_L = np.arcsin(wlL/d_nominal)
        confL = {'Energy': 'L',
                 'F_x': F_L_x,
                 'F_y': F_L_y,
                 'theta_grating': G_L}

        F_H_x = 2*rn*dr_nominal_x/wlH
        F_H_y = 2*rn*dr_nominal_y/wlH
        G_H = np.arcsin(wlH/d_nominal)
        confH = {'Energy': 'H',
                 'F_x': F_H_x,
                 'F_y': F_H_y,
                 'theta_grating': G_H}

        # Configuration for imaging plane
        sam = {'type': 'sam', 'z': sampleZ, 'incidence': samIncidence}
        det = {'type': 'det', 'z': detectorZ, 'incidence': 0}

        # update the beams
        self.UpdateBeams(self.samBeamsL, sam, confL)
        self.UpdateBeams(self.detBeamsL, det, confL)
        self.UpdateBeams(self.samBeamsH, sam, confH)
        self.UpdateBeams(self.detBeamsH, det, confH)

        # update Spot Size Tables
        for v in ['L', 'H']:
            df = pd.DataFrame(self.SpotSizes['sam'][v],
                              index=['F0G-1', 'F0G0', 'F0G1',
                                     'F1G-1', 'F1G0', 'F1G1'],
                              columns=['H (um)', 'V (um)'])
            self.widgets[f'SpotSize{v}'].value = df.to_html(
                float_format='{:.2f}'.format)
            df = pd.DataFrame(self.SpotCenters['sam'][v],
                              index=['F0G-1', 'F0G0', 'F0G1',
                                     'F1G-1', 'F1G0', 'F1G1'],
                              columns=['H (um)', 'V (um)'])
            self.widgets[f'SpotCenter{v}'].value = df.to_html(
                float_format='{:.2f}'.format)

        # update the detector
        detXoff = self.widgets['detX'].value*1e-3  # [m]
        detYoff = self.widgets['detY'].value*1e-3  # [m]
        self.DetectorUpdate(detXoff, detYoff)

        # update the sample
        self.SampleUpdate()

    def initWidgets(self):
        """ Creates the necessary interactive widget controls.
        """

        style = {'description_width': 'initial'}
        layout = {}  # max_width': '300px'}

        self.Update = widgets.Button(
            description='Update',
        )

        @self.Update.on_click
        def plot_on_click(b):
            self.UpdateFig()

        self.widgets = {}

        # Spot sizes
        self.widgets['SpotSizeL'] = widgets.HTML()
        self.widgets['SpotSizeH'] = widgets.HTML()
        SpotSize = HBox([VBox([widgets.Label(value='Low energy'),
                               self.widgets['SpotSizeL']]),
                         VBox([
                             widgets.Label(value='High energy'),
                             self.widgets['SpotSizeH']])
                         ])
        # Spot centers
        self.widgets['SpotCenterL'] = widgets.HTML()
        self.widgets['SpotCenterH'] = widgets.HTML()
        SpotCenter = HBox([VBox([widgets.Label(value='Low energy'),
                                 self.widgets['SpotCenterL']]),
                           VBox([
                               widgets.Label(value='High energy'),
                               self.widgets['SpotCenterH']])
                         ])
        SpotSizeAndCenter = HBox([
            VBox([widgets.Label(value='Spot Size'), SpotSize]),
            VBox([widgets.Label(value='Spot center'), SpotCenter])
            ])

        # Source
        self.Reset = widgets.Button(
            description='Reset',
        )

        @self.Reset.on_click
        def reset_on_click(b):
            self.init_beam_transport()

        self.widgets['EXw'] = widgets.BoundedIntText(
            value=100,
            min=0,
            max=2000,
            step=1,
            description='Exit Slit (um):',
            style=style,
            layout=layout
        )
        self.widgets['IHFw'] = widgets.BoundedIntText(
            value=200,
            min=0,
            max=2000,
            step=1,
            description='IHF width (um):',
            style=style,
            layout=layout
        )
        self.widgets['fVFM'] = widgets.BoundedFloatText(
            value=0,
            min=0,
            max=10,
            step=0.01,
            description='VFM focal length (m):',
            style=style,
            layout=layout
        )
        self.widgets['fHFM'] = widgets.BoundedFloatText(
            value=0,
            min=0,
            max=10,
            step=0.01,
            description='HFM focal length (m):',
            style=style,
            layout=layout
        )
        SourceTab = VBox([
            self.Reset,
            self.widgets['EXw'],
            self.widgets['IHFw'],
            self.widgets['fVFM'],
            self.widgets['fHFM']])

        # TZPG part
        self.widgets['Type'] = widgets.Dropdown(
            options=list(TZPG_db),
            value='Custom',
            description='Type:',
            style=style,
            disabled=False
            )

        def TZPGtype(change):
            v = TZPG_db[change.new]
            self.widgets['nrjD'].value = v['design_nrj']
            self.widgets['TZPGwH'].value = v['TZPGwH']
            self.widgets['TZPGwV'].value = v['TZPGwV']
            self.widgets['TZPGoffaxis'].value = v['TZPGoffaxis']
            self.widgets['grating'].value = v['grating']
            self.widgets['F_x'].value = v['F_x']
            self.widgets['F_y'].value = v['F_y']
            self.widgets['3beams'].value = v['3beams']

            # necessary to recompute grating pitch and outer zone plate width
            self.UpdateFig()

        self.widgets['Type'].observe(TZPGtype, names='value')

        # hidden nominal zone plate focus
        self.widgets['F_x'] = widgets.BoundedFloatText(
            value=0.25,
            min=0,
            max=1,
            step=0.01,
            description='Focal length (m) Horiz.:',
            style=style,
            layout=layout
        )
        self.widgets['F_y'] = widgets.BoundedFloatText(
            value=0.25,
            min=0,
            max=1,
            step=0.01,
            description='Vert.:',
            style=style,
            layout=layout
        )

        self.widgets['nrjL'] = widgets.BoundedIntText(
            value=840,
            min=450,
            max=3200,
            step=1,
            description='Low:',
            style=style,
            layout=layout
        )
        self.widgets['nrjH'] = widgets.BoundedIntText(
            value=880,
            min=450,
            max=3200,
            step=1,
            description='High:',
            style=style,
            layout=layout
        )
        self.widgets['nrjD'] = widgets.BoundedIntText(
            value=860,
            min=450,
            max=3200,
            step=1,
            width=4,
            description='Design energy (eV):',
            style=style,
            layout=layout
        )
        self.widgets['TZPGwV'] = widgets.BoundedFloatText(
            value=1.0,
            min=.1,
            max=3.0,
            step=0.05,
            description='Height:',
            style=style,
            layout=layout
        )
        self.widgets['TZPGwH'] = widgets.BoundedFloatText(
            value=1.0,
            min=.1,
            max=3.0,
            step=0.05,
            description='Width:',
            style=style,
            layout=layout
        )
        self.widgets['TZPGoffaxis'] = widgets.BoundedFloatText(
            value=0.75,
            min=.0,
            max=2.0,
            step=0.05,
            description='Off axis (mm):',
            style=style,
            layout=layout
        )
        self.widgets['grating'] = widgets.BoundedFloatText(
            value=3.8,
            min=1.,
            max=10.0,
            step=0.05,
            description='Grating angle (mrad):',
            style=style,
            layout=layout
        )
        self.widgets['3beams'] = widgets.Checkbox(
            value=True,
            description='3 beams:',
            style=style,
            layout=layout
        )
        self.widgets['dr_label_x'] = widgets.Label(value='dr_x')
        self.widgets['dr_label_y'] = widgets.Label(value='dr_y')
        self.widgets['d_label'] = widgets.Label(value='dr')
        TZPGTab = VBox([
            self.widgets['Type'],
            HBox([self.widgets['nrjD'],
                  self.widgets['F_x'],
                  self.widgets['F_y']
                  ]),
            HBox([self.widgets['grating'],
                  self.widgets['TZPGoffaxis'],
                  self.widgets['3beams']]),
            HBox([self.widgets['dr_label_x'],
                  self.widgets['dr_label_y']]),
            HBox([widgets.Label(value='Optics (mm):'),
                  self.widgets['TZPGwH'],
                  self.widgets['TZPGwV']]),
            HBox([widgets.Label(value='Energy range (eV):'),
                  self.widgets['nrjL'], self.widgets['nrjH']])
            ])

        # sample part
        self.widgets['SampleType'] = widgets.Dropdown(
            options=['Membranes Array', 'Flat Liquid Jet'],
            value='Membranes Array',
            decription='Sample type:',
            style=style,
            layout=layout
        )
        self.widgets['SAMZ'] = widgets.BoundedFloatText(
            value=30.,
            min=-10.,
            max=180.0,
            step=.1,
            description='Sample Z (mm):',
            style=style,
            layout=layout
        )

        # membranes
        self.widgets['samw'] = widgets.BoundedFloatText(
            value=.5,
            min=0.01,
            max=2.0,
            step=.01,
            description='width (mm):',
            style=style,
            layout=layout
        )
        self.widgets['samp'] = widgets.BoundedFloatText(
            value=1.0,
            min=0.01,
            max=2.0,
            step=.01,
            description='pitch (mm):',
            style=style,
            layout=layout
        )
        self.widgets['samthickness'] = widgets.BoundedFloatText(
            value=381,
            min=1,
            max=1000,
            step=1,
            description='Substrate thickness (um):',
            style=style,
            layout=layout
        )
        self.widgets['samEtchAngle'] = widgets.BoundedFloatText(
            value=54.74,
            min=0,
            max=90,
            step=0.01,
            description='etch angle (deg):',
            style=style,
            layout=layout
        )

        # Flat Liquid Jet
        self.widgets['FLJ_W'] = widgets.BoundedFloatText(
            value=1.0,
            min=0,
            max=5,
            step=0.1,
            description='Flat Liquid Jet width (mm):',
            style=style,
            layout=layout
        )
        self.widgets['FLJ_L'] = widgets.BoundedFloatText(
            value=4.6,
            min=0,
            max=15,
            step=0.1,
            description='of length (mm):',
            style=style,
            layout=layout
        )
        self.widgets['FLJ_mf'] = widgets.BoundedFloatText(
            value=0.75,
            min=0,
            max=1,
            step=0.01,
            description='at:',
            style=style,
            layout=layout
        )

        self.widgets['samX'] = widgets.BoundedFloatText(
            value=0.,
            min=-10,
            max=10,
            step=0.01,
            description='X:',
            style=style,
            layout=layout
        )
        self.widgets['samY'] = widgets.BoundedFloatText(
            value=0.,
            min=-10,
            max=10,
            step=0.01,
            description='Y:',
            style=style,
            layout=layout
        )
        self.widgets['samIncidence'] = widgets.BoundedFloatText(
            value=0,
            min=0,
            max=90,
            step=1,
            description='Sample normal (deg):',
            style=style,
            layout=layout
        )
        samTab = VBox([
            self.widgets['SAMZ'],
            self.widgets['SampleType'],
            HBox([widgets.Label(value='Membranes array, '),
                  self.widgets['samw'],
                  self.widgets['samp']]),
            HBox([self.widgets['samthickness'],
                  self.widgets['samEtchAngle']]),
            HBox([self.widgets['FLJ_W'],
                  self.widgets['FLJ_mf'],
                  self.widgets['FLJ_L']]),
            HBox([widgets.Label(value='Sample Offset (mm), '),
                  self.widgets['samX'],
                  self.widgets['samY']]),
            HBox([self.widgets['samIncidence'],
                  ])
        ])

        # Detector tab
        self.widgets['detZ'] = widgets.BoundedFloatText(
            value=2000.,
            min=1000,
            max=5800,
            step=1,
            description='Z:',
            style=style,
            layout=layout
        )
        self.widgets['detX'] = widgets.BoundedFloatText(
            value=34.5,
            min=-50,
            max=50,
            step=0.5,
            description='X:',
            style=style,
            layout=layout
        )
        self.widgets['detY'] = widgets.BoundedFloatText(
            value=-2.,
            min=-50,
            max=50,
            step=0.5,
            description='Y:',
            style=style,
            layout=layout
        )
        detTab = VBox([HBox([widgets.Label(value='Detector (m), '),
                             self.widgets['detZ'],
                             self.widgets['detX'],
                             self.widgets['detY']
                             ])])

        # combined tabs
        tab0 = widgets.Accordion([SourceTab])
        tab0.set_title(0, 'Source and KBS')
        tab0.selected_index = 0

        tab1 = widgets.Accordion([TZPGTab])
        tab1.set_title(0, 'Beam splitting Off axis Zone plate')
        tab1.selected_index = 0

        tab2 = widgets.Accordion([samTab])
        tab2.set_title(0, 'Sample')
        tab2.selected_index = 0

        tab3 = widgets.Accordion([detTab])
        tab3.set_title(0, 'Detector')
        tab3.selected_index = 0

        tab4 = widgets.Accordion([SpotSizeAndCenter])
        tab4.set_title(0, 'Spot sizes and centers')
        tab4.selected_index = 0

        self.control = VBox([self.Update, tab0, tab1, tab2, tab3, tab4])