opsgenius3 / app.py
YALCINKAYA's picture
fix for return value
cffec04
raw
history blame
6.39 kB
import os
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import re
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id):
global model, tokenizer
try:
print(f"Loading tokenizer for model_id: {model_id}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model and for model_id: {model_id}")
# Load the model
model = AutoModelForCausalLM.from_pretrained(model_id) #, device_map="auto")
model.config.use_cache = False
except Exception as e:
print(f"Error loading model: {e}")
def extract_relevant_text(response):
"""
This function extracts the first complete 'user' and 'assistant' blocks
between <|im_start|> and <|im_end|> in the generated response.
If the tags are corrupted, it returns the text up to the first <|im_end|> tag.
"""
# Regex to match content between <|im_start|> and <|im_end|> tags
pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
matches = pattern.findall(response)
# Debugging: print the matches found
print("Matches found:", matches)
# If complete matches found, extract them
if len(matches) >= 2:
user_message = matches[0].strip() # First <|im_start|> block
assistant_message = matches[1].strip() # Second <|im_start|> block
return f"user: {user_message}\nassistant: {assistant_message}"
# If no complete blocks found, check for a partial extraction
if '<|im_end|>' in response:
# Extract everything before the first <|im_end|>
partial_response = response.split('<|im_end|>')[0].strip()
return f"{partial_response}"
return "No complete blocks found. Please check the format of the response."
def generate_response(user_input, model_id):
prompt = formatted_prompt(user_input)
global model, tokenizer
# Load the model and tokenizer if they are not already loaded or if the model_id has changed
if model is None or tokenizer is None or (model.config._name_or_path != model_id):
get_model_and_tokenizer(model_id) # Load model and tokenizer
# Prepare the input tensors
inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to GPU if available
generation_config = GenerationConfig(
# max_new_tokens=100,
# min_length=5,
# do_sample=False,
# num_beams=1,
# pad_token_id=tokenizer.eos_token_id,
# truncation=True
#penalty_alpha=0.6,
#do_sample = True,
#top_k=5,
#temperature=0.5,
#repetition_penalty=1.2,
#max_new_tokens=60,
#pad_token_id=tokenizer.eos_token_id,
#truncation=True,
#penalty_alpha=0.6, # Keep this to balance exploration and exploitation
#do_sample=True, # Keep sampling to allow for variability in responses
#top_k=20, # Increase top_k to give more options for sampling
#temperature=0.3, # Lower temperature to make outputs more deterministic and focused
#repetition_penalty=1.5, # Increase repetition penalty to discourage repeated phrases
#max_new_tokens=60, # Keep this as is, depending on your expected output length
#pad_token_id=tokenizer.eos_token_id,
#truncation=True, # Enable truncation for input sequences
penalty_alpha=0.6, # Maintain this for balance
do_sample=True, # Allow sampling for variability
top_k=3, # Reduce top_k to narrow down options
temperature=0.7, # Keep this low for more deterministic responses
repetition_penalty=1.2, # Keep this moderate to avoid repetitive responses
max_new_tokens=60, # Maintain this limit
pad_token_id=tokenizer.eos_token_id,
truncation=True, # Enable truncation for longer prompts
)
try:
# Generate response
#outputs = model.generate(**inputs, generation_config=generation_config)
outputs = model.generate(**inputs, generation_config=generation_config)
#response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#use the slicing method
response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
return extract_relevant_text(response)
except Exception as e:
print(f"Error generating response: {e}")
return "Error generating response."
def formatted_prompt(question) -> str:
return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
try:
# Generate a response from the model
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print(f"Error handling POST request: {e}")
return jsonify({"error": "An error occurred while processing your request."}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)