# Read pairwise residue scores from a ModelCIF file and plot them in ChimeraX.

def modelcif_pae(session, structure, json_output_path = None, metric_id = None, default_score = 100):

    matrix = read_pairwise_scores(structure, metric_id = metric_id, default_score = default_score)

    if json_output_path is None:
        import tempfile
        temp = tempfile.NamedTemporaryFile(prefix = 'modelcif_pae_', suffix = '.json')
        json_output_path = temp.name

    write_json_pae_file(json_output_path, matrix)

    # Open PAE plot
    from chimerax.core.commands import run, quote_if_necessary
    open_cmd = f'alphafold pae #{structure.id_string} file {quote_if_necessary(json_output_path)}'
    run(session, open_cmd)

def read_pairwise_scores(structure, metric_id = None, default_score = 100):
    if not hasattr(structure, 'filename'):
        from chimerax.core.errors import UserError
        raise UserError(f'Structure {structure} has no associated file')

    # fetch data from ModelCIF
    values, metrics = read_ma_qa_metric_local_pairwise_table(structure.filename)
    if values is None:
        from chimerax.core.errors import UserError
        raise UserError(f'Structure file {structure.filename} contains no pairwise residue scores (i.e. no table "ma_qa_metric_local_pairwise")')
    
    # use only the scores with the given metric id.
    metrics_dict = {
        met_id: (met_name, met_type) \
        for met_id, met_name, met_type in metrics
    }
    if metric_id is None and len(metrics) > 0:
        # look for PAE type
        for met_id, _, met_type in metrics:
            if met_type == "PAE":
                metric_id = met_id
                break
        # fall back to whatever is the first one
        metric_id = metrics[0][0]
    if metric_id in metrics_dict:
        met_name, met_type = metrics_dict[metric_id]
        session.logger.info(
            f"Displaying local-pairwise metric with ID {metric_id} " \
            f"with type '{met_type}' and name '{met_name}'"
        )
    else:
        from chimerax.core.errors import UserError
        raise UserError(f'Structure file {structure.filename} has no metric for metric id "{metric_id}"')
    values = [v for v in values if v[5] == metric_id]
    if len(values) == 0:
        from chimerax.core.errors import UserError
        raise UserError(f'Structure file {structure.filename} has no scores for metric id "{metric_id}"')

    # fill matrix
    matrix_index = {(r.chain_id,r.number):ri for ri,r in enumerate(structure.residues)}

    nr = structure.num_residues
    from numpy import empty, float32
    matrix = empty((nr,nr), float32)
    matrix[:] = default_score
    
    for model_id, chain_id_1, res_num_1, chain_id_2, res_num_2, metric_id, metric_value in values:
        res_num_1, res_num_2, metric_value = int(res_num_1), int(res_num_2), float(metric_value)
        r1 = matrix_index[(chain_id_1, res_num_1)]
        r2 = matrix_index[(chain_id_2, res_num_2)]
        matrix[r1,r2] = metric_value

    return matrix

def read_ma_qa_metric_local_pairwise_table(path):
    """Get relevant data from ModelCIF file.
    Returns tuple (values, metrics) with
    - values = list of pairwise metric values stored as tuple
      (model_id, chain_id_1, res_num_1, chain_id_2, res_num_2, metric_id, metric_value)
    - metrics = list of available pairwise metrics stored as tuple
      (metric_id, metric_name, metric_type)
    """
    from chimerax.mmcif import get_cif_tables
    table_names = [
        'ma_qa_metric_local_pairwise',
        'ma_entry_associated_files',
        'ma_associated_archive_file_details',
        'ma_qa_metric'
    ]
    try:
        tables = get_cif_tables(path, table_names)
    except TypeError:
        # Bug in get_cif_tables() results in TypeError if table not present.  Ticket #16054
        return None, None
    if len(tables) != 4:
        return None, None

    # check different ways of storing QE
    values = []
    ma_qa_metric_local_pairwise = tables[0]
    ma_entry_associated_files = tables[1]
    ma_associated_archive_file_details = tables[2]
    ma_qa_metric = tables[3]

    # get info on available pairwise metrics
    if ma_qa_metric.num_rows() > 0:
        field_names = ['id', 'mode', 'name', 'type']
        all_metrics = ma_qa_metric.fields(field_names)
        metrics = [
            (metric_id, metric_name, metric_type) \
            for metric_id, metric_mode, metric_name, metric_type in all_metrics \
            if metric_mode == "local-pairwise"
        ]
        # for metric_id, metric_name, metric_type in metrics:
        #     session.logger.info(
        #         f"Available local-pairwise metric with ID {metric_id} " \
        #         f"with type '{metric_type}' and name '{metric_name}'"
        #     )
    else:
        # no metrics here
        metrics = []

    # option 1: it's directly in the file
    if ma_qa_metric_local_pairwise.num_rows() > 0:
        field_names = ['model_id',
                       'label_asym_id_1', 'label_seq_id_1',
                       'label_asym_id_2', 'label_seq_id_2',
                       'metric_id', 'metric_value']
        values.extend(ma_qa_metric_local_pairwise.fields(field_names))
    
    # option 2: it's in ma_entry_associated_files
    associated_files = []
    if ma_entry_associated_files.num_rows() > 0 \
       and ma_entry_associated_files.has_field("file_content"):
        field_names = ['id', 'file_url', 'file_content']
        associated_files = ma_entry_associated_files.fields(field_names)
    qa_files_to_load = [] # do it later
    zip_files = {}
    for file_id, file_url, file_content in associated_files:
        if file_content == "local pairwise QA scores":
            tmp_file_path, to_delete = fetch_file_url(file_url, path, "_tmp.cif")
            if tmp_file_path is not None:
                qa_files_to_load.append((tmp_file_path, to_delete))
        elif file_content == "archive with multiple files":
            zip_files[file_id] = file_url

    # option 3: it's listed in ma_associated_archive_file_details
    associated_qa_files = []
    if ma_associated_archive_file_details.num_rows() > 0 \
       and ma_associated_archive_file_details.has_field("file_content"):
        field_names = ['archive_file_id', 'file_path', 'file_content']
        row_fields = ma_associated_archive_file_details.fields(field_names)
        for archive_file_id, file_path, file_content in row_fields:
            if file_content == "local pairwise QA scores":
                if archive_file_id in zip_files:
                    associated_qa_files.append((zip_files[archive_file_id], file_path))
                else:
                    session.logger.warning(f'Structure file {path} has faulty archive_file_id for {file_path}.')
    for zip_file_url, file_path in associated_qa_files:
        import os
        tmp_zip_file_path, to_delete = fetch_file_url(zip_file_url, path, "_tmp.zip")
        if tmp_zip_file_path is not None:
            import zipfile
            import tempfile
            with zipfile.ZipFile(tmp_zip_file_path, 'r') as zip_ref:
                if file_path in zip_ref.namelist():
                    temp_file = tempfile.NamedTemporaryFile(suffix="_tmp.cif", delete=False)
                    with temp_file as f:
                        f.write(zip_ref.read(file_path))
                    qa_files_to_load.append((temp_file.name, True))
                else:
                    session.logger.warning(f'Could not find {file_path} in ZIP file from {zip_file_url}.')
        if to_delete:
            os.remove(tmp_zip_file_path)

    # load from temp. files
    for tmp_file_path, to_delete in qa_files_to_load:
        import os
        new_values, _ = read_ma_qa_metric_local_pairwise_table(tmp_file_path)
        if new_values is not None:
            values.extend(new_values)
        if to_delete:
            os.remove(tmp_file_path)
    
    return values, metrics

