# vim: set expandtab shiftwidth=4 softtabstop=4:

# === UCSF ChimeraX Copyright ===
# Copyright 2016 Regents of the University of California.
# All rights reserved.  This software provided pursuant to a
# license agreement containing restrictions on its disclosure,
# duplication and use.  For details see:
# http://www.rbvi.ucsf.edu/chimerax/docs/licensing.html
# This notice must be embedded in or attached to all copies,
# including partial copies, of the software or any revisions
# or derivations thereof.
# === UCSF ChimeraX Copyright ===

def mlp(session, atoms=None, method="fauchere", spacing=1.0, max_distance=5.0, nexp=3.0,
        color=True, palette=None, range=None, map=False):
    '''Display Molecular Lipophilic Potential for a single model.

    Parameters
    ----------
    atoms : Atoms
        Color surfaces for these atoms using MLP map.  Only amino acid residues are used.
    method : 'dubost','fauchere','brasseur','buckingham','type5'
        Distance dependent function to use for calculation
    spacing : float
    	Grid spacing, default 1 Angstrom.
    max_distance : float
        Maximum distance from atom to sum lipophilicity.  Default 5 Angstroms.
    nexp : float
        The buckingham method uses this numerical exponent.
    color : bool
        Whether to color molecular surfaces. They are created if they don't yet exist.
    palette : Colormap
        Color palette for coloring surfaces.
        Default is lipophilicity colormap (orange lipophilic, blue lipophobic).
    range : 2-tuple of float
        Range of lipophilicity values defining ends of color map.  Default is -20,20
    map : bool
        Whether to open a volume model of lipophilicity values
    '''
    if atoms is None:
        from chimerax.atomic import all_atoms
        atoms = all_atoms(session)

    from chimerax.atomic import Residue
    patoms = atoms[atoms.residues.polymer_types == Residue.PT_AMINO]
    if len(patoms) == 0:
        from chimerax.core.errors import UserError
        raise UserError('mlp: no amino acids specified')
        
    if palette is None:
        from chimerax.core.colors import BuiltinColormaps
        cmap = BuiltinColormaps['lipophilicity']
    else:
        cmap = palette
    if range is None and not cmap.values_specified:
        range = (-20,20)
        
    # Color surfaces by lipophilicity
    if color:
        # Compute surfaces if not already created
        from chimerax.surface import surface
        surfs = surface(session, patoms)
        for s in surfs:
            satoms = s.atoms
            name = 'mlp ' + s.name.split(maxsplit=1)[0]
            v = mlp_map(session, satoms, method, spacing, max_distance, nexp, name, open_map = map)
            from chimerax.surface import color_surfaces_by_map_value
            color_surfaces_by_map_value(satoms, map = v, palette = cmap, range = range)
    else:
        name = 'mlp map'
        v = mlp_map(session, patoms, method, spacing, max_distance, nexp, name, open_map = map)
            

def register_mlp_command(logger):
    from chimerax.core.commands import register, CmdDesc, SaveFileNameArg, FloatArg, EnumOf, BoolArg, ColormapArg, ColormapRangeArg
    from chimerax.atomic import AtomsArg
    desc = CmdDesc(optional=[('atoms', AtomsArg)],
                   keyword=[('spacing', FloatArg),
                            ('max_distance', FloatArg),
                            ('method', EnumOf(['dubost','fauchere','brasseur','buckingham','type5'])),
                            ('nexp', FloatArg),
                            ('color', BoolArg),
                            ('palette', ColormapArg),
                            ('range', ColormapRangeArg),
                            ('map', BoolArg),
                            ],
                   synopsis='display molecular lipophilic potential for selected models')
    register('mlp', desc, mlp, logger=logger)

def mlp_map(session, atoms, method, spacing, max_dist, nexp, name, open_map):
    data, bounds = calculatefimap(atoms, method, spacing, max_dist, nexp)

    # m.pot is 1-dimensional if m.writedxfile() was called.  Has indices in x,y,z order.
    origin = tuple(xmin for xmin,xmax in bounds)
    s = spacing
    step = (s,s,s)
    from chimerax.map.data import ArrayGridData
    g = ArrayGridData(data, origin, step, name = name)
    g.polar_values = True
    from chimerax.map import volume_from_grid_data
    v = volume_from_grid_data(g, session, open_model = open_map, show_dialog = open_map)
    v.update_drawings()  # Compute surface levels
    v.set_parameters(surface_colors = [(0, 139/255, 139/255, 1), (184/255, 134/255, 11/255, 1)])
    return v

#
# Code below is modified version of pyMLP, eliminating most the of code
# (unneeded parsing PDB files, writing dx files, ...) and optimizing the calculation speed.
#

