karthikqnq commited on
Commit
50e5e08
·
verified ·
1 Parent(s): d839a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Load the model
5
- model = pipeline("text-generation", model="karthikqnq/qnqgpt2")
 
 
6
 
7
  def respond(
8
  message,
@@ -19,17 +21,23 @@ def respond(
19
  prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
20
  prompt += f"User: {message}\nAssistant: "
21
 
 
 
 
22
  # Generate response
23
- response = model(
24
- prompt,
25
  max_length=max_tokens,
26
  temperature=temperature,
27
  top_p=top_p,
28
  do_sample=True,
29
  num_return_sequences=1
30
- )[0]['generated_text']
 
 
 
31
 
32
- # Extract only the assistant's response
33
  try:
34
  assistant_response = response.split("Assistant: ")[-1].strip()
35
  except:
@@ -72,4 +80,4 @@ demo = gr.ChatInterface(
72
  )
73
 
74
  if __name__ == "__main__":
75
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ # Load the model and tokenizer
5
+ model_name = "karthikqnq/qnqgpt2"
6
+ model = AutoModelForCausalLM.from_pretrained(model_name)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
  def respond(
10
  message,
 
21
  prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
22
  prompt += f"User: {message}\nAssistant: "
23
 
24
+ # Tokenize the input prompt
25
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
26
+
27
  # Generate response
28
+ outputs = model.generate(
29
+ **inputs,
30
  max_length=max_tokens,
31
  temperature=temperature,
32
  top_p=top_p,
33
  do_sample=True,
34
  num_return_sequences=1
35
+ )
36
+
37
+ # Decode the output and extract only the assistant's response
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
+ # Extract the assistant's reply after "Assistant:"
41
  try:
42
  assistant_response = response.split("Assistant: ")[-1].strip()
43
  except:
 
80
  )
81
 
82
  if __name__ == "__main__":
83
+ demo.launch()