def fetch_file_url(file_url, structure_path, save_name):
    """Path can be local file or file from web.
    Returns path to file and flag if file is downloaded temporary file."""
    import os
    dir_name = os.path.dirname(structure_path)
    file_path = os.path.join(dir_name, file_url)
    if os.path.exists(file_path):
        return file_path, False
    else:
        # try to get from the web
        from chimerax.core.fetch import fetch_file
        try:
            tmp_file_path = fetch_file(
                session, file_url,
                name=f'remote associated file for {structure_path}',
                save_name=save_name, save_dir=None
            )
            return tmp_file_path, True
        except:
            session.logger.warning(f"Failed to load {file_url} for {structure_path}")
            return None, False

def write_json_pae_file(json_output_path, matrix):
    # Write matrix in JSON AlphaFold PAE format
    # {"pae": [[17.14, 18.75, 17.91, ...], [5.32, 8.23, ...], ... ]}
    n = matrix.shape[0]
    dists = ', '.join(('[ ' + ', '.join('%.2f' % matrix[i,j] for j in range(n)) + ' ]')
                      for i in range(n))
    with open(json_output_path, 'w') as file:
        file.write('{"pae": [')
        file.write(dists)
        file.write(']}')

def modelarchive_open(session, ma_id):
    # fetch file
    from chimerax.core.fetch import fetch_file
    session.logger.info(f"Fetching {ma_id} from ModelArchive...")
    file_url = f"https://www.modelarchive.org/doi/10.5452/{ma_id}.cif"
    tmp_file_path = fetch_file(
        session, file_url, name=f'MA {ma_id}',
        save_name=f"{ma_id}.cif", save_dir="ModelArchive"
    )
    # open file
    from chimerax.core.commands import run, quote_if_necessary
    from chimerax.core.errors import UserError
    open_cmd = f'open {quote_if_necessary(tmp_file_path)}'
    model = run(session, open_cmd)[0]
    model_id = model.id_string
    # try to add PAE
    try:
        open_cmd = f'modelcif pae #{model_id}'
        run(session, open_cmd)
        #modelcif_pae(session, structure)
    except UserError:
        # ok for it not to have PAE...
        pass

def register_command(logger):
    from chimerax.core.commands import CmdDesc, register, StringArg, FloatArg, SaveFileNameArg
    from chimerax.atomic import StructureArg
    desc = CmdDesc(
        required = [('structure', StructureArg)],
        keyword = [('metric_id', StringArg),
                   ('default_score', FloatArg),
                   ('json_output_path', SaveFileNameArg)],
        synopsis = 'Plot ModelCIF pairwise residue scores'
    )
    register('modelcif pae', desc, modelcif_pae, logger=logger)
    # TEST: register extra command to remotely load MA file
    desc = CmdDesc(
        required = [('ma_id', StringArg)],
        synopsis = 'Load ModelArchive file with given ID (incl. PAE if available)'
    )
    register('modelarchive open', desc, modelarchive_open, logger=logger)
    #

register_command(session.logger)
