Sriv890 commited on
Commit
bd5d4e1
·
verified ·
1 Parent(s): 7cd2bf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -49
app.py CHANGED
@@ -21,23 +21,60 @@ from safetensors.torch import load_file
21
 
22
 
23
  # Replace with your actual API key
24
- api_key = "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"
25
- client = Groq(api_key=api_key)
 
 
 
26
 
27
  # Load the custom model for image generation
28
- base = "stabilityai/stable-diffusion-xl-base-1.0"
29
- repo = "ByteDance/SDXL-Lightning"
30
- ckpt = "sdxl_lightning_4step_unet.safetensors" # Ensure the correct checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Load the custom UNet and set up the pipeline
33
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cpu", torch.float16)
34
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cpu"))
35
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cpu")
36
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
37
 
38
- os.environ['HF_API_KEY']
39
- api_key = os.getenv('HF_API_KEY')
40
- API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Function to transcribe, translate, and generate an image
43
  def process_audio(audio_path, generate_image):
@@ -68,7 +105,8 @@ def process_audio(audio_path, generate_image):
68
  if generate_image:
69
  try:
70
  # Use the custom model and pipeline to generate an image
71
- img = pipe(translation, num_inference_steps=4, guidance_scale=0).images[0]
 
72
  return tamil_text, translation, img
73
  except Exception as e:
74
  return tamil_text, translation, f"An error occurred during image generation: {str(e)}"
@@ -76,48 +114,27 @@ def process_audio(audio_path, generate_image):
76
  return tamil_text, translation, None
77
 
78
 
79
- def query(payload, max_retries=5):
80
- for attempt in range(max_retries):
81
- response = requests.post(API_URL, headers=headers, json=payload)
82
-
83
- if response.status_code == 503:
84
- print(f"Model is still loading, retrying... Attempt {attempt + 1}/{max_retries}")
85
- estimated_time = min(response.json().get("estimated_time", 60), 60)
86
- time.sleep(estimated_time)
87
- continue
88
 
89
- if response.status_code != 200:
90
- print(f"Error: Received status code {response.status_code}")
91
- print(f"Response: {response.text}")
92
- return None
93
 
94
- return response.content
 
 
 
95
 
96
- print(f"Failed to generate image after {max_retries} attempts.")
97
- return None
98
 
99
  # Function for direct prompt to image generation
100
- def generate_image_from_prompt(prompt):
101
- image_bytes = query({"inputs": prompt})
102
 
103
- if image_bytes is None:
104
- error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
105
- d = ImageDraw.Draw(error_img)
106
- d.text((10, 150), "Image Generation Failed", fill=(255, 255, 255))
107
- return error_img
108
 
109
- # Debug: Check the type of image_bytes
110
- print(f"Received image_bytes type: {type(image_bytes)}")
111
-
112
- try:
113
- image = Image.open(io.BytesIO(image_bytes))
114
- return image
115
- except Exception as e:
116
- print(f"Error while opening image: {e}")
117
- error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
118
- d = ImageDraw.Draw(error_img)
119
- d.text((10, 150), "Invalid Image Data", fill=(255, 255, 255))
120
- return error_img
121
 
122
 
123
  # Assuming your 'process_audio' and 'generate_image_from_prompt' functions are defined elsewhere
