opsgenius3 / app.py
YALCINKAYA's picture
stop_sequences User: and Assistant:
f4c3c98
raw
history blame
4.88 kB
import os
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import re
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id):
global model, tokenizer
try:
print(f"Loading tokenizer for model_id: {model_id}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model and for model_id: {model_id}")
# Load the model
model = AutoModelForCausalLM.from_pretrained(model_id) #, device_map="auto")
model.config.use_cache = False
except Exception as e:
print(f"Error loading model: {e}")
return "No complete blocks found. Please check the format of the response."
# max_new_tokens=100,
# min_length=5,
# do_sample=False,
# num_beams=1,
# pad_token_id=tokenizer.eos_token_id,
# truncation=True
#penalty_alpha=0.6,
#do_sample = True,
#top_k=5,
#temperature=0.5,
#repetition_penalty=1.2,
#max_new_tokens=60,
#pad_token_id=tokenizer.eos_token_id,
#truncation=True,
#penalty_alpha=0.6, # Keep this to balance exploration and exploitation
#do_sample=True, # Keep sampling to allow for variability in responses
#top_k=20, # Increase top_k to give more options for sampling
#temperature=0.3, # Lower temperature to make outputs more deterministic and focused
#repetition_penalty=1.5, # Increase repetition penalty to discourage repeated phrases
#max_new_tokens=60, # Keep this as is, depending on your expected output length
#pad_token_id=tokenizer.eos_token_id,
#truncation=True, # Enable truncation for input sequences
#penalty_alpha=0.6, # Maintain this for balance
#do_sample=True, # Allow sampling for variability
#top_k=3, # Reduce top_k to narrow down options
#temperature=0.7, # Keep this low for more deterministic responses
#repetition_penalty=1.2, # Keep this moderate to avoid repetitive responses
#max_new_tokens=60, # Maintain this limit
#pad_token_id=tokenizer.eos_token_id,
#truncation=True, # Enable truncation for longer prompts
#
def generate_response(user_input):
prompt = formatted_prompt(user_input)
inputs = tokenizer([prompt], return_tensors="pt")
generation_config = GenerationConfig(
penalty_alpha=0.6,
do_sample=True,
top_k=5,
temperature=0.6,
repetition_penalty=1.2,
max_new_tokens=30, # Adjust as necessary
pad_token_id=tokenizer.eos_token_id,
stop_sequences=["User:", "Assistant:"],
)
outputs = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
return response.strip().split("Assistant:")[-1].strip() # Get the part after 'Assistant:'
def formatted_prompt(question) -> str:
return f"<|startoftext|>User: {question}\nAssistant:"
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
try:
# Generate a response from the model
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print(f"Error handling POST request: {e}")
return jsonify({"error": "An error occurred while processing your request."}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)