File size: 2,065 Bytes
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from wordcloud import WordCloud
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List


class SimpleGroupedColorFunc(object):
    """Create a color function object which assigns EXACT colors
       to certain words based on the color to words mapping

       Parameters
       ----------
       color_to_words : dict(str -> list(str))
         A dictionary that maps a color to the list of words.

       default_color : str
         Color that will be assigned to a word that's not a member
         of any value from color_to_words.
    """

    def __init__(
        self, 
        color_to_words: Dict, 
        default_color: str
    ) -> Dict:

        self.word_to_color = {
            word: color
            for (color, words) in color_to_words.items()
            for word in words
        }

        self.default_color = default_color

    def __call__(self, word, **kwargs):
        return self.word_to_color.get(word, self.default_color)


class SegmentedWordCloud:
    def __init__(
        self, 
        freq_dic: Dict[str, int], 
        less_group: List[str], 
        greater_group: List[str]
    ) -> WordCloud:

        colors = {
            'less': '#529ef3',
            'salient':'#d35400',
            'greater':'#5d6d7e',
        }

        color_to_words = {
            colors['greater']: greater_group,
            colors['less']: less_group,
        }
        

        grouped_color_func = SimpleGroupedColorFunc(
            color_to_words=color_to_words, 
            default_color=colors['salient']
        )

        self.wc = WordCloud(
            background_color="white", 
            width=900, 
            height=300,
            random_state=None).generate_from_frequencies(freq_dic)
        
        self.wc.recolor(color_func=grouped_color_func)

    def plot(
        self, 
        figsize: Tuple[int,int]
    ) -> plt.Figure:
    
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(self.wc, interpolation="bilinear")
        ax.axis("off")
        fig.tight_layout()
        return fig