@@ -171,5 +188,21 @@ with gr.Blocks(css="""
171
  # Bind the correct function that returns an image
172
  btn_image.click(fn=generate_image_from_prompt, inputs=prompt_input, outputs=image_output)
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Launch the interface
175
  iface.launch(server_name="0.0.0.0")
 
21
 
22
 
23
  # Replace with your actual API key
24
+ os.environ['hugging']
25
+ H_key = os.getenv('hugging')
26
+ API_URL = "https://api-inference.huggingface.co/models/Artples/LAI-ImageGeneration-vSDXL-2"
27
+ headers = {"Authorization": f"Bearer {H_key}"}
28
+
29
 
30
  # Load the custom model for image generation
31
+ # base = "stabilityai/stable-diffusion-xl-base-1.0"
32
+ # repo = "ByteDance/SDXL-Lightning"
33
+ # ckpt = "sdxl_lightning_4step_unet.safetensors" # Ensure the correct checkpoint
34
+
35
+ # # Load the custom UNet and set up the pipeline
36
+ # unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cpu", torch.float16)
37
+ # unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cpu"))
38
+ # pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cpu")
39
+ # pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
40
+
41
+ #key groq
42
+ os.environ['groq']
43
+ api_key = os.getenv('groq')
44
+ client = Groq(api_key=api_key)
45
+
46
+ def query(payload, max_retries=5):
47
+ for attempt in range(max_retries):
48
+ response = requests.post(API_URL, headers=headers, json=payload)
49
 
50
+ if response.status_code == 503:
51
+ print(f"Model is still loading, retrying... Attempt {attempt + 1}/{max_retries}")
52
+ estimated_time = min(response.json().get("estimated_time", 60), 60)
53
+ time.sleep(estimated_time)
54
+ continue
55
 
56
+ if response.status_code != 200:
57
+ print(f"Error: Received status code {response.status_code}")
58
+ print(f"Response: {response.text}")
59
+ return None
60
+
61
+ return response.content
62
+
63
+ print(f"Failed to generate image after {max_retries} attempts.")
64
+ return None
65
+
66
+ def generate_image_from_prompt(prompt):
67
+ image_bytes = query({"inputs": prompt})
68
+
69
+ if image_bytes is None:
70
+ return None
71
+
72
+ try:
73
+ image = Image.open(io.BytesIO(image_bytes)) # Opening the image from bytes
74
+ return image
75
+ except Exception as e:
76
+ print(f"Error: {e}")
77
+ return None
78
 
79
  # Function to transcribe, translate, and generate an image
80
  def process_audio(audio_path, generate_image):
 
105
  if generate_image:
106
  try:
107
  # Use the custom model and pipeline to generate an image
108
+ #img = pipe(translation, num_inference_steps=4, guidance_scale=0).images[0]
109
+ img=generate_image_from_prompt(translation)
110
  return tamil_text, translation, img
111
  except Exception as e:
112
  return tamil_text, translation, f"An error occurred during image generation: {str(e)}"
 
114
  return tamil_text, translation, None
115
 
116
 
117
+ def chatbox(prompt):
118
+ try:
119
+ chat_completion = client.chat.completions.create(
120
+ messages=[{"role": "user", "content": prompt}],
121
+ model="llama-3.2-90b-text-preview"
122
+ )
123
+ chatbot_response = chat_completion.choices[0].message.content
 
 
124
 
125
+ except Exception as e:
126
+ return f"An error occurred during chatbot interaction: {str(e)}", None
 
 
127
 
128
+ try:
129
+ img=generate_image_from_prompt(prompt)
130
+ except Exception as e:
131
+ return chatbot_response, None
132
 
133
+ return chatbot_response, img
 
134
 
135
  # Function for direct prompt to image generation
 
 
136
 
 
 
 
 
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  # Assuming your 'process_audio' and 'generate_image_from_prompt' functions are defined elsewhere
 
188
  # Bind the correct function that returns an image
189
  btn_image.click(fn=generate_image_from_prompt, inputs=prompt_input, outputs=image_output)
190
 
191
+ #third tab: Direct prompt
192
+ with gr.Tab("Chatbot - image generation"):
193
+ gr.Markdown("<h2 style='text-align: center; color:black;'>Input a prompt and generate an image</h2>")
194
+
195
+ prompt_input=gr.Textbox(label="Enter Prompt", placeholder="Enter the scene description here...", lines=2)
196
+ # Image output
197
+ output = [
198
+ gr.Textbox(label="Chatbot - response"),
199
+ gr.Image(label="Generated Image") # Expecting an image output
200
+ ]
201
+ # Expecting an image output
202
+ # chatbox_output =
203
+ btn_image = gr.Button("Chatbot Response Generation", elem_classes="btn-red")
204
+ # Bind the correct function that returns an image
205
+ btn_image.click(fn=chatbox, inputs=prompt_input, outputs=output)
206
+
207
  # Launch the interface
208
  iface.launch(server_name="0.0.0.0")