#!/usr/bin/env python # encoding: utf-8 from __future__ import (absolute_import, division, print_function, unicode_literals) import os, sys, re import logging import argparse import collections import numpy import time import TNode logger = logging.getLogger(__name__) class TGraph: def __init__(self, gene_id): self.node_cache = dict() self.gene_id = gene_id def get_node(self, transcript_id, loc_node_id, node_seq): """ Instantiates Node objects, and stores them in a graph. *** use this method for instantiating all Node objects *** use clear_node_cache() to clear the graph """ logger.debug("{}\t{}".format(loc_node_id, node_seq)) if len(node_seq) == 0: raise RuntimeError("Error, non-zero length node_seq required for parameter") if loc_node_id in self.node_cache: node_obj = self.node_cache[ loc_node_id ] node_obj.add_transcripts(transcript_id) if node_obj.seq != node_seq: errmsg = "ERROR: have conflicting node sequences for {} node_id: {}\n".format(self.get_gene_id(), loc_node_id) + "{}\n vs. \n{}".format(node_obj.seq, node_seq) logger.critical(errmsg) raise RuntimeError(errmsg) else: return node_obj else: # instantiate a new one node_obj = TNode.TNode(self, transcript_id, loc_node_id, node_seq) self.node_cache[ loc_node_id ] = node_obj return node_obj def get_all_nodes(self): return sorted(self.node_cache.values()) def clear_node_cache(self): """ clears the graph """ self.node_cache.clear() def clear_touch_settings(self): """ clear the touch settings for each of the nodes """ for node in self.get_all_nodes(): node.clear_touch() def add_edges(self, from_nodes_list, to_nodes_list): for from_node in from_nodes_list: for to_node in to_nodes_list: from_node.add_next_node(to_node) to_node.add_prev_node(from_node) def prune_edges(self, from_nodes_list, to_nodes_list): for from_node in from_nodes_list: for to_node in to_nodes_list: from_node.remove_next_node(to_node) to_node.remove_prev_node(from_node) def prune_node(self, node): logger.debug("pruning node: {}".format(node)) self.prune_edges(node.get_prev_nodes(), [node]) self.prune_edges([node], node.get_next_nodes()) node.dead = True self.node_cache.pop(node.get_loc_id()) def retrieve_node(self, node_id): """ does not instantiate, only retrieves. If loc_node_id is not in the graph, returns None """ if node_id in self.node_cache: return self.node_cache[node_id] else: return None def get_gene_id(self): return self.gene_id def draw_graph(self, filename): logger.debug("drawing graph: {}".format(filename)) ofh = open(filename, 'w') ofh.write("digraph G {\n") for node_id in sorted(self.node_cache): node = self.node_cache[node_id] node_seq = node.get_seq() gene_node_id = node.get_gene_node_id() next_nodes = node.get_next_nodes() node_seq_len = len(node_seq) max_len_show = 50 max_len_show_half = int(max_len_show/2) if node_seq_len > max_len_show: node_seq = node_seq[0:max_len_show_half] + "..." + node_seq[(node_seq_len-max_len_show_half):node_seq_len] ofh.write("{} [label=\"{}:Len{}:T{}:{}\"]\n".format(node.get_id(), gene_node_id, node_seq_len, node.get_topological_order(), node_seq)) for next_node in next_nodes: ofh.write("{}->{}\n".format(node.get_id(), next_node.get_id())) ofh.write("}\n") # close it ofh.close() def __repr__(self): txt = "" for node in self.get_all_nodes(): txt += node.toString() + "\n" return txt