Spaces:
Sleeping
Sleeping
Commit
·
8c39757
1
Parent(s):
188010c
extract_relevant_text
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
from flask import Flask, jsonify, request
|
3 |
from flask_cors import CORS
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
|
5 |
|
6 |
# Set the HF_HOME environment variable to a writable directory
|
7 |
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
|
@@ -20,7 +21,7 @@ def get_model_and_tokenizer(model_id):
|
|
20 |
try:
|
21 |
print(f"Loading tokenizer for model_id: {model_id}")
|
22 |
# Load the tokenizer
|
23 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id
|
24 |
tokenizer.pad_token = tokenizer.eos_token
|
25 |
|
26 |
print(f"Loading model and for model_id: {model_id}")
|
@@ -31,6 +32,25 @@ def get_model_and_tokenizer(model_id):
|
|
31 |
except Exception as e:
|
32 |
print(f"Error loading model: {e}")
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def generate_response(user_input, model_id):
|
35 |
prompt = formatted_prompt(user_input)
|
36 |
|
@@ -56,7 +76,7 @@ def generate_response(user_input, model_id):
|
|
56 |
# Generate response
|
57 |
outputs = model.generate(**inputs, generation_config=generation_config)
|
58 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
59 |
-
return response
|
60 |
except Exception as e:
|
61 |
print(f"Error generating response: {e}")
|
62 |
return "Error generating response."
|
|
|
2 |
from flask import Flask, jsonify, request
|
3 |
from flask_cors import CORS
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
5 |
+
import re
|
6 |
|
7 |
# Set the HF_HOME environment variable to a writable directory
|
8 |
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
|
|
|
21 |
try:
|
22 |
print(f"Loading tokenizer for model_id: {model_id}")
|
23 |
# Load the tokenizer
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
25 |
tokenizer.pad_token = tokenizer.eos_token
|
26 |
|
27 |
print(f"Loading model and for model_id: {model_id}")
|
|
|
32 |
except Exception as e:
|
33 |
print(f"Error loading model: {e}")
|
34 |
|
35 |
+
def extract_relevant_text(response):
|
36 |
+
"""
|
37 |
+
This function extracts the first 'user' and 'assistant' blocks between
|
38 |
+
<|im_start|> and <|im_end|> in the generated response.
|
39 |
+
"""
|
40 |
+
# Regex to match content between <|im_start|> and <|im_end|> tags
|
41 |
+
pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
|
42 |
+
matches = pattern.findall(response)
|
43 |
+
|
44 |
+
if len(matches) < 2:
|
45 |
+
return "Unable to extract sufficient data from the response."
|
46 |
+
|
47 |
+
# Assuming the first match is user and the second match is assistant
|
48 |
+
user_message = matches[0].strip() # First <|im_start|> block
|
49 |
+
assistant_message = matches[1].strip() # Second <|im_start|> block
|
50 |
+
|
51 |
+
# Format the extracted result
|
52 |
+
return f"user: {user_message}\nassistant: {assistant_message}"
|
53 |
+
|
54 |
def generate_response(user_input, model_id):
|
55 |
prompt = formatted_prompt(user_input)
|
56 |
|
|
|
76 |
# Generate response
|
77 |
outputs = model.generate(**inputs, generation_config=generation_config)
|
78 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
79 |
+
return extract_relevant_text(response)
|
80 |
except Exception as e:
|
81 |
print(f"Error generating response: {e}")
|
82 |
return "Error generating response."
|