ramalMr commited on
Commit
97425d1
·
verified ·
1 Parent(s): c655130

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -11
app.py CHANGED
@@ -2,26 +2,93 @@ from huggingface_hub import InferenceClient
2
  import gradio as gr
3
  import random
4
  import pandas as pd
 
5
  import csv
 
 
6
  import tempfile
7
  import re
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
 
11
- def extract_text_from_excel(file, column_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  df = pd.read_excel(file)
13
- text = ' '.join(df[column_name].astype(str))
14
  return text
15
 
16
- def generate(file, column_name, temperature, max_new_tokens, top_p, repetition_penalty, num_similar_sentences):
17
- text = extract_text_from_excel(file, column_name)
 
 
 
 
 
18
  sentences = text.split('.')
19
  random.shuffle(sentences) # Shuffle sentences
20
 
21
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
22
  fieldnames = ['Original Sentence', 'Generated Sentence']
23
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
24
- writer.writeheader()
25
 
26
  for sentence in sentences:
27
  sentence = sentence.strip()
@@ -38,10 +105,10 @@ def generate(file, column_name, temperature, max_new_tokens, top_p, repetition_p
38
  }
39
 
40
  try:
41
- stream = client.text_generation(sentence, **generate_kwargs, stream=True, return_full_text=False)
42
  output = ""
43
  for response in stream:
44
- output += response.text
45
 
46
  generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
47
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
@@ -50,7 +117,28 @@ def generate(file, column_name, temperature, max_new_tokens, top_p, repetition_p
50
  if not generated_sentences:
51
  break
52
  generated_sentence = generated_sentences.pop(random.randrange(len(generated_sentences)))
53
- writer.writerow({'Original Sentence': sentence, 'Generated Sentence': generated_sentence})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  except Exception as e:
56
  print(f"Error generating data for sentence '{sentence}': {e}")
@@ -59,17 +147,17 @@ def generate(file, column_name, temperature, max_new_tokens, top_p, repetition_p
59
 
60
  return tmp_path
61
 
62
- gr.Interface(
63
  fn=generate,
64
  inputs=[
65
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
66
- gr.TextAreaInput(label="Column Name", placeholder="Enter the column name"),
67
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
68
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
69
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
70
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
71
  gr.Slider(label="Number of similar sentences", value=10, minimum=1, maximum=20, step=1, interactive=True, info="Number of similar sentences to generate for each original sentence"),
72
- ],
 
73
  outputs=gr.File(label="Synthetic Data "),
74
  title="SDG",
75
  description="AYE QABIL.",
 
2
  import gradio as gr
3
  import random
4
  import pandas as pd
5
+ from io import BytesIO
6
  import csv
7
+ import os
8
+ import io
9
  import tempfile
10
  import re
11
+ import streamlit as st
12
+ import torch
13
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
14
+ import time
15
+ import logging
16
+
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda:0")
19
+ else:
20
+ device = torch.device("cpu")
21
+ logging.warning("GPU not found, using CPU, translation will be very slow.")
22
 
23
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
24
 
25
+ lang_id = {
26
+ "Afrikaans": "af",
27
+ "Amharic": "am",
28
+ "Arabic": "ar",
29
+ "Asturian": "ast",
30
+ "Azerbaijani": "az",
31
+ "Bashkir": "ba",
32
+ "Belarusian": "be",
33
+ "Bulgarian": "bg",
34
+ "Bengali": "bn",
35
+ "Breton": "br",
36
+ "Bosnian": "bs",
37
+ "Catalan": "ca",
38
+ "Cebuano": "ceb",
39
+ "Czech": "cs",
40
+ "Welsh": "cy",
41
+ "Danish": "da",
42
+ "German": "de",
43
+ "Greeek": "el",
44
+ "English": "en",
45
+ "Spanish": "es",
46
+ "Estonian": "et",
47
+ "Persian": "fa",
48
+ "Fulah": "ff",
49
+ "Finnish": "fi",
50
+ "French": "fr",
51
+ "Western Frisian": "fy",
52
+ "Irish": "ga",
53
+ "Gaelic": "gd",
54
+ "Galician": "gl",
55
+ "Gujarati": "gu",
56
+ "Hausa": "ha",
57
+ "Hebrew": "he",
58
+ "Hindi": "hi",
59
+ "Croatian": "hr",
60
+ "Haitian": "ht",
61
+ "Hungarian": "hu",
62
+ "Armenian": "hy",
63
+ "Indonesian": "id"
64
+ }
65
+
66
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
67
+ def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"):
68
+ tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
69
+ model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
70
+ model.eval()
71
+ return tokenizer, model
72
+
73
+ def extract_text_from_excel(file):
74
  df = pd.read_excel(file)
75
+ text = ' '.join(df['Unnamed: 1'].astype(str))
76
  return text
77
 
78
+ def save_to_csv(sentence, output, filename="synthetic_data.csv"):
79
+ with open(filename, mode='a', newline='', encoding='utf-8') as file:
80
+ writer = csv.writer(file)
81
+ writer.writerow([sentence, output])
82
+
83
+ def generate(file, temperature, max_new_tokens, top_p, repetition_penalty, num_similar_sentences):
84
+ text = extract_text_from_excel(file)
85
  sentences = text.split('.')
86
  random.shuffle(sentences) # Shuffle sentences
87
 
88
  with tempfile.NamedTemporaryFile(mode='w', newline='', delete=False, suffix='.csv') as tmp:
89
  fieldnames = ['Original Sentence', 'Generated Sentence']
90
  writer = csv.DictWriter(tmp, fieldnames=fieldnames)
91
+ writer.writeheader()
92
 
93
  for sentence in sentences:
94
  sentence = sentence.strip()
 
105
  }
106
 
107
  try:
108
+ stream = client.text_generation(sentence, **generate_kwargs, stream=True, details=True, return_full_text=False)
109
  output = ""
110
  for response in stream:
111
+ output += response.token.text
112
 
113
  generated_sentences = re.split(r'(?<=[\.\!\?:])[\s\n]+', output)
114
  generated_sentences = [s.strip() for s in generated_sentences if s.strip() and s != '.']
 
117
  if not generated_sentences:
118
  break
119
  generated_sentence = generated_sentences.pop(random.randrange(len(generated_sentences)))
120
+
121
+ # Translate generated sentence to English
122
+ tokenizer, model = load_model()
123
+ src_lang = lang_id[language]
124
+ trg_lang = lang_id["English"]
125
+ tokenizer.src_lang = src_lang
126
+ with torch.no_grad():
127
+ encoded_input = tokenizer(generated_sentence, return_tensors="pt").to(device)
128
+ generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
129
+ translated_sentence = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
130
+
131
+ # Translate original sentence to Azerbaijani
132
+ tokenizer, model = load_model()
133
+ src_lang = lang_id["English"]
134
+ trg_lang = lang_id["Azerbaijani"]
135
+ tokenizer.src_lang = src_lang
136
+ with torch.no_grad():
137
+ encoded_input = tokenizer(sentence, return_tensors="pt").to(device)
138
+ generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
139
+ translated_sentence_az = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
140
+
141
+ writer.writerow({'Original Sentence': translated_sentence_az, 'Generated Sentence': translated_sentence})
142
 
143
  except Exception as e:
144
  print(f"Error generating data for sentence '{sentence}': {e}")
 
147
 
148
  return tmp_path
149
 
150
+ gr.Interface(
151
  fn=generate,
152
  inputs=[
153
  gr.File(label="Upload Excel File", file_count="single", file_types=[".xlsx"]),
 
154
  gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
155
  gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=5120, step=64, interactive=True, info="The maximum numbers of new tokens"),
156
  gr.Slider(label="Top-p (nucleus sampling)", value=0.95, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
157
  gr.Slider(label="Repetition penalty", value=1.0, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repeated tokens"),
158
  gr.Slider(label="Number of similar sentences", value=10, minimum=1, maximum=20, step=1, interactive=True, info="Number of similar sentences to generate for each original sentence"),
159
+ gr.Dropdown(label="Language of the input data", choices=list(lang_id.keys()), value="English")
160
+ ],
161
  outputs=gr.File(label="Synthetic Data "),
162
  title="SDG",
163
  description="AYE QABIL.",