/* * BioJava development code * * This code may be freely distributed and modified under the * terms of the GNU Lesser General Public Licence. This should * be distributed with the code. If you do not have a copy, * see: * * http://www.gnu.org/copyleft/lesser.html * * Copyright for this code is held jointly by the individual * authors. These should be listed in @author doc comments. * * For more information on the BioJava project and its aims, * or to join the biojava-l mailing list, visit the home page * at: * * http://www.biojava.org/ * */ package org.biojava.bio.dp.onehead; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.biojava.bio.BioError; import org.biojava.bio.BioException; import org.biojava.bio.dist.Distribution; import org.biojava.bio.dp.BackPointer; import org.biojava.bio.dp.DP; import org.biojava.bio.dp.DPMatrix; import org.biojava.bio.dp.DotState; import org.biojava.bio.dp.EmissionState; import org.biojava.bio.dp.IllegalTransitionException; import org.biojava.bio.dp.MagicalState; import org.biojava.bio.dp.MarkovModel; import org.biojava.bio.dp.ScoreType; import org.biojava.bio.dp.SimpleStatePath; import org.biojava.bio.dp.State; import org.biojava.bio.dp.StatePath; import org.biojava.bio.symbol.Alignment; import org.biojava.bio.symbol.AlphabetManager; import org.biojava.bio.symbol.DoubleAlphabet; import org.biojava.bio.symbol.GappedSymbolList; import org.biojava.bio.symbol.IllegalAlphabetException; import org.biojava.bio.symbol.IllegalSymbolException; import org.biojava.bio.symbol.SimpleAlignment; import org.biojava.bio.symbol.SimpleGappedSymbolList; import org.biojava.bio.symbol.SimpleSymbolList; import org.biojava.bio.symbol.Symbol; import org.biojava.bio.symbol.SymbolList; /** * An implementation of DP that aligns a single sequence against a single model. * * @author Matthew Pocock * @author Thomas Down * @author Samiul Hasan * @author Lukas Kall */ public class SingleDP extends DP implements Serializable { protected final HashMap emissionsProb; protected final HashMap emissionsOdds; protected final HashMap emissionsNull; public SingleDP(MarkovModel model) throws IllegalSymbolException, IllegalTransitionException, BioException { super(model); emissionsProb = new HashMap(); emissionsOdds = new HashMap(); emissionsNull = new HashMap(); } public void update() { // System.out.println("Updating emissions as underlying model has changed!"); super.update(); // workaround for bug in vm if(emissionsProb != null) { emissionsProb.clear(); } if(emissionsOdds != null) { emissionsOdds.clear(); } if(emissionsNull != null) { emissionsNull.clear(); } } /** * This method is public for the benefit of training algorithms, * and in the future we should look at a better way of exposing * the emissions cache. */ public double [] getEmission(Symbol sym, ScoreType scoreType) throws IllegalSymbolException { Map emissions; if(scoreType == ScoreType.PROBABILITY) { emissions = emissionsProb; } else if(scoreType == ScoreType.ODDS) { emissions = emissionsOdds; } else if(scoreType == ScoreType.NULL_MODEL) { emissions = emissionsNull; } else { throw new BioError("Unknown ScoreType object: " + scoreType); } double [] em = (double []) emissions.get(sym); if(em == null) { int dsi = getDotStatesIndex(); em = new double[dsi]; State [] states = getStates(); if(sym == AlphabetManager.getGapSymbol()) { em[0] = 0; } else { em[0] = Double.NEGATIVE_INFINITY; } for(int i = 1; i < dsi; i++) { EmissionState es = (EmissionState) states[i]; Distribution dis = es.getDistribution(); em[i] = Math.log(scoreType.calculateScore(dis, sym)); } emissions.put(sym, em); /*System.out.println("Emissions for " + sym); for(int i = 0; i < em.length; i++) { System.out.println("\t" + states[i] + "\t-> " + em[i]); }*/ } return em; } public double forward(SymbolList [] seq, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); DPCursor dpCursor = new SmallCursor( getStates(), seq[0], seq[0].iterator() ); double score = forward(dpCursor, scoreType); unlockModel(); return score; } public double backward(SymbolList [] seq, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); DPCursor dpCursor = new SmallCursor( getStates(), seq[0], new ReverseIterator(seq[0]) ); double score = backward(dpCursor, scoreType); unlockModel(); return score; } public DPMatrix forwardMatrix(SymbolList [] seq, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); SingleDPMatrix matrix = new SingleDPMatrix(this, seq[0]); DPCursor dpCursor = new MatrixCursor(matrix, seq[0].iterator(), +1); matrix.setScore(forward(dpCursor, scoreType)); unlockModel(); return matrix; } public DPMatrix backwardMatrix(SymbolList [] seq, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); SingleDPMatrix matrix = new SingleDPMatrix(this, seq[0]); DPCursor dpCursor = new MatrixCursor(matrix, new ReverseIterator(seq[0]), -1); matrix.setScore(backward(dpCursor, scoreType)); unlockModel(); return matrix; } public DPMatrix forwardMatrix(SymbolList [] seq, DPMatrix matrix, ScoreType scoreType) throws IllegalArgumentException, IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); SingleDPMatrix sm = (SingleDPMatrix) matrix; DPCursor dpCursor = new MatrixCursor(sm, seq[0].iterator(), +1); sm.setScore(forward(dpCursor, scoreType)); unlockModel(); return sm; } public DPMatrix backwardMatrix(SymbolList [] seq, DPMatrix matrix, ScoreType scoreType) throws IllegalArgumentException, IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException { if(seq.length != 1) { throw new IllegalArgumentException("seq must be 1 long, not " + seq.length); } lockModel(); SingleDPMatrix sm = (SingleDPMatrix) matrix; DPCursor dpCursor = new MatrixCursor(sm, new ReverseIterator(seq[0]), -1); sm.setScore(backward(dpCursor, scoreType)); unlockModel(); return sm; } protected double forward(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { forward_initialize(dpCursor, scoreType); forward_recurse(dpCursor, scoreType); return forward_termination(dpCursor, scoreType); } protected double backward(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { backward_initialize(dpCursor, scoreType); backward_recurse(dpCursor, scoreType); return backward_termination(dpCursor, scoreType); } protected void forward_initialize(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { double [] v = dpCursor.currentCol(); State [] states = getStates(); for (int l = 0; l < getDotStatesIndex(); l++) { if(states[l] == getModel().magicalState()) { //prob 1 v[l] = 0.0; } else { //prob 0 v[l] = Double.NEGATIVE_INFINITY; } } int [][] transitions = getForwardTransitions(); double [][] transitionScore = getForwardTransitionScores(scoreType); double [] currentCol = dpCursor.currentCol(); //l over dots for (int l = getDotStatesIndex(); l < states.length; l++) { double score = 0.0; int [] tr = transitions[l]; double [] trs = transitionScore[l]; int ci = 0; while( ci < tr.length && ( currentCol[tr[ci]] == Double.NEGATIVE_INFINITY || currentCol[tr[ci]] == Double.NaN || currentCol[tr[ci]] == Double.POSITIVE_INFINITY ) ) { ci++; } double constant = (ci < tr.length) ? currentCol[tr[ci]] : 0.0; for(int kc = 0; kc < tr.length; kc++) { int k = tr[kc]; if( currentCol[k] != Double.NEGATIVE_INFINITY && currentCol[k] != Double.NaN && currentCol[k] != Double.POSITIVE_INFINITY ) { double t = trs[kc]; score += Math.exp(t + (currentCol[k] - constant)); } else { } } currentCol[l] = Math.log(score) + constant; } } protected void backward_initialize(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { double [] v = dpCursor.currentCol(); State [] states = getStates(); for (int l = 0; l < states.length; l++) { if(states[l] == getModel().magicalState()) { v[l] = 0.0; } else { v[l] = Double.NEGATIVE_INFINITY; } } } private void forward_recurse(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { State [] states = getStates(); int [][] transitions = getForwardTransitions(); double [][] transitionScore = getForwardTransitionScores(scoreType); // int _index = 0; while (dpCursor.canAdvance()) { // _index++; // System.out.println("\n*** Index=" + _index + " ***"); dpCursor.advance(); Symbol sym = dpCursor.currentRes(); double [] emissions = getEmission(sym, scoreType); // System.out.println("Consuming " + sym.getName()); double [] currentCol = dpCursor.currentCol(); double [] lastCol = dpCursor.lastCol(); for (int l = 0; l < getDotStatesIndex(); l++) { //any -> emission double weight = emissions[l]; if (weight == Double.NEGATIVE_INFINITY) { // System.out.println("*"); currentCol[l] = Double.NEGATIVE_INFINITY; } else { double score = 0.0; int [] tr = transitions[l]; double [] trs = transitionScore[l]; // System.out.println("l=" + states[l].getName()); int ci = 0; while ( ci < tr.length && (lastCol[tr[ci]] == Double.NEGATIVE_INFINITY || lastCol[tr[ci]] == Double.NaN || lastCol[tr[ci]] == Double.POSITIVE_INFINITY) ) { ci++; } double constant = (ci < tr.length) ? lastCol[tr[ci]] : 0.0; for (int kc = 0; kc < tr.length; kc++) { int k = tr[kc]; // System.out.println("k=" + states[k].getName()); if (lastCol[k] != Double.NEGATIVE_INFINITY) { double t = trs[kc]; if(states[l]== getModel().magicalState()) { // System.out.print("magic " + "lastCol[k]=" + lastCol[k] + " , "); // System.out.println("t=" + t); } score += Math.exp(t + (lastCol[k] - constant)); } else { // System.out.println("-"); } } // new_l = emission_l(sym) * sum_k(transition(k, l) * old_k) currentCol[l] = (weight + Math.log(score)) + constant; // System.out.println("currentCol[" + states[l].getName() + "]=" + currentCol[l]); if(states[l] == getModel().magicalState()) { // System.out.print("magic\t"); //System.out.print("Weight " + weight + "\t"); // System.out.print(", score " + score + " = " + Math.log(score) + "\t"); // System.out.println(", constant " + constant); } } } for(int l = getDotStatesIndex(); l < states.length; l++) { // all dot states from emissions double score = 0.0; int [] tr = transitions[l]; double [] trs = transitionScore[l]; int ci = 0; while( ci < tr.length && ( currentCol[tr[ci]] == Double.NEGATIVE_INFINITY || currentCol[tr[ci]] == Double.NaN || currentCol[tr[ci]] == Double.POSITIVE_INFINITY) ) { ci++; } double constant = (ci < tr.length) ? currentCol[tr[ci]] : 0.0; //System.out.println("constant: " + constant); //System.out.println("read from state: " + ((ci < tr.length) ? states[tr[ci]].getName() : "none")); for(int kc = 0; kc < tr.length; kc++) { int k = tr[kc]; if(currentCol[k] != Double.NEGATIVE_INFINITY && currentCol[k] !=Double.NaN && currentCol[k] != Double.POSITIVE_INFINITY) { double t = trs[kc]; score += Math.exp(t + (currentCol[k] - constant)); } else { } } currentCol[l] = Math.log(score) + constant; //System.out.println("currentCol[" + states[l].getName() + "]=" + currentCol[l]); } } } protected void backward_recurse(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { State [] states = getStates(); int stateCount = states.length; int [][] transitions = getBackwardTransitions(); double [][] transitionScore = getBackwardTransitionScores(scoreType); double [] prevScores = new double[getDotStatesIndex()]; while (dpCursor.canAdvance()) { dpCursor.advance(); Symbol sym = dpCursor.lastRes(); double [] emissions = getEmission(sym, scoreType); double [] currentCol = dpCursor.currentCol(); double [] lastCol = dpCursor.lastCol(); for(int k = getDotStatesIndex() - 1; k >= 0; k--) { prevScores[k] = emissions[k]; } //System.out.println(sym.getName()); for (int k = stateCount-1; k >= 0; k--) { //System.out.println("State " + k + " of " + stateCount + ", " + transitions.length); //System.out.println(states[k].getName()); int [] tr = transitions[k]; double [] trs = transitionScore[k]; double score = 0.0; int ci = 0; while ( ci < tr.length && lastCol[tr[ci]] == Double.NEGATIVE_INFINITY ) { ci++; } double constant = (ci < tr.length) ? lastCol[tr[ci]] : 0.0; //System.out.println("Chosen constant: " + constant); for (int lc = tr.length-1; lc >= 0; lc--) { // any->emission int l = tr[lc]; if(l >= getDotStatesIndex()) { continue; } //System.out.println(states[k].getName() + " -> " + states[l].getName()); double weight = prevScores[l]; //System.out.println("weight = " + weight); if ( lastCol[l] != Double.NEGATIVE_INFINITY && weight != Double.NEGATIVE_INFINITY ) { double t = trs[lc]; score += Math.exp(t + weight + (lastCol[l] - constant)); } } //System.out.println("Score = " + score); for(int lc = tr.length-1; lc >= 0; lc--) { // any->dot int l = tr[lc]; if(l < getDotStatesIndex() || l <= k) { break; } /*System.out.println( "Processing dot-state transition " + states[k].getName() + " -> " + states[l].getName() );*/ if(currentCol[l] != Double.NEGATIVE_INFINITY) { score += Math.exp(trs[lc] + (currentCol[l] - constant)); } } //System.out.println("Score = " + score); currentCol[k] = Math.log(score) + constant; //System.out.println("currentCol = " + currentCol[k]); } } } private double forward_termination(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { double [] scores = dpCursor.currentCol(); State [] states = getStates(); int l = 0; while (states[l] != getModel().magicalState()) l++; return scores[l]; } protected double backward_termination(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { double [] scores = dpCursor.currentCol(); State [] states = getStates(); int l = 0; while (states[l] != getModel().magicalState()) l++; return scores[l]; } public StatePath viterbi(SymbolList [] symList, ScoreType scoreType) throws IllegalSymbolException { SymbolList r = symList[0]; DPCursor dpCursor = new SmallCursor(getStates(), r, r.iterator()); return viterbi(dpCursor, scoreType); } private StatePath viterbi(DPCursor dpCursor, ScoreType scoreType) throws IllegalSymbolException { lockModel(); State [] states = getStates(); int [][] transitions = getForwardTransitions(); double [][] transitionScore = getForwardTransitionScores(scoreType); int stateCount = states.length; BackPointer [] oldPointers = new BackPointer[stateCount]; BackPointer [] newPointers = new BackPointer[stateCount]; // initialize { double [] vc = dpCursor.currentCol(); double [] vl = dpCursor.lastCol(); for (int l = 0; l < getDotStatesIndex(); l++) { if(states[l] == getModel().magicalState()) { //System.out.println("Initializing start state to 0.0"); vc[l] = vl[l] = 0.0; oldPointers[l] = newPointers[l] = new BackPointer(states[l]); } else { vc[l] = vl[l] = Double.NEGATIVE_INFINITY; } } for (int l = getDotStatesIndex(); l < stateCount; l++) { int [] tr = transitions[l]; double [] trs = transitionScore[l]; double transProb = Double.NEGATIVE_INFINITY; double trans = Double.NEGATIVE_INFINITY; int prev = -1; for (int kc = 0; kc < tr.length; kc++) { int k = tr[kc]; double t = trs[kc]; double s = vc[k]; double p = t + s; if (p > transProb) { transProb = p; prev = k; trans = t; } } if(prev != -1) { vc[l] = vl[l] = transProb; oldPointers[l] = newPointers[l] = new BackPointer( states[l], newPointers[prev], trans ); } else { vc [l] = vl[l] = Double.NEGATIVE_INFINITY; oldPointers[l] = newPointers[l] = null; } } } // viterbi while (dpCursor.canAdvance()) { // symbol i dpCursor.advance(); Symbol sym = dpCursor.currentRes(); double [] emissions = getEmission(sym, scoreType); //System.out.println(sym.getName()); double [] currentCol = dpCursor.currentCol(); double [] lastCol = dpCursor.lastCol(); for (int l = 0; l < states.length; l++) { // don't move from magical state double emission; if(l < getDotStatesIndex()) { emission = emissions[l]; } else { emission = 0.0; } int [] tr = transitions[l]; //System.out.println("Considering " + tr.length + " alternatives"); double [] trs = transitionScore[l]; if (emission == Double.NEGATIVE_INFINITY) { //System.out.println(states[l].getName() + ": impossible emission"); currentCol[l] = Double.NEGATIVE_INFINITY; newPointers[l] = null; } else { double transProb = Double.NEGATIVE_INFINITY; double trans = Double.NEGATIVE_INFINITY; int prev = -1; for (int kc = 0; kc < tr.length; kc++) { int k = tr[kc]; double t = trs[kc]; double s = (l < getDotStatesIndex()) ? lastCol[k] : currentCol[k]; double p = t + s; /*System.out.println("Looking at scores from " + states[k].getName()); System.out.println("Old = " + lastCol[k]); System.out.println("New = " + currentCol[k]); System.out.println( "Considering " + states[k].getName() + " -> " + states[l].getName() + ", " + t + " + " + s + " = " + p );*/ if (p > transProb) { transProb = p; prev = k; trans = t; } } if(prev != -1) { currentCol[l] = transProb + emission; /*System.out.println( states[prev].getName() + "->" + states[l].getName() + ", " + (trans + emission) );*/ newPointers[l] = new BackPointer( states[l], (l < getDotStatesIndex()) ? oldPointers[prev] : newPointers[prev], trans + emission ); /*System.out.println("Succesfully completed " + states[l].getName()); System.out.println("Old = " + lastCol[l]); System.out.println("New = " + currentCol[l]);*/ } else { //System.out.println(states[l].getName() + ": Nowhere to come from"); currentCol[l] = Double.NEGATIVE_INFINITY; newPointers[l] = null; } } } BackPointer [] bp = newPointers; newPointers = oldPointers; oldPointers = bp; } // find max in last row BackPointer best = null; double bestScore = Double.NaN; for (int l = 0; l < stateCount; l++) { if (states[l] == getModel().magicalState()) { best = oldPointers[l].back; bestScore = dpCursor.currentCol()[l]; break; } } int len = 0; BackPointer b2 = best; int dotC = 0; int emC = 0; // trace back ruit to check out size of path while(b2.back != b2) { len++; if(b2.state instanceof EmissionState) { emC++; } else { dotC++; } b2 = b2.back; }; Map aMap = new HashMap(); aMap.put(dpCursor.symList(), dpCursor.symList()); Alignment ali = new SimpleAlignment(aMap); GappedSymbolList symView = new SimpleGappedSymbolList(ali); double [] scores = new double[len]; List stateList = new ArrayList(len); for (int j = 0; j < len; j++) { stateList.add(null); } b2 = best; int ri = dpCursor.symList().length()+1; int lc = len; int gaps = 0; while(b2.back != b2) { lc--; //System.out.println("At " + lc + " state=" + b2.state.getName() + ", score=" + b2.score + ", back=" + b2.back); if(b2.state instanceof MagicalState) { b2 = b2.back; continue; } stateList.set(lc, b2.state); if(b2.state instanceof DotState) { symView.addGapInSource(ri); gaps++; } else { ri--; } scores[lc] = b2.score; b2 = b2.back; } /*System.out.println("Counted " + emC + " emissions and " + dotC + " dots"); System.out.println("Counted backpointers. Alignment of length " + len); System.out.println("Counted states " + stateList.size()); System.out.println("Input list had length " + dpCursor.symList().length()); System.out.println("Added gaps: " + gaps); System.out.println("Gapped view has length " + symView.length());*/ unlockModel(); return new SimpleStatePath( bestScore, symView, new SimpleSymbolList(getModel().stateAlphabet(), stateList), DoubleAlphabet.fromArray(scores) ); } }