Hatman commited on
Commit
3ec06f1
·
verified ·
1 Parent(s): 2cf216d

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -380
main.py DELETED
@@ -1,380 +0,0 @@
1
- import random
2
- from fastapi import FastAPI
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.responses import FileResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from huggingface_hub import InferenceClient, login
7
- from transformers import AutoTokenizer
8
- from pydantic import BaseModel
9
- from gradio_client import Client, file
10
- from starlette.responses import StreamingResponse
11
- import re
12
- from datetime import datetime
13
- import json
14
- import requests
15
- import base64
16
- import os
17
- import time
18
- from PIL import Image
19
- from io import BytesIO
20
- import aiohttp
21
- import asyncio
22
- from typing import Optional
23
- from dotenv import load_dotenv
24
- import boto3
25
- from groq import Groq
26
-
27
- app = FastAPI()
28
-
29
- app.add_middleware(
30
- CORSMiddleware,
31
- allow_origins=["*"],
32
- allow_credentials=True,
33
- allow_methods=["*"],
34
- allow_headers=["*"],
35
- )
36
-
37
- groqClient = Groq (api_key=os.environ.get("GROQ_API_KEY"))
38
-
39
- load_dotenv()
40
- token = os.environ.get("HF_TOKEN")
41
- login(token)
42
-
43
- prompt_model = "llama-3.1-8b-instant"
44
- magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
45
- options = {"use_cache": False, "wait_for_model": True}
46
- parameters = {"return_full_text":False, "max_new_tokens":300}
47
- headers = {"Authorization": f"Bearer {token}", "x-use-cache":"0", 'Content-Type' :'application/json'}
48
- API_URL = f'https://api-inference.huggingface.co/models/'
49
- perm_negative_prompt = "watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"
50
- cwd = os.getcwd()
51
- pictures_directory = os.path.join(cwd, 'pictures')
52
- last_two_models = []
53
-
54
- class Item(BaseModel):
55
- prompt: str
56
- steps: int
57
- guidance: float
58
- modelID: str
59
- modelLabel: str
60
- image: Optional[str] = None
61
- target: str
62
- control: float
63
-
64
- class Core(BaseModel):
65
- itemString: str
66
-
67
- @app.get("/core")
68
- async def core():
69
- if not os.path.exists(pictures_directory):
70
- os.makedirs(pictures_directory)
71
- async def generator():
72
- # Start JSON array
73
- yield '['
74
- first = True
75
- for filename in os.listdir(pictures_directory):
76
- if filename.endswith('.json'):
77
- file_path = os.path.join(pictures_directory, filename)
78
- with open(file_path, 'r') as file:
79
- data = json.load(file)
80
-
81
- # For JSON formatting, ensure only the first item doesn't have a preceding comma
82
- if first:
83
- first = False
84
- else:
85
- yield ','
86
- yield json.dumps({"base64": data["base64image"], "prompt": data["returnedPrompt"]})
87
- # End JSON array
88
- yield ']'
89
-
90
- return StreamingResponse(generator(), media_type="application/json")
91
-
92
-
93
- def getPrompt(prompt, modelID, attempts=1):
94
- response = {}
95
- print(modelID)
96
- try:
97
- if modelID != magic_prompt_model:
98
- chat = [
99
- {"role": "user", "content": prompt_base},
100
- {"role": "assistant", "content": prompt_assistant},
101
- {"role": "user", "content": prompt},
102
- ]
103
- response = groqClient.chat.completions.create(messages=chat, temperature=1, max_tokens=2048, top_p=1, stream=False, stop=None, model=modelID)
104
- else:
105
- apiData={"inputs":prompt, "parameters": parameters, "options": options, "timeout": 45}
106
- response = requests.post(API_URL + modelID, headers=headers, data=json.dumps(apiData))
107
- return response.json()
108
- except Exception as e:
109
- print(f"An error occurred: {e}")
110
- if attempts < 3:
111
- getPrompt(prompt, modelID, attempts + 1)
112
- return response
113
-
114
- @app.post("/inferencePrompt")
115
- def inferencePrompt(item: Core):
116
- print("Start API Inference Prompt")
117
- try:
118
- plain_response_data = getPrompt(item.itemString, prompt_model)
119
- magic_response_data = getPrompt(item.itemString, magic_prompt_model)
120
- returnJson = {"plain": plain_response_data.choices[0].message.content, "magic": item.itemString + magic_response_data[0]["generated_text"]}
121
- print(f'Return Json {returnJson}')
122
- return returnJson
123
- except Exception as e:
124
- returnJson = {"plain": f'An Error occured: {e}', "magic": f'An Error occured: {e}'}
125
-
126
- async def wake_model(modelID):
127
- data = {"inputs":"wake up call", "options":options}
128
- headers = {"Authorization": f"Bearer {token}"}
129
- api_data = json.dumps(data)
130
- try:
131
- timeout = aiohttp.ClientTimeout(total=60) # Set timeout to 60 seconds
132
- async with aiohttp.ClientSession(timeout=timeout) as session:
133
- async with session.post(API_URL + modelID, headers=headers, data=api_data) as response:
134
- pass
135
- print('Model Waking')
136
-
137
- except Exception as e:
138
- print(f"An error occurred: {e}")
139
-
140
- def formatReturn(result):
141
- img = Image.open(result)
142
- img.save("test.png")
143
- img_byte_arr = BytesIO()
144
- img.save(img_byte_arr, format='PNG')
145
- img_byte_arr = img_byte_arr.getvalue()
146
- base64_img = base64.b64encode(img_byte_arr).decode('utf-8')
147
-
148
- return base64_img
149
-
150
- def save_image(base64image, item, model, NSFW):
151
- if not NSFW:
152
- data = {"base64image": "data:image/png;base64," + base64image, "returnedPrompt": "Model:\n" + model + "\n\nPrompt:\n" + item.prompt, "prompt": item.prompt, "steps": item.steps, "guidance": item.guidance, "control": item.control, "target": item.target}
153
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
154
- file_path = os.path.join(pictures_directory, f'{timestamp}.json')
155
- with open(file_path, 'w') as json_file:
156
- json.dump(data, json_file)
157
-
158
- def gradioSD3(item):
159
- client = Client(item.modelID, hf_token=token)
160
- result = client.predict(
161
- prompt=item.prompt,
162
- negative_prompt=perm_negative_prompt,
163
- guidance_scale=item.guidance,
164
- num_inference_steps=item.steps,
165
- api_name="/infer"
166
- )
167
- return formatReturn(result[0])
168
-
169
- def gradioAuraFlow(item):
170
- client = Client("multimodalart/AuraFlow")
171
- result = client.predict(
172
- prompt=item.prompt,
173
- negative_prompt=perm_negative_prompt,
174
- randomize_seed=True,
175
- guidance_scale=item.guidance,
176
- num_inference_steps=item.steps,
177
- api_name="/infer"
178
- )
179
- print(result[0])
180
- return formatReturn(result[0]["value"])
181
-
182
- def gradioHatmanInstantStyle(item):
183
- client = Client("Hatman/InstantStyle")
184
- image_stream = BytesIO(base64.b64decode(item.image.split("base64,")[1]))
185
- image = Image.open(image_stream)
186
- image.save("style.png")
187
- result = client.predict(
188
- image_pil=file("style.png"),
189
- prompt=item.prompt,
190
- n_prompt=perm_negative_prompt,
191
- scale=1,
192
- control_scale=item.control,
193
- guidance_scale=item.guidance,
194
- num_inference_steps=item.steps,
195
- seed=1,
196
- target=item.target,
197
- api_name="/create_image"
198
- )
199
- return formatReturn(result)
200
-
201
- def lambda_image(prompt, modelID):
202
- data = {
203
- "prompt": prompt,
204
- "modelID": modelID
205
- }
206
- serialized_data = json.dumps(data)
207
- aws_id = os.environ.get("AWS_ID")
208
- aws_secret = os.environ.get("AWS_SECRET")
209
- aws_region = os.environ.get("AWS_REGION")
210
- try:
211
- session = boto3.Session(aws_access_key_id=aws_id, aws_secret_access_key=aws_secret, region_name=aws_region)
212
- lambda_client = session.client('lambda')
213
- response = lambda_client.invoke(
214
- FunctionName='pixel_prompt_lambda',
215
- InvocationType='RequestResponse',
216
- Payload=serialized_data
217
- )
218
- response_payload = response['Payload'].read()
219
- response_data = json.loads(response_payload)
220
- except Exception as e:
221
- print(f"An error occurred: {e}")
222
-
223
- return response_data['body']
224
-
225
- def inferenceAPI(model, item, attempts = 1):
226
- print(f'Inference model {model}')
227
- if attempts > 5:
228
- return 'An error occured when Processing', model
229
- prompt = item.prompt
230
- if "dallinmackay" in model:
231
- prompt = "lvngvncnt, " + item.prompt
232
- data = {"inputs":prompt, "negative_prompt": perm_negative_prompt, "options":options, "timeout": 45}
233
- api_data = json.dumps(data)
234
- try:
235
- response = requests.request("POST", API_URL + model, headers=headers, data=api_data)
236
- if response is None:
237
- inferenceAPI(get_random_model(activeModels['text-to-image']), item, attempts+1)
238
- print(response.content[0:200])
239
- image_stream = BytesIO(response.content)
240
- image = Image.open(image_stream)
241
- image.save("response.png")
242
- with open('response.png', 'rb') as f:
243
- base64_img = base64.b64encode(f.read()).decode('utf-8')
244
- return model, base64_img
245
- except Exception as e:
246
- print(f'Error When Processing Image: {e}')
247
- activeModels = InferenceClient().list_deployed_models()
248
- model = get_random_model(activeModels['text-to-image'])
249
- pattern = r'^(.{1,30})\/(.{1,50})$'
250
- if not re.match(pattern, model):
251
- return "error model not valid", model
252
- return inferenceAPI(model, item, attempts+1)
253
-
254
-
255
- def get_random_model(models):
256
- global last_two_models
257
- model = None
258
- priorities = [
259
- "stabilityai/stable-diffusion-3.5-large-turbo",
260
- "stabilityai/stable-diffusion-3.5-large",
261
- "black-forest-labs",
262
- "kandinsky-community",
263
- "Kolors-diffusers",
264
- "Juggernaut",
265
- "insaneRealistic",
266
- "MajicMIX",
267
- "digiautogpt3",
268
- "fluently"
269
- ]
270
-
271
- for priority in priorities:
272
- for i, model_name in enumerate(models):
273
- if priority in model_name and model_name not in last_two_models:
274
- model = models[i]
275
- break
276
- if model is not None:
277
- break
278
- if model is None:
279
- print("Choosing randomly")
280
- model = random.choice(models)
281
- last_two_models.append(model)
282
- last_two_models = last_two_models[-5:]
283
-
284
- return model
285
-
286
- def nsfw_check(item, attempts=1):
287
- try:
288
- API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
289
- with open('response.png', 'rb') as f:
290
- data = f.read()
291
- response = requests.request("POST", API_URL, headers=headers, data=data)
292
- decoded_response = response.content.decode("utf-8")
293
- print(item.prompt)
294
- print(decoded_response)
295
-
296
- json_response = json.loads(decoded_response)
297
-
298
- if "error" in json_response:
299
- time.sleep(json_response["estimated_time"])
300
- return nsfw_check(item, attempts+1)
301
-
302
- scores = {item['label']: item['score'] for item in json_response}
303
- error_msg = scores.get('nsfw', 0) > .1
304
- return error_msg
305
- except json.JSONDecodeError as e:
306
- print(f'JSON Decoding Error: {e}')
307
- return True
308
- except Exception as e:
309
- print(f'NSFW Check Error: {e}')
310
- if attempts > 30:
311
- return True
312
- return nsfw_check(item, attempts+1)
313
-
314
-
315
- @app.post("/api")
316
- async def inference(item: Item):
317
- print("Start API Inference")
318
- activeModels = InferenceClient().list_deployed_models()
319
- base64_img = ""
320
- model = item.modelID
321
- print(f'Start Model {model}')
322
- NSFW = False
323
- try:
324
- if item.image:
325
- model = "stabilityai/stable-diffusion-xl-base-1.0"
326
- base64_img = gradioHatmanInstantStyle(item)
327
- elif "AuraFlow" in item.modelID:
328
- base64_img = gradioAuraFlow(item)
329
- elif "Random" in item.modelID:
330
- model = get_random_model(activeModels['text-to-image'])
331
- pattern = r'^(.{1,30})\/(.{1,50})$'
332
- if not re.match(pattern, model):
333
- raise ValueError("Model not Valid")
334
- model, base64_img= inferenceAPI(model, item)
335
- elif "stable-diffusion-3" in item.modelID:
336
- base64_img = gradioSD3(item)
337
- elif "Voxel" in item.modelID or "pixel" in item.modelID:
338
- prompt = item.prompt
339
- if "Voxel" in item.modelID:
340
- prompt = "voxel style, " + item.prompt
341
- base64_img = lambda_image(prompt, item.modelID)
342
- elif item.modelID not in activeModels['text-to-image']:
343
- asyncio.create_task(wake_model(item.modelID))
344
- return {"output": "Model Waking"}
345
- else:
346
- base64_img, model = inferenceAPI(item.modelID, item)
347
- if 'error' in base64_img:
348
- return {"output": base64_img, "model": model}
349
- NSFW = nsfw_check(item)
350
-
351
- save_image(base64_img, item, model, NSFW)
352
- except Exception as e:
353
- print(f"An error occurred: {e}")
354
- base64_img = f"An error occurred: {e}"
355
- return {"output": base64_img, "model": model, "NSFW": NSFW}
356
-
357
- prompt_base = 'Instructions:\
358
- \
359
- 1. Take the provided seed string as inspiration.\
360
- 2. Generate a prompt that is clear, vivid, and imaginative.\
361
- 3. This is a visual image so any reference to senses other than sight should be avoided.\
362
- 4. Ensure the prompt is between 90 and 100 tokens.\
363
- 5. Return only the prompt.\
364
- Format your response as follows:\
365
- Stable Diffusion Prompt: [Your prompt here]\
366
- \
367
- Remember:\
368
- \
369
- - The prompt should be descriptive.\
370
- - Avoid overly complex or abstract phrases.\
371
- - Make sure the prompt evokes strong imagery and can guide the creation of visual content.\
372
- - Make sure the prompt is between 90 and 100 tokens.'
373
-
374
- prompt_assistant = "I am ready to return a prompt that is between 90 and 100 tokens. What is your seed string?"
375
-
376
- app.mount("/", StaticFiles(directory="web-build", html=True), name="build")
377
-
378
- @app.get('/')
379
- def homepage() -> FileResponse:
380
- return FileResponse(path="/app/build/index.html", media_type="text/html")