nktssk commited on
Commit
1f706a9
·
verified ·
1 Parent(s): a4f756a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -126
app.py CHANGED
@@ -1,177 +1,150 @@
1
  import os
2
- import io
 
3
  import torch
 
 
 
 
 
4
  import gradio as gr
5
  import wikipediaapi
6
- import re
7
- import inflect
8
- import soundfile as sf
9
- import unicodedata
10
- import num2words
11
- import requests
12
- import json
13
  from PIL import Image
14
- from num2words import num2words
15
- from google.cloud import vision
16
- from datasets import load_dataset
17
- from scipy.io.wavfile import write
18
- from transformers import VitsModel, AutoTokenizer
19
- from transformers import pipeline
20
- from transformers import CLIPProcessor, CLIPModel
21
- from transformers import T5ForConditionalGeneration, T5Tokenizer
22
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
23
 
24
  def load_attractions_json(url):
25
- response = requests.get(url)
26
- response.raise_for_status()
27
- json_text = response.text
28
- data = json.loads(json_text)
29
- return data
30
 
31
- url = "https://raw.githubusercontent.com/nktssk/tourist-helper/refs/heads/main/landmarks.json"
32
  landmark_titles = load_attractions_json(url)
33
 
34
- print(landmark_titles)
35
-
36
- # HELPERS
37
  def clean_text(text):
38
  text = re.sub(r'МФА:?\s?\[.*?\]', '', text)
39
  text = re.sub(r'\[.*?\]', '', text)
40
-
41
- def remove_diacritics(char):
42
- if unicodedata.category(char) == 'Mn':
43
- return ''
44
- return char
45
-
46
  text = unicodedata.normalize('NFD', text)
47
- text = ''.join(remove_diacritics(char) for char in text)
48
  text = unicodedata.normalize('NFC', text)
49
-
50
  text = re.sub(r'\s+', ' ', text)
51
  text = re.sub(r'[^\w\s.,!?-]', '', text)
52
-
53
  return text.strip()
54
 
55
- def replace_numbers_with_text(input_string):
56
- def convert_number(match):
57
- number = match.group(0)
58
- try:
59
- return num2words(float(number) if '.' in number else int(number), lang='ru')
60
- except Exception:
61
- return number
62
- return re.sub(r'\d+(\.\d+)?', convert_number, input_string)
63
-
64
- # MODELS
65
- summarization_model = pipeline("summarization", model="facebook/bart-large-cnn")
66
- wiki = wikipediaapi.Wikipedia("Nikita", "en")
67
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
68
- t2s_pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru")
 
70
 
71
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
72
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
73
 
74
- text_inputs = clip_processor(
75
- text=landmark_titles,
76
- images=None,
77
- return_tensors="pt",
78
- padding=True
79
- )
80
  with torch.no_grad():
81
  text_embeds = clip_model.get_text_features(**text_inputs)
82
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
83
 
84
- # TEXT-TO-SPEECH
85
- def text_to_speech(text, output_path="speech.wav"):
86
- text = replace_numbers_with_text(text)
87
- model = VitsModel.from_pretrained("facebook/mms-tts-rus")
88
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
89
-
90
- inputs = tokenizer(text, return_tensors="pt")
91
-
92
- with torch.no_grad():
93
- output = model(**inputs).waveform.squeeze().numpy()
94
-
95
- sf.write(output_path, output, samplerate=model.config.sampling_rate)
96
 
97
- return output_path
 
 
 
 
98
 
99
- # WIKI
100
  def fetch_wikipedia_summary(landmark):
101
  page = wiki.page(landmark)
102
- if page.exists():
103
- return clean_text(page.summary)
104
- else:
105
- return "Found error!"
106
 
107
- # CLIP
108
  def recognize_landmark_clip(image):
109
  if not isinstance(image, Image.Image):
110
  image = Image.fromarray(image)
111
-
112
- image_inputs = clip_processor(images=image, return_tensors="pt")
113
  with torch.no_grad():
114
- image_embed = clip_model.get_image_features(**image_inputs)
115
- image_embed = image_embed / image_embed.norm(p=2, dim=-1, keepdim=True)
116
-
117
- similarity = (image_embed @ text_embeds.T).squeeze(0)
118
- best_idx = similarity.argmax().item()
119
- best_score = similarity[best_idx].item()
120
- recognized_landmark = landmark_titles[best_idx]
121
- return recognized_landmark, best_score
122
-
123
- # DEMO
124
- def tourist_helper_with_russian(landmark):
125
- wiki_text = fetch_wikipedia_summary(landmark)
126
- if wiki_text == "Found error!":
127
  return None
