Spaces:
Sleeping
Sleeping
File size: 3,736 Bytes
27f6ef7 cbe8e48 e384a9f 1e04073 34139ad bbaa18e e384a9f 8a9401d e384a9f 188010c 834da73 bbaa18e 56599c7 bbaa18e 4721a1c cbe8e48 9143358 4721a1c 8c39757 4721a1c cbe8e48 9143358 bfe1386 f4c3c98 bfe1386 f4c3c98 bfe1386 c693434 bfe1386 f4c3c98 9f05250 9143358 34139ad 9143358 bfe1386 f4c3c98 e384a9f b4930ce df73242 233b98c d469f0d 834da73 d469f0d 498261c d469f0d e384a9f 8a9401d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import GPTNeoForCausalLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id):
global model, tokenizer
if model is None or tokenizer is None:
try:
print(f"Loading tokenizer for model_id: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for model_id: {model_id} on {device}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto"
)
model.config.use_cache=False
model.config.pretraining_tp=1
except Exception as e:
print(f"Error loading model: {e}")
raise e # Raise the error to be caught in the POST request
else:
print(f"Model and tokenizer for {model_id} are already loaded.")
def generate_response(user_input, model_id):
# Ensure model and tokenizer are loaded
get_model_and_tokenizer(model_id)
prompt = user_input
generation_config = GenerationConfig(
penalty_alpha=0.6,
do_sample=True,
top_p=0.2,
top_k=50,
temperature=0.3,
repetition_penalty=1.2,
max_new_tokens=60,
pad_token_id=tokenizer.eos_token_id,
stop_sequences=["User:", "Assistant:", "\n"],
)
inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
outputs = model.generate(**inputs, generation_config=generation_config)
response = (tokenizer.decode(outputs[0], skip_special_tokens=True))
cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()
return cleaned_response.strip().split("\n")[0] # Keep only the first line of response
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
try:
print(f"Loading")
# Generate a response from the model
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print(f"Error handling POST request: {e}")
return jsonify({"error": "An error occurred while processing your request."}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|