Hotair8914 commited on
Commit
b3d6461
·
verified ·
1 Parent(s): 526e0bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -1,34 +1,47 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  # Load pre-trained model and tokenizer
6
  def load_model(model_name):
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model = model.to(device)
11
- return tokenizer, model, device
 
 
 
 
12
 
13
  # Function to generate chat responses
14
  def chat_with_niti(message, history):
15
  tokenizer, model, device = load_model("facebook/mbart-large-50")
16
- input_ids = tokenizer.encode(message, return_tensors="pt").to(device)
17
- output = model.generate(
18
- input_ids,
19
- max_length=100,
20
- temperature=0.7,
21
- num_return_sequences=1,
22
- pad_token_id=tokenizer.eos_token_id
23
- )
24
- response = tokenizer.decode(output[0], skip_special_tokens=True)
25
- return response
 
 
 
 
 
 
 
 
 
26
 
27
  # Create Gradio chat interface
28
  demo = gr.ChatInterface(
29
  fn=chat_with_niti,
30
- title="Niti - Your AI Chatbot",
31
- description="Ask Niti anything in Hindi, Hinglish, or English!"
32
  )
33
 
34
  # Launch the interface
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
  # Load pre-trained model and tokenizer
6
  def load_model(model_name):
7
+ try:
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = model.to(device)
12
+ return tokenizer, model, device
13
+ except Exception as e:
14
+ print(f"Error loading model: {e}")
15
+ return None, None, None
16
 
17
  # Function to generate chat responses
18
  def chat_with_niti(message, history):
19
  tokenizer, model, device = load_model("facebook/mbart-large-50")
20
+ if tokenizer is None or model is None:
21
+ return "Sorry, I'm having trouble loading the model. Please try again later."
22
+
23
+ try:
24
+ # Add a prompt for better responses
25
+ prompt = f"User: {message}\nChatNiti:"
26
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
27
+ output = model.generate(
28
+ input_ids,
29
+ max_length=100,
30
+ temperature=0.7,
31
+ num_return_sequences=1,
32
+ pad_token_id=tokenizer.eos_token_id
33
+ )
34
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
35
+ return response.split("ChatNiti:")[-1].strip() # Extract ChatNiti's response
36
+ except Exception as e:
37
+ print(f"Error generating response: {e}")
38
+ return "Sorry, I encountered an error while generating a response."
39
 
40
  # Create Gradio chat interface
41
  demo = gr.ChatInterface(
42
  fn=chat_with_niti,
43
+ title="ChatNiti - Your AI Chatbot",
44
+ description="Ask ChatNiti anything in Hindi, Hinglish, or English!"
45
  )
46
 
47
  # Launch the interface