fastapi / main.py
ka1kuk's picture
Rename app.py to main.py
aeca129
raw
history blame
923 Bytes
from flask import request
from diffusers import StableDiffusionPipeline
import torch
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
app = Flask(__name__)
app.add_middleware( # add the middleware
CORSMiddleware,
allow_credentials=True, # allow credentials
allow_origins=["*"], # allow all origins
allow_methods=["*"], # allow all methods
allow_headers=["*"], # allow all headers
)
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cpu")
def dummy(images, **kwargs):
return images, False
pipe.safety_checker = dummy
@app.route('/')
def generate_image():
prompt = request.args.get('prompt')
image = pipe(prompt).images[0]
# do something with the generated image
image_data = image.tobytes().hex()
return {'image_data': image_data}