Naman Pundir commited on
Commit
3d71508
·
1 Parent(s): c9bb7a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -1,45 +1,39 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # Define the models and their corresponding names
5
  models = {
6
- "Model 1 (lmsys/vicuna-13b-v1.3)": {
7
- "model_name": "lmsys/vicuna-13b-v1.3",
8
- "description": "Model 1: Foundation 13B Vicuna Model",
9
  },
10
- "Model 2 (Aiyan99/theus_concepttagger)": {
11
- "model_name": "Aiyan99/theus_concepttagger",
12
- "description": "Model 2: My finetuned model",
13
  },
14
  }
15
 
16
- # Define the Gradio interface
17
  def summarize_text(input_text, selected_model):
18
- # Get the selected model and its tokenizer
19
  model_info = models[selected_model]
20
  tokenizer = AutoTokenizer.from_pretrained(model_info["model_name"])
21
  model = AutoModelForSeq2SeqLM.from_pretrained(model_info["model_name"])
22
 
23
- # Tokenize and generate summary
24
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True)
25
  summary_ids = model.generate(input_ids, max_length=10, min_length=1, length_penalty=1.0, num_beams=4, early_stopping=True)
26
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
27
  return summary
28
 
29
- # Define a custom theme for the Gradio interface
30
- custom_theme = gr.theme(
31
- page_title="MLE - Project (Tuning and Infra Project)1MLE - Project (Tuning and Infra Project)",
32
- layout="wide",
33
- page_bgcolor="black",
34
- )
35
 
36
  iface = gr.Interface(
37
  fn=summarize_text,
38
  inputs=[gr.inputs.Textbox(label="Input Text"), gr.inputs.Radio(list(models.keys()), label="Select Model")],
39
  outputs="text",
40
- title="Text Summarization App",
41
- description="Choose a model for text summarization and enter the text to summarize.",
42
- theme=custom_theme, # Apply the custom theme
43
  )
44
 
45
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
 
4
  models = {
5
+ "Model 1 (facebook/bart-large-cnn)": {
6
+ "model_name": "facebook/bart-large-cnn",
7
+ "description": "Model 1",
8
  },
9
+ "Model 2 (google/pegasus-multi_news)": {
10
+ "model_name": "google/pegasus-multi_news",
11
+ "description": "Model 2",
12
  },
13
  }
14
 
 
15
  def summarize_text(input_text, selected_model):
 
16
  model_info = models[selected_model]
17
  tokenizer = AutoTokenizer.from_pretrained(model_info["model_name"])
18
  model = AutoModelForSeq2SeqLM.from_pretrained(model_info["model_name"])
19
 
 
20
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True)
21
  summary_ids = model.generate(input_ids, max_length=10, min_length=1, length_penalty=1.0, num_beams=4, early_stopping=True)
22
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
23
  return summary
24
 
25
+ custom_title = """
26
+ <div style="color: white; text-align: center; background-color: black; padding: 20px;">
27
+ <h1>MLE - Project (Tuning and Infra Project)1MLE - Project (Tuning and Infra Project)</h1>
28
+ </div>
29
+ """
 
30
 
31
  iface = gr.Interface(
32
  fn=summarize_text,
33
  inputs=[gr.inputs.Textbox(label="Input Text"), gr.inputs.Radio(list(models.keys()), label="Select Model")],
34
  outputs="text",
35
+ title=custom_title,
36
+ description="Choose a model for Concept Assignation and enter the text to summarize.",
 
37
  )
38
 
39
  if __name__ == "__main__":