Singularity666's picture
Update main.py
e02283a
raw
history blame
4.49 kB
import streamlit as st
import os
import requests
import base64
from PIL import Image
from io import BytesIO
import replicate
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
Image.MAX_IMAGE_PIXELS = None
# Configure your API keys here
CLIPDROP_API_KEY = '1143a102dbe21628248d4bb992b391a49dc058c584181ea72e17c2ccd49be9ca69ccf4a2b97fc82c89ff1029578abbea'
STABLE_DIFFUSION_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
ESRGAN_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
# Set up environment variable for Replicate API Token
os.environ['REPLICATE_API_TOKEN'] = 'r8_Tm3LQMS81QaGXzzdGVRyUCOQ3cuNd1i1sJlqp' # Replace with your actual API token
def generate_image_from_text(prompt):
r = requests.post('https://clipdrop-api.co/text-to-image/v1',
files = {
'prompt': (None, prompt, 'text/plain')
},
headers = { 'x-api-key': CLIPDROP_API_KEY }
)
if r.ok:
return r.content
else:
r.raise_for_status()
def upscale_image_esrgan(image_bytes):
# Set up environment variables
os.environ['ESRGAN_API_KEY'] = ESRGAN_API_KEY
# Set up the connection to the API
stability_api = client.StabilityInference(
key=os.environ['ESRGAN_API_KEY'],
upscale_engine="esrgan-v1-x2plus",
verbose=True,
)
# Open the image from bytes
img = Image.open(BytesIO(image_bytes))
# Call the upscale API
answers = stability_api.upscale(init_image=img)
# Process the response
upscaled_img_bytes = None
for resp in answers:
for artifact in resp.artifacts:
if artifact.type == generation.ARTIFACT_IMAGE:
upscaled_img = Image.open(BytesIO(artifact.binary))
upscaled_img_bytes = BytesIO()
upscaled_img.save(upscaled_img_bytes, format='PNG')
upscaled_img_bytes = upscaled_img_bytes.getvalue()
return upscaled_img_bytes
def further_upscale_image(image_bytes):
# Ensure environment variable is set correctly
print("Replicate API token: ", os.environ['REPLICATE_API_TOKEN'])
# Save the image bytes to a temporary file
temp_file_name = "temp.png"
with open(temp_file_name, 'wb') as temp_file:
temp_file.write(image_bytes)
# Run the GFPGAN model
try:
print("Running GFPGAN model...")
output = replicate.run(
"tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3",
input={"img": open(temp_file_name, "rb"), "version": "v1.4", "scale": 16}
)
print("Model output: ", output)
except Exception as e:
print("Error running GFPGAN model: ", e)
raise e
# Get the image data from the output URI
try:
print("Fetching image data from output URI...")
response = requests.get(output)
except Exception as e:
print("Error fetching image data from output URI: ", e)
raise e
# Open and save the image
try:
print("Saving upscaled image...")
img = Image.open(BytesIO(response.content))
output_file = "upscaled.png"
img.save(output_file) # Save the upscaled image
except Exception as e:
print("Error saving upscaled image: ", e)
raise e
# Create a function to make download link
def create_download_link(file, filename):
with open(file, 'rb') as f:
bytes = f.read()
b64 = base64.b64encode(bytes).decode()
href = f'<a href="data:file/octet-stream;base64,{b64}" download="{filename}">Download File</a>'
return href
return create_download_link(output_file, "upscaled_image.png")
def main():
st.title("Image Generation and Upscaling")
st.write("Enter a text prompt and an image will be generated and upscaled.")
prompt = st.text_input("Enter a textual prompt to generate an image...")
if prompt:
st.success("Generating image from text prompt...")
image_bytes = generate_image_from_text(prompt)
st.success("Upscaling image with ESRGAN...")
upscaled_image_bytes = upscale_image_esrgan(image_bytes)
st.success("Further upscaling image with GFPGAN...")
download_link = further_upscale_image(upscaled_image_bytes)
st.markdown(download_link, unsafe_allow_html=True)
if __name__ == "__main__":
main()