class Defaults(object):
    """Constants"""

    def __init__(self):
        self.gridmargin = 10.0
        self.fidatadefault = {                    #Default fi table
 'ALA': {'CB': 0.5595,    #fi : lipophilic atomic potential
         'C': -0.0702,
         'CA': -0.0971,
         'O': 0.0067,
         'N': -0.5549},
 'ARG': {'C': -0.0702,
         'CA': -0.0971,
         'CB':  0.4112,
         'CG':  0.4112,
         'CD':  0.1016,
         'CZ':  0.5442,
         'N': -0.5549,
         'NE': -0.0825,
         'NH1': -0.5055,
         'NH2': -0.5055,
         'O': 0.0067},
 'ASN': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.1248,
         'CG': -0.0702,
         'N': -0.5549,
         'ND2': -0.6285,
         'O': 0.0067,
         'OD1': 0.0067},
 'ASP': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.1248,
         'CG': -0.0702,
         'N': -0.5549,
         'O': 0.0067,
         'OD1': -0.3787,
         'OD2': -0.3787},
 'CYS': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.1016,
         'N': -0.5549,
         'O': 0.0067,
         'SG': 0.5710},
 'GLN': {'C': -0.0702,
         'CA': -0.0971,
         'CB':  0.4112,
         'CG':  0.1248,
         'CD': -0.0702,
         'N': -0.5549,
         'NE2': -0.6285,
         'O': 0.0067,
         'OE1': 0.0067},
 'GLU': {'C': -0.0702,
         'CA': -0.0971,
         'CB':  0.4112,
         'CG':  0.1248,
         'CD': -0.0702,
         'N': -0.5549,
         'O': 0.0067,
         'OE1': -0.3787,
         'OE2': -0.3787},
 'GLY': {'C': -0.0702,
         'CA': -0.1118,
         'O': 0.0067,
         'N': -0.5549},
 'HIS': {'C': -0.0702,
         'CA': -0.0971,
         'CB':  0.1248,
         'CG': 0.2661,
         'CD2': 0.5785,
         'CE1': 0.2043,
         'N': -0.5549,
         'ND1': -0.2060,
         'NE2': -0.2060,
         'O': 0.0067},
 'HYP': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': -0.0096,
         'CD': 0.1016,
         'N': -0.4813,
         'O': 0.0067,
         'OD1': -0.4003},
 'ILE': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.0585,
         'CG1': 0.5462,
         'CG2': 0.7620,
         'CD1': 0.7620,
         'N': -0.5549,
         'O': 0.0067},
 'LEU': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.0660,
         'CD1': 0.7620,
         'CD2': 0.7620,
         'N': -0.5549,
         'O': 0.0067},
 'LYS': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.5462,
         'CD': 0.5462,
         'CE': 0.1016,
         'NZ': -0.7335,
         'N': -0.5549,
         'O': 0.0067},
 'MET': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.1016,
         'CE': 0.2223,
         'N': -0.5549,
         'O': 0.0067,
         'SD': 0.6206},
 'MSE': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.1016,
         'CE': 0.2223,
         'N': -0.5549,
         'O': 0.0067,
         'SE': 0.6901},
 'UNK': {'C': -0.0702,
         'CA': -0.0971,
         'N': -0.5549,
         'O':  0.0067},
 'ACE': {'C': -0.0702,
         'CH3': 0.1299,
         'O': 0.0067},
 'NME': {'N': -0.5549,
         'C': 0.2223},
 'NH2': {'N': -0.6285},
 'PCA': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.1248,
         'CD': -0.0702,
         'N': -0.5549,
         'O': 0.0067,
         'OE': 0.0067},
 'PHE': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.1792,
         'CD1': 0.3650,
         'CD2': 0.3650,
         'CE1': 0.3650,
         'CE2': 0.3650,
         'CZ': 0.3650,
         'N': -0.5549,
         'O': 0.0067},
 'PRO': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.4112,
         'CD': 0.1016,
         'N': -0.4813,
         'O': 0.0067},
 'SER': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.1016,
         'N': -0.5549,
         'O': 0.0067,
         'OG': -0.4003},
 'THR': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.0086,
         'CG2': 0.5595,
         'N': -0.5549,
         'O': 0.0067,
         'OG1': -0.4003},
 'TRP': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.1248,
         'CG': 0.1792,
         'CD1': 0.5185,
         'CD2': 0.1792,
         'CE2': 0.1839,
         'CE3': 0.3650,
         'CH2': 0.3650,
         'CZ2': 0.3650,
         'CZ3': 0.3650,
         'N': -0.5549,
         'NE1': 0.0823,
         'O': 0.0067},
 'TYR': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.4112,
         'CG': 0.1792,
         'CD1': 0.3650,
         'CD2': 0.3650,
         'CE1': 0.3650,
         'CE2': 0.3650,
         'CZ': 0.1839,
         'N': -0.5549,
         'O': 0.0067,
         'OH': -0.0563},
 'VAL': {'C': -0.0702,
         'CA': -0.0971,
         'CB': 0.0585,
         'CG1': 0.7620,
         'CG2': 0.7620,
         'N': -0.5549,
         'O': 0.0067}}

