aiinference222 / test_it.py
mart9992's picture
d
0893e31
from io import BytesIO
from PIL import Image
import requests
import time
import os
import boto3
S3_REGION = "fra1"
S3_ACCESS_ID = "0RN7BZXS59HYSBD3VB79"
S3_ACCESS_SECRET = "hfSPgBlWl5jsGHa2xuByVkSpancgVeA2CVQf2EMp"
S3_ENDPOINT_URL = "https://s3.solarcom.ch"
S3_BUCKET_NAME = "pissnelke"
s3_session = boto3.session.Session()
s3 = s3_session.client(
service_name="s3",
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_ID,
aws_secret_access_key=S3_ACCESS_SECRET,
endpoint_url=S3_ENDPOINT_URL,
)
def get_mask_replicate(input_pil, positive_prompt, expand_by=0, negative_prompt="", replicate_api_key=""):
# Set up the API endpoint and headers
api_endpoint = "https://api.replicate.com/v1/predictions"
headers = {
"Authorization": f"Token {replicate_api_key}"
}
s3filepath = f"target/{os.urandom(20).hex()}.png"
input_buffer = BytesIO()
input_pil.save(input_buffer, 'JPEG') # Use the appropriate format
input_buffer.seek(0)
s3.put_object(Bucket=S3_BUCKET_NAME, Key=s3filepath, Body=input_buffer)
# Prepare the data for the POST request
data = {
"version": "ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c",
"input": {
"image": f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3filepath}",
"mask_prompt": positive_prompt,
"negative_mask_prompt": negative_prompt,
"adjustment_factor": expand_by,
}
}
# Make the initial POST request
response = requests.post(api_endpoint, json=data, headers=headers)
response_data = response.json()
print(response_data)
# Check the status of the prediction and wait for completion
while True:
prediction_response = requests.get(f"{api_endpoint}/{response_data['id']}", headers=headers)
prediction_data = prediction_response.json()
if prediction_data['status'] == 'failed':
raise Exception(prediction_data.get('error'))
if prediction_data.get('status') == 'succeeded':
output_link = prediction_data['output'][2]
break
time.sleep(1) # Avoid spamming the server, wait for a bit before the next status check
# Get the output image
output_response = requests.get(output_link)
image_data = BytesIO(output_response.content)
# Use PIL to handle the image
output_image = Image.open(image_data)
return output_image
verrueckt_pil = Image.open("sport.jpg")
x = get_mask_replicate(verrueckt_pil, "bra . blouse . skirt . dress", negative_prompt="face", expand_by=10, replicate_api_key="r8_GTeyENFqfOXFAI0COiGlB2RkhqEzqS64XBuIk")
x.save("hallo.png")
print(x)