FridayMaster commited on
Commit
95fd627
·
verified ·
1 Parent(s): 08415b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
3
 
4
  # Load the custom model and tokenizer
5
- model_path = 'redael/model_udc'
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
7
  model = GPT2LMHeadModel.from_pretrained(model_path)
8
 
@@ -21,10 +22,10 @@ def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, tem
21
  max_length=max_length,
22
  num_return_sequences=1,
23
  pad_token_id=tokenizer.eos_token_id,
24
- num_beams=num_beams, # Use a lower number of beams
25
  temperature=temperature,
26
  top_p=top_p,
27
- repetition_penalty=repetition_penalty, # Increased repetition penalty
28
  early_stopping=True
29
  )
30
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -39,7 +40,7 @@ def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, tem
39
  response = ' '.join(clean_response)
40
  return response.strip()
41
 
42
- def respond(message, history: list[tuple[str, str]]):
43
  # Prepare the prompt from the history and the new message
44
  system_message = "You are a friendly chatbot."
45
  conversation = system_message + "\n"
@@ -56,10 +57,10 @@ def respond(message, history: list[tuple[str, str]]):
56
 
57
  return response
58
 
59
- # Gradio Chat Interface without customizable inputs
60
  demo = gr.ChatInterface(
61
  respond
62
  )
63
 
64
  if __name__ == "__main__":
65
- demo.launch()
 
1
  import torch
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
+ import gradio as gr
4
 
5
  # Load the custom model and tokenizer
6
+ model_path = 'redael/model_udc'
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
8
  model = GPT2LMHeadModel.from_pretrained(model_path)
9
 
 
22
  max_length=max_length,
23
  num_return_sequences=1,
24
  pad_token_id=tokenizer.eos_token_id,
25
+ num_beams=num_beams,
26
  temperature=temperature,
27
  top_p=top_p,
28
+ repetition_penalty=repetition_penalty,
29
  early_stopping=True
30
  )
31
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
40
  response = ' '.join(clean_response)
41
  return response.strip()
42
 
43
+ def respond(message, history):
44
  # Prepare the prompt from the history and the new message
45
  system_message = "You are a friendly chatbot."
46
  conversation = system_message + "\n"
 
57
 
58
  return response
59
 
60
+ # Gradio Chat Interface
61
  demo = gr.ChatInterface(
62
  respond
63
  )
64
 
65
  if __name__ == "__main__":
66
+ demo.launch()