| 1 | Traceback (most recent call last):
|
|---|
| 2 | File "<ipython-input-1-dc30e8128641>", line 577, in <module>
|
|---|
| 3 | run_prediction(sequence)
|
|---|
| 4 | File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
|
|---|
| 5 | predict_structure(sequence, alignments, deletions, model_name, output_dir)
|
|---|
| 6 | File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
|
|---|
| 7 | prediction_result = model_runner.predict(processed_feature_dict)
|
|---|
| 8 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
|
|---|
| 9 | result = self.apply(self.params, jax.random.PRNGKey(0), feat)
|
|---|
| 10 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
|
|---|
| 11 | return fun(*args, **kwargs)
|
|---|
| 12 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/api.py", line 435, in cache_miss
|
|---|
| 13 | assert len(avals) == len(out_flat)
|
|---|
| 14 | File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1709, in bind
|
|---|
| 15 | class _TempAxisName:
|
|---|
| 16 | File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1721, in call_bind
|
|---|
| 17 | return type(other) is _TempAxisName and self.id == other.id
|
|---|
| 18 | File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 614, in process_call
|
|---|
| 19 | process_map = process_call
|
|---|
| 20 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
|
|---|
| 21 | File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 272, in memoized_fun
|
|---|
| 22 | thread_local.most_recent_entry = None
|
|---|
| 23 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
|
|---|
| 24 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
|
|---|
| 25 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 198, in lower_xla_callable
|
|---|
| 26 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
|
|---|
| 27 | File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1798, in trace_to_jaxpr_final
|
|---|
| 28 | File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
|
|---|
| 29 | File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
|
|---|
| 30 | ans = self.f(*args, **dict(self.params, **kwargs))
|
|---|
| 31 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
|
|---|
| 32 | out, state = f.apply(params, {}, *args, **kwargs)
|
|---|
| 33 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
|
|---|
| 34 | out = f(*args, **kwargs)
|
|---|
| 35 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
|
|---|
| 36 | ensemble_representations=True)
|
|---|
| 37 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 38 | out = f(*args, **kwargs)
|
|---|
| 39 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 40 | return bound_method(*args, **kwargs)
|
|---|
| 41 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
|
|---|
| 42 | (0, prev))
|
|---|
| 43 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
|
|---|
| 44 | val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
|
|---|
| 45 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
|
|---|
| 46 | return fun(*args, **kwargs)
|
|---|
| 47 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 312, in while_loop
|
|---|
| 48 | outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
|
|---|
| 49 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 296, in _create_jaxpr
|
|---|
| 50 | # the case when init contains weakly-typed values (e.g. Python scalars), with avals that
|
|---|
| 51 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
|
|---|
| 52 | return out
|
|---|
| 53 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
|
|---|
| 54 | wrapper.cache_info = memoized.cache_info
|
|---|
| 55 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 84, in _initial_style_jaxpr
|
|---|
| 56 | # When staging the branches of a conditional into jaxprs, constants are
|
|---|
| 57 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
|
|---|
| 58 | return out
|
|---|
| 59 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
|
|---|
| 60 | wrapper.cache_info = memoized.cache_info
|
|---|
| 61 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 77, in _initial_style_open_jaxpr
|
|---|
| 62 | fun, in_tree, in_avals, primitive_name)
|
|---|
| 63 | File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
|
|---|
| 64 | File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1739, in trace_to_jaxpr_dynamic
|
|---|
| 65 | File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
|
|---|
| 66 | File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
|
|---|
| 67 | ans = self.f(*args, **dict(self.params, **kwargs))
|
|---|
| 68 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
|
|---|
| 69 | val = body_fun(val)
|
|---|
| 70 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
|
|---|
| 71 | compute_loss=False)))
|
|---|
| 72 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
|
|---|
| 73 | ensemble_representations=ensemble_representations)
|
|---|
| 74 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 75 | out = f(*args, **kwargs)
|
|---|
| 76 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 77 | return bound_method(*args, **kwargs)
|
|---|
| 78 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
|
|---|
| 79 | representations = evoformer_module(batch0, is_training)
|
|---|
| 80 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 81 | out = f(*args, **kwargs)
|
|---|
| 82 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 83 | return bound_method(*args, **kwargs)
|
|---|
| 84 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
|
|---|
| 85 | msa_activations = jax.ops.index_add(msa_activations, 0,
|
|---|
| 86 | jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'jax.ops' has no attribute 'index_add'
|
|---|
| 87 |
|
|---|
| 88 | The stack trace below excludes JAX-internal frames.
|
|---|
| 89 | The preceding is the original exception that occurred, unmodified.
|
|---|
| 90 |
|
|---|
| 91 | --------------------
|
|---|
| 92 |
|
|---|
| 93 | The above exception was the direct cause of the following exception:
|
|---|
| 94 |
|
|---|
| 95 | Traceback (most recent call last):
|
|---|
| 96 | File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
|
|---|
| 97 | predict_structure(sequence, alignments, deletions, model_name, output_dir)
|
|---|
| 98 | File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
|
|---|
| 99 | prediction_result = model_runner.predict(processed_feature_dict)
|
|---|
| 100 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
|
|---|
| 101 | result = self.apply(self.params, jax.random.PRNGKey(0), feat)
|
|---|
| 102 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
|
|---|
| 103 | out, state = f.apply(params, {}, *args, **kwargs)
|
|---|
| 104 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
|
|---|
| 105 | out = f(*args, **kwargs)
|
|---|
| 106 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
|
|---|
| 107 | ensemble_representations=True)
|
|---|
| 108 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 109 | out = f(*args, **kwargs)
|
|---|
| 110 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 111 | return bound_method(*args, **kwargs)
|
|---|
| 112 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
|
|---|
| 113 | (0, prev))
|
|---|
| 114 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
|
|---|
| 115 | val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
|
|---|
| 116 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
|
|---|
| 117 | val = body_fun(val)
|
|---|
| 118 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
|
|---|
| 119 | compute_loss=False)))
|
|---|
| 120 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
|
|---|
| 121 | ensemble_representations=ensemble_representations)
|
|---|
| 122 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 123 | out = f(*args, **kwargs)
|
|---|
| 124 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 125 | return bound_method(*args, **kwargs)
|
|---|
| 126 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
|
|---|
| 127 | representations = evoformer_module(batch0, is_training)
|
|---|
| 128 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
|
|---|
| 129 | out = f(*args, **kwargs)
|
|---|
| 130 | File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
|
|---|
| 131 | return bound_method(*args, **kwargs)
|
|---|
| 132 | File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
|
|---|
| 133 | msa_activations = jax.ops.index_add(msa_activations, 0,
|
|---|
| 134 | AttributeError: module 'jax.ops' has no attribute 'index_add'
|
|---|