redael commited on
Commit
a1e1bf3
·
verified ·
1 Parent(s): a3f73ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -13,24 +13,36 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
  # Function to generate response
16
- def generate_response(prompt, model, tokenizer, max_length=100, num_beams=5, temperature=0.5, top_p=0.9, repetition_penalty=4.0):
17
- # Prepare the prompt
18
- prompt = f"User: {prompt}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
19
  inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
 
 
20
  outputs = model.generate(
21
  inputs['input_ids'],
22
- max_length=max_length,
23
  num_return_sequences=1,
24
  pad_token_id=tokenizer.eos_token_id,
25
- num_beams=num_beams,
26
  temperature=temperature,
27
  top_p=top_p,
28
- repetition_penalty=repetition_penalty,
29
- early_stopping=True
30
  )
31
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
33
- # Post-processing to clean up the response
34
  response = response.split("Assistant:")[-1].strip()
35
  response_lines = response.split('\n')
36
  clean_response = []
@@ -39,13 +51,12 @@ def generate_response(prompt, model, tokenizer, max_length=100, num_beams=5, tem
39
  clean_response.append(line)
40
  response = ' '.join(clean_response)
41
 
42
- return
43
-
44
- return [(prompt, response.strip())]
45
 
46
  # Create the Gradio chat interface
47
  demo = gr.ChatInterface(
48
  fn=generate_response,
 
49
  title="Chatbot",
50
  description="Ask anything to the chatbot."
51
  )
 
13
  model.to(device)
14
 
15
  # Function to generate response
16
+ def generate_response(message, history, system_message, max_tokens, temperature, top_p):
17
+ # Prepare the conversation history
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for user_msg, bot_msg in history:
21
+ if user_msg:
22
+ messages.append({"role": "user", "content": user_msg})
23
+ if bot_msg:
24
+ messages.append({"role": "assistant", "content": bot_msg})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ # Tokenize and prepare the input
29
+ prompt = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages])
30
  inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
31
+
32
+ # Generate the response
33
  outputs = model.generate(
34
  inputs['input_ids'],
35
+ max_length=max_tokens,
36
  num_return_sequences=1,
37
  pad_token_id=tokenizer.eos_token_id,
 
38
  temperature=temperature,
39
  top_p=top_p,
40
+ early_stopping=True,
41
+ do_sample=True # Enable sampling
42
  )
43
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
 
45
+ # Clean up the response
46
  response = response.split("Assistant:")[-1].strip()
47
  response_lines = response.split('\n')
48
  clean_response = []
 
51
  clean_response.append(line)
52
  response = ' '.join(clean_response)
53
 
54
+ return [(message, response)]
 
 
55
 
56
  # Create the Gradio chat interface
57
  demo = gr.ChatInterface(
58
  fn=generate_response,
59
+
60
  title="Chatbot",
61
  description="Ask anything to the chatbot."
62
  )