Spaces:
Sleeping
Sleeping
Commit
·
ed324ed
1
Parent(s):
b1d9e55
bug fix in extract_relevant_text
Browse files
app.py
CHANGED
@@ -34,22 +34,30 @@ def get_model_and_tokenizer(model_id):
|
|
34 |
|
35 |
def extract_relevant_text(response):
|
36 |
"""
|
37 |
-
This function extracts the first 'user' and 'assistant' blocks
|
38 |
-
<|im_start|> and <|im_end|> in the generated response.
|
|
|
39 |
"""
|
40 |
# Regex to match content between <|im_start|> and <|im_end|> tags
|
41 |
pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
|
42 |
matches = pattern.findall(response)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
|
51 |
-
#
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def generate_response(user_input, model_id):
|
55 |
prompt = formatted_prompt(user_input)
|
|
|
34 |
|
35 |
def extract_relevant_text(response):
|
36 |
"""
|
37 |
+
This function extracts the first complete 'user' and 'assistant' blocks
|
38 |
+
between <|im_start|> and <|im_end|> in the generated response.
|
39 |
+
If the tags are corrupted, it returns the text up to the first <|im_end|> tag.
|
40 |
"""
|
41 |
# Regex to match content between <|im_start|> and <|im_end|> tags
|
42 |
pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
|
43 |
matches = pattern.findall(response)
|
44 |
|
45 |
+
# Debugging: print the matches found
|
46 |
+
print("Matches found:", matches)
|
47 |
|
48 |
+
# If complete matches found, extract them
|
49 |
+
if len(matches) >= 2:
|
50 |
+
user_message = matches[0].strip() # First <|im_start|> block
|
51 |
+
assistant_message = matches[1].strip() # Second <|im_start|> block
|
52 |
+
return f"user: {user_message}\nassistant: {assistant_message}"
|
53 |
|
54 |
+
# If no complete blocks found, check for a partial extraction
|
55 |
+
if '<|im_end|>' in response:
|
56 |
+
# Extract everything before the first <|im_end|>
|
57 |
+
partial_response = response.split('<|im_end|>')[0].strip()
|
58 |
+
return f"Partial Response: {partial_response}"
|
59 |
+
|
60 |
+
return "No complete blocks found. Please check the format of the response."
|
61 |
|
62 |
def generate_response(user_input, model_id):
|
63 |
prompt = formatted_prompt(user_input)
|