File size: 4,494 Bytes
3e7d217
33c27ec
 
3f00a5c
33c27ec
 
9184aae
30b12a2
 
9184aae
e02283a
 
 
9184aae
 
 
33ddfbe
33c27ec
9184aae
680ecae
9184aae
 
 
33c27ec
9184aae
 
 
 
 
 
 
 
 
 
bad15c0
3e7d217
bad15c0
3e7d217
 
30b12a2
bad15c0
 
30b12a2
 
3e7d217
 
30b12a2
3e7d217
 
30b12a2
3e7d217
 
30b12a2
 
 
 
 
 
 
 
3e7d217
30b12a2
33c27ec
9184aae
5c84876
 
 
 
 
 
 
 
3e7d217
5c84876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f00a5c
 
5c84876
 
 
 
3f00a5c
 
 
 
 
 
 
33c27ec
3f00a5c
5c84876
3e7d217
 
 
 
 
 
 
 
 
 
bad15c0
74c2727
3e7d217
 
3f00a5c
3e7d217
3f00a5c
33c27ec
 
3f00a5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
import os
import requests
import base64
from PIL import Image
from io import BytesIO
import replicate
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

Image.MAX_IMAGE_PIXELS = None


# Configure your API keys here
CLIPDROP_API_KEY = '1143a102dbe21628248d4bb992b391a49dc058c584181ea72e17c2ccd49be9ca69ccf4a2b97fc82c89ff1029578abbea'
STABLE_DIFFUSION_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
ESRGAN_API_KEY = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'

# Set up environment variable for Replicate API Token
os.environ['REPLICATE_API_TOKEN'] = 'r8_Tm3LQMS81QaGXzzdGVRyUCOQ3cuNd1i1sJlqp'  # Replace with your actual API token

def generate_image_from_text(prompt):
    r = requests.post('https://clipdrop-api.co/text-to-image/v1',
        files = {
            'prompt': (None, prompt, 'text/plain')
        },
        headers = { 'x-api-key': CLIPDROP_API_KEY }
    )
    
    if r.ok:
        return r.content
    else:
        r.raise_for_status()

def upscale_image_esrgan(image_bytes):
    # Set up environment variables
    os.environ['ESRGAN_API_KEY'] = ESRGAN_API_KEY

    # Set up the connection to the API
    stability_api = client.StabilityInference(
        key=os.environ['ESRGAN_API_KEY'],
        upscale_engine="esrgan-v1-x2plus",
        verbose=True,
    )

    # Open the image from bytes
    img = Image.open(BytesIO(image_bytes))

    # Call the upscale API
    answers = stability_api.upscale(init_image=img)

    # Process the response
    upscaled_img_bytes = None
    for resp in answers:
        for artifact in resp.artifacts:
            if artifact.type == generation.ARTIFACT_IMAGE:
                upscaled_img = Image.open(BytesIO(artifact.binary))
                upscaled_img_bytes = BytesIO()
                upscaled_img.save(upscaled_img_bytes, format='PNG')
                upscaled_img_bytes = upscaled_img_bytes.getvalue()
    
    return upscaled_img_bytes

def further_upscale_image(image_bytes):
    # Ensure environment variable is set correctly
    print("Replicate API token: ", os.environ['REPLICATE_API_TOKEN'])

    # Save the image bytes to a temporary file
    temp_file_name = "temp.png"
    with open(temp_file_name, 'wb') as temp_file:
        temp_file.write(image_bytes)

    # Run the GFPGAN model
    try:
        print("Running GFPGAN model...")
        output = replicate.run(
            "tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3",
            input={"img": open(temp_file_name, "rb"), "version": "v1.4", "scale": 16}
        )
        print("Model output: ", output)
    except Exception as e:
        print("Error running GFPGAN model: ", e)
        raise e

    # Get the image data from the output URI
    try:
        print("Fetching image data from output URI...")
        response = requests.get(output)
    except Exception as e:
        print("Error fetching image data from output URI: ", e)
        raise e

    # Open and save the image
    try:
        print("Saving upscaled image...")
        img = Image.open(BytesIO(response.content))
        output_file = "upscaled.png"
        img.save(output_file)  # Save the upscaled image
    except Exception as e:
        print("Error saving upscaled image: ", e)
        raise e

    # Create a function to make download link
    def create_download_link(file, filename):
        with open(file, 'rb') as f:
            bytes = f.read()
            b64 = base64.b64encode(bytes).decode()
            href = f'<a href="data:file/octet-stream;base64,{b64}" download="{filename}">Download File</a>'
            return href

    return create_download_link(output_file, "upscaled_image.png")

def main():
    st.title("Image Generation and Upscaling")
    st.write("Enter a text prompt and an image will be generated and upscaled.")

    prompt = st.text_input("Enter a textual prompt to generate an image...")
    
    if prompt:
        st.success("Generating image from text prompt...")
        image_bytes = generate_image_from_text(prompt)
        
        st.success("Upscaling image with ESRGAN...")
        upscaled_image_bytes = upscale_image_esrgan(image_bytes)
        
        st.success("Further upscaling image with GFPGAN...")
        download_link = further_upscale_image(upscaled_image_bytes)
        
        st.markdown(download_link, unsafe_allow_html=True)

if __name__ == "__main__":
    main()