Ticket #16088: modelcif_pae.py

File modelcif_pae.py, 10.7 KB (added by Tom Goddard, 13 months ago)

Contributed code from Gerardo Tauriello

Line 
1# Read pairwise residue scores from a ModelCIF file and plot them in ChimeraX.
2
3def modelcif_pae(session, structure, json_output_path = None, metric_id = None, default_score = 100):
4
5 matrix = read_pairwise_scores(structure, metric_id = metric_id, default_score = default_score)
6
7 if json_output_path is None:
8 import tempfile
9 temp = tempfile.NamedTemporaryFile(prefix = 'modelcif_pae_', suffix = '.json')
10 json_output_path = temp.name
11
12 write_json_pae_file(json_output_path, matrix)
13
14 # Open PAE plot
15 from chimerax.core.commands import run, quote_if_necessary
16 open_cmd = f'alphafold pae #{structure.id_string} file {quote_if_necessary(json_output_path)}'
17 run(session, open_cmd)
18
19def read_pairwise_scores(structure, metric_id = None, default_score = 100):
20 if not hasattr(structure, 'filename'):
21 from chimerax.core.errors import UserError
22 raise UserError(f'Structure {structure} has no associated file')
23
24 # fetch data from ModelCIF
25 values, metrics = read_ma_qa_metric_local_pairwise_table(structure.filename)
26 if values is None:
27 from chimerax.core.errors import UserError
28 raise UserError(f'Structure file {structure.filename} contains no pairwise residue scores (i.e. no table "ma_qa_metric_local_pairwise")')
29
30 # use only the scores with the given metric id.
31 metrics_dict = {
32 met_id: (met_name, met_type) \
33 for met_id, met_name, met_type in metrics
34 }
35 if metric_id is None and len(metrics) > 0:
36 # look for PAE type
37 for met_id, _, met_type in metrics:
38 if met_type == "PAE":
39 metric_id = met_id
40 break
41 # fall back to whatever is the first one
42 metric_id = metrics[0][0]
43 if metric_id in metrics_dict:
44 met_name, met_type = metrics_dict[metric_id]
45 session.logger.info(
46 f"Displaying local-pairwise metric with ID {metric_id} " \
47 f"with type '{met_type}' and name '{met_name}'"
48 )
49 else:
50 from chimerax.core.errors import UserError
51 raise UserError(f'Structure file {structure.filename} has no metric for metric id "{metric_id}"')
52 values = [v for v in values if v[5] == metric_id]
53 if len(values) == 0:
54 from chimerax.core.errors import UserError
55 raise UserError(f'Structure file {structure.filename} has no scores for metric id "{metric_id}"')
56
57 # fill matrix
58 matrix_index = {(r.chain_id,r.number):ri for ri,r in enumerate(structure.residues)}
59
60 nr = structure.num_residues
61 from numpy import empty, float32
62 matrix = empty((nr,nr), float32)
63 matrix[:] = default_score
64
65 for model_id, chain_id_1, res_num_1, chain_id_2, res_num_2, metric_id, metric_value in values:
66 res_num_1, res_num_2, metric_value = int(res_num_1), int(res_num_2), float(metric_value)
67 r1 = matrix_index[(chain_id_1, res_num_1)]
68 r2 = matrix_index[(chain_id_2, res_num_2)]
69 matrix[r1,r2] = metric_value
70
71 return matrix
72
73def read_ma_qa_metric_local_pairwise_table(path):
74 """Get relevant data from ModelCIF file.
75 Returns tuple (values, metrics) with
76 - values = list of pairwise metric values stored as tuple
77 (model_id, chain_id_1, res_num_1, chain_id_2, res_num_2, metric_id, metric_value)
78 - metrics = list of available pairwise metrics stored as tuple
79 (metric_id, metric_name, metric_type)
80 """
81 from chimerax.mmcif import get_cif_tables
82 table_names = [
83 'ma_qa_metric_local_pairwise',
84 'ma_entry_associated_files',
85 'ma_associated_archive_file_details',
86 'ma_qa_metric'
87 ]
88 try:
89 tables = get_cif_tables(path, table_names)
90 except TypeError:
91 # Bug in get_cif_tables() results in TypeError if table not present. Ticket #16054
92 return None, None
93 if len(tables) != 4:
94 return None, None
95
96 # check different ways of storing QE
97 values = []
98 ma_qa_metric_local_pairwise = tables[0]
99 ma_entry_associated_files = tables[1]
100 ma_associated_archive_file_details = tables[2]
101 ma_qa_metric = tables[3]
102
103 # get info on available pairwise metrics
104 if ma_qa_metric.num_rows() > 0:
105 field_names = ['id', 'mode', 'name', 'type']
106 all_metrics = ma_qa_metric.fields(field_names)
107 metrics = [
108 (metric_id, metric_name, metric_type) \
109 for metric_id, metric_mode, metric_name, metric_type in all_metrics \
110 if metric_mode == "local-pairwise"
111 ]
112 # for metric_id, metric_name, metric_type in metrics:
113 # session.logger.info(
114 # f"Available local-pairwise metric with ID {metric_id} " \
115 # f"with type '{metric_type}' and name '{metric_name}'"
116 # )
117 else:
118 # no metrics here
119 metrics = []
120
121 # option 1: it's directly in the file
122 if ma_qa_metric_local_pairwise.num_rows() > 0:
123 field_names = ['model_id',
124 'label_asym_id_1', 'label_seq_id_1',
125 'label_asym_id_2', 'label_seq_id_2',
126 'metric_id', 'metric_value']
127 values.extend(ma_qa_metric_local_pairwise.fields(field_names))
128
129 # option 2: it's in ma_entry_associated_files
130 associated_files = []
131 if ma_entry_associated_files.num_rows() > 0 \
132 and ma_entry_associated_files.has_field("file_content"):
133 field_names = ['id', 'file_url', 'file_content']
134 associated_files = ma_entry_associated_files.fields(field_names)
135 qa_files_to_load = [] # do it later
136 zip_files = {}
137 for file_id, file_url, file_content in associated_files:
138 if file_content == "local pairwise QA scores":
139 tmp_file_path, to_delete = fetch_file_url(file_url, path, "_tmp.cif")
140 if tmp_file_path is not None:
141 qa_files_to_load.append((tmp_file_path, to_delete))
142 elif file_content == "archive with multiple files":
143 zip_files[file_id] = file_url
144
145 # option 3: it's listed in ma_associated_archive_file_details
146 associated_qa_files = []
147 if ma_associated_archive_file_details.num_rows() > 0 \
148 and ma_associated_archive_file_details.has_field("file_content"):
149 field_names = ['archive_file_id', 'file_path', 'file_content']
150 row_fields = ma_associated_archive_file_details.fields(field_names)
151 for archive_file_id, file_path, file_content in row_fields:
152 if file_content == "local pairwise QA scores":
153 if archive_file_id in zip_files:
154 associated_qa_files.append((zip_files[archive_file_id], file_path))
155 else:
156 session.logger.warning(f'Structure file {path} has faulty archive_file_id for {file_path}.')
157 for zip_file_url, file_path in associated_qa_files:
158 import os
159 tmp_zip_file_path, to_delete = fetch_file_url(zip_file_url, path, "_tmp.zip")
160 if tmp_zip_file_path is not None:
161 import zipfile
162 import tempfile
163 with zipfile.ZipFile(tmp_zip_file_path, 'r') as zip_ref:
164 if file_path in zip_ref.namelist():
165 temp_file = tempfile.NamedTemporaryFile(suffix="_tmp.cif", delete=False)
166 with temp_file as f:
167 f.write(zip_ref.read(file_path))
168 qa_files_to_load.append((temp_file.name, True))
169 else:
170 session.logger.warning(f'Could not find {file_path} in ZIP file from {zip_file_url}.')
171 if to_delete:
172 os.remove(tmp_zip_file_path)
173
174 # load from temp. files
175 for tmp_file_path, to_delete in qa_files_to_load:
176 import os
177 new_values, _ = read_ma_qa_metric_local_pairwise_table(tmp_file_path)
178 if new_values is not None:
179 values.extend(new_values)
180 if to_delete:
181 os.remove(tmp_file_path)
182
183 return values, metrics
184
185def fetch_file_url(file_url, structure_path, save_name):
186 """Path can be local file or file from web.
187 Returns path to file and flag if file is downloaded temporary file."""
188 import os
189 dir_name = os.path.dirname(structure_path)
190 file_path = os.path.join(dir_name, file_url)
191 if os.path.exists(file_path):
192 return file_path, False
193 else:
194 # try to get from the web
195 from chimerax.core.fetch import fetch_file
196 try:
197 tmp_file_path = fetch_file(
198 session, file_url,
199 name=f'remote associated file for {structure_path}',
200 save_name=save_name, save_dir=None
201 )
202 return tmp_file_path, True
203 except:
204 session.logger.warning(f"Failed to load {file_url} for {structure_path}")
205 return None, False
206
207def write_json_pae_file(json_output_path, matrix):
208 # Write matrix in JSON AlphaFold PAE format
209 # {"pae": [[17.14, 18.75, 17.91, ...], [5.32, 8.23, ...], ... ]}
210 n = matrix.shape[0]
211 dists = ', '.join(('[ ' + ', '.join('%.2f' % matrix[i,j] for j in range(n)) + ' ]')
212 for i in range(n))
213 with open(json_output_path, 'w') as file:
214 file.write('{"pae": [')
215 file.write(dists)
216 file.write(']}')
217
218def modelarchive_open(session, ma_id):
219 # fetch file
220 from chimerax.core.fetch import fetch_file
221 session.logger.info(f"Fetching {ma_id} from ModelArchive...")
222 file_url = f"https://www.modelarchive.org/doi/10.5452/{ma_id}.cif"
223 tmp_file_path = fetch_file(
224 session, file_url, name=f'MA {ma_id}',
225 save_name=f"{ma_id}.cif", save_dir="ModelArchive"
226 )
227 # open file
228 from chimerax.core.commands import run, quote_if_necessary
229 from chimerax.core.errors import UserError
230 open_cmd = f'open {quote_if_necessary(tmp_file_path)}'
231 model = run(session, open_cmd)[0]
232 model_id = model.id_string
233 # try to add PAE
234 try:
235 open_cmd = f'modelcif pae #{model_id}'
236 run(session, open_cmd)
237 #modelcif_pae(session, structure)
238 except UserError:
239 # ok for it not to have PAE...
240 pass
241
242def register_command(logger):
243 from chimerax.core.commands import CmdDesc, register, StringArg, FloatArg, SaveFileNameArg
244 from chimerax.atomic import StructureArg
245 desc = CmdDesc(
246 required = [('structure', StructureArg)],
247 keyword = [('metric_id', StringArg),
248 ('default_score', FloatArg),
249 ('json_output_path', SaveFileNameArg)],
250 synopsis = 'Plot ModelCIF pairwise residue scores'
251 )
252 register('modelcif pae', desc, modelcif_pae, logger=logger)
253 # TEST: register extra command to remotely load MA file
254 desc = CmdDesc(
255 required = [('ma_id', StringArg)],
256 synopsis = 'Load ModelArchive file with given ID (incl. PAE if available)'
257 )
258 register('modelarchive open', desc, modelarchive_open, logger=logger)
259 #
260
261register_command(session.logger)