def assignfi(fidata, atoms):
    """assign fi parameters to each atom in the pdbfile"""
    n = len(atoms)
    from numpy import empty, float32
    fi = empty((n,), float32)
    resname = atoms.residues.names
    aname = atoms.names
    for i in range(n):
        rname = resname[i]
        rfidata = fidata.get(rname)
        if rfidata:
            fi[i]=rfidata.get(aname[i], 0)
    return fi

def _griddimcalc(listcoord, spacing, gridmargin):
    """Determination of the grid dimension"""
    coordmin = min(listcoord) - gridmargin
    coordmax = max(listcoord) + gridmargin
    adjustment = ((spacing - (coordmax - coordmin)) % spacing) / 2.
    coordmin = coordmin - adjustment
    coordmax = coordmax + adjustment
    ngrid = int(round((coordmax - coordmin) / spacing))
    return coordmin, coordmax, ngrid

def calculatefimap(atoms, method, spacing, max_dist, nexp):
    """Calculation loop"""

    #grid settings in angstrom
    gridmargin = Defaults().gridmargin
    xyz = atoms.scene_coords
    xmingrid, xmaxgrid, nxgrid = _griddimcalc(xyz[:,0], spacing, gridmargin)
    ymingrid, ymaxgrid, nygrid = _griddimcalc(xyz[:,1], spacing, gridmargin)
    zmingrid, zmaxgrid, nzgrid = _griddimcalc(xyz[:,2], spacing, gridmargin)
    bounds = [[xmingrid, xmaxgrid],
              [ymingrid, ymaxgrid],
              [zmingrid, zmaxgrid]]
    origin = (xmingrid, ymingrid, zmingrid)

    fi_table = Defaults().fidatadefault
    fi = assignfi(fi_table, atoms)

    from numpy import zeros, float32
    pot = zeros((nzgrid+1, nygrid+1, nxgrid+1), float32)
    from ._mlp import mlp_sum
    mlp_sum(xyz, fi, origin, spacing, max_dist, method, nexp, pot)
                 
    return pot, bounds

def mlp_sum(xyz, fi, origin, spacing, max_dist, method, nexp, pot):
    if method == 'dubost':
        computemethod = _dubost
    elif method == 'fauchere':
        computemethod = _fauchere
    elif method == 'brasseur':
        computemethod = _brasseur
    elif method == 'buckingham':
        computemethod = _buckingham
    elif method == 'type5':
        computemethod = _type5
    else:
        raise ValueError('Unknown lipophilicity method %s\n' % computemethod)

    from numpy import zeros, float32, empty, subtract, sqrt
    grid_pt = empty((3,), float32)
    dxyz = empty((len(xyz),3), float32)
    dist = empty((len(xyz),), float32)
    nz,ny,nx = pot.shape
    x0,y0,z0 = origin
    for k in range(nz):
        grid_pt[2] = z0 + k * spacing
        for j in range(ny):
            grid_pt[1] = y0 + j * spacing
            for i in range(nx):
                #Evaluation of the distance between th grid point and each atoms
                grid_pt[0] = x0 + i * spacing
                subtract(xyz, grid_pt, dxyz)
                dxyz *= dxyz
                dist = dxyz[:,0]
                dist += dxyz[:,1]
                dist += dxyz[:,2]
                sqrt(dist, dist)
                pot[k,j,i] = computemethod(fi, dist, nexp)

def _dubost(fi, d, n):
    return (100 * fi / (1 + d)).sum()

def _fauchere(fi, d, n):
    from numpy import exp
    return (100 * fi * exp(-d)).sum()

def _brasseur(fi, d, n):
    #3.1 division is there to remove any units in the equation
    #3.1A is the average diameter of a water molecule (2.82 -> 3.2)
    from numpy import exp
    return (100 * fi * exp(-d/3.1)).sum()

def _buckingham(fi, d, n):
    return (100 * fi / (d**n)).sum()

def _type5(fi, d, n):
    from numpy import exp, sqrt
    return (100 * fi * exp(-sqrt(d))).sum()
