def new_residue_from_template(model, template, chain_id, center,
        residue_number=None, insert_code=' ', b_factor=50, precedes=None):
    '''
    Create a new residue based on a template, and add it to the model.
    '''
    if residue_number is None:
        if chain_id in model.residues.chain_ids:
            residue_number = suggest_new_residue_number_for_ligand(model, chain_id)
        else:
            residue_number = 0
    import numpy
    from chimerax.atomic import Atom
    t_coords = numpy.array([a.coord for a in template.atoms])
    t_center = t_coords.mean(axis=0)
    t_coords += numpy.array(center) - t_center
    tatom_to_atom = {}
    r = model.new_residue(template.name, chain_id, residue_number,
        insert=insert_code, precedes=precedes)
    from chimerax.atomic.struct_edit import add_bond
    for i, ta in enumerate(template.atoms):
        a = tatom_to_atom[ta] = model.new_atom(ta.name, ta.element)
        a.coord = t_coords[i]
        a.bfactor = b_factor
        r.add_atom(a)
        for tn in ta.neighbors:
            n = tatom_to_atom.get(tn, None)
            if n is not None:
                add_bond(a, n)
    return r


def add_amino_acid_residue(model, resname, prev_res=None, next_res=None,
        chain_id=None, number=None, center=None, insertion_code=' ', b_factor=50,
        occupancy=1, phi=-135, psi=135):
    session = model.session
    if (not chain_id or not number or not center) and (not prev_res and not next_res):
        raise TypeError('If no anchor residues are specified, chain ID, '
            'number and center must be provided!')
    if prev_res and next_res:
        raise TypeError('Cannot specify both previous and next residues!')
    other_atom = None
    insertion_point = None

    if prev_res:
        pri = model.residues.index(prev_res)
        if pri > 0 and pri < len(model.residues)-1:
            insertion_point = model.residues[pri+1]
        catom = prev_res.find_atom('C')
        for n in catom.neighbors:
            if n.residue != prev_res:
                raise TypeError('This residue already has another bonded to its '
                    'C terminus!')
        chain_id = prev_res.chain_id
        oxt = prev_res.find_atom('OXT')
        if oxt is not None:
            oxt.delete()
    elif next_res:
        insertion_point = next_res
        natom = next_res.find_atom('N')
        for n in natom.neighbors:
            if n.residue != next_res:
                raise TypeError('This residue already has another bonded to its '
                    'N terminus!')
        chain_id = next_res.chain_id
        for hname in ('H2', 'H3'):
            h = next_res.find_atom(hname)
            if h is not None:
                h.delete()
            h = next_res.find_atom('H1')
            if h is not None:
                h.name='H'
            if next_res.name == 'PRO':
                h = next_res.find_atom('H')
                if h:
                    h.delete()
    if number is None:
        if prev_res:
            number = prev_res.number + 1
        elif next_res:
            number = next_res.number - 1

    from chimerax import mmcif
    tmpl = mmcif.find_template_residue(session, resname)
    import numpy
    # delete extraneous atoms
    r = new_residue_from_template(model, tmpl, chain_id, [0,0,0], number,
            insert_code=insertion_code, b_factor=b_factor, precedes=insertion_point)
    r.atoms[numpy.in1d(r.atoms.names, ['OXT', 'HXT', 'H2', 'H1', 'HN1', 'HN2'])].delete()

    # Translate and rotate residue to (roughly) match the desired position
    if not next_res and not prev_res:
        r.atoms.coords += numpy.array(center) - r.atoms.coords.mean(axis=0)
    else:
        from chimerax.geometry import align_points
        from chimerax.atomic.struct_edit import add_bond
        if prev_res:
            add_bond(r.find_atom('N'), prev_res.find_atom('C'))
            n_pos = _find_next_N_position(prev_res)
            ca_pos = _find_next_CA_position(n_pos, prev_res)
            c_pos = _find_next_C_position(ca_pos, n_pos, prev_res, phi)
            target_coords = numpy.array([n_pos, ca_pos, c_pos])
            align_coords = numpy.array([r.find_atom(a).coord for a in ['N', 'CA', 'C']])
        elif next_res:
            add_bond(r.find_atom('C'), next_res.find_atom('N'))
            c_pos = _find_prev_C_position(next_res, psi)
            ca_pos = _find_prev_CA_position(c_pos, next_res)
            o_pos = _find_prev_O_position(c_pos, next_res)
            target_coords = numpy.array([c_pos, ca_pos, o_pos])
            align_coords = numpy.array([r.find_atom(a).coord for a in ['C', 'CA', 'O']])

        tf = align_points(align_coords, target_coords)[0]
        r.atoms.coords = tf*r.atoms.coords
    if r.name in ('GLU', 'ASP'):
        fix_amino_acid_protonation_state(r)
    if r.name == 'PRO':
        r.atoms[r.atoms.names=='H'].delete()

    r.atoms.bfactors = b_factor
    r.atoms.occupancies = occupancy

    return r



