|
from io import BytesIO |
|
from PIL import Image |
|
import requests |
|
import time |
|
import os |
|
import boto3 |
|
|
|
|
|
S3_REGION = "fra1" |
|
S3_ACCESS_ID = "0RN7BZXS59HYSBD3VB79" |
|
S3_ACCESS_SECRET = "hfSPgBlWl5jsGHa2xuByVkSpancgVeA2CVQf2EMp" |
|
S3_ENDPOINT_URL = "https://s3.solarcom.ch" |
|
S3_BUCKET_NAME = "pissnelke" |
|
|
|
s3_session = boto3.session.Session() |
|
s3 = s3_session.client( |
|
service_name="s3", |
|
region_name=S3_REGION, |
|
aws_access_key_id=S3_ACCESS_ID, |
|
aws_secret_access_key=S3_ACCESS_SECRET, |
|
endpoint_url=S3_ENDPOINT_URL, |
|
) |
|
|
|
def get_mask_replicate(input_pil, positive_prompt, expand_by=0, negative_prompt="", replicate_api_key=""): |
|
|
|
api_endpoint = "https://api.replicate.com/v1/predictions" |
|
headers = { |
|
"Authorization": f"Token {replicate_api_key}" |
|
} |
|
|
|
s3filepath = f"target/{os.urandom(20).hex()}.png" |
|
input_buffer = BytesIO() |
|
input_pil.save(input_buffer, 'JPEG') |
|
input_buffer.seek(0) |
|
s3.put_object(Bucket=S3_BUCKET_NAME, Key=s3filepath, Body=input_buffer) |
|
|
|
|
|
data = { |
|
"version": "ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c", |
|
"input": { |
|
"image": f"{S3_ENDPOINT_URL}/{S3_BUCKET_NAME}/{s3filepath}", |
|
"mask_prompt": positive_prompt, |
|
"negative_mask_prompt": negative_prompt, |
|
"adjustment_factor": expand_by, |
|
} |
|
} |
|
|
|
|
|
response = requests.post(api_endpoint, json=data, headers=headers) |
|
response_data = response.json() |
|
|
|
print(response_data) |
|
|
|
|
|
while True: |
|
prediction_response = requests.get(f"{api_endpoint}/{response_data['id']}", headers=headers) |
|
prediction_data = prediction_response.json() |
|
|
|
if prediction_data['status'] == 'failed': |
|
raise Exception(prediction_data.get('error')) |
|
|
|
if prediction_data.get('status') == 'succeeded': |
|
output_link = prediction_data['output'][2] |
|
break |
|
|
|
time.sleep(1) |
|
|
|
|
|
output_response = requests.get(output_link) |
|
image_data = BytesIO(output_response.content) |
|
|
|
|
|
output_image = Image.open(image_data) |
|
|
|
return output_image |
|
|
|
verrueckt_pil = Image.open("sport.jpg") |
|
x = get_mask_replicate(verrueckt_pil, "bra . blouse . skirt . dress", negative_prompt="face", expand_by=10, replicate_api_key="r8_GTeyENFqfOXFAI0COiGlB2RkhqEzqS64XBuIk") |
|
x.save("hallo.png") |
|
print(x) |