redael commited on
Commit
f95718f
·
verified ·
1 Parent(s): b12166e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -5,11 +5,8 @@ import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
 
8
- # Load your model and tokenizer from Hugging Face
9
- print("l.......")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
- print("done")
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.to(device)
@@ -39,20 +36,22 @@ def generate_response(message, history, system_message, max_tokens, temperature,
39
  pad_token_id=tokenizer.eos_token_id,
40
  temperature=temperature,
41
  top_p=top_p,
42
- early_stopping=True
 
43
  )
44
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
- # Clean up the response
47
  response = response.split("Assistant:")[-1].strip()
48
  response_lines = response.split('\n')
49
  clean_response = []
50
  for line in response_lines:
51
  if "User:" not in line and "Assistant:" not in line:
52
  clean_response.append(line)
53
- response = ' '.join(clean_response)
54
 
55
- return [(message, response)]
 
56
 
57
  # Create the Gradio chat interface
58
  demo = gr.ChatInterface(
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
 
36
  pad_token_id=tokenizer.eos_token_id,
37
  temperature=temperature,
38
  top_p=top_p,
39
+ early_stopping=True,
40
+ do_sample=True # Enable sampling
41
  )
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
+ # Post-process the response
45
  response = response.split("Assistant:")[-1].strip()
46
  response_lines = response.split('\n')
47
  clean_response = []
48
  for line in response_lines:
49
  if "User:" not in line and "Assistant:" not in line:
50
  clean_response.append(line)
51
+ response = ' '.join(clean_response).strip()
52
 
53
+ history.append((message, response))
54
+ return history, history
55
 
56
  # Create the Gradio chat interface
57
  demo = gr.ChatInterface(