File size: 3,629 Bytes
87a6c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import os
import requests
from PIL import Image
from io import BytesIO

# Set up environment variables for API keys
os.environ['CLIPDROP_API_KEY'] = '1143a102dbe21628248d4bb992b391a49dc058c584181ea72e17c2ccd49be9ca69ccf4a2b97fc82c89ff1029578abbea'
os.environ['STABILITY_API_KEY'] = 'sk-GBmsWR78MmCSAWGkkC1CFgWgE6GPgV00pNLJlxlyZWyT3QQO'
os.environ['REPLICATE_API_TOKEN'] = '1143a102dbe21628248d4bb992b391a49dc058c584181ea72e17c2ccd49be9ca69ccf4a2b97fc82c89ff1029578abbea'

# Importing Replicate and Stability SDK libraries
import replicate
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

def upscale_image(image_path):
    # Open the image file
    with open(image_path, "rb") as img_file:
        # Run the GFPGAN model
        output = replicate.run(
            "tencentarc/gfpgan:9283608cc6b7be6b65a8e44983db012355fde4132009bf99d976b2f0896856a3",
            input={"img": img_file, "version": "v1.4", "scale": 16}
        )
        
        # The output is a URI of the processed image
        # We will retrieve the image data and save it
        response = requests.get(output)
        img = Image.open(BytesIO(response.content))

        return img

def generate_and_upscale_image(prompt):
    # Make a POST request to the ClipDrop text-to-image API
    url = 'https://clipdrop-api.co/text-to-image/v1'
    headers = {'x-api-key': os.environ['CLIPDROP_API_KEY']}
    data = {'prompt': prompt}
    response = requests.post(url, headers=headers, data=data)

    if response.status_code == 200:
        # Get the generated image from the response
        img = Image.open(BytesIO(response.content))

        # Upscale the generated image using the Stability API
        upscale_api = replicate.StabilityInference(
            key=os.environ['STABILITY_API_KEY'],
            upscale_engine="stable-diffusion-x4-latent-upscaler"
        )
        upscale_responses = upscale_api.upscale(init_image=img)

        if upscale_responses:
            # Get the upscaled image from the response
            upscaled_img = None
            for resp in upscale_responses:
                for artifact in resp.artifacts:
                    if artifact.type == generation.ARTIFACT_IMAGE:
                        upscaled_img = Image.open(BytesIO(artifact.binary))
                        break
                if upscaled_img:
                    break
            return upscaled_img
        else:
            st.error('Failed to upscale the image.')
    else:
        st.error('Failed to generate image from text prompt.')

def main():
    st.title("Image Upscaling")
    st.write("Upload an image or enter a text prompt to generate and upscale an image.")

    uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
    text_prompt = st.text_input("Enter a text prompt:", max_chars=1000)

    if uploaded_file is not None:
        with open("temp_img.png", "wb") as f:
            f.write(uploaded_file.getbuffer())
        st.success("Uploaded image successfully!")

        if st.button("Upscale Image"):
            # Upscale the uploaded image using GFPGAN
            img = upscale_image("temp_img.png")
            st.image(img, caption='Upscaled Image (GFPGAN)', use_column_width=True)

    elif text_prompt != "":
        if st.button("Generate and Upscale"):
            # Generate and upscale an image from the text prompt
            img = generate_and_upscale_image(text_prompt)
            if img:
                st.image(img, caption='Generated and Upscaled Image', use_column_width=True)
    
if __name__ == "__main__":
    main()