File size: 3,413 Bytes
1211c23
87191d2
 
1211c23
87191d2
 
 
 
 
 
 
1211c23
87191d2
 
e2e8031
 
 
1211c23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87191d2
 
 
 
 
 
 
 
 
 
 
 
bb13100
87191d2
 
 
 
 
1211c23
 
 
87191d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5047d85
87191d2
 
 
 
 
 
 
 
 
1211c23
87191d2
4cb2ba5
 
 
1211c23
4cb2ba5
 
 
87191d2
 
 
 
5bdcd79
87191d2
 
 
 
5047d85
1211c23
 
157450c
87191d2
1211c23
 
 
 
 
 
 
 
 
 
87191d2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from flask import Flask, request, jsonify, send_file
from PIL import Image
import base64
import io
import random
import uuid
import numpy as np
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler

# Create a Flask instance
app = Flask(__name__)

def clear_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

# Initialize model only once
pipe = None

def load_model():
    global pipe
    if pipe is None:
        pipe = StableDiffusionXLPipeline.from_pretrained(
            "fluently/Fluently-XL-v2",
            torch_dtype=torch.float16,
            use_safetensors=True,
        )
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
        pipe.set_adapters("dalle")

        if torch.cuda.is_available():
            pipe.to("cuda")

# Load the model during app initialization
load_model()

def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

MAX_SEED = np.iinfo(np.int32).max

@spaces.GPU
def generate(
    prompt: str,
    negative_prompt: str = "",
    use_negative_prompt: bool = False,
    seed: int = 0,
    num_images_per_prompt: int = 1,
    width: int = 512,  # Reduced image width
    height: int = 512,  # Reduced image height
    guidance_scale: float = 3,
    randomize_seed: bool = False,
):
    seed = int(randomize_seed_fn(seed, randomize_seed))

    if not use_negative_prompt:
        negative_prompt = ""  # type: ignore

    images = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=25,
        num_images_per_prompt=num_images_per_prompt,
        cross_attention_kwargs={"scale": 0.65},
        output_type="pil",
    ).images
    image_paths = [save_image(img) for img in images]
    print(image_paths)
    return image_paths, seed

@app.get("/")
def root():
    return "Welcome to the Fashion Outfit"

@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
    try:
        return send_file(image_id, mimetype='image/png')
    except FileNotFoundError:
        return jsonify({'error': 'Image not found'}), 404

@app.route('/api/run', methods=['POST'])
def run():
    data = request.json
    print(data)
    prompt = data['prompt']
    negative_prompt = data['negative_prompt']
    use_negative_prompt = data['use_negative_prompt']
    guidance_scale = data['guidance_scale']
    randomize_seed = data['randomize_seed']
    num_images_per_prompt = data['num_images_per_prompt']
    width = data['width'] if 'width' in data else 512  # Default width
    height = data['height'] if 'height' in data else 512  # Default height
    #clear_gpu_memory()
    result = generate(
        prompt,
        negative_prompt,
        use_negative_prompt,
        0,
        num_images_per_prompt,
        width,
        height,
        guidance_scale,
        randomize_seed
    )
    return jsonify({'out': result})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)