File size: 13,748 Bytes
17191f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import os
import pandas as pd
import math
import pickle
import pprint
pp = pprint.PrettyPrinter(indent=4)

# For phylogeny parsing
# !pip install opentree
from opentree import OT
# !pip install ete3
from ete3 import Tree, PhyloTree

# Constants
Fix_Tree = True
format_ = 1 #8

class Phylogeny:
    # Phylogeny class for Fish dataset
    # If node_ids is None, it assumes that the tree already exists. Otherwise, you have to pass node_ids (i.e., list of species names).
    def __init__(self, filePath, node_ids=None, verbose=False):
        # filenames for phylo tree and cached mapping ottid-speciesname
        cleaned_fine_tree_fileName = "cleaned_metadata.tre"
        name_conversion_file = "name_conversion.pkl"
        self.ott_ids = []
        self.ott_id_dict = {}
        self.node_ids = node_ids
        self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
        self.conversionFileNameAndPath = os.path.join(filePath, name_conversion_file)
        self.total_distance = -1 # -1 means we never calculated it before.

        self.distance_matrix = {}
        self.species_groups_within_relative_distance = {}

        self.get_ott_ids(node_ids, verbose=verbose)
        self.get_tree(self.treeFileNameAndPath)
        self.get_total_distance()
    
    # Given two species names, get the phylo distance between them
    def get_distance(self, species1, species2):
        d= None
        if self.distance_matrix[species1][species2] == -1:
            if species1 == species2:
                return 0

            ott_id1 = 'ott' + str(self.ott_id_dict[species1])
            ott_id2 = 'ott' + str(self.ott_id_dict[species2])
            d = self.tree.get_distance(ott_id1, ott_id2)

            self.distance_matrix[species1][species2] = d
        else:
            d = self.distance_matrix[species1][species2]

        return d

    # relative_distance = 0 => species node itself
    # relative_distance = 1 => all species 
    def get_siblings_by_name(self, species, relative_distance, verbose=False):
        self.get_species_groups(relative_distance, verbose)
        for species_group in self.species_groups_within_relative_distance[relative_distance]:
            if species in species_group:
                return species_group
        
        raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]
    
    def get_parent_by_name(self, species, relative_distance, verbose=False):
        ott_id = 'ott' + str(self.ott_id_dict[species])
        parent = self.get_parent_by_ottid(ott_id, relative_distance, verbose)
        return parent
    
    def get_distance_between_parents(self, species1, species2, relative_distance):
        parent1 = self.get_parent_by_name(species1, relative_distance)
        parent2 = self.get_parent_by_name(species2, relative_distance)
        return self.tree.get_distance(parent1, parent2)
    
    def get_species_groups(self, relative_distance, verbose=False):
        if relative_distance not in self.species_groups_within_relative_distance.keys():
            groups = {}

            for species in self.getLabelList():
                parent_node = self.get_parent_by_name(species, relative_distance, verbose)
                parent = parent_node.name
                if parent not in groups.keys():
                    groups[parent] = [species]
                else:
                    groups[parent].append(species)
            
            self.species_groups_within_relative_distance[relative_distance] = groups.values()
            
            if verbose:
                print("At relative_distance", relative_distance, ", the groups are:", groups.values())
        
        return self.species_groups_within_relative_distance[relative_distance]
                
            

    def getLabelList(self):
        return list(self.node_ids)


    # ------- privete functions

    def get_total_distance(self):
        if self.node_ids is None:
            self.node_ids = self.ott_id_dict.keys()

        self.init_distance_matrix()

        # For one time, measure distance from all leaves down to root. They all should be equal.
        # Save the value and reuse it.
        
        if self.total_distance==-1:
            for leaf in self.tree.iter_leaves():
                total_distance = self.tree.get_distance(leaf) # gets distance to rootprint
                assert math.isclose(self.total_distance, total_distance) or self.total_distance==-1
                self.total_distance = total_distance

        return self.total_distance

    def init_distance_matrix(self):
        for i in self.node_ids:
            self.distance_matrix[i] = {}
            for j in self.node_ids:
                self.distance_matrix[i][j] = -1
                
    def get_parent_by_ottid(self, ott_id, relative_distance, verbose=False):
        abs_distance = relative_distance*self.total_distance
        species_node = self.tree.search_nodes(name=ott_id)[0]
        if verbose:
            print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)

        # keep going up till distance exceeds abs_distance
        distance = 0
        parent = species_node
        while distance < abs_distance:
            if parent.up is None:
                break
            parent = parent.up
            distance = self.tree.get_distance(parent, species_node)
        
        return parent



    #     return ott_id_list
    # node_ids: list of taxa
    # returns: corresponding list of ott_ids
    def get_ott_ids(self, node_ids, verbose=False):
        if not os.path.exists(self.conversionFileNameAndPath):
            if node_ids is None:
                raise TypeError('No existing ottid-speciesnames found. node_ids should be a list of species names.')
            if verbose:
                print('Included taxonomy: ', node_ids, len(node_ids))
                df2 = pd.DataFrame(columns=['in csv', 'in response', 'Same?'])

            # Get the matches
            resp = OT.tnrs_match(node_ids, do_approximate_matching=True)
            matches = resp.response_dict['results']
            unmatched_names = resp.response_dict['unmatched_names']

            # Get the corresponding ott_ids
            ott_ids = set()
            ott_id_dict={}
            assert len(unmatched_names)==0 # everything is matched!
            for match_array in matches:
                match_array_matches = match_array['matches']
                assert len(match_array_matches)==1, match_array['name'] + " has too many matches" + str(list(map(lambda x: x['matched_name'], match_array_matches)))  # we have a single unambiguous match!
                first_match = match_array_matches[0]
                ott_id = first_match['taxon']['ott_id']
                ott_ids.add(ott_id)
                if verbose:
                    #some original and matched names are not exactly the same. Not a bug
                    df2 = df2.append({'in csv':match_array['name'], 'in response': first_match['matched_name'], 'Same?': match_array['name'] == first_match['matched_name']}, ignore_index=True)
                ott_id_dict[match_array['name']] = ott_id
            ott_ids = list(ott_ids)

            if verbose:
                print(df2[df2['Same?']== False])
                pp.pprint(ott_id_dict)

            with open(self.conversionFileNameAndPath, 'wb') as f:
                pickle.dump([ott_ids, ott_id_dict], f)
        else:
            with open(self.conversionFileNameAndPath, 'rb') as f:
                ott_ids, ott_id_dict = pickle.load(f)

        

        self.ott_ids = ott_ids
        self.ott_id_dict = ott_id_dict
        print(self.ott_id_dict)

    def fix_tree(self, treeFileNameAndPath):
        tree = PhyloTree(treeFileNameAndPath, format=format_)

        # Special case for Fish dataset: Fix Esox Americanus.
        D = tree.search_nodes(name="mrcaott47023ott496121")[0]
        D.name = "ott496115"
        tree.write(format=format_, outfile=treeFileNameAndPath)
    
    def get_tree(self, treeFileNameAndPath):
        if not os.path.exists(treeFileNameAndPath):
            output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),

            output.tree.write(path = treeFileNameAndPath, schema = "newick")

            if Fix_Tree:
                self.fix_tree(treeFileNameAndPath)

        self.tree = PhyloTree(treeFileNameAndPath, format=format_)

