aditi2222 commited on
Commit
41135d2
·
1 Parent(s): e3404ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -71
app.py CHANGED
@@ -1,71 +1,12 @@
1
- import streamlit as st
2
- import os
3
- import io
4
- from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
5
- import time
6
- import json
7
- from typing import List
8
- import torch
9
- import random
10
- import logging
11
-
12
- if torch.cuda.is_available():
13
- device = torch.device("cuda:0")
14
- else:
15
- device = torch.device("cpu")
16
- logging.warning("GPU not found, using CPU, translation will be very slow.")
17
-
18
- st.cache(suppress_st_warning=True, allow_output_mutation=True)
19
- st.set_page_config(page_title="M2M100 Translator")
20
-
21
- lang_id = {
22
- "English": "en",
23
- "French": "fr",
24
- }
25
-
26
-
27
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
28
- def load_model(
29
- pretrained_model: str = "facebook/m2m100_418M",
30
- cache_dir: str = "models/",
31
- ):
32
- tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
33
- model = M2M100ForConditionalGeneration.from_pretrained(
34
- pretrained_model, cache_dir=cache_dir
35
- ).to(device)
36
- model.eval()
37
- return tokenizer, model
38
-
39
-
40
- st.title("M2M100 Translator")
41
-
42
-
43
- user_input: str = st.text_area(
44
- "Input text",
45
- height=200,
46
- max_chars=5120,
47
- )
48
-
49
- source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
50
- target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
51
-
52
- if st.button("Run"):
53
- time_start = time.time()
54
- tokenizer, model = load_model()
55
-
56
- src_lang = lang_id[source_lang]
57
- trg_lang = lang_id[target_lang]
58
- tokenizer.src_lang = src_lang
59
- with torch.no_grad():
60
- encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
61
- generated_tokens = model.generate(
62
- **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
63
- )
64
- translated_text = tokenizer.batch_decode(
65
- generated_tokens, skip_special_tokens=True
66
- )[0]
67
-
68
- time_end = time.time()
69
- st.success(translated_text)
70
-
71
- st.write(f"Computation time: {round((time_end-time_start),3)} segs")
 
1
+ from transformers import pipeline
2
+ import gradio as gr
3
+ pipe= pipeline('text2text-generation', model="facebook/m2m100_418M")
4
+ def generate_text(inp):
5
+ output=pipe(inp, forced_bos_token_id=pipe.tokenizer.get_lang_id('en'))
6
+ tln=output[0]
7
+ for item in tln:
8
+ result=tln[item]
9
+ return result
10
+ #Gradio Interface
11
+ output_text = gr.outputs.Textbox()
12
+ gr.Interface(generate_text,"textbox", output_text).launch(inline=False)