Dash-inc commited on
Commit
df5fc1f
·
verified ·
1 Parent(s): 1ba2980

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +136 -100
main.py CHANGED
@@ -1,125 +1,161 @@
1
- import os
2
- import tempfile
3
- from fastapi import FastAPI, HTTPException, BackgroundTasks
4
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
- import io
7
  import requests
8
- from typing import Optional
9
- import threading
10
- import uuid
11
 
12
  app = FastAPI()
13
 
14
- # Use a temporary directory for storing images
15
- temp_dir = tempfile.mkdtemp()
16
- print(f"Temporary directory for images: {temp_dir}")
17
-
18
- # Ensure that the temporary directory exists
19
- os.makedirs(temp_dir, exist_ok=True)
20
-
 
 
 
 
 
 
 
 
 
21
 
22
- # Endpoint models
23
- class NewChatRequest(BaseModel):
24
- user_name: str
25
-
26
-
27
- class GenerateImageRequest(BaseModel):
28
- prompt: str
29
- width: int = 1024
30
- height: int = 1024
31
- inference_steps: int = 50
32
- guidance_scale: Optional[float] = 7.5
33
 
 
 
34
 
35
- class UpscaleImageRequest(BaseModel):
36
- image_id: str
 
 
 
 
 
 
 
 
37
 
 
 
 
 
38
 
39
- class SetPromptRequest(BaseModel):
 
40
  subject: str
41
  style: str
42
  color_theme: str
43
  elements: str
 
44
  lighting_conditions: str
45
  framing_style: str
46
  material_details: str
47
  text: str
48
  background_details: str
49
- chat_id: str
50
  user_prompt: str
 
51
 
52
-
53
- # Dictionary to store images temporarily
54
- image_store = {}
55
-
56
-
57
- @app.post("/new-chat")
58
- def new_chat(request: NewChatRequest):
59
  chat_id = str(uuid.uuid4())
60
- # Simulate creating a chat session
61
- return {"chat_id": chat_id, "message": f"New chat created for user: {request.user_name}"}
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- @app.post("/generate-image")
65
- def generate_image(request: GenerateImageRequest, background_tasks: BackgroundTasks):
66
- try:
67
  # Simulate image generation
68
- image_id = str(uuid.uuid4())
69
- image_path = os.path.join(temp_dir, f"{image_id}.png")
70
-
71
- # Mock image generation (replace with actual logic)
72
- img = Image.new("RGB", (request.width, request.height), color=(255, 255, 255))
73
- img.save(image_path)
74
-
75
- # Store image path
76
- image_store[image_id] = image_path
77
- return {"image_id": image_id, "message": "Image generation started"}
78
- except Exception as e:
79
- raise HTTPException(status_code=500, detail=str(e))
80
-
81
-
82
- @app.post("/upscale-image")
83
- def upscale_image(request: UpscaleImageRequest):
84
- image_path = image_store.get(request.image_id)
85
- if not image_path:
86
- raise HTTPException(status_code=404, detail="Image not found")
87
-
88
- try:
89
- # Simulate image upscaling
90
- img = Image.open(image_path)
91
- upscaled_img = img.resize((img.width * 2, img.height * 2), Image.LANCZOS)
92
- upscaled_image_path = os.path.join(temp_dir, f"{request.image_id}_upscaled.png")
93
- upscaled_img.save(upscaled_image_path)
94
-
95
- # Update store
96
- image_store[request.image_id] = upscaled_image_path
97
- return {"image_id": request.image_id, "message": "Image upscaled"}
98
- except Exception as e:
99
- raise HTTPException(status_code=500, detail=str(e))
100
-
101
-
102
- @app.post("/set-prompt")
103
- def set_prompt(request: SetPromptRequest):
104
- try:
105
- # Simulate prompt generation (replace with actual LLM call)
106
- refined_prompt = (
107
- f"Subject: {request.subject}, Style: {request.style}, "
108
- f"Color Theme: {request.color_theme}, Elements: {request.elements}, "
109
- f"Background: {request.background_details}"
110
- )
111
- return {"refined_prompt": refined_prompt}
112
- except Exception as e:
113
- raise HTTPException(status_code=500, detail=str(e)}
114
-
115
-
116
- # Clean up the temporary directory on shutdown
117
- @app.on_event("shutdown")
118
- def cleanup():
119
- print("Cleaning up temporary directory...")
120
- try:
121
- for file in os.listdir(temp_dir):
122
- os.remove(os.path.join(temp_dir, file))
123
- os.rmdir(temp_dir)
124
- except Exception as e:
125
- print(f"Error cleaning up temporary directory: {e}")
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
 
 
2
  from pydantic import BaseModel
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uuid
5
+ import tempfile
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from pymongo import MongoClient
8
+ from urllib.parse import quote_plus
9
+ from langchain_groq import ChatGroq
10
+ from aura_sr import AuraSR
11
+ from io import BytesIO
12
  from PIL import Image
 
13
  import requests
14
+ import os
 
 
15
 
16
  app = FastAPI()
17
 
