Spaces:
Running
Running
Delete main.py
Browse files
main.py
DELETED
@@ -1,380 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from fastapi import FastAPI
|
3 |
-
from fastapi.staticfiles import StaticFiles
|
4 |
-
from fastapi.responses import FileResponse
|
5 |
-
from fastapi.middleware.cors import CORSMiddleware
|
6 |
-
from huggingface_hub import InferenceClient, login
|
7 |
-
from transformers import AutoTokenizer
|
8 |
-
from pydantic import BaseModel
|
9 |
-
from gradio_client import Client, file
|
10 |
-
from starlette.responses import StreamingResponse
|
11 |
-
import re
|
12 |
-
from datetime import datetime
|
13 |
-
import json
|
14 |
-
import requests
|
15 |
-
import base64
|
16 |
-
import os
|
17 |
-
import time
|
18 |
-
from PIL import Image
|
19 |
-
from io import BytesIO
|
20 |
-
import aiohttp
|
21 |
-
import asyncio
|
22 |
-
from typing import Optional
|
23 |
-
from dotenv import load_dotenv
|
24 |
-
import boto3
|
25 |
-
from groq import Groq
|
26 |
-
|
27 |
-
app = FastAPI()
|
28 |
-
|
29 |
-
app.add_middleware(
|
30 |
-
CORSMiddleware,
|
31 |
-
allow_origins=["*"],
|
32 |
-
allow_credentials=True,
|
33 |
-
allow_methods=["*"],
|
34 |
-
allow_headers=["*"],
|
35 |
-
)
|
36 |
-
|
37 |
-
groqClient = Groq (api_key=os.environ.get("GROQ_API_KEY"))
|
38 |
-
|
39 |
-
load_dotenv()
|
40 |
-
token = os.environ.get("HF_TOKEN")
|
41 |
-
login(token)
|
42 |
-
|
43 |
-
prompt_model = "llama-3.1-8b-instant"
|
44 |
-
magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
|
45 |
-
options = {"use_cache": False, "wait_for_model": True}
|
46 |
-
parameters = {"return_full_text":False, "max_new_tokens":300}
|
47 |
-
headers = {"Authorization": f"Bearer {token}", "x-use-cache":"0", 'Content-Type' :'application/json'}
|
48 |
-
API_URL = f'https://api-inference.huggingface.co/models/'
|
49 |
-
perm_negative_prompt = "watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"
|
50 |
-
cwd = os.getcwd()
|
51 |
-
pictures_directory = os.path.join(cwd, 'pictures')
|
52 |
-
last_two_models = []
|
53 |
-
|
54 |
-
class Item(BaseModel):
|
55 |
-
prompt: str
|
56 |
-
steps: int
|
57 |
-
guidance: float
|
58 |
-
modelID: str
|
59 |
-
modelLabel: str
|
60 |
-
image: Optional[str] = None
|
61 |
-
target: str
|
62 |
-
control: float
|
63 |
-
|
64 |
-
class Core(BaseModel):
|
65 |
-
itemString: str
|
66 |
-
|
67 |
-
@app.get("/core")
|
68 |
-
async def core():
|
69 |
-
if not os.path.exists(pictures_directory):
|
70 |
-
os.makedirs(pictures_directory)
|
71 |
-
async def generator():
|
72 |
-
# Start JSON array
|
73 |
-
yield '['
|
74 |
-
first = True
|
75 |
-
for filename in os.listdir(pictures_directory):
|
76 |
-
if filename.endswith('.json'):
|
77 |
-
file_path = os.path.join(pictures_directory, filename)
|
78 |
-
with open(file_path, 'r') as file:
|
79 |
-
data = json.load(file)
|
80 |
-
|
81 |
-
# For JSON formatting, ensure only the first item doesn't have a preceding comma
|
82 |
-
if first:
|
83 |
-
first = False
|
84 |
-
else:
|
85 |
-
yield ','
|
86 |
-
yield json.dumps({"base64": data["base64image"], "prompt": data["returnedPrompt"]})
|
87 |
-
# End JSON array
|
88 |
-
yield ']'
|
89 |
-
|
90 |
-
return StreamingResponse(generator(), media_type="application/json")
|
91 |
-
|
92 |
-
|
93 |
-
def getPrompt(prompt, modelID, attempts=1):
|
94 |
-
response = {}
|
95 |
-
print(modelID)
|
96 |
-
try:
|
97 |
-
if modelID != magic_prompt_model:
|
98 |
-
chat = [
|
99 |
-
{"role": "user", "content": prompt_base},
|
100 |
-
{"role": "assistant", "content": prompt_assistant},
|
101 |
-
{"role": "user", "content": prompt},
|
102 |
-
]
|
103 |
-
response = groqClient.chat.completions.create(messages=chat, temperature=1, max_tokens=2048, top_p=1, stream=False, stop=None, model=modelID)
|
104 |
-
else:
|
105 |
-
apiData={"inputs":prompt, "parameters": parameters, "options": options, "timeout": 45}
|
106 |
-
response = requests.post(API_URL + modelID, headers=headers, data=json.dumps(apiData))
|
107 |
-
return response.json()
|
108 |
-
except Exception as e:
|
109 |
-
print(f"An error occurred: {e}")
|
110 |
-
if attempts < 3:
|
111 |
-
getPrompt(prompt, modelID, attempts + 1)
|
112 |
-
return response
|
113 |
-
|
114 |
-
@app.post("/inferencePrompt")
|
115 |
-
def inferencePrompt(item: Core):
|
116 |
-
print("Start API Inference Prompt")
|
117 |
-
try:
|
118 |
-
plain_response_data = getPrompt(item.itemString, prompt_model)
|
119 |
-
magic_response_data = getPrompt(item.itemString, magic_prompt_model)
|
120 |
-
returnJson = {"plain": plain_response_data.choices[0].message.content, "magic": item.itemString + magic_response_data[0]["generated_text"]}
|
121 |
-
print(f'Return Json {returnJson}')
|
122 |
-
return returnJson
|
123 |
-
except Exception as e:
|
124 |
-
returnJson = {"plain": f'An Error occured: {e}', "magic": f'An Error occured: {e}'}
|
125 |
-
|
126 |
-
async def wake_model(modelID):
|
127 |
-
data = {"inputs":"wake up call", "options":options}
|
128 |
-
headers = {"Authorization": f"Bearer {token}"}
|
129 |
-
api_data = json.dumps(data)
|
130 |
-
try:
|
131 |
-
timeout = aiohttp.ClientTimeout(total=60) # Set timeout to 60 seconds
|
132 |
-
async with aiohttp.ClientSession(timeout=timeout) as session:
|
133 |
-
async with session.post(API_URL + modelID, headers=headers, data=api_data) as response:
|
134 |
-
pass
|
135 |
-
print('Model Waking')
|
136 |
-
|
137 |
-
except Exception as e:
|
138 |
-
print(f"An error occurred: {e}")
|
139 |
-
|
140 |
-
def formatReturn(result):
|
141 |
-
img = Image.open(result)
|
142 |
-
img.save("test.png")
|
143 |
-
img_byte_arr = BytesIO()
|
144 |
-
img.save(img_byte_arr, format='PNG')
|
145 |
-
img_byte_arr = img_byte_arr.getvalue()
|
146 |
-
base64_img = base64.b64encode(img_byte_arr).decode('utf-8')
|
147 |
-
|
148 |
-
return base64_img
|
149 |
-
|
150 |
-
def save_image(base64image, item, model, NSFW):
|
151 |
-
if not NSFW:
|
152 |
-
data = {"base64image": "data:image/png;base64," + base64image, "returnedPrompt": "Model:\n" + model + "\n\nPrompt:\n" + item.prompt, "prompt": item.prompt, "steps": item.steps, "guidance": item.guidance, "control": item.control, "target": item.target}
|
153 |
-
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
154 |
-
file_path = os.path.join(pictures_directory, f'{timestamp}.json')
|
155 |
-
with open(file_path, 'w') as json_file:
|
156 |
-
json.dump(data, json_file)
|
157 |
-
|
158 |
-
def gradioSD3(item):
|
159 |
-
client = Client(item.modelID, hf_token=token)
|
160 |
-
result = client.predict(
|
161 |
-
prompt=item.prompt,
|
162 |
-
negative_prompt=perm_negative_prompt,
|
163 |
-
guidance_scale=item.guidance,
|
164 |
-
num_inference_steps=item.steps,
|
165 |
-
api_name="/infer"
|
166 |
-
)
|
167 |
-
return formatReturn(result[0])
|
168 |
-
|
169 |
-
def gradioAuraFlow(item):
|
170 |
-
client = Client("multimodalart/AuraFlow")
|
171 |
-
result = client.predict(
|
172 |
-
prompt=item.prompt,
|
173 |
-
negative_prompt=perm_negative_prompt,
|
174 |
-
randomize_seed=True,
|
175 |
-
guidance_scale=item.guidance,
|
176 |
-
num_inference_steps=item.steps,
|
177 |
-
api_name="/infer"
|
178 |
-
)
|
179 |
-
print(result[0])
|
180 |
-
return formatReturn(result[0]["value"])
|
181 |
-
|
182 |
-
def gradioHatmanInstantStyle(item):
|
183 |
-
client = Client("Hatman/InstantStyle")
|
184 |
-
image_stream = BytesIO(base64.b64decode(item.image.split("base64,")[1]))
|
185 |
-
image = Image.open(image_stream)
|
186 |
-
image.save("style.png")
|
187 |
-
result = client.predict(
|
188 |
-
image_pil=file("style.png"),
|
189 |
-
prompt=item.prompt,
|
190 |
-
n_prompt=perm_negative_prompt,
|
191 |
-
scale=1,
|
192 |
-
control_scale=item.control,
|
193 |
-
guidance_scale=item.guidance,
|
194 |
-
num_inference_steps=item.steps,
|
195 |
-
seed=1,
|
196 |
-
target=item.target,
|
197 |
-
api_name="/create_image"
|
198 |
-
)
|
199 |
-
return formatReturn(result)
|
200 |
-
|
201 |
-
def lambda_image(prompt, modelID):
|
202 |
-
data = {
|
203 |
-
"prompt": prompt,
|
204 |
-
"modelID": modelID
|
205 |
-
}
|
206 |
-
serialized_data = json.dumps(data)
|
207 |
-
aws_id = os.environ.get("AWS_ID")
|
208 |
-
aws_secret = os.environ.get("AWS_SECRET")
|
209 |
-
aws_region = os.environ.get("AWS_REGION")
|
210 |
-
try:
|
211 |
-
session = boto3.Session(aws_access_key_id=aws_id, aws_secret_access_key=aws_secret, region_name=aws_region)
|
212 |
-
lambda_client = session.client('lambda')
|
213 |
-
response = lambda_client.invoke(
|
214 |
-
FunctionName='pixel_prompt_lambda',
|
215 |
-
InvocationType='RequestResponse',
|
216 |
-
Payload=serialized_data
|
217 |
-
)
|
218 |
-
response_payload = response['Payload'].read()
|
219 |
-
response_data = json.loads(response_payload)
|
220 |
-
except Exception as e:
|
221 |
-
print(f"An error occurred: {e}")
|
222 |
-
|
223 |
-
return response_data['body']
|
224 |
-
|
225 |
-
def inferenceAPI(model, item, attempts = 1):
|
226 |
-
print(f'Inference model {model}')
|
227 |
-
if attempts > 5:
|
228 |
-
return 'An error occured when Processing', model
|
229 |
-
prompt = item.prompt
|
230 |
-
if "dallinmackay" in model:
|
231 |
-
prompt = "lvngvncnt, " + item.prompt
|
232 |
-
data = {"inputs":prompt, "negative_prompt": perm_negative_prompt, "options":options, "timeout": 45}
|
233 |
-
api_data = json.dumps(data)
|
234 |
-
try:
|
235 |
-
response = requests.request("POST", API_URL + model, headers=headers, data=api_data)
|
236 |
-
if response is None:
|
237 |
-
inferenceAPI(get_random_model(activeModels['text-to-image']), item, attempts+1)
|
238 |
-
print(response.content[0:200])
|
239 |
-
image_stream = BytesIO(response.content)
|
240 |
-
image = Image.open(image_stream)
|
241 |
-
image.save("response.png")
|
242 |
-
with open('response.png', 'rb') as f:
|
243 |
-
base64_img = base64.b64encode(f.read()).decode('utf-8')
|
244 |
-
return model, base64_img
|
245 |
-
except Exception as e:
|
246 |
-
print(f'Error When Processing Image: {e}')
|
247 |
-
activeModels = InferenceClient().list_deployed_models()
|
248 |
-
model = get_random_model(activeModels['text-to-image'])
|
249 |
-
pattern = r'^(.{1,30})\/(.{1,50})$'
|
250 |
-
if not re.match(pattern, model):
|
251 |
-
return "error model not valid", model
|
252 |
-
return inferenceAPI(model, item, attempts+1)
|
253 |
-
|
254 |
-
|
255 |
-
def get_random_model(models):
|
256 |
-
global last_two_models
|
257 |
-
model = None
|
258 |
-
priorities = [
|
259 |
-
"stabilityai/stable-diffusion-3.5-large-turbo",
|
260 |
-
"stabilityai/stable-diffusion-3.5-large",
|
261 |
-
"black-forest-labs",
|
262 |
-
"kandinsky-community",
|
263 |
-
"Kolors-diffusers",
|
264 |
-
"Juggernaut",
|
265 |
-
"insaneRealistic",
|
266 |
-
"MajicMIX",
|
267 |
-
"digiautogpt3",
|
268 |
-
"fluently"
|
269 |
-
]
|
270 |
-
|
271 |
-
for priority in priorities:
|
272 |
-
for i, model_name in enumerate(models):
|
273 |
-
if priority in model_name and model_name not in last_two_models:
|
274 |
-
model = models[i]
|
275 |
-
break
|
276 |
-
if model is not None:
|
277 |
-
break
|
278 |
-
if model is None:
|
279 |
-
print("Choosing randomly")
|
280 |
-
model = random.choice(models)
|
281 |
-
last_two_models.append(model)
|
282 |
-
last_two_models = last_two_models[-5:]
|
283 |
-
|
284 |
-
return model
|
285 |
-
|
286 |
-
def nsfw_check(item, attempts=1):
|
287 |
-
try:
|
288 |
-
API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
|
289 |
-
with open('response.png', 'rb') as f:
|
290 |
-
data = f.read()
|
291 |
-
response = requests.request("POST", API_URL, headers=headers, data=data)
|
292 |
-
decoded_response = response.content.decode("utf-8")
|
293 |
-
print(item.prompt)
|
294 |
-
print(decoded_response)
|
295 |
-
|
296 |
-
json_response = json.loads(decoded_response)
|
297 |
-
|
298 |
-
if "error" in json_response:
|
299 |
-
time.sleep(json_response["estimated_time"])
|
300 |
-
return nsfw_check(item, attempts+1)
|
301 |
-
|
302 |
-
scores = {item['label']: item['score'] for item in json_response}
|
303 |
-
error_msg = scores.get('nsfw', 0) > .1
|
304 |
-
return error_msg
|
305 |
-
except json.JSONDecodeError as e:
|
306 |
-
print(f'JSON Decoding Error: {e}')
|
307 |
-
return True
|
308 |
-
except Exception as e:
|
309 |
-
print(f'NSFW Check Error: {e}')
|
310 |
-
if attempts > 30:
|
311 |
-
return True
|
312 |
-
return nsfw_check(item, attempts+1)
|
313 |
-
|
314 |
-
|
315 |
-
@app.post("/api")
|
316 |
-
async def inference(item: Item):
|
317 |
-
print("Start API Inference")
|
318 |
-
activeModels = InferenceClient().list_deployed_models()
|
319 |
-
base64_img = ""
|
320 |
-
model = item.modelID
|
321 |
-
print(f'Start Model {model}')
|
322 |
-
NSFW = False
|
323 |
-
try:
|
324 |
-
if item.image:
|
325 |
-
model = "stabilityai/stable-diffusion-xl-base-1.0"
|
326 |
-
base64_img = gradioHatmanInstantStyle(item)
|
327 |
-
elif "AuraFlow" in item.modelID:
|
328 |
-
base64_img = gradioAuraFlow(item)
|
329 |
-
elif "Random" in item.modelID:
|
330 |
-
model = get_random_model(activeModels['text-to-image'])
|
331 |
-
pattern = r'^(.{1,30})\/(.{1,50})$'
|
332 |
-
if not re.match(pattern, model):
|
333 |
-
raise ValueError("Model not Valid")
|
334 |
-
model, base64_img= inferenceAPI(model, item)
|
335 |
-
elif "stable-diffusion-3" in item.modelID:
|
336 |
-
base64_img = gradioSD3(item)
|
337 |
-
elif "Voxel" in item.modelID or "pixel" in item.modelID:
|
338 |
-
prompt = item.prompt
|
339 |
-
if "Voxel" in item.modelID:
|
340 |
-
prompt = "voxel style, " + item.prompt
|
341 |
-
base64_img = lambda_image(prompt, item.modelID)
|
342 |
-
elif item.modelID not in activeModels['text-to-image']:
|
343 |
-
asyncio.create_task(wake_model(item.modelID))
|
344 |
-
return {"output": "Model Waking"}
|
345 |
-
else:
|
346 |
-
base64_img, model = inferenceAPI(item.modelID, item)
|
347 |
-
if 'error' in base64_img:
|
348 |
-
return {"output": base64_img, "model": model}
|
349 |
-
NSFW = nsfw_check(item)
|
350 |
-
|
351 |
-
save_image(base64_img, item, model, NSFW)
|
352 |
-
except Exception as e:
|
353 |
-
print(f"An error occurred: {e}")
|
354 |
-
base64_img = f"An error occurred: {e}"
|
355 |
-
return {"output": base64_img, "model": model, "NSFW": NSFW}
|
356 |
-
|
357 |
-
prompt_base = 'Instructions:\
|
358 |
-
\
|
359 |
-
1. Take the provided seed string as inspiration.\
|
360 |
-
2. Generate a prompt that is clear, vivid, and imaginative.\
|
361 |
-
3. This is a visual image so any reference to senses other than sight should be avoided.\
|
362 |
-
4. Ensure the prompt is between 90 and 100 tokens.\
|
363 |
-
5. Return only the prompt.\
|
364 |
-
Format your response as follows:\
|
365 |
-
Stable Diffusion Prompt: [Your prompt here]\
|
366 |
-
\
|
367 |
-
Remember:\
|
368 |
-
\
|
369 |
-
- The prompt should be descriptive.\
|
370 |
-
- Avoid overly complex or abstract phrases.\
|
371 |
-
- Make sure the prompt evokes strong imagery and can guide the creation of visual content.\
|
372 |
-
- Make sure the prompt is between 90 and 100 tokens.'
|
373 |
-
|
374 |
-
prompt_assistant = "I am ready to return a prompt that is between 90 and 100 tokens. What is your seed string?"
|
375 |
-
|
376 |
-
app.mount("/", StaticFiles(directory="web-build", html=True), name="build")
|
377 |
-
|
378 |
-
@app.get('/')
|
379 |
-
def homepage() -> FileResponse:
|
380 |
-
return FileResponse(path="/app/build/index.html", media_type="text/html")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|