Singularity666 commited on
Commit
7521b47
·
1 Parent(s): d9de62d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +55 -105
main.py CHANGED
@@ -1,132 +1,82 @@
1
  import streamlit as st
2
- import os
3
  import requests
4
- import base64
5
  from PIL import Image
6
  from io import BytesIO
7
- import replicate
 
8
  from stability_sdk import client
9
  import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
 
10
 
11
- Image.MAX_IMAGE_PIXELS = None
12
-
13
-
14
- # Configure your API keys here
15
- CLIPDROP_API_KEY = '1143a102dbe21628248d4bb992b391a49dc058c584181ea72e17c2ccd49be9ca69ccf4a2b97fc82c89ff1029578abbea'
16
- STABLE_DIFFUSION_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
17
- ESRGAN_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
18
 
19
- # Set up environment variable for Replicate API Token
20
- os.environ['REPLICATE_API_TOKEN'] = 'r8_Tm3LQMS81QaGXzzdGVRyUCOQ3cuNd1i1sJlqp' # Replace with your actual API token
21
 
22
- def generate_image_from_text(prompt):
23
- r = requests.post('https://clipdrop-api.co/text-to-image/v1',
24
- files = {
25
- 'prompt': (None, prompt, 'text/plain')
26
- },
27
- headers = { 'x-api-key': CLIPDROP_API_KEY }
28
- )
29
-
30
- if r.ok:
31
- return r.content
 
 
 
 
 
32
  else:
33
- r.raise_for_status()
34
-
35
- def upscale_image_esrgan(image_bytes):
36
- # Set up environment variables
37
- os.environ['ESRGAN_API_KEY'] = ESRGAN_API_KEY
38
-
39
- # Set up the connection to the API
40
- stability_api = client.StabilityInference(
41
- key=os.environ['ESRGAN_API_KEY'],
42
- upscale_engine="esrgan-v1-x2plus",
43
- verbose=True,
44
- )
45
 
46
- # Open the image from bytes
47
- img = Image.open(BytesIO(image_bytes))
48
-
49
- # Call the upscale API
50
  answers = stability_api.upscale(init_image=img)
51
 
52
- # Process the response
53
- upscaled_img_bytes = None
54
  for resp in answers:
55
  for artifact in resp.artifacts:
 
 
 
 
56
  if artifact.type == generation.ARTIFACT_IMAGE:
57
- upscaled_img = Image.open(BytesIO(artifact.binary))
58
- upscaled_img_bytes = BytesIO()
59
- upscaled_img.save(upscaled_img_bytes, format='PNG')
60
- upscaled_img_bytes = upscaled_img_bytes.getvalue()
61
-
62
- return upscaled_img_bytes
63
-
64
- def further_upscale_image(image_bytes):
65
- # Ensure environment variable is set correctly
66
- print("Replicate API token: ", os.environ['REPLICATE_API_TOKEN'])
67
 
68
- # Save the image bytes to a temporary file
69
- temp_file_name = "temp.png"
70
- with open(temp_file_name, 'wb') as temp_file:
71
- temp_file.write(image_bytes)
72
-
73
- # Run the GFPGAN model
74
- try:
75
- print("Running GFPGAN model...")
76
  output = replicate.run(
77
  "tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3",
78
- input={"img": open(temp_file_name, "rb"), "version": "v1.4", "scale": 16}
79
  )
80
- print("Model output: ", output)
81
- except Exception as e:
82
- print("Error running GFPGAN model: ", e)
83
- raise e
84
-
85
- # Get the image data from the output URI
86
- try:
87
- print("Fetching image data from output URI...")
88
  response = requests.get(output)
89
- except Exception as e:
90
- print("Error fetching image data from output URI: ", e)
91
- raise e
92
 
93
- # Open and save the image
94
- try:
95
- print("Saving upscaled image...")
96
- img = Image.open(BytesIO(response.content))
97
- output_file = "upscaled.png"
98
- img.save(output_file) # Save the upscaled image
99
- except Exception as e:
100
- print("Error saving upscaled image: ", e)
101
- raise e
102
 
103
- # Create a function to make download link
104
- def create_download_link(file, filename):
105
- with open(file, 'rb') as f:
106
- bytes = f.read()
107
- b64 = base64.b64encode(bytes).decode()
108
- href = f'<a href="data:file/octet-stream;base64,{b64}" download="{filename}">Download File</a>'
109
- return href
110
 
111
- return create_download_link(output_file, "upscaled_image.png")
 
 
112
 
