Dash-inc commited on
Commit
db93af0
·
verified ·
1 Parent(s): d3e4644

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +146 -0
main.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
2
+ from pydantic import BaseModel
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uuid
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from pymongo import MongoClient
7
+ from urllib.parse import quote_plus
8
+ from langchain_groq import ChatGroq
9
+ from aura_sr import AuraSR
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import requests
13
+ import os
14
+
15
+ app = FastAPI()
16
+
17
+ # Middleware for CORS
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # Globals
26
+ executor = ThreadPoolExecutor(max_workers=10)
27
+ llm = None
28
+ upscale_model = None
29
+ client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
30
+ db = client["Flux"]
31
+ collection = db["chat_histories"]
32
+ chat_sessions = {}
33
+ image_storage_dir = "./images" # Directory to save images locally
34
+
35
+ # Ensure the image storage directory exists
36
+ os.makedirs(image_storage_dir, exist_ok=True)
37
+
38
+ @app.on_event("startup")
39
+ async def startup():
40
+ global llm, upscale_model
41
+ llm = ChatGroq(
42
+ model="llama-3.3-70b-versatile",
43
+ temperature=0.7,
44
+ max_tokens=1024,
45
+ api_key="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",
46
+ )
47
+ upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
48
+
49
+ @app.on_event("shutdown")
50
+ def shutdown():
51
+ client.close()
52
+ executor.shutdown()
53
+
54
+ # Pydantic models
55
+ class ImageRequest(BaseModel):
56
+ subject: str
57
+ style: str
58
+ color_theme: str
59
+ elements: str
60
+ color_mode: str
61
+ lighting_conditions: str
62
+ framing_style: str
63
+ material_details: str
64
+ text: str
65
+ background_details: str
66
+ user_prompt: str
67
+ chat_id: str
68
+
69
+ # Helper Functions
70
+ def generate_chat_id():
71
+ chat_id = str(uuid.uuid4())
72
+ chat_sessions[chat_id] = collection
73
+ return chat_id
74
+
75
+ def get_chat_history(chat_id):
76
+ messages = collection.find({"session_id": chat_id})
77
+ return "\n".join(
78
+ f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
79
+ for msg in messages
80
+ )
81
+
82
+ def save_image_locally(image, filename):
83
+ filepath = os.path.join(image_storage_dir, filename)
84
+ image.save(filepath, format="PNG")
85
+ return filepath
86
+
87
+ # Endpoints
88
+ @app.post("/new-chat", response_model=dict)
89
+ async def new_chat():
90
+ chat_id = generate_chat_id()
91
+ return {"chat_id": chat_id}
92
+
93
+ @app.post("/generate-image", response_model=dict)
94
+ async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks):
95
+ def process_request():
96
+ chat_history = get_chat_history(request.chat_id)
97
+ prompt = f"""
98
+ Subject: {request.subject}
99
+ Style: {request.style}
100
+ ...
101
+ Chat History: {chat_history}
102
+ User Prompt: {request.user_prompt}
103
+ """
104
+ refined_prompt = llm.invoke(prompt).content.strip()
105
+ collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
106
+ collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
107
+
108
+ # Simulate image generation
109
+ response = requests.post(
110
+ "https://api.bfl.ml/v1/flux-pro-1.1",
111
+ json={"prompt": refined_prompt}
112
+ ).json()
113
+ image_url = response["result"]["sample"]
114
+
115
+ # Download and save the image locally
116
+ image_response = requests.get(image_url)
117
+ img = Image.open(BytesIO(image_response.content))
118
+ filename = f"generated_{uuid.uuid4()}.png"
119
+ filepath = save_image_locally(img, filename)
120
+ return filepath
121
+
122
+ task = executor.submit(process_request)
123
+ background_tasks.add_task(task)
124
+ return {"status": "Processing"}
125
+
126
+ @app.post("/upscale-image", response_model=dict)
127
+ async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
128
+ def process_image():
129
+ response = requests.get(image_url)
130
+ img = Image.open(BytesIO(response.content))
131
+ upscaled_image = upscale_model.upscale_4x_overlapped(img)
132
+ filename = f"upscaled_{uuid.uuid4()}.png"
133
+ filepath = save_image_locally(upscaled_image, filename)
134
+ return filepath
135
+
136
+ task = executor.submit(process_image)
137
+ background_tasks.add_task(task)
138
+ return {"status": "Processing"}
139
+
140
+ @app.post("/set-prompt", response_model=dict)
141
+ async def set_prompt(chat_id: str, user_prompt: str):
142
+ chat_history = get_chat_history(chat_id)
143
+ refined_prompt = llm.invoke(f"{chat_history}\nUser Prompt: {user_prompt}").content.strip()
144
+ collection.insert_one({"session_id": chat_id, "role": "user", "content": user_prompt})
145
+ collection.insert_one({"session_id": chat_id, "role": "ai", "content": refined_prompt})
146
+ return {"refined_prompt": refined_prompt}