18
+ # Middleware for CORS
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # Globals
27
+ executor = ThreadPoolExecutor(max_workers=10)
28
+ llm = None
29
+ upscale_model = None
30
+ client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
31
+ db = client["Flux"]
32
+ collection = db["chat_histories"]
33
+ chat_sessions = {}
34
 
35
+ # Use a temporary directory for storing images
36
+ image_storage_dir = tempfile.mkdtemp()
37
+ print(f"Temporary directory for images: {image_storage_dir}")
 
 
 
 
 
 
 
 
38
 
39
+ # Dictionary to store images temporarily
40
+ image_store = {}
41
 
42
+ @app.on_event("startup")
43
+ async def startup():
44
+ global llm, upscale_model
45
+ llm = ChatGroq(
46
+ model="llama-3.3-70b-versatile",
47
+ temperature=0.7,
48
+ max_tokens=1024,
49
+ api_key="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",
50
+ )
51
+ upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
52
 
53
+ @app.on_event("shutdown")
54
+ def shutdown():
55
+ client.close()
56
+ executor.shutdown()
57
 
58
+ # Pydantic models
59
+ class ImageRequest(BaseModel):
60
  subject: str
61
  style: str
62
  color_theme: str
63
  elements: str
64
+ color_mode: str
65
  lighting_conditions: str
66
  framing_style: str
67
  material_details: str
68
  text: str
69
  background_details: str
 
70
  user_prompt: str
71
+ chat_id: str
72
 
73
+ # Helper Functions
74
+ def generate_chat_id():
 
 
 
 
 
75
  chat_id = str(uuid.uuid4())
76
+ chat_sessions[chat_id] = collection
77
+ return chat_id
78
+
79
+ def get_chat_history(chat_id):
80
+ messages = collection.find({"session_id": chat_id})
81
+ return "\n".join(
82
+ f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
83
+ for msg in messages
84
+ )
85
+
86
+ def save_image_locally(image, filename):
87
+ filepath = os.path.join(image_storage_dir, filename)
88
+ image.save(filepath, format="PNG")
89
+ return filepath
90
+
91
+ # Endpoints
92
+ @app.post("/new-chat", response_model=dict)
93
+ async def new_chat():
94
+ chat_id = generate_chat_id()
95
+ return {"chat_id": chat_id}
96
+
97
+ @app.post("/generate-image", response_model=dict)
98
+ async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks):
99
+ def process_request():
100
+ chat_history = get_chat_history(request.chat_id)
101
+ prompt = f"""
102
+ You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
103
+
104
+ Image Specifications:
105
+ - **Subject**: Focus on **{subject}**, highlighting its defining features, expressions, and textures.
106
+ - **Style**: Emphasize the **{style}**, capturing its key characteristics.
107
+ - **Background**: Create a background with **{background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
108
+ - **Camera and Lighting**:
109
+ - Lighting: Apply **{lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
110
+ - **Framing**: Use a **{framing_style}** to enhance the composition around both the subject and the background.
111
+ - **Materials**: Highlight textures like **{material_details}**, with realistic details and natural imperfections on the subject and background.
112
+ - **Key Elements**: Include **{elements}** to enrich the subject’s details and add cohesive elements to the background.
113
+ - **Color Theme**: Follow the **{color_theme}** to set the mood and tone for the entire scene.
114
+ - Negative Prompt: Avoid grainy, blurry, or deformed outputs.
115
+ - **Text to Include in Image**: Clearly display the text **"{text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
116
+ - Write the prompt only in the output. Do not include anything else except the prompt. Do not write Generated Image Prompt: in the output.
117
+
118
+ Chat History:
119
+ {chat_history_text}
120
+
121
+ User Prompt:
122
+ {user_prompt}
123
+
124
+ Generated Image Prompt:
125
+
126
+ """
127
+ refined_prompt = llm.invoke(prompt).content.strip()
128
+ collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
129
+ collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
130
 
 
 
 
131
  # Simulate image generation
132
+ response = requests.post(
133
+ "https://api.bfl.ml/v1/flux-pro-1.1",
134
+ json={"prompt": refined_prompt}
135
+ ).json()
136
+ image_url = response["result"]["sample"]
137
+
138
+ # Download and save the image locally
139
+ image_response = requests.get(image_url)
140
+ img = Image.open(BytesIO(image_response.content))
141
+ filename = f"generated_{uuid.uuid4()}.png"
142
+ filepath = save_image_locally(img, filename)
143
+ return filepath
144
+
145
+ task = executor.submit(process_request)
146
+ background_tasks.add_task(task)
147
+ return {"status": "Processing"}
148
+
149
+ @app.post("/upscale-image", response_model=dict)
150
+ async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
151
+ def process_image():
152
+ response = requests.get(image_url)
153
+ img = Image.open(BytesIO(response.content))
154
+ upscaled_image = upscale_model.upscale_4x_overlapped(img)
155
+ filename = f"upscaled_{uuid.uuid4()}.png"
156
+ filepath = save_image_locally(upscaled_image, filename)
157
+ return filepath
158
+
159
+ task = executor.submit(process_image)
160
+ background_tasks.add_task(task)
161
+ return {"status": "Processing"}