redael commited on
Commit
11cda4d
·
verified ·
1 Parent(s): da6d98b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -13,36 +13,24 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
  # Function to generate response
16
- def generate_response(message, history, system_message, max_tokens, temperature, top_p):
17
- # Prepare the conversation history
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for user_msg, bot_msg in history:
21
- if user_msg:
22
- messages.append({"role": "user", "content": user_msg})
23
- if bot_msg:
24
- messages.append({"role": "assistant", "content": bot_msg})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- # Tokenize and prepare the input
29
- prompt = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages])
30
  inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
31
-
32
- # Generate the response
33
  outputs = model.generate(
34
  inputs['input_ids'],
35
- max_length=max_tokens,
36
  num_return_sequences=1,
37
  pad_token_id=tokenizer.eos_token_id,
 
38
  temperature=temperature,
39
  top_p=top_p,
40
- early_stopping=True,
41
- do_sample=True # Enable sampling
42
  )
43
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
 
45
- # Clean up the response
46
  response = response.split("Assistant:")[-1].strip()
47
  response_lines = response.split('\n')
48
  clean_response = []
@@ -51,7 +39,9 @@ def generate_response(message, history, system_message, max_tokens, temperature,
51
  clean_response.append(line)
52
  response = ' '.join(clean_response)
53
 
54
- return [(message, response)]
 
 
55
 
56
  # Create the Gradio chat interface
57
  demo = gr.ChatInterface(
 
13
  model.to(device)
14
 
15
  # Function to generate response
16
+ def generate_response(prompt, model, tokenizer, max_length=100, num_beams=5, temperature=0.5, top_p=0.9, repetition_penalty=4.0):
17
+ # Prepare the prompt
18
+ prompt = f"User: {prompt}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
19
  inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
 
 
20
  outputs = model.generate(
21
  inputs['input_ids'],
22
+ max_length=max_length,
23
  num_return_sequences=1,
24
  pad_token_id=tokenizer.eos_token_id,
25
+ num_beams=num_beams,
26
  temperature=temperature,
27
  top_p=top_p,
28
+ repetition_penalty=repetition_penalty,
29
+ early_stopping=True
30
  )
31
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
33
+ # Post-processing to clean up the response
34
  response = response.split("Assistant:")[-1].strip()
35
  response_lines = response.split('\n')
36
  clean_response = []
 
39
  clean_response.append(line)
40
  response = ' '.join(clean_response)
41
 
42
+ return
43
+
44
+ return [(prompt, response.strip())]
45
 
46
  # Create the Gradio chat interface
47
  demo = gr.ChatInterface(