Dash-inc commited on
Commit
1ba2980
·
verified ·
1 Parent(s): db93af0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +103 -124
main.py CHANGED
@@ -1,146 +1,125 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
 
 
2
  from pydantic import BaseModel
3
- from fastapi.middleware.cors import CORSMiddleware
4
- import uuid
5
- from concurrent.futures import ThreadPoolExecutor
6
- from pymongo import MongoClient
7
- from urllib.parse import quote_plus
8
- from langchain_groq import ChatGroq
9
- from aura_sr import AuraSR
10
- from io import BytesIO
11
  from PIL import Image
 
12
  import requests
13
- import os
 
 
14
 
15
  app = FastAPI()
16
 
17
- # Middleware for CORS
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"],
21
- allow_methods=["*"],
22
- allow_headers=["*"],
23
- )
24
-
25
- # Globals
26
- executor = ThreadPoolExecutor(max_workers=10)
27
- llm = None
28
- upscale_model = None
29
- client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
30
- db = client["Flux"]
31
- collection = db["chat_histories"]
32
- chat_sessions = {}
33
- image_storage_dir = "./images" # Directory to save images locally
34
-
35
- # Ensure the image storage directory exists
36
- os.makedirs(image_storage_dir, exist_ok=True)
37
-
38
- @app.on_event("startup")
39
- async def startup():
40
- global llm, upscale_model
41
- llm = ChatGroq(
42
- model="llama-3.3-70b-versatile",
43
- temperature=0.7,
44
- max_tokens=1024,
45
- api_key="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",
46
- )
47
- upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
48
 
49
- @app.on_event("shutdown")
50
- def shutdown():
51
- client.close()
52
- executor.shutdown()
53
 
54
- # Pydantic models
55
- class ImageRequest(BaseModel):
 
 
 
56
  subject: str
57
  style: str
58
  color_theme: str
59
  elements: str
60
- color_mode: str
61
  lighting_conditions: str
62
  framing_style: str
63
  material_details: str
64
  text: str
65
  background_details: str
66
- user_prompt: str
67
  chat_id: str
 
 
68
 
69
- # Helper Functions
70
- def generate_chat_id():
 
 
 
 
71
  chat_id = str(uuid.uuid4())
72
- chat_sessions[chat_id] = collection
73
- return chat_id
74
-
75
- def get_chat_history(chat_id):
76
- messages = collection.find({"session_id": chat_id})
77
- return "\n".join(
78
- f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
79
- for msg in messages
80
- )
81
-
82
- def save_image_locally(image, filename):
83
- filepath = os.path.join(image_storage_dir, filename)
84
- image.save(filepath, format="PNG")
85
- return filepath
86
-
87
- # Endpoints
88
- @app.post("/new-chat", response_model=dict)
89
- async def new_chat():
90
- chat_id = generate_chat_id()
91
- return {"chat_id": chat_id}
92
-
93
- @app.post("/generate-image", response_model=dict)
94
- async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks):
95
- def process_request():
96
- chat_history = get_chat_history(request.chat_id)
97
- prompt = f"""
98
- Subject: {request.subject}
99
- Style: {request.style}
100
- ...
101
- Chat History: {chat_history}
102
- User Prompt: {request.user_prompt}
103
- """
104
- refined_prompt = llm.invoke(prompt).content.strip()
105
- collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
106
- collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
107
 
 
 
 
 
108
  # Simulate image generation
109
- response = requests.post(
110
- "https://api.bfl.ml/v1/flux-pro-1.1",
111
- json={"prompt": refined_prompt}
112
- ).json()
113
- image_url = response["result"]["sample"]
114
-
115
- # Download and save the image locally
116
- image_response = requests.get(image_url)
117
- img = Image.open(BytesIO(image_response.content))
118
- filename = f"generated_{uuid.uuid4()}.png"
119
- filepath = save_image_locally(img, filename)
120
- return filepath
121
-
122
- task = executor.submit(process_request)
123
- background_tasks.add_task(task)
124
- return {"status": "Processing"}
125
-
126
- @app.post("/upscale-image", response_model=dict)
127
- async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
128
- def process_image():
129
- response = requests.get(image_url)
130
- img = Image.open(BytesIO(response.content))
131
- upscaled_image = upscale_model.upscale_4x_overlapped(img)
132
- filename = f"upscaled_{uuid.uuid4()}.png"
133
- filepath = save_image_locally(upscaled_image, filename)
134
- return filepath
135
-
136
- task = executor.submit(process_image)
137
- background_tasks.add_task(task)
138
- return {"status": "Processing"}
139
-
140
- @app.post("/set-prompt", response_model=dict)
141
- async def set_prompt(chat_id: str, user_prompt: str):
142
- chat_history = get_chat_history(chat_id)
143
- refined_prompt = llm.invoke(f"{chat_history}\nUser Prompt: {user_prompt}").content.strip()
144
- collection.insert_one({"session_id": chat_id, "role": "user", "content": user_prompt})
145
- collection.insert_one({"session_id": chat_id, "role": "ai", "content": refined_prompt})
146
- return {"refined_prompt": refined_prompt}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
4
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
+ import io
7
  import requests