128
-
129
- print(wiki_text)
130
- summarized_text = summarization_model(wiki_text, min_length=20, max_length=210)[0]["summary_text"]
131
- print(summarized_text)
132
-
133
- translated = translator(summarized_text, max_length=1000)[0]["translation_text"]
134
- print(translated)
135
-
136
- audio_path = text_to_speech(translated)
137
- return audio_path
138
 
139
  def process_image_clip(image):
140
  recognized, score = recognize_landmark_clip(image)
141
- print(f"[CLIP] Распознано: {recognized}, score={score:.2f}")
142
- audio_path = tourist_helper_with_russian(recognized)
143
- return audio_path
144
 
145
  def process_text_clip(landmark):
146
- return tourist_helper_with_russian(landmark)
147
 
148
  with gr.Blocks() as demo:
149
  gr.Markdown("## Помощь туристу")
150
-
151
  with gr.Tabs():
152
  with gr.Tab("CLIP + Sum + Translate + T2S"):
153
- gr.Markdown("### Распознавание (CLIP) и перевод на русский")
154
-
155
  with gr.Row():
156
- image_input_c = gr.Image(label="Загрузите фото", type="pil")
157
- text_input_c = gr.Textbox(label="Или введите название")
158
-
159
- audio_output_c = gr.Audio(label="Результатт")
160
-
161
  with gr.Row():
162
- btn_recognize_c = gr.Button("Распознать и перевести на русский")
163
- btn_text_c = gr.Button("Поиск по тексту")
164
-
165
- btn_recognize_c.click(
166
- fn=process_image_clip,
167
- inputs=image_input_c,
168
- outputs=audio_output_c
169
- )
170
- btn_text_c.click(
171
- fn=process_text_clip,
172
- inputs=text_input_c,
173
- outputs=audio_output_c
174
- )
175
-
176
- demo.launch(debug=True)
177
 
 
 
1
  import os
2
+ import re
3
+ import json
4
  import torch
5
+ import requests
6
+ import unicodedata
7
+ import soundfile as sf
8
+ import pymorphy2
9
+
10
  import gradio as gr
11
  import wikipediaapi
 
 
 
 
 
 
 
12
  from PIL import Image
13
+ from transformers import pipeline, CLIPProcessor, CLIPModel
14
+
15
+ morph = pymorphy2.MorphAnalyzer()
 
 
 
 
 
 
16
 
17
  def load_attractions_json(url):
18
+ r = requests.get(url)
19
+ r.raise_for_status()
20
+ return json.loads(r.text)
 
 
21
 
22
+ url = "https://raw.githubusercontent.com/nktssk/tourist-helper/refs/heads/main/landmarks.json"
23
  landmark_titles = load_attractions_json(url)
24
 
 
 
 
25
  def clean_text(text):
26
  text = re.sub(r'МФА:?\s?\[.*?\]', '', text)
27
  text = re.sub(r'\[.*?\]', '', text)
28
+ def rm_diacritics(c):
29
+ return '' if unicodedata.category(c) == 'Mn' else c
 
 
 
 
30
  text = unicodedata.normalize('NFD', text)
31
+ text = ''.join(rm_diacritics(c) for c in text)
32
  text = unicodedata.normalize('NFC', text)
 
33
  text = re.sub(r'\s+', ' ', text)
34
  text = re.sub(r'[^\w\s.,!?-]', '', text)
 
35
  return text.strip()
36
 
