#! /usr/bin/env python3

# Copyright 2010, 2011, 2012, 2014 Martin C. Frith

# Read query-genome alignments: write them along with the probability
# that each alignment is not the true mapping of its query.  These
# probabilities make the risky assumption that one of the alignments
# reported for each query is correct.

from __future__ import print_function

import gzip
import sys, os, math, optparse, signal

def myOpen(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def logsum(numbers, temperature):
    """Adds numbers, in log space, to avoid overflow."""
    m = max(numbers) / temperature
    s = sum(math.exp(i / temperature - m) for i in numbers)
    return m + math.log(s)

def mafScore(aLine):
    for word in aLine.split():
        if word.startswith("score="):
            return float(word[6:])
    raise Exception("found an alignment without a score")

def doOneBatch(lines, scores, opts, temperature):
    if scores:
        logDenominator = logsum(scores, temperature)

    isWanted = True
    i = 0
    for line in lines:
        if line[0] == "a":
            s = scores[i]
            p = 1.0 - math.exp(s / temperature - logDenominator)
            i += 1
            if s < opts.score or p > opts.mismap:
                isWanted = False
            else:
                newLineEnd = " mismap=%.3g\n" % p
                line = line.rstrip() + newLineEnd
        elif line[0].isdigit():  # we have an alignment in tabular format
            s = scores[i]
            p = 1.0 - math.exp(s / temperature - logDenominator)
            i += 1
            if s < opts.score or p > opts.mismap: continue
            newLineEnd = "\tmismap=%.3g\n" % p
            line = line.rstrip() + newLineEnd
        if isWanted: print(line, end="")
        if line.isspace(): isWanted = True  # reset at end of maf paragraph

def readHeaderOrDie(lines):
    t = 0.0
    e = -1
    for line in lines:
        if line.startswith("#") or line.isspace():
            print(line, end="")
            for i in line.split():
                if i.startswith("t="): t = float(i[2:])
                elif i.startswith("e="): e = int(i[2:])
            if t > 0 and e >= 0: break
        else:
            raise Exception("I need a header with t= and e=")
    return t, e

def doOneFile(opts, f):
    temperature, e = readHeaderOrDie(f)
    if opts.score < 0: opts.score = e + round(temperature * math.log(1000))
    lines = []
    scores = []
    queryName = ""

    for line in f:
        if line[0] == "a":
            batchEnd = len(lines)
            score = mafScore(line)
            sLineCount = 0
        elif line[0] == "s":
            sLineCount += 1
            if sLineCount == 2:
                qName = line.split(None, 2)[1]
                if qName != queryName:
                    doOneBatch(lines[:batchEnd], scores, opts, temperature)
                    lines = lines[batchEnd:]
                    scores = []
                    queryName = qName
                scores.append(score)
        elif line[0].isdigit():  # we have an alignment in tabular format
            fields = line.split(None, 7)
            qName = fields[6]
            if qName != queryName:
                doOneBatch(lines, scores, opts, temperature)
                lines = []
                scores = []
                queryName = qName
            scores.append(float(fields[0]))
        lines.append(line)

    doOneBatch(lines, scores, opts, temperature)

def lastMapProbs(opts, args):
    if args:
        for i in args:
            doOneFile(opts, myOpen(i))
    else:
        doOneFile(opts, sys.stdin)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message

    usage = """
  %prog --help
  %prog [options] lastal-alignments"""

    description = "Calculate a mismap probability for each alignment.  This is the probability that the alignment does not reflect the origin of the query sequence, assuming that one reported alignment does reflect the origin of each query."

    op = optparse.OptionParser(usage=usage, description=description)
    op.add_option("-m", "--mismap", type="float", default=0.01, metavar="M",
                  help="don't write alignments with mismap probability > M (default: %default)")
    op.add_option("-s", "--score", type="float", default=-1, metavar="S",
                  help="don't write alignments with score < S (default: e+t*ln[1000])")
    (opts, args) = op.parse_args()
    if not args and sys.stdin.isatty():
        op.print_help()
        op.exit()

    try: lastMapProbs(opts, args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception as e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))