# Test whether add_coordset() assigns the correct atom coordinates.

import numpy as np

from chimerax.atomic import AtomicStructure

elements = ["O", "C", "H", "H"]
atom_names = ["O1", "C1", "H1", "H2"]

coords_array = np.array([[0.00000,    0.00000,    0.68769], \
                         [0.00000,    0.00000,   -0.53966], \
                         [0.00000,    0.93947,   -1.13180], \
                         [0.00000,   -0.93947,   -1.13180]])

trans_coords = coords_array + 1

coordsets_array = np.array([coords_array, trans_coords])

mdl = AtomicStructure(session, name="test")
res = mdl.new_residue("res", "a", 1)

for i, (ele, name) in enumerate(zip(elements, atom_names)):
    atom = mdl.new_atom(name, ele)
    atom.coord = coords_array[i]
    res.add_atom(atom)

#this works
mdl.add_coordsets(coordsets_array, replace=True)
for cid in mdl.coordset_ids:
    print("coordset", cid)
    print(mdl.coordset(cid).xyzs)
    print("atom coords:")
    mdl.active_coordset_id = cid
    for atom in mdl.atoms:
        print(atom.coord)
 
#the issue happens when removing and adding atoms    
mdl.active_coordset_id = 1
new_atom = mdl.new_atom(atom_names[-1], elements[-1])
res.add_atom(atom)
new_atom.coord = coords_array[3]

mdl.atoms[3].delete()

print("replaced last atom")
mdl.add_coordsets(coordsets_array, replace=True)
for cid in mdl.coordset_ids:
    print("coordset", cid)
    print(mdl.coordset(cid).xyzs)
    print("atom coords:")
    mdl.active_coordset_id = cid
    for atom in mdl.atoms:
        print(atom.coord)

#XXX these values differ
print("coordinates of last atom from atom attribute:", mdl.atoms[-1].coord)
print("coordinates of last atom from coordset:", mdl.coordset(mdl.active_coordset_id).xyzs[-1])

#make a new coordset array with a position for the deleted atom
coordset_indices = mdl.atoms.coord_indices

new_coordsets_array = np.zeros((2, 5, 3))
for old_coordset, new_coordset in zip(coordsets_array, new_coordsets_array):
    new_coordset[coordset_indices] = old_coordset
    
print("trying to add these coordsets with an index for the deleted atom:")
for i, coordset in enumerate(new_coordsets_array):
    print(i+1)
    print(coordset)

#XXX mdl.add_coordsets is not allowed because my coordsets should be 5 long and mdl.num_atoms == 4
#mdl.add_coordsets(new_coordsets_array, replace=True)

mdl.remove_coordsets()
for i, coordset in enumerate(new_coordsets_array):
    mdl.add_coordset(i+1, coordset)
    
print("coordsets after coordset change:")
for cid in mdl.coordset_ids:
    print("coordset", cid)
    print(mdl.coordset(cid).xyzs)
    print("atom coords:")
    mdl.active_coordset_id = cid
    for atom in mdl.atoms:
        print(atom.coord)
