Spaces:
Running
Running
Commit
·
9e822e4
1
Parent(s):
d2eb85a
feat: working magic replace
Browse files- agent.py +15 -1
- src/agents/mask_generation_agent.py +54 -14
- src/hopter/client.py +50 -8
- src/services/generate_mask.py +7 -7
- src/services/image_uploader.py +161 -0
- src/services/openai_file_upload.py +13 -0
agent.py
CHANGED
@@ -17,6 +17,18 @@ from pydantic_ai.messages import (
|
|
17 |
)
|
18 |
import asyncio
|
19 |
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class ChatMessage(TypedDict):
|
22 |
"""Format of messages sent to the browser/API."""
|
@@ -60,7 +72,9 @@ async def run_agent(user_input: str, image_b64: str):
|
|
60 |
]
|
61 |
deps = ImageEditDeps(
|
62 |
edit_instruction=user_input,
|
63 |
-
image_url=image_b64
|
|
|
|
|
64 |
)
|
65 |
async with mask_generation_agent.run_stream(
|
66 |
messages,
|
|
|
17 |
)
|
18 |
import asyncio
|
19 |
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
|
20 |
+
from src.hopter.client import Hopter, Environment
|
21 |
+
import os
|
22 |
+
from src.services.generate_mask import GenerateMaskService
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
|
25 |
+
load_dotenv()
|
26 |
+
|
27 |
+
hopter = Hopter(
|
28 |
+
api_key=os.getenv("HOPTER_API_KEY"),
|
29 |
+
environment=Environment.STAGING
|
30 |
+
)
|
31 |
+
mask_service = GenerateMaskService(hopter=hopter)
|
32 |
|
33 |
class ChatMessage(TypedDict):
|
34 |
"""Format of messages sent to the browser/API."""
|
|
|
72 |
]
|
73 |
deps = ImageEditDeps(
|
74 |
edit_instruction=user_input,
|
75 |
+
image_url=image_b64,
|
76 |
+
hopter_client=hopter,
|
77 |
+
mask_service=mask_service
|
78 |
)
|
79 |
async with mask_generation_agent.run_stream(
|
80 |
messages,
|
src/agents/mask_generation_agent.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from pydantic_ai import Agent, RunContext
|
2 |
-
from pydantic_ai.settings import ModelSettings
|
3 |
from pydantic_ai.models.openai import OpenAIModel
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
@@ -8,6 +7,7 @@ import base64
|
|
8 |
from dataclasses import dataclass
|
9 |
import logfire
|
10 |
from src.services.generate_mask import GenerateMaskService
|
|
|
11 |
|
12 |
load_dotenv()
|
13 |
|
@@ -23,6 +23,8 @@ if the edit instruction involved modifying parts of the image, please generate a
|
|
23 |
class ImageEditDeps:
|
24 |
edit_instruction: str
|
25 |
image_url: str
|
|
|
|
|
26 |
|
27 |
model = OpenAIModel(
|
28 |
"gpt-4o",
|
@@ -33,26 +35,53 @@ model = OpenAIModel(
|
|
33 |
class MaskGenerationResult:
|
34 |
mask_image_base64: str
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
mask_generation_agent = Agent(
|
37 |
model,
|
38 |
system_prompt=system_prompt,
|
39 |
deps_type=ImageEditDeps
|
40 |
)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
@mask_generation_agent.tool
|
43 |
-
async def
|
44 |
"""
|
45 |
-
|
46 |
"""
|
47 |
-
print("Invoking
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
async def main():
|
58 |
image_file_path = "./assets/lakeview.jpg"
|
@@ -75,12 +104,23 @@ async def main():
|
|
75 |
}
|
76 |
]
|
77 |
|
|
|
|
|
|
|
|
|
|
|
78 |
deps = ImageEditDeps(
|
79 |
edit_instruction=prompt,
|
80 |
-
image_url=image_url
|
|
|
|
|
81 |
)
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
if __name__ == "__main__":
|
|
|
1 |
from pydantic_ai import Agent, RunContext
|
|
|
2 |
from pydantic_ai.models.openai import OpenAIModel
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
|
|
7 |
from dataclasses import dataclass
|
8 |
import logfire
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
+
from src.hopter.client import Hopter, Environment, MagicReplaceInput
|
11 |
|
12 |
load_dotenv()
|
13 |
|
|
|
23 |
class ImageEditDeps:
|
24 |
edit_instruction: str
|
25 |
image_url: str
|
26 |
+
hopter_client: Hopter
|
27 |
+
mask_service: GenerateMaskService
|
28 |
|
29 |
model = OpenAIModel(
|
30 |
"gpt-4o",
|
|
|
35 |
class MaskGenerationResult:
|
36 |
mask_image_base64: str
|
37 |
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class EditImageResult:
|
41 |
+
edited_image_base64: str
|
42 |
+
|
43 |
mask_generation_agent = Agent(
|
44 |
model,
|
45 |
system_prompt=system_prompt,
|
46 |
deps_type=ImageEditDeps
|
47 |
)
|
48 |
|
49 |
+
# @mask_generation_agent.tool
|
50 |
+
# async def generate_mask(ctx: RunContext[ImageEditDeps]) -> MaskGenerationResult:
|
51 |
+
# """
|
52 |
+
# Generate a mask for the image editing instruction.
|
53 |
+
# """
|
54 |
+
# print("Invoking generate_mask tool")
|
55 |
+
# service = GenerateMaskService()
|
56 |
+
# mask_instruction = await service.get_mask_generation_instruction(ctx.deps.edit_instruction, ctx.deps.image_url)
|
57 |
+
# response = mask_instruction.model_dump_json(indent=4)
|
58 |
+
# print(f"generate_mask tool response: {response}")
|
59 |
+
|
60 |
+
# mask = await service.generate_mask(mask_instruction, ctx.deps.image_url)
|
61 |
+
# print("Exiting generate_mask tool")
|
62 |
+
# return MaskGenerationResult(mask_image_base64=mask)
|
63 |
+
|
64 |
@mask_generation_agent.tool
|
65 |
+
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
66 |
"""
|
67 |
+
Edit an object in the image.
|
68 |
"""
|
69 |
+
print("Invoking edit_object tool")
|
70 |
+
edit_instruction = ctx.deps.edit_instruction
|
71 |
+
image_url = ctx.deps.image_url
|
72 |
+
mask_service = ctx.deps.mask_service
|
73 |
+
hopter_client = ctx.deps.hopter_client
|
74 |
+
|
75 |
+
# Generate mask
|
76 |
+
print("Generating mask")
|
77 |
+
mask_instruction = await mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
78 |
+
mask = await mask_service.generate_mask(mask_instruction, image_url)
|
79 |
|
80 |
+
# Magic replace
|
81 |
+
input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
|
82 |
+
result = await hopter_client.magic_replace(input)
|
83 |
+
print("Exiting edit_object tool: ", result)
|
84 |
+
return EditImageResult(edited_image_base64=result.base64_image)
|
85 |
|
86 |
async def main():
|
87 |
image_file_path = "./assets/lakeview.jpg"
|
|
|
104 |
}
|
105 |
]
|
106 |
|
107 |
+
# Initialize services
|
108 |
+
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
109 |
+
mask_service = GenerateMaskService(hopter=hopter)
|
110 |
+
|
111 |
+
# Initialize dependencies
|
112 |
deps = ImageEditDeps(
|
113 |
edit_instruction=prompt,
|
114 |
+
image_url=image_url,
|
115 |
+
hopter_client=hopter,
|
116 |
+
mask_service=mask_service
|
117 |
)
|
118 |
+
async with mask_generation_agent.run_stream(
|
119 |
+
messages,
|
120 |
+
deps=deps
|
121 |
+
) as result:
|
122 |
+
async for message in result.stream():
|
123 |
+
print(message)
|
124 |
|
125 |
|
126 |
if __name__ == "__main__":
|
src/hopter/client.py
CHANGED
@@ -32,6 +32,14 @@ class RamGroundedSamResult(BaseModel):
|
|
32 |
confidence: float = Field(..., description="The confidence score of the mask.")
|
33 |
bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
class Hopter:
|
36 |
def __init__(
|
37 |
self,
|
@@ -64,7 +72,45 @@ class Hopter:
|
|
64 |
except Exception as exc:
|
65 |
print(f"An unexpected error occurred: {exc}")
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
async def main():
|
70 |
hopter = Hopter(
|
@@ -77,12 +123,8 @@ async def main():
|
|
77 |
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
78 |
image_url = f"data:image/jpeg;base64,{image_base64}"
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
)
|
84 |
-
mask = await hopter.generate_mask(input)
|
85 |
-
print(mask)
|
86 |
-
|
87 |
if __name__ == "__main__":
|
88 |
asyncio.run(main())
|
|
|
32 |
confidence: float = Field(..., description="The confidence score of the mask.")
|
33 |
bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
|
34 |
|
35 |
+
class MagicReplaceInput(BaseModel):
|
36 |
+
image: str = Field(..., description="The image in base64 format.")
|
37 |
+
mask: str = Field(..., description="The mask in base64 format.")
|
38 |
+
prompt: str = Field(..., description="The prompt for the magic replace.")
|
39 |
+
|
40 |
+
class MagicReplaceResult(BaseModel):
|
41 |
+
base64_image: str = Field(..., description="The edited image in base64 format.")
|
42 |
+
|
43 |
class Hopter:
|
44 |
def __init__(
|
45 |
self,
|
|
|
72 |
except Exception as exc:
|
73 |
print(f"An unexpected error occurred: {exc}")
|
74 |
|
75 |
+
async def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
|
76 |
+
print(f"Magic replacing with input: {input.prompt}")
|
77 |
+
try:
|
78 |
+
response = await self.client.post(
|
79 |
+
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
|
80 |
+
headers={
|
81 |
+
"Authorization": f"Bearer {self.api_key}",
|
82 |
+
"Content-Type": "application/json"
|
83 |
+
},
|
84 |
+
json={
|
85 |
+
"input": input.model_dump()
|
86 |
+
}
|
87 |
+
)
|
88 |
+
print(response)
|
89 |
+
response.raise_for_status() # Raise an error for bad responses
|
90 |
+
instance = response.json().get("output")
|
91 |
+
print("Magic replaced.")
|
92 |
+
return MagicReplaceResult(**instance)
|
93 |
+
except httpx.HTTPStatusError as exc:
|
94 |
+
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
|
95 |
+
except Exception as exc:
|
96 |
+
print(f"An unexpected error occurred: {exc}")
|
97 |
+
|
98 |
+
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
|
99 |
+
input = RamGroundedSamInput(
|
100 |
+
text_prompt="pole",
|
101 |
+
image_b64=image_url
|
102 |
+
)
|
103 |
+
mask = await hopter.generate_mask(input)
|
104 |
+
return mask.mask_b64
|
105 |
+
|
106 |
+
async def test_magic_replace(hopter: Hopter, image_url: str, mask: str, prompt: str) -> str:
|
107 |
+
input = MagicReplaceInput(
|
108 |
+
image=image_url,
|
109 |
+
mask=mask,
|
110 |
+
prompt=prompt
|
111 |
+
)
|
112 |
+
result = await hopter.magic_replace(input)
|
113 |
+
return result.base64_image
|
114 |
|
115 |
async def main():
|
116 |
hopter = Hopter(
|
|
|
123 |
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
124 |
image_url = f"data:image/jpeg;base64,{image_base64}"
|
125 |
|
126 |
+
mask = await test_generate_mask(hopter, image_url)
|
127 |
+
magic_replace_result = await test_magic_replace(hopter, image_url, mask, "remove the pole")
|
128 |
+
print(magic_replace_result)
|
|
|
|
|
|
|
|
|
129 |
if __name__ == "__main__":
|
130 |
asyncio.run(main())
|
src/services/generate_mask.py
CHANGED
@@ -6,6 +6,8 @@ import base64
|
|
6 |
import asyncio
|
7 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
8 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
|
|
|
|
9 |
load_dotenv()
|
10 |
|
11 |
system_prompt = """
|
@@ -38,13 +40,13 @@ Do not output 'sorry, xxx', even if it's a guess, directly output the answer you
|
|
38 |
"""
|
39 |
|
40 |
class GenerateMaskService:
|
41 |
-
def __init__(self):
|
42 |
self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
43 |
-
self.hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
44 |
self.model = "gpt-4o"
|
|
|
|
|
45 |
|
46 |
async def get_mask_generation_instruction(self, edit_instruction: str, image_url: str) -> GenerateMaskInstruction:
|
47 |
-
|
48 |
messages = [
|
49 |
{
|
50 |
"role": "system",
|
@@ -93,13 +95,11 @@ class GenerateMaskService:
|
|
93 |
return generate_mask_result.mask_b64
|
94 |
|
95 |
async def main():
|
96 |
-
service = GenerateMaskService()
|
97 |
edit_instruction = "remove the light post"
|
98 |
image_file_path = "./assets/lakeview.jpg"
|
99 |
with open(image_file_path, "rb") as image_file:
|
100 |
-
|
101 |
-
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
102 |
-
image_url = f"data:image/jpeg;base64,{image_base64}"
|
103 |
|
104 |
instruction = await service.get_mask_generation_instruction(edit_instruction, image_url)
|
105 |
print(instruction)
|
|
|
6 |
import asyncio
|
7 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
8 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
9 |
+
from src.services.openai_file_upload import OpenAIFileUpload
|
10 |
+
|
11 |
load_dotenv()
|
12 |
|
13 |
system_prompt = """
|
|
|
40 |
"""
|
41 |
|
42 |
class GenerateMaskService:
|
43 |
+
def __init__(self, hopter: Hopter):
|
44 |
self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
45 |
self.model = "gpt-4o"
|
46 |
+
self.openai_file_upload = OpenAIFileUpload()
|
47 |
+
self.hopter = hopter
|
48 |
|
49 |
async def get_mask_generation_instruction(self, edit_instruction: str, image_url: str) -> GenerateMaskInstruction:
|
|
|
50 |
messages = [
|
51 |
{
|
52 |
"role": "system",
|
|
|
95 |
return generate_mask_result.mask_b64
|
96 |
|
97 |
async def main():
|
98 |
+
service = GenerateMaskService(Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING))
|
99 |
edit_instruction = "remove the light post"
|
100 |
image_file_path = "./assets/lakeview.jpg"
|
101 |
with open(image_file_path, "rb") as image_file:
|
102 |
+
image_url = service.openai_file_upload.upload_image(image_file.read(), "vision")
|
|
|
|
|
103 |
|
104 |
instruction = await service.get_mask_generation_instruction(edit_instruction, image_url)
|
105 |
print(instruction)
|
src/services/image_uploader.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from typing import Optional, Union
|
3 |
+
import base64
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
class ImageInfo(BaseModel):
|
9 |
+
filename: str
|
10 |
+
name: str
|
11 |
+
mime: str
|
12 |
+
extension: str
|
13 |
+
url: str
|
14 |
+
|
15 |
+
class ImgBBData(BaseModel):
|
16 |
+
id: str
|
17 |
+
title: str
|
18 |
+
url_viewer: str
|
19 |
+
url: str
|
20 |
+
display_url: str
|
21 |
+
width: int
|
22 |
+
height: int
|
23 |
+
size: int
|
24 |
+
time: int
|
25 |
+
expiration: int
|
26 |
+
image: ImageInfo
|
27 |
+
thumb: ImageInfo
|
28 |
+
medium: ImageInfo
|
29 |
+
delete_url: str
|
30 |
+
|
31 |
+
class ImgBBResponse(BaseModel):
|
32 |
+
data: ImgBBData
|
33 |
+
success: bool
|
34 |
+
status: int
|
35 |
+
|
36 |
+
class ImageUploader:
|
37 |
+
"""A class to handle image uploads to ImgBB service."""
|
38 |
+
|
39 |
+
def __init__(self, api_key: str):
|
40 |
+
"""
|
41 |
+
Initialize the ImageUploader with an API key.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
api_key (str): The ImgBB API key
|
45 |
+
"""
|
46 |
+
self.api_key = api_key
|
47 |
+
self.base_url = "https://api.imgbb.com/1/upload"
|
48 |
+
|
49 |
+
def upload(
|
50 |
+
self,
|
51 |
+
image: Union[str, bytes, Path],
|
52 |
+
name: Optional[str] = None,
|
53 |
+
expiration: Optional[int] = None
|
54 |
+
) -> ImgBBResponse:
|
55 |
+
"""
|
56 |
+
Upload an image to ImgBB.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
image: Can be:
|
60 |
+
- A file path (str or Path)
|
61 |
+
- Base64 encoded string
|
62 |
+
- Base64 data URI (e.g., data:image/jpeg;base64,...)
|
63 |
+
- URL to an image
|
64 |
+
- Bytes of an image
|
65 |
+
name: Optional name for the uploaded file
|
66 |
+
expiration: Optional expiration time in seconds (60-15552000)
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
ImgBBResponse containing the parsed upload response from ImgBB
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
ValueError: If the image format is invalid or upload fails
|
73 |
+
requests.RequestException: If the API request fails
|
74 |
+
"""
|
75 |
+
# Prepare the parameters
|
76 |
+
params = {'key': self.api_key}
|
77 |
+
if expiration:
|
78 |
+
if not 60 <= expiration <= 15552000:
|
79 |
+
raise ValueError("Expiration must be between 60 and 15552000 seconds")
|
80 |
+
params['expiration'] = expiration
|
81 |
+
|
82 |
+
# Handle different image input types
|
83 |
+
if isinstance(image, (str, Path)):
|
84 |
+
image_str = str(image)
|
85 |
+
files = {}
|
86 |
+
if os.path.isfile(image_str):
|
87 |
+
# It's a file path
|
88 |
+
with open(image_str, 'rb') as file:
|
89 |
+
files['image'] = file
|
90 |
+
elif image_str.startswith(('http://', 'https://')):
|
91 |
+
# It's a URL
|
92 |
+
files['image'] = (None, image_str)
|
93 |
+
elif image_str.startswith('data:image/'):
|
94 |
+
# It's a data URI
|
95 |
+
# Extract the base64 part after the comma
|
96 |
+
base64_data = image_str.split(',', 1)[1]
|
97 |
+
files['image'] = (None, base64_data)
|
98 |
+
else:
|
99 |
+
# Assume it's base64 data
|
100 |
+
files['image'] = (None, image_str)
|
101 |
+
|
102 |
+
if name:
|
103 |
+
files['name'] = (None, name)
|
104 |
+
response = requests.post(self.base_url, params=params, files=files)
|
105 |
+
elif isinstance(image, bytes):
|
106 |
+
# Convert bytes to base64
|
107 |
+
base64_image = base64.b64encode(image).decode('utf-8')
|
108 |
+
files = {
|
109 |
+
'image': (None, base64_image)
|
110 |
+
}
|
111 |
+
if name:
|
112 |
+
files['name'] = (None, name)
|
113 |
+
response = requests.post(self.base_url, params=params, files=files)
|
114 |
+
else:
|
115 |
+
raise ValueError("Invalid image format. Must be file path, URL, base64 string, or bytes")
|
116 |
+
|
117 |
+
# Check the response
|
118 |
+
if response.status_code != 200:
|
119 |
+
raise ValueError(f"Upload failed with status {response.status_code}: {response.text}")
|
120 |
+
|
121 |
+
# Parse the response using Pydantic model
|
122 |
+
response_json = response.json()
|
123 |
+
return ImgBBResponse.parse_obj(response_json)
|
124 |
+
|
125 |
+
def upload_file(
|
126 |
+
self,
|
127 |
+
file_path: Union[str, Path],
|
128 |
+
name: Optional[str] = None,
|
129 |
+
expiration: Optional[int] = None
|
130 |
+
) -> ImgBBResponse:
|
131 |
+
"""
|
132 |
+
Convenience method to upload an image file.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
file_path: Path to the image file
|
136 |
+
name: Optional name for the uploaded file
|
137 |
+
expiration: Optional expiration time in seconds (60-15552000)
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
ImgBBResponse containing the parsed upload response from ImgBB
|
141 |
+
"""
|
142 |
+
return self.upload(file_path, name=name, expiration=expiration)
|
143 |
+
|
144 |
+
def upload_url(
|
145 |
+
self,
|
146 |
+
image_url: str,
|
147 |
+
name: Optional[str] = None,
|
148 |
+
expiration: Optional[int] = None
|
149 |
+
) -> ImgBBResponse:
|
150 |
+
"""
|
151 |
+
Convenience method to upload an image from a URL.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
image_url: URL of the image to upload
|
155 |
+
name: Optional name for the uploaded file
|
156 |
+
expiration: Optional expiration time in seconds (60-15552000)
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
ImgBBResponse containing the parsed upload response from ImgBB
|
160 |
+
"""
|
161 |
+
return self.upload(image_url, name=name, expiration=expiration)
|
src/services/openai_file_upload.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import os
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
class OpenAIFileUpload:
|
8 |
+
def __init__(self):
|
9 |
+
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
10 |
+
|
11 |
+
def upload_image(self, image, purpose) -> str:
|
12 |
+
file = self.client.files.create(file=image, purpose=purpose)
|
13 |
+
return file.id
|