37
+ # Упрощенное определение падежа по предлогу
38
+ def get_case_for_preposition(prep):
39
+ d = {
40
+ 'в': 'loc', 'на': 'loc', 'о': 'loc', 'об': 'loc', 'обо': 'loc',
41
+ 'к': 'dat',
42
+ 'с': 'ins', 'со': 'ins', 'над': 'ins', 'под': 'ins',
43
+ 'из': 'gen', 'от': 'gen', 'у': 'gen', 'до': 'gen', 'для': 'gen'
44
+ }
45
+ return d.get(prep.lower(), 'nom')
46
+
47
+ def replace_numbers_with_text_in_context(text):
48
+ tokens = text.split()
49
+ result = []
50
+ for i, token in enumerate(tokens):
51
+ if re.match(r'^\d+(\.\d+)?$', token):
52
+ cse = 'nom'
53
+ if i > 0:
54
+ cse = get_case_for_preposition(tokens[i - 1])
55
+ # Сначала переводим число в текст (nominative)
56
+ from num2words import num2words
57
+ number_as_words = num2words(float(token) if '.' in token else int(token), lang='ru')
58
+ number_as_words = number_as_words.replace('-', ' ')
59
+ subtokens = number_as_words.split()
60
+ inflected_subtokens = []
61
+ for st in subtokens:
62
+ p = morph.parse(st)
63
+ if p:
64
+ best = p[0]
65
+ if cse in best.tag.cases:
66
+ form = best.inflect({cse})
67
+ inflected_subtokens.append(form.word if form else st)
68
+ else:
69
+ inflected_subtokens.append(st)
70
+ else:
71
+ inflected_subtokens.append(st)
72
+ result.append(' '.join(inflected_subtokens))
73
+ else:
74
+ result.append(token)
75
+ return ' '.join(result)
76
+
77
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
78
  translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru")
79
+ wiki = wikipediaapi.Wikipedia("Nikita", "en")
80
 
81
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
82
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
83
 
84
+ text_inputs = clip_processor(text=landmark_titles, images=None, return_tensors="pt", padding=True)
 
 
 
 
 
85
  with torch.no_grad():
86
  text_embeds = clip_model.get_text_features(**text_inputs)
87
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
88
 
89
+ language = 'ru'
90
+ model_id = 'v3_1_ru'
91
+ sample_rate = 48000
92
+ speaker = 'eugene'
93
+ silero_model, _ = torch.hub.load(
94
+ repo_or_dir='snakers4/silero-models',
95
+ model='silero_tts',
96
+ language=language,
97
+ speaker=model_id
98
+ )
 
 
99
 
100
+ def text_to_speech(text, out_path="speech.wav"):
101
+ text = replace_numbers_with_text_in_context(text)
102
+ audio = silero_model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
103
+ sf.write(out_path, audio, sample_rate)
104
+ return out_path
105
 
 
106
  def fetch_wikipedia_summary(landmark):
107
  page = wiki.page(landmark)
108
+ return clean_text(page.summary) if page.exists() else "Found error!"
 
 
 
109
 
 
110
  def recognize_landmark_clip(image):
111
  if not isinstance(image, Image.Image):
112
  image = Image.fromarray(image)
113
+ img_in = clip_processor(images=image, return_tensors="pt")
 
114
  with torch.no_grad():
115
+ img_embed = clip_model.get_image_features(**img_in)
116
+ img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True)
117
+ sim = (img_embed @ text_embeds.T).squeeze(0)
118
+ best_idx = sim.argmax().item()
119
+ return landmark_titles[best_idx], sim[best_idx].item()
120
+
121
+ def process_landmark(landmark):
122
+ txt = fetch_wikipedia_summary(landmark)
123
+ if txt == "Found error!":
 
 
 
 
124
  return None
125
+ sm = summarizer(txt, min_length=20, max_length=210)[0]["summary_text"]
126
+ tr = translator(sm, max_length=1000)[0]["translation_text"]
127
+ return text_to_speech(tr)
 
 
 
 
 
 
 
128
 
129
  def process_image_clip(image):
130
  recognized, score = recognize_landmark_clip(image)
131
+ return process_landmark(recognized)
 
 
132
 
133
  def process_text_clip(landmark):
134
+ return process_landmark(landmark)
135
 
136
  with gr.Blocks() as demo:
137
  gr.Markdown("## Помощь туристу")
 
138
  with gr.Tabs():
139
  with gr.Tab("CLIP + Sum + Translate + T2S"):
 
 
140
  with gr.Row():
141
+ image_input = gr.Image(label="Загрузите фото", type="pil")
142
+ text_input = gr.Textbox(label="Или введите название")
143
+ audio_output = gr.Audio(label="Результат")
 
 
144
  with gr.Row():
145
+ btn_img = gr.Button("Распознать и перевести")
146
+ btn_txt = gr.Button("Поиск по названию")
147
+ btn_img.click(fn=process_image_clip, inputs=image_input, outputs=audio_output)
148
+ btn_txt.click(fn=process_text_clip, inputs=text_input, outputs=audio_output)
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ demo.launch(debug=True)