def _find_next_N_position(prev_res):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.34
    angle = 120
    dihedral = 180
    a1 = prev_res.find_atom('C')
    a2 = prev_res.find_atom('CA')
    a3 = prev_res.find_atom('O')
    return find_pt(*[a.coord for a in [a1, a2, a3]], bond_length, angle, dihedral)

def _find_next_CA_position(n_pos, prev_res):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.48
    angle = 124
    omega = 180
    c = prev_res.find_atom('C')
    ca = prev_res.find_atom('CA')
    return find_pt(n_pos, *[a.coord for a in [c, ca]], bond_length, angle, omega)

def _find_next_C_position(ca_pos, n_pos, prev_res, phi):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.53
    angle = 120
    c = prev_res.find_atom('C')
    return find_pt(ca_pos, n_pos, c.coord, bond_length, angle, phi)


def _find_prev_C_position(next_res, psi):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.34
    angle = 120
    a1 = next_res.find_atom('N')
    a2 = next_res.find_atom('CA')
    a3 = next_res.find_atom('C')
    return find_pt(*[a.coord for a in [a1, a2, a3]], bond_length, angle, psi)

def _find_prev_CA_position(c_pos, next_res):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.53
    angle = 120
    omega = 180
    n = next_res.find_atom('N')
    ca = next_res.find_atom('CA')
    return find_pt(c_pos, *[a.coord for a in [n, ca]], bond_length, angle, omega)

def _find_prev_O_position(c_pos, next_res):
    from chimerax.atomic.struct_edit import find_pt
    bond_length = 1.22
    angle = 120
    dihedral = 0
    n = next_res.find_atom('N')
    ca = next_res.find_atom('CA')
    return find_pt(c_pos, *[a.coord for a in (n, ca)], bond_length, angle, dihedral)

def add_aa_cmd(session, residue, resname):
    from chimerax.core.errors import UserError
    if len(residue) != 1:
        raise UserError('Please select a single residue!')
    residue = residue[0]
    from chimerax.atomic import Residue
    if residue.polymer_type != Residue.PT_AMINO:
        raise UserError('Selection must be an amino acid from a protein chain!')
    aa_neighbors = []
    for n in residue.neighbors:
        if n.polymer_type == Residue.PT_AMINO:
            is_peptide_bond = False
            for b in residue.bonds_between(n):
                if all([a.element.name in ('N','C') for a in b.atoms]):
                    is_peptide_bond = True
            if is_peptide_bond:
                aa_neighbors.append(n)
    if len(aa_neighbors) != 1:
        raise UserError('Selection is not a terminal residue!')
    n = aa_neighbors[0]
    if (n.number < residue.number or 
        (n.number == residue.number and n.insertion_code < residue.insertion_code)):
        prev_res = residue
        next_res = None
    else:
        next_res = residue
        prev_res = None
    add_amino_acid_residue(residue.structure, resname, prev_res=prev_res, next_res=next_res)

from chimerax.core.commands import (
    register, CmdDesc, 
    StringArg
)
from chimerax.atomic import ResiduesArg

desc = CmdDesc(
    required=[
        ('residue', ResiduesArg),
        ('resname', StringArg) 
    ]
)
register('addaa', desc, add_aa_cmd, logger=session.logger)
