mukdhesh commited on
Commit
66cee24
1 Parent(s): 0b7a0a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -27
app.py CHANGED
@@ -4,7 +4,11 @@ from transformers import (
4
  GPT2LMHeadModel, GPT2Tokenizer,
5
  pipeline
6
  )
 
 
7
  st.title("Multi Chatbot")
 
 
8
  models = {
9
  "English to French": {
10
  "name": "Helsinki-NLP/opus-mt-en-fr",
@@ -19,39 +23,62 @@ models = {
19
  "description": "Generate creative stories based on input."
20
  }
21
  }
 
 
22
  st.sidebar.header("Choose a Model")
23
  selected_model_key = st.sidebar.radio("Select a Model:", list(models.keys()))
24
  model_name = models[selected_model_key]["name"]
25
  model_description = models[selected_model_key]["description"]
 
26
  st.sidebar.markdown(f"### Model Description\n{model_description}")
27
- try:
28
- if selected_model_key == "English to French":
29
- st.write("Loading English to French model...")
30
- tokenizer = MarianTokenizer.from_pretrained(model_name)
31
- model = MarianMTModel.from_pretrained(model_name)
32
- st.write("English to French model loaded successfully.")
33
- elif selected_model_key == "Sentiment Analysis":
34
- st.write("Loading Sentiment Analysis model...")
35
- sentiment_analyzer = pipeline("sentiment-analysis", model=model_name)
36
- st.write("Sentiment Analysis model loaded successfully.")
37
- elif selected_model_key == "Story Generator":
38
- st.write("Loading Story Generator model...")
39
- tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
40
- model = GPT2LMHeadModel.from_pretrained("distilgpt2")
41
- tokenizer.pad_token = tokenizer.eos_token
42
- st.write("Story Generator model loaded successfully.")
43
- except Exception as e:
44
- st.error(f"Failed to load the model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  user_input = st.text_input("Enter your query:")
 
46
  if user_input:
47
  if selected_model_key == "English to French":
48
  try:
49
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
50
- outputs = model.generate(inputs["input_ids"], max_length=150, num_return_sequences=1, no_repeat_ngram_size=2)
51
- bot_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- st.write(f"Translated Text: {bot_response}")
53
  except Exception as e:
54
  st.error(f"Error during translation: {e}")
 
55
  elif selected_model_key == "Sentiment Analysis":
56
  try:
57
  result = sentiment_analyzer(user_input)[0]
@@ -59,12 +86,19 @@ if user_input:
59
  st.write(f"Confidence: {result['score']:.2f}")
60
  except Exception as e:
61
  st.error(f"Error during sentiment analysis: {e}")
 
62
  elif selected_model_key == "Story Generator":
63
  try:
64
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
65
- outputs = model.generate(inputs["input_ids"], max_length=500, num_return_sequences=1, no_repeat_ngram_size=2, temperature=0.7)
66
- bot_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
- st.write(f"Generated Story: {bot_response}")
 
 
 
 
 
 
 
68
  except Exception as e:
69
  st.error(f"Error during story generation: {e}")
70
-
 
4
  GPT2LMHeadModel, GPT2Tokenizer,
5
  pipeline
6
  )
7
+
8
+ # App title
9
  st.title("Multi Chatbot")
10
+
11
+ # Define models and descriptions
12
  models = {
13
  "English to French": {
14
  "name": "Helsinki-NLP/opus-mt-en-fr",
 
23
  "description": "Generate creative stories based on input."
24
  }
25
  }
26
+
27
+ # Sidebar: Model selection
28
  st.sidebar.header("Choose a Model")
29
  selected_model_key = st.sidebar.radio("Select a Model:", list(models.keys()))
30
  model_name = models[selected_model_key]["name"]
31
  model_description = models[selected_model_key]["description"]
32
+
33
  st.sidebar.markdown(f"### Model Description\n{model_description}")
34
+
35
+ # Cache model loading for efficiency
36
+ @st.cache_resource
37
+ def load_english_to_french():
38
+ tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
39
+ model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
40
+ return tokenizer, model
41
+
42
+ @st.cache_resource
43
+ def load_sentiment_analysis():
44
+ return pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
45
+
46
+ @st.cache_resource
47
+ def load_story_generator():
48
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
49
+ model = GPT2LMHeadModel.from_pretrained("distilgpt2")
50
+ tokenizer.pad_token = tokenizer.eos_token # Set pad token to EOS token
51
+ return tokenizer, model
52
+
53
+ # Load the selected model
54
+ if selected_model_key == "English to French":
55
+ st.write("Loading English to French model...")
56
+ en_fr_tokenizer, en_fr_model = load_english_to_french()
57
+ st.write("English to French model loaded successfully.")
58
+
59
+ elif selected_model_key == "Sentiment Analysis":
60
+ st.write("Loading Sentiment Analysis model...")
61
+ sentiment_analyzer = load_sentiment_analysis()
62
+ st.write("Sentiment Analysis model loaded successfully.")
63
+
64
+ elif selected_model_key == "Story Generator":
65
+ st.write("Loading Story Generator model...")
66
+ story_gen_tokenizer, story_gen_model = load_story_generator()
67
+ st.write("Story Generator model loaded successfully.")
68
+
69
+ # User input
70
  user_input = st.text_input("Enter your query:")
71
+
72
  if user_input:
73
  if selected_model_key == "English to French":
74
  try:
75
+ inputs = en_fr_tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
76
+ outputs = en_fr_model.generate(inputs["input_ids"], max_length=150, num_return_sequences=1)
77
+ translated_text = en_fr_tokenizer.decode(outputs[0], skip_special_tokens=True)
78
+ st.write(f"Translated Text: {translated_text}")
79
  except Exception as e:
80
  st.error(f"Error during translation: {e}")
81
+
82
  elif selected_model_key == "Sentiment Analysis":
83
  try:
84
  result = sentiment_analyzer(user_input)[0]
 
86
  st.write(f"Confidence: {result['score']:.2f}")
87
  except Exception as e:
88
  st.error(f"Error during sentiment analysis: {e}")
89
+
90
  elif selected_model_key == "Story Generator":
91
  try:
92
+ inputs = story_gen_tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
93
+ outputs = story_gen_model.generate(
94
+ inputs["input_ids"],
95
+ attention_mask=inputs["attention_mask"], # Pass the attention mask
96
+ max_length=200,
97
+ num_return_sequences=1,
98
+ temperature=0.7,
99
+ no_repeat_ngram_size=2
100
+ )
101
+ story = story_gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ st.write(f"Generated Story: {story}")
103
  except Exception as e:
104
  st.error(f"Error during story generation: {e}")