# Plot per-residue DMS scores for several assays on the same plot where the
# horizontal axis is residue number.  For each score sum the DMS scores 2 SD above wild type
# for one trace, and 2 SD below wild type for another trace.
# This is a prototype.  If it looks useful the plot could be added to ChimeraX where
# mousing over or clicking could highlight or color associated residues on a structure.
#
# Open this script in ChimeraX after opening the DMS data to make the plot.

def create_trace_plot(session, mutation_set_name = None, sdev = 2):
    from chimerax.mutation_scores.ms_data import mutation_scores
    mset = mutation_scores(session, mutation_set_name)
    res_types = {}	# Map residue number to amino acid 1-letter code for wild-type.
    score_sums = {score_name[7:]: compute_score_sums(mset, score_name, res_types, sdev)
                  for score_name in mset.score_names() if score_name.startswith('effect_')}
    tplot = TracePlot(score_sums, res_types)
    return tplot
    
def compute_score_sums(mset, score_name, res_types, sdev):
    score_values = mset.score_values(score_name)
    syn_values = [value for res_num, from_aa, to_aa, value in score_values.all_values() if to_aa == from_aa]
    from numpy import mean, std
    m, d = mean(syn_values), std(syn_values)
    smin, smax = m - sdev*d, m + sdev*d
    rv = score_values.values_by_residue_number  # Maps res number to list of (from_aa, to_aa, value).
    psum, nsum = {}, {}
    for res_num, values in rv.items():
        for from_aa, to_aa, value in values:
            if value >= smax:
                psum[res_num] = psum.get(res_num, 0) + value
            elif value <= smin:
                nsum[res_num] = nsum.get(res_num, 0) + value
            res_types[res_num] = from_aa
    return psum, nsum

from chimerax.interfaces.graph import Graph
class TracePlot(Graph):
    def __init__(self, score_sums, res_types):
        nodes = edges = []
        Graph.__init__(self, session, nodes, edges,
                       tool_name = 'Mutation score traces', title = 'Mutation score traces',
                       hide_ticks = True, zoom_axes = 'x', translate_axes = 'x')
        self.tool_window.fill_context_menu = None
        self._draw_plot(score_sums, res_types)

    def _draw_plot(self, score_sums, res_types):
        res_nums = list(res_types.keys())
        res_nums.sort()
        nres = len(res_nums)
        res_names = [res_types[res_num] for res_num in res_nums]
        res_names = [i for i,rn in enumerate(res_names)]
        fig = self.figure
        axes = fig.subplots(nrows=len(score_sums), ncols=1, sharex=True)
        fig.suptitle('Per-residue sum of scores above (red) and below (blue) wild-type', fontsize=12, y=0.95)
        fig.subplots_adjust(right = .96)
        self.axes.get_xaxis().set_visible(False)	# Hide ticks and axis scale
        self.axes.get_yaxis().set_visible(False)
        c = 0
        for score_name, (psum, nsum) in score_sums.items():
            ax = axes[c]
            ax.set_frame_on(False)
            ax.set_xlim(0, nres)
            if c == len(score_sums)-1:
                ax.set_xlabel('Residue number')
            else:
                ax.get_xaxis().set_visible(False)	# Hide ticks and axis scales
            ax.set_ylabel(score_name, rotation=0, loc='top', labelpad=5)
            ax.yaxis.label.set_position((0, .3))
            ax.tick_params(axis='y', which='both', left=False, labelleft=False)
            heights = [psum.get(res_num, 0) for res_num in res_nums]
            ax.bar(res_names, heights, color = 'red', width = 1)
            heights = [nsum.get(res_num, 0) for res_num in res_nums]
            ax.bar(res_names, heights, color = 'blue', width = 1)
            c += 1
        self.canvas.draw()

create_trace_plot(session)
