File size: 4,402 Bytes
5120311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from function.topic_clustering import model, AgglomerativeClustering

def check_duplicate_title_domain(docs):
    lst_title_domain = [f"{d.get('domain', '')} {d.get('title','')}" for d in docs]
    for i in range(1,len(lst_title_domain) -1):
        for j in range(i+1,len(lst_title_domain)):
            if lst_title_domain[j] == lst_title_domain[i]:
                lst_title_domain[j] = 'dup'
    lst_filter_docs = [docs[i] for i,x in enumerate(lst_title_domain) if x != 'dup']
    return lst_filter_docs

def main(req):
    # threshold = req.get('threshold', 0.3)
    type = req['type']
    if type == 'monthly':
        MAX_CLUSTER = 50
    else:
        MAX_CLUSTER = 20

    MAX_NUM_DOC_PER_CLUSTER = 50

    threshold = 0.4

    preprocess = req.get('preprocess', [])
    lst_labels = []
    lst_topics = []
    for date_clusters in preprocess:
        # date = date_clusters['date']
        topic = date_clusters.get('topic', [])
        if topic:
            for topic_id in topic:
                # print(topic_id)
                topic_docs = topic[topic_id]
                lst_topics.append(topic[topic_id])
                label = '. '.join([topic_docs[0].get('title',''),topic_docs[0].get('snippet','')]) 
                lst_labels.append(label)
    
    final_clusters = []
    label_clusters = sbert_clustering(lst_labels, distance_threshold=threshold,return_ids=True)
    
    # print(lst_labels)
    print(label_clusters)

    if label_clusters:
        for id_label_clusters in label_clusters:
            merge_clusters = []
            num_docs = 0
            for topic_id in label_clusters[id_label_clusters]:
                topic = lst_topics[topic_id]
                count_doc = topic[0].get('num_docs',1)
                num_docs += count_doc
                merge_clusters.extend(lst_topics[topic_id])
                
            merge_clusters = sorted(merge_clusters, key=lambda x: -x.get('created_time',0))
            merge_clusters = check_duplicate_title_domain(merge_clusters)

            merge_clusters = merge_clusters[:MAX_NUM_DOC_PER_CLUSTER]
            for doc in merge_clusters:
                doc['num_docs'] = num_docs
            final_clusters.append(merge_clusters)
    
    final_clusters = sorted(final_clusters, key=lambda x: -x[0]['num_docs'])
    final_clusters = final_clusters[:MAX_CLUSTER]

    final_result = {}
    for i,cluster in enumerate(final_clusters):
        final_result[i] = cluster
    with open('zzz.json','w') as f:
        json.dump(final_result, f, ensure_ascii=False)
    return final_result

def get_sbert_embedding(lst_sentence):
    embs = model.encode(lst_sentence)
    # embs = np.array(embs)
    return embs

def sbert_clustering(lst_sentence, distance_threshold=0.25, return_ids = False):
    lst_sentence = [sen.replace('_',' ') for sen in lst_sentence]
    if len(lst_sentence) == 0:
        return
    if len(lst_sentence) == 1:
        if return_ids:
            return {
                0: [0]
            }
        return {
            0: lst_sentence
        }
    
    # embs = model.encode(lst_sentence, show_progress_bar=True)
    embs = get_sbert_embedding(lst_sentence)

    hyer_clusteror = AgglomerativeClustering(n_clusters = None,compute_full_tree = True, affinity = 'cosine', 
                                        linkage = 'complete', distance_threshold=distance_threshold)
    # print(f'[INFO] Len lst_sentence: {len(lst_sentence)}')
    # print(f'[INFO] Len embs: {len(embs)}')
    hyer_clusteror.fit(embs)
    # print(hyer_clusteror.n_clusters_)

    dict_result = {}
    dict_ids = {}
    for i in range(hyer_clusteror.n_clusters_):
        if i not in dict_result:
            dict_result[i] = []
            dict_ids[i] = []
        for j in range(len(lst_sentence)):
            if hyer_clusteror.labels_[j] == i:
                dict_result[i].append(lst_sentence[j])
                dict_ids[i].append(j)
    
    if return_ids:
        output = dict_ids
    else:
        output = dict_result
    result = dict(sorted(output.items(), key=lambda i: -len(i[1])))
    return result

if __name__ == '__main__':
    with open("input_merge.json",'r') as f:
        req = json.load(f)
    main(req)