8
+ from typing import Optional
9
+ import threading
10
+ import uuid
11
 
12
  app = FastAPI()
13
 
14
+ # Use a temporary directory for storing images
15
+ temp_dir = tempfile.mkdtemp()
16
+ print(f"Temporary directory for images: {temp_dir}")
17
+
18
+ # Ensure that the temporary directory exists
19
+ os.makedirs(temp_dir, exist_ok=True)
20
+
21
+
22
+ # Endpoint models
23
+ class NewChatRequest(BaseModel):
24
+ user_name: str
25
+
26
+
27
+ class GenerateImageRequest(BaseModel):
28
+ prompt: str
29
+ width: int = 1024
30
+ height: int = 1024
31
+ inference_steps: int = 50
32
+ guidance_scale: Optional[float] = 7.5
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
34
 
35
+ class UpscaleImageRequest(BaseModel):
36
+ image_id: str
37
+
38
+
39
+ class SetPromptRequest(BaseModel):
40
  subject: str
41
  style: str
42
  color_theme: str
43
  elements: str
 
44
  lighting_conditions: str
45
  framing_style: str
46
  material_details: str
47
  text: str
48
  background_details: str
 
49
  chat_id: str
50
+ user_prompt: str
51
+
52
 
53
+ # Dictionary to store images temporarily
54
+ image_store = {}
55
+
56
+
57
+ @app.post("/new-chat")
58
+ def new_chat(request: NewChatRequest):
59
  chat_id = str(uuid.uuid4())
60
+ # Simulate creating a chat session
61
+ return {"chat_id": chat_id, "message": f"New chat created for user: {request.user_name}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+
64
+ @app.post("/generate-image")
65
+ def generate_image(request: GenerateImageRequest, background_tasks: BackgroundTasks):
66
+ try:
67
  # Simulate image generation
68
+ image_id = str(uuid.uuid4())
69
+ image_path = os.path.join(temp_dir, f"{image_id}.png")
70
+
71
+ # Mock image generation (replace with actual logic)
72
+ img = Image.new("RGB", (request.width, request.height), color=(255, 255, 255))
73
+ img.save(image_path)
74
+
75
+ # Store image path
76
+ image_store[image_id] = image_path
77
+ return {"image_id": image_id, "message": "Image generation started"}
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=str(e))
80
+
81
+
82
+ @app.post("/upscale-image")
83
+ def upscale_image(request: UpscaleImageRequest):
84
+ image_path = image_store.get(request.image_id)
85
+ if not image_path:
86
+ raise HTTPException(status_code=404, detail="Image not found")
87
+
88
+ try:
89
+ # Simulate image upscaling
90
+ img = Image.open(image_path)
91
+ upscaled_img = img.resize((img.width * 2, img.height * 2), Image.LANCZOS)
92
+ upscaled_image_path = os.path.join(temp_dir, f"{request.image_id}_upscaled.png")
93
+ upscaled_img.save(upscaled_image_path)
94
+
95
+ # Update store
96
+ image_store[request.image_id] = upscaled_image_path
97
+ return {"image_id": request.image_id, "message": "Image upscaled"}
98
+ except Exception as e:
99
+ raise HTTPException(status_code=500, detail=str(e))
100
+
101
+
102
+ @app.post("/set-prompt")
103
+ def set_prompt(request: SetPromptRequest):
104
+ try:
105
+ # Simulate prompt generation (replace with actual LLM call)
106
+ refined_prompt = (
107
+ f"Subject: {request.subject}, Style: {request.style}, "
108
+ f"Color Theme: {request.color_theme}, Elements: {request.elements}, "
109
+ f"Background: {request.background_details}"
110
+ )
111
+ return {"refined_prompt": refined_prompt}
112
+ except Exception as e:
113
+ raise HTTPException(status_code=500, detail=str(e)}
114
+
115
+
116
+ # Clean up the temporary directory on shutdown
117
+ @app.on_event("shutdown")
118
+ def cleanup():
119
+ print("Cleaning up temporary directory...")
120
+ try:
121
+ for file in os.listdir(temp_dir):
122
+ os.remove(os.path.join(temp_dir, file))
123
+ os.rmdir(temp_dir)
124
+ except Exception as e:
125
+ print(f"Error cleaning up temporary directory: {e}")