Dimitre commited on
Commit
d6e5859
·
1 Parent(s): 90de23d

deleting empty folder

Browse files
Files changed (3) hide show
  1. src/app.py +0 -232
  2. src/common.py +0 -55
  3. src/hint.py +0 -149
src/app.py DELETED
@@ -1,232 +0,0 @@
1
- import logging
2
- import os
3
- from typing import Any
4
-
5
- import pandas as pd
6
- import streamlit as st
7
- from countryinfo import CountryInfo
8
- from dotenv import load_dotenv
9
-
10
- from common import HintType, configs, get_distance
11
- from hint import AudioHint, ImageHint, TextHint
12
-
13
-
14
- def setup_models(_cache: Any, configs: dict) -> None:
15
- """Setups all hint models.
16
-
17
- Args:
18
- _cache (st.session_state): Streamlit cache object
19
- configs (dict): Configurations used by the models
20
- """
21
- for model_type in _cache["hint_types"]:
22
- if _cache["model"][model_type] is None:
23
- if model_type == HintType.TEXT.value:
24
- _cache["model"][model_type] = setup_text_hint(configs)
25
- elif model_type == HintType.IMAGE.value:
26
- _cache["model"][model_type] = setup_image_hint(configs)
27
- elif model_type == HintType.AUDIO.value:
28
- _cache["model"][model_type] = setup_audio_hint(configs)
29
-
30
-
31
- @st.cache_resource()
32
- def setup_text_hint(configs: dict) -> TextHint:
33
- """Setups the text hint model.
34
-
35
- Args:
36
- configs (dict): Configurations used by the model
37
-
38
- Returns:
39
- TextHint: Hint model
40
- """
41
- with st.spinner("Loading text model..."):
42
- model_configs = configs["local"][HintType.TEXT.value.lower()]
43
- model_configs["hf_access_token"] = os.environ["HF_ACCESS_TOKEN"]
44
- textHint = TextHint(configs=model_configs)
45
- textHint.initialize()
46
- return textHint
47
-
48
-
49
- @st.cache_resource()
50
- def setup_image_hint(configs: dict) -> ImageHint:
51
- """Setups the image hint model.
52
-
53
- Args:
54
- configs (dict): Configurations used by the model
55
-
56
- Returns:
57
- ImageHint: Hint model
58
- """
59
- with st.spinner("Loading image model..."):
60
- model_configs = configs["local"][HintType.IMAGE.value.lower()]
61
- imageHint = ImageHint(configs=model_configs)
62
- imageHint.initialize()
63
- return imageHint
64
-
65
-
66
- @st.cache_resource()
67
- def setup_audio_hint(configs: dict) -> AudioHint:
68
- """Setups the audio hint model.
69
-
70
- Args:
71
- configs (dict): Configurations used by the model
72
-
73
- Returns:
74
- AudioHint: Hint model
75
- """
76
- with st.spinner("Loading audio model..."):
77
- model_configs = configs["local"][HintType.AUDIO.value.lower()]
78
- audioHint = AudioHint(configs=model_configs)
79
- audioHint.initialize()
80
- return audioHint
81
-
82
-
83
- @st.cache_resource()
84
- def get_country_list() -> pd.DataFrame:
85
- """Builds a database of countries and metadata.
86
-
87
- Returns:
88
- pd.DataFrame: Country database
89
- """
90
- country_list = list(CountryInfo().all().keys())
91
-
92
- country_df = {}
93
- for country in country_list:
94
- try:
95
- area = CountryInfo(country).area()
96
- country_df[country] = area
97
- except:
98
- pass
99
-
100
- country_df = pd.DataFrame(country_df.items(), columns=["country", "area"])
101
- return country_df
102
-
103
-
104
- def pick_country(country_df: pd.DataFrame) -> str:
105
- """Selects a country, the probability of each country is related to its area size.
106
-
107
- Args:
108
- country_df (pd.DataFrame): Database of country and their metadata
109
-
110
- Returns:
111
- str: The selected country
112
- """
113
- country = country_df.sample(n=1, weights="area")["country"].iloc[0]
114
- return country
115
-
116
-
117
- def reset_cache() -> None:
118
- """Reset the Streamlit APP cache."""
119
- country_df = get_country_list()
120
- st.session_state["country_list"] = country_df["country"].values.tolist()
121
- st.session_state["country"] = pick_country(country_df)
122
- st.session_state["hint_types"] = []
123
- st.session_state["n_hints"] = 1
124
- st.session_state["game_started"] = False
125
- st.session_state["model"] = {
126
- HintType.TEXT.value: None,
127
- HintType.IMAGE.value: None,
128
- HintType.AUDIO.value: None,
129
- }
130
-
131
-
132
- logging.basicConfig(level=logging.INFO)
133
- logger = logging.getLogger(__name__)
134
-
135
- st.set_page_config(
136
- page_title="Gen AI GeoGuesser",
137
- page_icon="🌎",
138
- )
139
-
140
- if not st.session_state:
141
- load_dotenv()
142
- reset_cache()
143
-
144
- st.title("Generative AI GeoGuesser 🌎")
145
-
146
- st.markdown("### Guess the country based on hints generated by AI")
147
-
148
- col1, col2 = st.columns([2, 1])
149
-
150
- with col1:
151
- st.session_state["hint_types"] = st.multiselect(
152
- "Chose which hint types you want",
153
- [x.value for x in HintType],
154
- default=st.session_state["hint_types"],
155
- )
156
-
157
- with col2:
158
- st.session_state["n_hints"] = st.slider(
159
- "Number of hints",
160
- min_value=1,
161
- max_value=5,
162
- value=st.session_state["n_hints"],
163
- )
164
-
165
- start_btn = st.button("Start game")
166
-
167
- if start_btn:
168
- if not st.session_state["hint_types"]:
169
- st.error("Pick at least one hint type")
170
- reset_cache()
171
- else:
172
- print(f'Chosen country "{st.session_state["country"]}"')
173
-
174
- setup_models(st.session_state, configs)
175
-
176
- for hint_type in st.session_state["hint_types"]:
177
- with st.spinner(f"Generating {hint_type} hint..."):
178
- st.session_state["model"][hint_type].generate_hint(
179
- st.session_state["country"],
180
- st.session_state["n_hints"],
181
- )
182
-
183
- st.session_state["game_started"] = True
184
-
185
- if st.session_state["game_started"]:
186
- game_col1, game_col2, game_col3 = st.columns([2, 1, 1])
187
-
188
- with game_col1:
189
- guess = st.selectbox("Country guess", ([""] + st.session_state["country_list"]))
190
- with game_col2:
191
- guess_btn = st.button("Make a guess")
192
- with game_col3:
193
- reset_btn = st.button("Reset game")
194
-
195
- if guess_btn:
196
- if st.session_state["country"] == guess:
197
- st.success("Correct guess you won!")
198
- st.balloons()
199
- else:
200
- if guess:
201
- country_latlong = CountryInfo(st.session_state["country"]).latlng()
202
- guess_latlong = CountryInfo(guess).latlng()
203
- distance = int(get_distance(country_latlong, guess_latlong))
204
- st.error(
205
- f"""
206
- Wrong guess, you missed the correct country by {distance} KM.
207
- The correct answer was {st.session_state["country"]}.
208
- """
209
- )
210
- else:
211
- st.error("Pick a country.")
212
-
213
- if reset_btn:
214
- reset_cache()
215
-
216
- if st.session_state["game_started"]:
217
- tabs = st.tabs([f"{x} hint" for x in st.session_state["hint_types"]])
218
-
219
- for tab_idx, tab in enumerate(tabs):
220
- hint_type = st.session_state["hint_types"][tab_idx]
221
- with tab:
222
- if st.session_state["model"][hint_type]:
223
- for hint_idx, hint in enumerate(
224
- st.session_state["model"][hint_type].hints
225
- ):
226
- st.markdown(f"#### Hint #{hint_idx+1}")
227
- if hint_type == HintType.TEXT.value:
228
- st.write(hint["text"])
229
- elif hint_type == HintType.IMAGE.value:
230
- st.image(hint["image"])
231
- elif hint_type == HintType.AUDIO.value:
232
- st.audio(hint["audio"], sample_rate=hint["sample_rate"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/common.py DELETED
@@ -1,55 +0,0 @@
1
- import logging
2
- import pprint
3
- from enum import Enum
4
- from math import acos, cos, radians, sin
5
-
6
- import yaml
7
-
8
-
9
- def parse_configs(configs_path: str) -> dict:
10
- """Parse configs from the YAML file.
11
-
12
- Args:
13
- configs_path (str): Path to the YAML file
14
-
15
- Returns:
16
- dict: Parsed configs
17
- """
18
- configs = yaml.safe_load(open(configs_path, "r"))
19
- logger.info(f"Configs: {pprint.pformat(configs)}")
20
- return configs
21
-
22
-
23
- def get_distance(source_country: list[float], target_country: list[float]) -> float:
24
- """Calculate the distance between two countries.
25
-
26
- Args:
27
- source_country (list[float]): Source country coordinates
28
- target_country (list[float]): Target country coordinates
29
-
30
- Returns:
31
- float: Distance in KM
32
- """
33
- source_lat = radians(source_country[0])
34
- source_long = radians(source_country[1])
35
- target_lat = radians(target_country[0])
36
- target_long = radians(target_country[1])
37
- dist = 6371.01 * acos(
38
- sin(source_lat) * sin(target_lat)
39
- + cos(source_lat) * cos(target_lat) * cos(source_long - target_long)
40
- )
41
- return dist
42
-
43
-
44
- class HintType(Enum):
45
- AUDIO = "Audio"
46
- TEXT = "Text"
47
- IMAGE = "Image"
48
-
49
-
50
- CONFIGS_PATH = "configs.yaml"
51
-
52
- logging.basicConfig(level=logging.INFO)
53
- logger = logging.getLogger(__file__)
54
-
55
- configs = parse_configs(CONFIGS_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/hint.py DELETED
@@ -1,149 +0,0 @@
1
- import abc
2
- import logging
3
- import re
4
- from typing import Any
5
-
6
- import torch
7
- from diffusers import AudioLDM2Pipeline, AutoPipelineForText2Image
8
- from pydantic import BaseModel
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- SAMPLE_RATE = 16000
16
-
17
-
18
- class BaseHint(BaseModel, abc.ABC):
19
- configs: dict
20
- hints: list = []
21
- model: Any = None
22
-
23
- @abc.abstractmethod
24
- def initialize(self):
25
- """Initialize the hint model."""
26
- pass
27
-
28
- @abc.abstractmethod
29
- def generate_hint(self, country: str, n_hints: int):
30
- """Generate hints.
31
-
32
- Args:
33
- country (str): Country name used to base the hint
34
- n_hints (int): Number of hints that will be generated
35
- """
36
- pass
37
-
38
-
39
- class TextHint(BaseHint):
40
- tokenizer: Any = None
41
-
42
- def initialize(self):
43
- logger.info(
44
- f"""Initializing text hint with model '{self.configs["model_id"]}'"""
45
- )
46
- self.tokenizer = AutoTokenizer.from_pretrained(
47
- self.configs["model_id"],
48
- token=self.configs["hf_access_token"],
49
- )
50
- self.model = AutoModelForCausalLM.from_pretrained(
51
- self.configs["model_id"],
52
- torch_dtype=torch.float16,
53
- token=self.configs["hf_access_token"],
54
- ).to(self.configs["device"])
55
- logger.info("Initialization finisehd")
56
-
57
- def generate_hint(self, country: str, n_hints: int):
58
- logger.info(f"Generating '{n_hints}' text hints")
59
-
60
- generation_config = GenerationConfig(
61
- do_sample=True,
62
- max_new_tokens=self.configs["max_output_tokens"],
63
- top_k=self.configs["top_k"],
64
- top_p=self.configs["top_p"],
65
- temperature=self.configs["temperature"],
66
- )
67
-
68
- prompt = [
69
- f'Describe the country "{country}" without mentioning its name\n'
70
- for _ in range(n_hints)
71
- ]
72
- input_ids = self.tokenizer(prompt, return_tensors="pt")
73
- text_hints = self.model.generate(
74
- **input_ids.to(self.configs["device"]),
75
- generation_config=generation_config,
76
- )
77
-
78
- for idx, text_hint in enumerate(text_hints):
79
- text_hint = (
80
- self.tokenizer.decode(text_hint, skip_special_tokens=True)
81
- .strip()
82
- .replace(prompt[idx], "")
83
- .strip()
84
- )
85
- text_hint = re.sub(
86
- re.escape(country), "***", text_hint, flags=re.IGNORECASE
87
- )
88
-
89
- self.hints.append({"text": text_hint})
90
-
91
- logger.info(f"Text hints '{n_hints}' successfully generated")
92
-
93
-
94
- class ImageHint(BaseHint):
95
- def initialize(self):
96
- logger.info(
97
- f"""Initializing image hint with model '{self.configs["model_id"]}'"""
98
- )
99
- self.model = AutoPipelineForText2Image.from_pretrained(
100
- self.configs["model_id"],
101
- torch_dtype=torch.float16,
102
- variant="fp16",
103
- ).to(self.configs["device"])
104
- logger.info("Initialization finisehd")
105
-
106
- def generate_hint(self, country: str, n_hints: int):
107
- logger.info(f"Generating '{n_hints}' image hints")
108
- prompt = [f"An image related to the country {country}" for _ in range(n_hints)]
109
- img_hints = self.model(
110
- prompt=prompt,
111
- num_inference_steps=self.configs["num_inference_steps"],
112
- guidance_scale=self.configs["guidance_scale"],
113
- ).images
114
- self.hints = [{"image": img_hint} for img_hint in img_hints]
115
- logger.info(f"Image hints '{n_hints}' successfully generated")
116
-
117
-
118
- class AudioHint(BaseHint):
119
- def initialize(self):
120
- logger.info(
121
- f"""Initializing audio hint with model '{self.configs["model_id"]}'"""
122
- )
123
- self.model = AudioLDM2Pipeline.from_pretrained(
124
- self.configs["model_id"],
125
- # torch_dtype=torch.float16, # Not working with MacOS
126
- ).to(self.configs["device"])
127
- logger.info("Initialization finisehd")
128
-
129
- def generate_hint(self, country: str, n_hints: int):
130
- logger.info(f"Generating '{n_hints}' audio hints")
131
- prompt = f"A sound that resembles the country of {country}"
132
- negative_prompt = "Low quality"
133
-
134
- audio_hints = self.model(
135
- prompt,
136
- negative_prompt=negative_prompt,
137
- num_inference_steps=self.configs["num_inference_steps"],
138
- audio_length_in_s=self.configs["audio_length_in_s"],
139
- num_waveforms_per_prompt=n_hints,
140
- ).audios
141
-
142
- for audio_hint in audio_hints:
143
- self.hints.append(
144
- {
145
- "audio": audio_hint,
146
- "sample_rate": SAMPLE_RATE,
147
- }
148
- )
149
- logger.info(f"Audio hints '{n_hints}' successfully generated")