113
- def main():
114
- st.title("Image Generation and Upscaling")
115
- st.write("Enter a text prompt and an image will be generated and upscaled.")
116
 
117
- prompt = st.text_input("Enter a textual prompt to generate an image...")
118
-
119
- if prompt:
120
- st.success("Generating image from text prompt...")
121
- image_bytes = generate_image_from_text(prompt)
122
-
123
- st.success("Upscaling image with ESRGAN...")
124
- upscaled_image_bytes = upscale_image_esrgan(image_bytes)
125
-
126
- st.success("Further upscaling image with GFPGAN...")
127
- download_link = further_upscale_image(upscaled_image_bytes)
128
-
129
- st.markdown(download_link, unsafe_allow_html=True)
130
 
131
- if __name__ == "__main__":
132
- main()
 
 
 
 
1
  import streamlit as st
 
2
  import requests
 
3
  from PIL import Image
4
  from io import BytesIO
5
+ import getpass, os
6
+ import warnings
7
  from stability_sdk import client
8
  import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
9
+ import replicate
10
 
11
+ # API keys
12
+ api_key = 'YOUR_API_KEY'
13
+ os.environ['STABILITY_KEY'] = 'YOUR_API_KEY'
14
+ os.environ['REPLICATE_API_TOKEN'] = 'REPLICATE_API_TOKEN' # Replace with your actual API token
 
 
 
15
 
16
+ # Increase the pixel limit
17
+ Image.MAX_IMAGE_PIXELS = None
18
 
19
+ # Establish connection to Stability API
20
+ stability_api = client.StabilityInference(
21
+ key=os.environ['STABILITY_KEY'],
22
+ upscale_engine="esrgan-v1-x2plus",
23
+ verbose=True,
24
+ )
25
+
26
+ # ClipDrop API function
27
+ def generate_image(prompt):
28
+ headers = {'x-api-key': api_key}
29
+ body_params = {'prompt': (None, prompt, 'text/plain')}
30
+ response = requests.post('https://clipdrop-api.co/text-to-image/v1', files=body_params, headers=headers)
31
+
32
+ if response.status_code == 200:
33
+ return Image.open(BytesIO(response.content))
34
  else:
35
+ st.write(f"Request failed with status code {response.status_code}")
36
+ return None
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Stability API function
39
+ def upscale_image_stability(img):
 
 
40
  answers = stability_api.upscale(init_image=img)
41
 
 
 
42
  for resp in answers:
43
  for artifact in resp.artifacts:
44
+ if artifact.finish_reason == generation.FILTER:
45
+ warnings.warn(
46
+ "Your request activated the API's safety filters and could not be processed."
47
+ "Please submit a different image and try again.")
48
  if artifact.type == generation.ARTIFACT_IMAGE:
49
+ return Image.open(io.BytesIO(artifact.binary))
 
 
 
 
 
 
 
 
 
50
 
51
+ # GFPGAN function
52
+ def upscale_image_gfpgan(image_path):
53
+ with open(image_path, "rb") as img_file:
 
 
 
 
 
54
  output = replicate.run(
55
  "tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3",
56
+ input={"img": img_file, "version": "v1.4", "scale": 16}
57
  )
 
 
 
 
 
 
 
 
58
  response = requests.get(output)
59
+ return Image.open(BytesIO(response.content))
 
 
60
 
61
+ # Streamlit UI
62
+ st.title("Image Generator and Upscaler")
 
 
 
 
 
 
 
63
 
64
+ prompt = st.text_input("Enter a prompt for the image generation")
 
 
 
 
 
 
65
 
66
+ if st.button("Generate and Upscale"):
67
+ if prompt:
68
+ img1 = generate_image(prompt)
69
 
70
+ if img1:
71
+ st.image(img1, caption="Generated Image", use_column_width=True)
72
+ img1.save('generated_image.png')
73
 
74
+ img2 = upscale_image_stability(img1)
75
+ st.image(img2, caption="Upscaled Image (Stability API)", use_column_width=True)
76
+ img2.save('upscaled_image_stability.png')
 
 
 
 
 
 
 
 
 
 
77
 
78
+ img3 = upscale_image_gfpgan('upscaled_image_stability.png')
79
+ st.image(img3, caption="Upscaled Image (GFPGAN)", use_column_width=True)
80
+ img3.save('upscaled_image_gfpgan.png')
81
+ else:
82
+ st.write("Please enter a prompt")