Source code for transmission_models.classes.genetic_prior


from random import choice,randint,random,sample,choices
from scipy.special import gamma as GAMMA
from scipy.stats import nbinom, gamma, binom, expon, poisson
from .partial_sampled_utils import *
import numpy as np
from transmission_models import utils
from itertools import combinations
import networkx as nx

[docs]class genetic_prior_tree():
[docs] def __init__(self, model, mu, distance_matrix): """ Initialize the genetic prior tree object. Parameters ---------- model : object The transmission model containing the tree structure. mu : float The mutation rate parameter for the Poisson distribution. distance_matrix : numpy.ndarray Matrix containing pairwise genetic distances between hosts. Notes ----- This initializes the genetic prior calculator with: - A Poisson distribution with rate mu for modeling genetic distances - A distance matrix for pairwise host comparisons - A reference to the transmission model """ self.mu = mu self.distance_matrix = distance_matrix self.prior_dist = poisson(mu) self.model = model self.correction_LL = 0 self.log_prior = 0
[docs] @staticmethod def search_firsts_sampled_siblings(host, T, distance_matrix): """ Find all sampled siblings of a host in the transmission tree. Parameters ---------- host : object The host for which to find sampled siblings. T : networkx.DiGraph The transmission tree. distance_matrix : numpy.ndarray Matrix containing pairwise genetic distances between hosts. Returns ------- list List of sampled sibling hosts that have genetic distance data. Notes ----- This method recursively searches through the tree to find all sampled hosts that are descendants of the given host and have valid genetic distance data (non-NaN values in the distance matrix). """ sampled_hosts = [] for h in T.successors(host): if not np.isnan(distance_matrix[int(h),int(h)]) and h.sampled:#If sampled sampled_hosts.append(h) else: sampled_hosts += genetic_prior_tree.search_firsts_sampled_siblings(h, T, distance_matrix) return sampled_hosts
[docs] @staticmethod def search_first_sampled_parent(host, T, root): """ Find the first sampled ancestor of a host in the transmission tree. Parameters ---------- host : object The host for which to find the first sampled parent. T : networkx.DiGraph The transmission tree. root : object The root host of the transmission tree. Returns ------- object or None The first sampled parent host, or None if no sampled parent is found. Notes ----- This method traverses up the tree from the given host until it finds the first sampled ancestor, or reaches the root without finding one. """ if host == root: return None parent = next(T.predecessors(host)) if not parent.sampled: return genetic_prior_tree.search_first_sampled_parent(parent, T, root) else: return parent
[docs] @staticmethod def get_mut_time_dist(hp, hs): """ Calculate the mutation time distance between two hosts. Parameters ---------- hp : object The parent host. hs : object The sibling host. Returns ------- float The mutation time distance: (hs.t_sample + hp.t_sample - 2 * hp.t_inf). Notes ----- This calculates the time available for mutations to accumulate between the sampling times of two hosts, accounting for their common infection time. """ return (hs.t_sample + hp.t_sample - 2 * hp.t_inf)
[docs] def get_closest_sampling_siblings(self,T=None,verbose=False): """ Calculate log-likelihood correction for closest sampling siblings. Parameters ---------- T : networkx.DiGraph, optional The transmission tree. If None, uses self.model.T. verbose : bool, optional If True, print detailed information during calculation. Returns ------- float The log-likelihood correction value. Notes ----- This method calculates correction terms for the genetic prior by finding the closest sampled siblings for each host and computing the log-likelihood of their genetic distances based on the time difference between sampling events. """ if T is None: T = self.model.T # self.model.get_root_subtrees() roots_subtrees = get_roots_data_subtrees(self.model.root_host, T, self.distance_matrix) non_observed = list(roots_subtrees) LL_correction = 0 # print(roots_subtrees[::-1],shuffle(roots_subtrees[::-1])) if verbose: print("Top correction\n","_"*20) for h in roots_subtrees: if h not in non_observed: continue N_samp = 0 parent = h relatives = [] jumped = False closest = None while N_samp == 0: parent = list(T.predecessors(parent))[0] closest = None # print(h,parent) if parent != self.model.root_host: if T.out_degree(parent) == 1: # parent = model.parent(parent) # print(h,parent) jumped = True continue # elif model.out_degree(parent)==2 and not jumped: # parent = model.parent(parent) # jumped = True # continue for h2 in T.successors(parent): # print("-"*6,h,parent,h2) if h2.sampled: if h2 == h or np.isnan(self.distance_matrix[int(h2),int(h2)]): continue else: # if h2 not in non_observed: continue # print("KAKA2",h2,parent) N_samp += 1 relatives.append(h2) # non_observed.remove(h2) # else: # print("KAKA",h2,parent) if parent == self.model.root_host: break # non_observed.remove(h) if not relatives: continue closest = min(relatives, key=lambda h2: h2.t_sample) Dt = (closest.t_sample + h.t_sample - 2 * parent.t_inf) LL_correction += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(closest)])) if verbose: print(f"\t\t{int(h),int(closest)},{Dt=} {self.distance_matrix[int(h), int(closest)]=} {np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(closest)]))}") # print("----->",h,closest,parent) return LL_correction
# def get_closest_sampling_siblings(self): # non_observed = list(self.model.roots_subtrees) # LL_correction = 0 # # for h in self.model.roots_subtrees: # if h not in non_observed: continue # N_samp = 0 # parent = h # relatives = [] # jumped = False # closest = None # while N_samp == 0: # parent = self.model.parent(parent) # closest = None # # print(h,parent) # if parent != self.model.root_host: # if self.model.out_degree(parent) == 1: # # parent = model.parent(parent) # # print(h,parent) # jumped = True # continue # # elif model.out_degree(parent)==2 and not jumped: # # parent = model.parent(parent) # # jumped = True # # continue # # for h2 in self.model.successors(parent): # # print("-"*6,h,parent,h2) # if h2.sampled: # if h2 == h: # continue # elif not h2.sampled: # continue # else: # # if h2 not in non_observed: continue # # print("KAKA2",h2,parent) # N_samp += 1 # relatives.append(h2) # # non_observed.remove(h2) # # else: # # print("KAKA",h2,parent) # if parent == self.model.root_host: break # # non_observed.remove(h) # if relatives == []: continue # closest = min(relatives, key=lambda h2: h2.t_sample) # Dt = (closest.t_sample + h.t_sample - 2 * parent.t_inf) # # print(f"\t\t{int(pair[0]),int(pair[1])},{Dt`=}") # LL_correction += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(closest)])) # # # print("----->",h,closest,parent) # return LL_correction
[docs] def prior_host(self, host, T, parent_dist=False): """ Calculate the log prior for a specific host in the transmission tree. Parameters ---------- host : object The host for which to calculate the log prior. T : networkx.DiGraph The transmission tree. parent_dist : bool, optional If True, include parent distance in the calculation. Default is False. Returns ------- float The log prior value for the host. Notes ----- This method calculates the log prior by considering: 1. Direct connections to sampled hosts 2. Connections to sampled siblings through unsampled intermediate hosts 3. Parent distance (if parent_dist=True) The calculation uses Poisson distributions based on the mutation rate and time differences between sampling events. """ log_prior = 0 for h2 in T[host]: if h2.sampled: # print(f"{host}-->{h2}") Dt = h2.t_sample - host.t_sample log_prior += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(h2)])) p = poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(h2)]) # print(int(h),int(h2),Dt,p,np.log(p)) else: siblings = genetic_prior_tree.search_firsts_sampled_siblings(h2, T, self.distance_matrix) for hs in siblings: # print(f"{host}-->{hs}") Dt = hs.t_sample - host.t_sample log_prior += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(hs)])) if parent_dist and host != self.model.root_host: parent = self.model.parent(host) if parent.sampled: Dt = host.t_sample - parent.t_sample log_prior += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(parent)])) else: parent = genetic_prior_tree.search_first_sampled_parent(host, T, self.model.root_host) if parent is not None: # print(f"{parent}-->{host}") Dt = host.t_sample - parent.t_sample log_prior += np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(parent)])) return log_prior
[docs] def prior_pair(self, h1, h2): """ Calculate the log prior for a pair of hosts. Parameters ---------- h1 : object First host in the pair. h2 : object Second host in the pair. Returns ------- float The log prior value for the pair, or 0 if either host is not sampled. Notes ----- This method calculates the log prior for the genetic distance between two hosts based on their sampling time difference and the Poisson distribution with rate mu * Dt. """ log_prior = 0 if not h1.sampled or not h2.sampled: return 0 Dt = h2.t_sample - h1.t_sample return np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h1), int(h2)]))
[docs] def log_prior_host_list(self,host_list,T=None): """ Calculate the total log prior for a list of hosts. Parameters ---------- host_list : list List of hosts for which to calculate the log prior. T : networkx.DiGraph, optional The transmission tree. If None, uses self.model.T. Returns ------- float The sum of log priors for all hosts in the list. Notes ----- This method iterates through the host list and sums the log priors for each individual host using the log_prior_host method. """ log_prior = 0 for host in host_list: log_prior += self.log_prior_host(host,T) return log_prior
[docs] def log_prior_host(self, host, T=None): """ Compute the log prior for a host. Parameters ---------- host : object The host for which to compute the log prior. T : object, optional Transmission tree. Default is None. Returns ------- float The log prior value for the host. Notes ----- The function operates as follows: 1. Computes the log prior for the host based on the transmission tree. 2. Returns the log prior value. """ if T is None: T = self.model.T sampled_siblings = genetic_prior_tree.search_firsts_sampled_siblings(host, T, self.distance_matrix) log_prior = 0 for h2 in sampled_siblings: Dt = self.get_mut_time_dist(host, h2) lp = np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(host), int(h2)])) # print(host,h2,lp) log_prior += lp return log_prior
[docs] def log_prior_T(self, T, update_up=True,verbose=False): """ Calculate the total log prior for an entire transmission tree. Parameters ---------- T : networkx.DiGraph The transmission tree. update_up : bool, optional If True, include correction terms for closest sampling siblings. Default is True. verbose : bool, optional If True, print detailed information during calculation. Returns ------- float The total log prior value for the transmission tree. Notes ----- This method calculates the complete log prior for a transmission tree by: 1. Iterating through all hosts and their connections 2. Computing log-likelihoods for direct connections to sampled hosts 3. Computing log-likelihoods for connections to sampled siblings through unsampled hosts 4. Adding correction terms for closest sampling siblings (if update_up=True) The calculation uses Poisson distributions based on mutation rates and time differences. """ self.log_prior = 0 suma = 0 for h in T: if np.isnan(self.distance_matrix[int(h),int(h)]) or not h.sampled:continue#Check if we have info of h for h2 in T[h]: if not np.isnan(self.distance_matrix[int(h2), int(h2)]) and h2.sampled: # Check if we have info of h2 # print(f"{h}-->{h2} {self.distance_matrix[int(h2), int(h2)]}") Dt = h2.t_sample - h2.t_inf + np.abs(h.t_sample - h2.t_inf) log_L = np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(h2)])) if verbose: print(f"{h}-->{h2} {Dt=} {log_L=}") suma += log_L self.log_prior += log_L # p = poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(h2)]) # print(int(h),int(h2),Dt,p,np.log(p)) else: siblings = genetic_prior_tree.search_firsts_sampled_siblings(h2, T, self.distance_matrix) for hs in siblings: if np.isnan(self.distance_matrix[int(hs),int(hs)]) or not hs.sampled:continue # print(f"{h}-->{hs} (jumped) {self.distance_matrix[int(h),int(hs)]}, {hs.sampled}") Dt = hs.t_sample - h2.t_inf + np.abs(h.t_sample - h2.t_inf) log_L = np.log(poisson(self.mu * Dt).pmf(self.distance_matrix[int(h), int(hs)])) if verbose: print(f"{h}-->{hs} (jumped) {Dt=} {log_L=}") suma += log_L self.log_prior += log_L if verbose: print(f"{suma=}") if update_up: # self.model.get_root_subtrees() LL_correction = self.get_closest_sampling_siblings(T) self.correction_LL = LL_correction self.log_prior += LL_correction else: self.correction_LL = 0 # print(f"{self.correction_LL-self.log_prior=},{self.correction_LL=}") return self.log_prior
[docs] def Delta_log_prior(self, host, T_end, T_ini): """ Calculate the difference in log prior between two transmission tree states. Parameters ---------- host : object The host for which to calculate the log prior difference. T_end : networkx.DiGraph The final transmission tree state. T_ini : networkx.DiGraph The initial transmission tree state. Returns ------- float The difference in log prior: log_prior(T_end) - log_prior(T_ini). Notes ----- This method calculates how the log prior changes when a transmission tree transitions from state T_ini to T_end. It considers: 1. Changes in parent relationships 2. Changes in sibling relationships The calculation is useful for MCMC acceptance ratios where only the difference in log prior is needed, not the absolute values. """ Delta = 0 if not host.sampled: return 0 if T_ini is None: T_ini = self.model.T # Parent if host != self.model.root_host: # Ini parent = genetic_prior_tree.search_first_sampled_parent(host, T_ini, self.model.root_host) if parent is None: D_time_ini = 0 D_gen_ini = 0 LL_ini = 0 else: D_time_ini = host.t_sample - parent.t_sample D_gen_ini = self.distance_matrix[host.index, parent.index] LL_ini = np.log(self.prior_dist.pmf(D_time_ini * D_gen_ini)) # print("parent ini",D_time_ini,D_gen_ini,LL_ini) # End parent = genetic_prior_tree.search_first_sampled_parent(host, T_end, self.model.root_host) if parent is None: D_time_end = 0 D_gen_end = 0 LL_end = 0 else: D_time_end = host.t_sample - parent.t_sample D_gen_end = self.distance_matrix[host.index, parent.index] LL_end = np.log(self.prior_dist.pmf(D_time_end * D_gen_end)) # print("parent end",D_time_end,D_gen_end,LL_end) Delta += LL_end - LL_ini # Sons siblings = genetic_prior_tree.search_firsts_sampled_siblings(host, T_ini, self.distance_matrix) LL = 0 for h in siblings: D_time = h.t_sample - host.t_sample D_gen = self.distance_matrix[host.index, h.index] LL -= np.log(self.prior_dist.pmf(D_time * D_gen)) # print("sibling ini",D_time,D_gen,LL,p.prior_dist.pmf(D_time*D_gen)) siblings = genetic_prior_tree.search_firsts_sampled_siblings(host, T_end, self.distance_matrix) for h in siblings: D_time = h.t_sample - host.t_sample D_gen = self.distance_matrix[host.index, h.index] LL += np.log(self.prior_dist.pmf(D_time * D_gen)) Delta += LL return Delta
[docs]def get_roots_data_subtrees(host, T, dist_matrix): """ Get all sampled hosts with genetic data in subtrees rooted at a given host. Parameters ---------- host : object The root host of the subtrees to search. T : networkx.DiGraph The transmission tree. dist_matrix : numpy.ndarray Matrix containing pairwise genetic distances between hosts. Returns ------- list List of sampled hosts that have valid genetic distance data. Notes ----- This function recursively searches through all subtrees rooted at the given host and collects all sampled hosts that have non-NaN values in the distance matrix (indicating they have genetic sequence data). """ sampled_hosts = [] for h in T.successors(host): if not np.isnan(dist_matrix[int(h), int(h)]) and h.sampled: sampled_hosts.append(h) else: sampled_hosts += get_roots_data_subtrees(h, T, dist_matrix) return sampled_hosts