| 1 | # Plot per-residue DMS scores for several assays on the same plot where the
|
|---|
| 2 | # horizontal axis is residue number. For each score sum the DMS scores 2 SD above wild type
|
|---|
| 3 | # for one trace, and 2 SD below wild type for another trace.
|
|---|
| 4 | # This is a prototype. If it looks useful the plot could be added to ChimeraX where
|
|---|
| 5 | # mousing over or clicking could highlight or color associated residues on a structure.
|
|---|
| 6 | #
|
|---|
| 7 | # Open this script in ChimeraX after opening the DMS data to make the plot.
|
|---|
| 8 |
|
|---|
| 9 | def create_trace_plot(session, mutation_set_name = None, sdev = 2):
|
|---|
| 10 | from chimerax.mutation_scores.ms_data import mutation_scores
|
|---|
| 11 | mset = mutation_scores(session, mutation_set_name)
|
|---|
| 12 | res_types = {} # Map residue number to amino acid 1-letter code for wild-type.
|
|---|
| 13 | score_sums = {score_name[7:]: compute_score_sums(mset, score_name, res_types, sdev)
|
|---|
| 14 | for score_name in mset.score_names() if score_name.startswith('effect_')}
|
|---|
| 15 | tplot = TracePlot(score_sums, res_types)
|
|---|
| 16 | return tplot
|
|---|
| 17 |
|
|---|
| 18 | def compute_score_sums(mset, score_name, res_types, sdev):
|
|---|
| 19 | score_values = mset.score_values(score_name)
|
|---|
| 20 | syn_values = [value for res_num, from_aa, to_aa, value in score_values.all_values() if to_aa == from_aa]
|
|---|
| 21 | from numpy import mean, std
|
|---|
| 22 | m, d = mean(syn_values), std(syn_values)
|
|---|
| 23 | smin, smax = m - sdev*d, m + sdev*d
|
|---|
| 24 | rv = score_values.values_by_residue_number # Maps res number to list of (from_aa, to_aa, value).
|
|---|
| 25 | psum, nsum = {}, {}
|
|---|
| 26 | for res_num, values in rv.items():
|
|---|
| 27 | for from_aa, to_aa, value in values:
|
|---|
| 28 | if value >= smax:
|
|---|
| 29 | psum[res_num] = psum.get(res_num, 0) + value
|
|---|
| 30 | elif value <= smin:
|
|---|
| 31 | nsum[res_num] = nsum.get(res_num, 0) + value
|
|---|
| 32 | res_types[res_num] = from_aa
|
|---|
| 33 | return psum, nsum
|
|---|
| 34 |
|
|---|
| 35 | from chimerax.interfaces.graph import Graph
|
|---|
| 36 | class TracePlot(Graph):
|
|---|
| 37 | def __init__(self, score_sums, res_types):
|
|---|
| 38 | nodes = edges = []
|
|---|
| 39 | Graph.__init__(self, session, nodes, edges,
|
|---|
| 40 | tool_name = 'Mutation score traces', title = 'Mutation score traces',
|
|---|
| 41 | hide_ticks = True, zoom_axes = 'x', translate_axes = 'x')
|
|---|
| 42 | self.tool_window.fill_context_menu = None
|
|---|
| 43 | self._draw_plot(score_sums, res_types)
|
|---|
| 44 |
|
|---|
| 45 | def _draw_plot(self, score_sums, res_types):
|
|---|
| 46 | res_nums = list(res_types.keys())
|
|---|
| 47 | res_nums.sort()
|
|---|
| 48 | nres = len(res_nums)
|
|---|
| 49 | res_names = [res_types[res_num] for res_num in res_nums]
|
|---|
| 50 | res_names = [i for i,rn in enumerate(res_names)]
|
|---|
| 51 | fig = self.figure
|
|---|
| 52 | axes = fig.subplots(nrows=len(score_sums), ncols=1, sharex=True)
|
|---|
| 53 | fig.suptitle('Per-residue sum of scores above (red) and below (blue) wild-type', fontsize=12, y=0.95)
|
|---|
| 54 | fig.subplots_adjust(right = .96)
|
|---|
| 55 | self.axes.get_xaxis().set_visible(False) # Hide ticks and axis scale
|
|---|
| 56 | self.axes.get_yaxis().set_visible(False)
|
|---|
| 57 | c = 0
|
|---|
| 58 | for score_name, (psum, nsum) in score_sums.items():
|
|---|
| 59 | ax = axes[c]
|
|---|
| 60 | ax.set_frame_on(False)
|
|---|
| 61 | ax.set_xlim(0, nres)
|
|---|
| 62 | if c == len(score_sums)-1:
|
|---|
| 63 | ax.set_xlabel('Residue number')
|
|---|
| 64 | else:
|
|---|
| 65 | ax.get_xaxis().set_visible(False) # Hide ticks and axis scales
|
|---|
| 66 | ax.set_ylabel(score_name, rotation=0, loc='top', labelpad=5)
|
|---|
| 67 | ax.yaxis.label.set_position((0, .3))
|
|---|
| 68 | ax.tick_params(axis='y', which='both', left=False, labelleft=False)
|
|---|
| 69 | heights = [psum.get(res_num, 0) for res_num in res_nums]
|
|---|
| 70 | ax.bar(res_names, heights, color = 'red', width = 1)
|
|---|
| 71 | heights = [nsum.get(res_num, 0) for res_num in res_nums]
|
|---|
| 72 | ax.bar(res_names, heights, color = 'blue', width = 1)
|
|---|
| 73 | c += 1
|
|---|
| 74 | self.canvas.draw()
|
|---|
| 75 |
|
|---|
| 76 | create_trace_plot(session)
|
|---|