YALCINKAYA commited on
Commit
f4c3c98
·
1 Parent(s): cffec04

stop_sequences User: and Assistant:

Browse files
Files changed (1) hide show
  1. app.py +31 -61
app.py CHANGED
@@ -31,47 +31,10 @@ def get_model_and_tokenizer(model_id):
31
 
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
-
35
- def extract_relevant_text(response):
36
- """
37
- This function extracts the first complete 'user' and 'assistant' blocks
38
- between <|im_start|> and <|im_end|> in the generated response.
39
- If the tags are corrupted, it returns the text up to the first <|im_end|> tag.
40
- """
41
- # Regex to match content between <|im_start|> and <|im_end|> tags
42
- pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
43
- matches = pattern.findall(response)
44
-
45
- # Debugging: print the matches found
46
- print("Matches found:", matches)
47
-
48
- # If complete matches found, extract them
49
- if len(matches) >= 2:
50
- user_message = matches[0].strip() # First <|im_start|> block
51
- assistant_message = matches[1].strip() # Second <|im_start|> block
52
- return f"user: {user_message}\nassistant: {assistant_message}"
53
-
54
- # If no complete blocks found, check for a partial extraction
55
- if '<|im_end|>' in response:
56
- # Extract everything before the first <|im_end|>
57
- partial_response = response.split('<|im_end|>')[0].strip()
58
- return f"{partial_response}"
59
-
60
  return "No complete blocks found. Please check the format of the response."
 
61
 
62
- def generate_response(user_input, model_id):
63
- prompt = formatted_prompt(user_input)
64
-
65
- global model, tokenizer
66
-
67
- # Load the model and tokenizer if they are not already loaded or if the model_id has changed
68
- if model is None or tokenizer is None or (model.config._name_or_path != model_id):
69
- get_model_and_tokenizer(model_id) # Load model and tokenizer
70
-
71
- # Prepare the input tensors
72
- inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to GPU if available
73
-
74
- generation_config = GenerationConfig(
75
  # max_new_tokens=100,
76
  # min_length=5,
77
  # do_sample=False,
@@ -97,31 +60,38 @@ def generate_response(user_input, model_id):
97
  #pad_token_id=tokenizer.eos_token_id,
98
  #truncation=True, # Enable truncation for input sequences
99
 
100
- penalty_alpha=0.6, # Maintain this for balance
101
- do_sample=True, # Allow sampling for variability
102
- top_k=3, # Reduce top_k to narrow down options
103
- temperature=0.7, # Keep this low for more deterministic responses
104
- repetition_penalty=1.2, # Keep this moderate to avoid repetitive responses
105
- max_new_tokens=60, # Maintain this limit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  pad_token_id=tokenizer.eos_token_id,
107
- truncation=True, # Enable truncation for longer prompts
108
- )
109
 
110
- try:
111
- # Generate response
112
- #outputs = model.generate(**inputs, generation_config=generation_config)
113
- outputs = model.generate(**inputs, generation_config=generation_config)
114
-
115
- #response = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
- #use the slicing method
117
- response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
118
- return extract_relevant_text(response)
119
- except Exception as e:
120
- print(f"Error generating response: {e}")
121
- return "Error generating response."
122
-
123
  def formatted_prompt(question) -> str:
124
- return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
 
125
 
126
  @app.route("/", methods=["GET"])
127
  def handle_get_request():
 
31
 
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  return "No complete blocks found. Please check the format of the response."
36
+
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # max_new_tokens=100,
39
  # min_length=5,
40
  # do_sample=False,
 
60
  #pad_token_id=tokenizer.eos_token_id,
61
  #truncation=True, # Enable truncation for input sequences
62
 
63
+ #penalty_alpha=0.6, # Maintain this for balance
64
+ #do_sample=True, # Allow sampling for variability
65
+ #top_k=3, # Reduce top_k to narrow down options
66
+ #temperature=0.7, # Keep this low for more deterministic responses
67
+ #repetition_penalty=1.2, # Keep this moderate to avoid repetitive responses
68
+ #max_new_tokens=60, # Maintain this limit
69
+ #pad_token_id=tokenizer.eos_token_id,
70
+ #truncation=True, # Enable truncation for longer prompts
71
+ #
72
+
73
+ def generate_response(user_input):
74
+ prompt = formatted_prompt(user_input)
75
+ inputs = tokenizer([prompt], return_tensors="pt")
76
+
77
+ generation_config = GenerationConfig(
78
+ penalty_alpha=0.6,
79
+ do_sample=True,
80
+ top_k=5,
81
+ temperature=0.6,
82
+ repetition_penalty=1.2,
83
+ max_new_tokens=30, # Adjust as necessary
84
  pad_token_id=tokenizer.eos_token_id,
85
+ stop_sequences=["User:", "Assistant:"],
86
+ )
87
 
88
+ outputs = model.generate(**inputs, generation_config=generation_config)
89
+ response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
90
+ return response.strip().split("Assistant:")[-1].strip() # Get the part after 'Assistant:'
91
+
 
 
 
 
 
 
 
 
 
92
  def formatted_prompt(question) -> str:
93
+ return f"<|startoftext|>User: {question}\nAssistant:"
94
+
95
 
96
  @app.route("/", methods=["GET"])
97
  def handle_get_request():