Spaces:
Sleeping
Sleeping
File size: 6,385 Bytes
27f6ef7 e384a9f 233b98c 44e0ccd 8c39757 bbaa18e e384a9f 8a9401d e384a9f 188010c 8a9401d bbaa18e 56599c7 bbaa18e 70f5edf 44e0ccd 8c39757 bbaa18e 44e0ccd 188010c 05f391e 188010c bbaa18e 188010c bbaa18e 8c39757 ed324ed 8c39757 ed324ed 8c39757 ed324ed 8c39757 ed324ed cffec04 ed324ed 8c39757 bbaa18e 9f05250 09df582 188010c 56599c7 8a9401d 05f391e 56599c7 9f05250 d3382bd 85de869 c693434 0a575e7 79de7a5 c693434 3b2f4b3 9f05250 d469f0d b1d9e55 a7c12c0 b1d9e55 8c39757 d469f0d 09df582 233b98c e384a9f b4930ce df73242 233b98c d469f0d 498261c d469f0d e384a9f 8a9401d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import re
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id):
global model, tokenizer
try:
print(f"Loading tokenizer for model_id: {model_id}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model and for model_id: {model_id}")
# Load the model
model = AutoModelForCausalLM.from_pretrained(model_id) #, device_map="auto")
model.config.use_cache = False
except Exception as e:
print(f"Error loading model: {e}")
def extract_relevant_text(response):
"""
This function extracts the first complete 'user' and 'assistant' blocks
between <|im_start|> and <|im_end|> in the generated response.
If the tags are corrupted, it returns the text up to the first <|im_end|> tag.
"""
# Regex to match content between <|im_start|> and <|im_end|> tags
pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
matches = pattern.findall(response)
# Debugging: print the matches found
print("Matches found:", matches)
# If complete matches found, extract them
if len(matches) >= 2:
user_message = matches[0].strip() # First <|im_start|> block
assistant_message = matches[1].strip() # Second <|im_start|> block
return f"user: {user_message}\nassistant: {assistant_message}"
# If no complete blocks found, check for a partial extraction
if '<|im_end|>' in response:
# Extract everything before the first <|im_end|>
partial_response = response.split('<|im_end|>')[0].strip()
return f"{partial_response}"
return "No complete blocks found. Please check the format of the response."
def generate_response(user_input, model_id):
prompt = formatted_prompt(user_input)
global model, tokenizer
# Load the model and tokenizer if they are not already loaded or if the model_id has changed
if model is None or tokenizer is None or (model.config._name_or_path != model_id):
get_model_and_tokenizer(model_id) # Load model and tokenizer
# Prepare the input tensors
inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to GPU if available
generation_config = GenerationConfig(
# max_new_tokens=100,
# min_length=5,
# do_sample=False,
# num_beams=1,
# pad_token_id=tokenizer.eos_token_id,
# truncation=True
#penalty_alpha=0.6,
#do_sample = True,
#top_k=5,
#temperature=0.5,
#repetition_penalty=1.2,
#max_new_tokens=60,
#pad_token_id=tokenizer.eos_token_id,
#truncation=True,
#penalty_alpha=0.6, # Keep this to balance exploration and exploitation
#do_sample=True, # Keep sampling to allow for variability in responses
#top_k=20, # Increase top_k to give more options for sampling
#temperature=0.3, # Lower temperature to make outputs more deterministic and focused
#repetition_penalty=1.5, # Increase repetition penalty to discourage repeated phrases
#max_new_tokens=60, # Keep this as is, depending on your expected output length
#pad_token_id=tokenizer.eos_token_id,
#truncation=True, # Enable truncation for input sequences
penalty_alpha=0.6, # Maintain this for balance
do_sample=True, # Allow sampling for variability
top_k=3, # Reduce top_k to narrow down options
temperature=0.7, # Keep this low for more deterministic responses
repetition_penalty=1.2, # Keep this moderate to avoid repetitive responses
max_new_tokens=60, # Maintain this limit
pad_token_id=tokenizer.eos_token_id,
truncation=True, # Enable truncation for longer prompts
)
try:
# Generate response
#outputs = model.generate(**inputs, generation_config=generation_config)
outputs = model.generate(**inputs, generation_config=generation_config)
#response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#use the slicing method
response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
return extract_relevant_text(response)
except Exception as e:
print(f"Error generating response: {e}")
return "Error generating response."
def formatted_prompt(question) -> str:
return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
try:
# Generate a response from the model
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print(f"Error handling POST request: {e}")
return jsonify({"error": "An error occurred while processing your request."}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|