tadmztxi / myapp.py
Geek7's picture
Update myapp.py
06d2ce8 verified
from flask import Flask, request, jsonify
import torch
from transformers import AutoModel, AutoTokenizer
from fastsafetensors import safe_load
# Initialize the Flask app
myapp = Flask(__name__)
# Load the model and tokenizer using safe_load
model_path = "https://huggingface.co/prompthero/openjourney-v4/blob/main/safety_checker/model.safetensors" # Replace with your .safetensors file path
model_data = safe_load(model_path)
# Specify the model name, adjust as necessary
model_name = "prompthero/openjourney-v4" # Replace with your model name
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the model weights from safeload
model = AutoModel.from_pretrained(model_name, state_dict=model_data).to("cpu")
@myapp.route('/')
def index():
return "Welcome to the AI Model API!"
@myapp.route('/generate', methods=['POST'])
def generate_output():
data = request.json
prompt = data.get('prompt', 'Hello, world!')
# Tokenize input prompt
inputs = tokenizer(prompt, return_tensors="pt")
# Generate output
with torch.no_grad():
outputs = model(**inputs)
# Process and return the output
return jsonify(outputs)
if __name__ == "__main__":
myapp.run(host='0.0.0.0', port=5000)