Ticket #6469: model_1_error

File model_1_error, 9.1 KB (added by davis797@…, 4 years ago)

Added by email2trac

Line 
1Traceback (most recent call last):
2 File "<ipython-input-1-dc30e8128641>", line 577, in <module>
3 run_prediction(sequence)
4 File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
5 predict_structure(sequence, alignments, deletions, model_name, output_dir)
6 File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
7 prediction_result = model_runner.predict(processed_feature_dict)
8 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
9 result = self.apply(self.params, jax.random.PRNGKey(0), feat)
10 File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
11 return fun(*args, **kwargs)
12 File "/usr/local/lib/python3.7/dist-packages/jax/_src/api.py", line 435, in cache_miss
13 assert len(avals) == len(out_flat)
14 File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1709, in bind
15 class _TempAxisName:
16 File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1721, in call_bind
17 return type(other) is _TempAxisName and self.id == other.id
18 File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 614, in process_call
19 process_map = process_call
20 File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
21 File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 272, in memoized_fun
22 thread_local.most_recent_entry = None
23 File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
24 File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
25 File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 198, in lower_xla_callable
26 File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
27 File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1798, in trace_to_jaxpr_final
28 File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
29 File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
30 ans = self.f(*args, **dict(self.params, **kwargs))
31 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
32 out, state = f.apply(params, {}, *args, **kwargs)
33 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
34 out = f(*args, **kwargs)
35 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
36 ensemble_representations=True)
37 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
38 out = f(*args, **kwargs)
39 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
40 return bound_method(*args, **kwargs)
41 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
42 (0, prev))
43 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
44 val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
45 File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
46 return fun(*args, **kwargs)
47 File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 312, in while_loop
48 outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
49 File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 296, in _create_jaxpr
50 # the case when init contains weakly-typed values (e.g. Python scalars), with avals that
51 File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
52 return out
53 File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
54 wrapper.cache_info = memoized.cache_info
55 File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 84, in _initial_style_jaxpr
56 # When staging the branches of a conditional into jaxprs, constants are
57 File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
58 return out
59 File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
60 wrapper.cache_info = memoized.cache_info
61 File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 77, in _initial_style_open_jaxpr
62 fun, in_tree, in_avals, primitive_name)
63 File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
64 File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1739, in trace_to_jaxpr_dynamic
65 File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
66 File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
67 ans = self.f(*args, **dict(self.params, **kwargs))
68 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
69 val = body_fun(val)
70 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
71 compute_loss=False)))
72 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
73 ensemble_representations=ensemble_representations)
74 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
75 out = f(*args, **kwargs)
76 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
77 return bound_method(*args, **kwargs)
78 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
79 representations = evoformer_module(batch0, is_training)
80 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
81 out = f(*args, **kwargs)
82 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
83 return bound_method(*args, **kwargs)
84 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
85 msa_activations = jax.ops.index_add(msa_activations, 0,
86jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'jax.ops' has no attribute 'index_add'
87
88The stack trace below excludes JAX-internal frames.
89The preceding is the original exception that occurred, unmodified.
90
91--------------------
92
93The above exception was the direct cause of the following exception:
94
95Traceback (most recent call last):
96 File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
97 predict_structure(sequence, alignments, deletions, model_name, output_dir)
98 File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
99 prediction_result = model_runner.predict(processed_feature_dict)
100 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
101 result = self.apply(self.params, jax.random.PRNGKey(0), feat)
102 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
103 out, state = f.apply(params, {}, *args, **kwargs)
104 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
105 out = f(*args, **kwargs)
106 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
107 ensemble_representations=True)
108 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
109 out = f(*args, **kwargs)
110 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
111 return bound_method(*args, **kwargs)
112 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
113 (0, prev))
114 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
115 val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
116 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
117 val = body_fun(val)
118 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
119 compute_loss=False)))
120 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
121 ensemble_representations=ensemble_representations)
122 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
123 out = f(*args, **kwargs)
124 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
125 return bound_method(*args, **kwargs)
126 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
127 representations = evoformer_module(batch0, is_training)
128 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
129 out = f(*args, **kwargs)
130 File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
131 return bound_method(*args, **kwargs)
132 File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
133 msa_activations = jax.ops.index_add(msa_activations, 0,
134AttributeError: module 'jax.ops' has no attribute 'index_add'