Traceback (most recent call last):
  File "<ipython-input-1-dc30e8128641>", line 577, in <module>
    run_prediction(sequence)
  File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
    predict_structure(sequence, alignments, deletions, model_name, output_dir)
  File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
    result = self.apply(self.params, jax.random.PRNGKey(0), feat)
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/api.py", line 435, in cache_miss
    assert len(avals) == len(out_flat)
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1709, in bind
    class _TempAxisName:
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1721, in call_bind
    return type(other) is _TempAxisName and self.id == other.id
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 614, in process_call
    process_map = process_call
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
  File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 272, in memoized_fun
    thread_local.most_recent_entry = None
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py", line 198, in lower_xla_callable
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
  File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1798, in trace_to_jaxpr_final
  File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
  File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
    ensemble_representations=True)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
    (0, prev))
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
    val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 312, in while_loop
    outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 296, in _create_jaxpr
    # the case when init contains weakly-typed values (e.g. Python scalars), with avals that
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
    return out
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
    wrapper.cache_info = memoized.cache_info
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 84, in _initial_style_jaxpr
    # When staging the branches of a conditional into jaxprs, constants are
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 210, in wrapper
    return out
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/util.py", line 203, in cached
    wrapper.cache_info = memoized.cache_info
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/lax/control_flow.py", line 77, in _initial_style_open_jaxpr
    fun, in_tree, in_avals, primitive_name)
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py", line 206, in wrapper
  File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1739, in trace_to_jaxpr_dynamic
  File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
  File "/usr/local/lib/python3.7/dist-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
    val = body_fun(val)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
    compute_loss=False)))
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
    ensemble_representations=ensemble_representations)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
    representations = evoformer_module(batch0, is_training)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
    msa_activations = jax.ops.index_add(msa_activations, 0,
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'jax.ops' has no attribute 'index_add'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<ipython-input-1-dc30e8128641>", line 550, in run_prediction
    predict_structure(sequence, alignments, deletions, model_name, output_dir)
  File "<ipython-input-1-dc30e8128641>", line 375, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 133, in predict
    result = self.apply(self.params, jax.random.PRNGKey(0), feat)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 125, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py", line 313, in apply_fn
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/model.py", line 63, in _forward_fn
    ensemble_representations=True)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 379, in __call__
    (0, prev))
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 610, in while_loop
    val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/stateful.py", line 605, in pure_body_fun
    val = body_fun(val)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 370, in <lambda>
    compute_loss=False)))
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 342, in do_call
    ensemble_representations=ensemble_representations)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 161, in __call__
    representations = evoformer_module(batch0, is_training)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/alphafold/model/modules.py", line 1733, in __call__
    msa_activations = jax.ops.index_add(msa_activations, 0,
AttributeError: module 'jax.ops' has no attribute 'index_add'