class PhylogenyCUB:
    # Phylogeny class for CUB dataset
    def __init__(self, filePath, node_ids=None, verbose=False):
        # cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies.phy"
        # cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-AllSpecies-cub-names.phy"
        cleaned_fine_tree_fileName = "1_tree-consensus-Hacket-27Species-cub-names.phy"
        self.node_ids = node_ids
        self.treeFileNameAndPath = os.path.join(filePath, cleaned_fine_tree_fileName)
        self.total_distance = -1 # -1 means we never calculated it before.

        self.distance_matrix = {}
        self.species_groups_within_relative_distance = {}

        self.get_tree(self.treeFileNameAndPath)
        self.get_total_distance()
    
    # Given two species names, get the phylo distance between them
    def get_distance(self, species1, species2):
        d= None
        if self.distance_matrix[species1][species2] == -1:
            if species1 == species2:
                return 0
            d = self.tree.get_distance(species1, species2)

            self.distance_matrix[species1][species2] = d
        else:
            d = self.distance_matrix[species1][species2]

        return d

    # relative_distance = 0 => species node itself
    # relative_distance = 1 => all species 
    def get_siblings_by_name(self, species, relative_distance, verbose=False):
        #NOTE: This implementation was causing inconsistencies since finding the parent.get_leaves() was not equivalent to get_species_groups 
        # ott_id = 'ott' + str(self.ott_id_dict[species])
        # return self.get_siblings_by_ottid(ott_id, relative_distance, get_ottids, verbose)
        
        self.get_species_groups(relative_distance, verbose)
        for species_group in self.species_groups_within_relative_distance[relative_distance]:
            if species in species_group:
                return species_group
        
        raise species+" was not found in " + self.species_groups_within_relative_distance[relative_distance]

    def get_parent_by_name(self, species, relative_distance, verbose=False):
        abs_distance = relative_distance*self.total_distance
        species_node = self.tree.search_nodes(name=species)[0]
        if verbose:
            print('distance to ancestor: ', abs_distance, ". relaive distance: ", relative_distance)

        # keep going up till distance exceeds abs_distance
        distance = 0
        parent = species_node
        while distance < abs_distance:
            if parent.up is None:
                break
            parent = parent.up
            distance = self.tree.get_distance(parent, species_node)
        
        return parent
    
    def get_distance_between_parents(self, species1, species2, relative_distance):
        parent1 = self.get_parent_by_name(species1, relative_distance)
        parent2 = self.get_parent_by_name(species2, relative_distance)
        return self.tree.get_distance(parent1, parent2)
    
    def get_species_groups(self, relative_distance, verbose=False):
        if relative_distance not in self.species_groups_within_relative_distance.keys():
            groups = {}

            for species in self.getLabelList():
                parent_node = self.get_parent_by_name(species, relative_distance, verbose)
                parent = parent_node.name
                if parent not in groups.keys():
                    groups[parent] = [species]
                else:
                    groups[parent].append(species)
            
            self.species_groups_within_relative_distance[relative_distance] = groups.values()
            
            if verbose:
                print("At relative_distance", relative_distance, ", the groups are:", groups.values())
        
        return self.species_groups_within_relative_distance[relative_distance]


    def getLabelList(self):
        return list(self.node_ids)


    # ------- privete functions

    def get_total_distance(self):
        if self.node_ids is None:
            self.node_ids = sorted([leaf.name for leaf in self.tree.iter_leaves()])

        self.init_distance_matrix()

        # maximum distance between root and lead node taken as total distance
        leaf_to_root_distances = [self.tree.get_distance(leaf) for leaf in self.tree.iter_leaves()]
        self.total_distance = max(leaf_to_root_distances)

        return self.total_distance

    def init_distance_matrix(self):
        for i in self.node_ids:
            self.distance_matrix[i] = {}
            for j in self.node_ids:
                self.distance_matrix[i][j] = -1
    
    def get_tree(self, treeFileNameAndPath):
        # if not os.path.exists(treeFileNameAndPath):
        #     output = OT.synth_induced_tree(ott_ids=self.ott_ids, ignore_unknown_ids=False, label_format='id') # name_and_id ott_ids=list(ott_ids),

        #     output.tree.write(path = treeFileNameAndPath, schema = "newick")

        self.tree = PhyloTree(treeFileNameAndPath, format=format_)

        # setting a dummy name to the internal nodes if it is unnamed
        for i, node in enumerate(self.tree.traverse("postorder")):
            if not len(node.name) > 0:
                node.name = str(i)