Spaces:
Running
Running
Commit
·
ad3aed5
1
Parent(s):
24c3cc8
feat: fix chat
Browse files- gradio_chat.py +6 -6
- gradio_demo.py +4 -7
- poetry.lock +140 -1
- pyproject.toml +1 -0
- src/agents/mask_generation_agent.py +22 -9
- src/services/generate_mask.py +3 -2
- src/services/google_cloud_image_upload.py +65 -0
- src/utils.py +39 -1
gradio_chat.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
-
from src.utils import image_path_to_uri
|
8 |
from pydantic_ai.messages import (
|
9 |
ToolCallPart,
|
10 |
ToolReturnPart
|
@@ -46,9 +46,7 @@ def build_user_message(chat_input):
|
|
46 |
|
47 |
def build_messages_for_agent(chat_input, past_messages):
|
48 |
# filter out image messages from past messages to save on tokens
|
49 |
-
messages =
|
50 |
-
if not (isinstance(msg, dict)
|
51 |
-
and msg.get("type") == "image_url")]
|
52 |
|
53 |
# add the user's text message
|
54 |
if chat_input["text"]:
|
@@ -59,7 +57,7 @@ def build_messages_for_agent(chat_input, past_messages):
|
|
59 |
|
60 |
# add the user's image message
|
61 |
files = chat_input.get("files", [])
|
62 |
-
image_url =
|
63 |
if image_url:
|
64 |
messages.append({
|
65 |
"type": "image_url",
|
@@ -77,7 +75,7 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
|
77 |
# Prepare messages for the agent
|
78 |
text = chat_input["text"]
|
79 |
files = chat_input.get("files", [])
|
80 |
-
image_url =
|
81 |
messages = [
|
82 |
{
|
83 |
"type": "text",
|
@@ -99,10 +97,12 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
|
99 |
hopter_client=hopter,
|
100 |
mask_service=mask_service
|
101 |
)
|
|
|
102 |
# Run the agent
|
103 |
async with mask_generation_agent.run_stream(
|
104 |
messages,
|
105 |
deps=deps,
|
|
|
106 |
) as result:
|
107 |
for message in result.new_messages():
|
108 |
for call in message.parts:
|
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
+
from src.utils import image_path_to_uri, upload_image
|
8 |
from pydantic_ai.messages import (
|
9 |
ToolCallPart,
|
10 |
ToolReturnPart
|
|
|
46 |
|
47 |
def build_messages_for_agent(chat_input, past_messages):
|
48 |
# filter out image messages from past messages to save on tokens
|
49 |
+
messages = past_messages
|
|
|
|
|
50 |
|
51 |
# add the user's text message
|
52 |
if chat_input["text"]:
|
|
|
57 |
|
58 |
# add the user's image message
|
59 |
files = chat_input.get("files", [])
|
60 |
+
image_url = upload_image(files[0]) if files else None
|
61 |
if image_url:
|
62 |
messages.append({
|
63 |
"type": "image_url",
|
|
|
75 |
# Prepare messages for the agent
|
76 |
text = chat_input["text"]
|
77 |
files = chat_input.get("files", [])
|
78 |
+
image_url = upload_image(files[0]) if files else None
|
79 |
messages = [
|
80 |
{
|
81 |
"type": "text",
|
|
|
97 |
hopter_client=hopter,
|
98 |
mask_service=mask_service
|
99 |
)
|
100 |
+
|
101 |
# Run the agent
|
102 |
async with mask_generation_agent.run_stream(
|
103 |
messages,
|
104 |
deps=deps,
|
105 |
+
message_history=past_messages
|
106 |
) as result:
|
107 |
for message in result.new_messages():
|
108 |
for call in message.parts:
|
gradio_demo.py
CHANGED
@@ -4,20 +4,17 @@ import os
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
-
from src.utils import image_path_to_uri
|
8 |
from pydantic_ai.messages import (
|
9 |
-
ToolCallPart,
|
10 |
ToolReturnPart
|
11 |
)
|
12 |
from src.agents.mask_generation_agent import EditImageResult
|
13 |
-
from
|
14 |
-
from pydantic_ai.models.openai import OpenAIModel
|
15 |
-
|
16 |
load_dotenv()
|
17 |
|
18 |
async def process_edit(image, instruction):
|
19 |
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
20 |
mask_service = GenerateMaskService(hopter=hopter)
|
|
|
21 |
messages = [
|
22 |
{
|
23 |
"type": "text",
|
@@ -26,11 +23,11 @@ async def process_edit(image, instruction):
|
|
26 |
]
|
27 |
if image:
|
28 |
messages.append(
|
29 |
-
{"type": "image_url", "image_url": {"url":
|
30 |
)
|
31 |
deps = ImageEditDeps(
|
32 |
edit_instruction=instruction,
|
33 |
-
image_url=
|
34 |
hopter_client=hopter,
|
35 |
mask_service=mask_service
|
36 |
)
|
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
|
|
7 |
from pydantic_ai.messages import (
|
|
|
8 |
ToolReturnPart
|
9 |
)
|
10 |
from src.agents.mask_generation_agent import EditImageResult
|
11 |
+
from src.utils import upload_image
|
|
|
|
|
12 |
load_dotenv()
|
13 |
|
14 |
async def process_edit(image, instruction):
|
15 |
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
16 |
mask_service = GenerateMaskService(hopter=hopter)
|
17 |
+
image_url = upload_image(image)
|
18 |
messages = [
|
19 |
{
|
20 |
"type": "text",
|
|
|
23 |
]
|
24 |
if image:
|
25 |
messages.append(
|
26 |
+
{"type": "image_url", "image_url": {"url": image_url}}
|
27 |
)
|
28 |
deps = ImageEditDeps(
|
29 |
edit_instruction=instruction,
|
30 |
+
image_url=image_url,
|
31 |
hopter_client=hopter,
|
32 |
mask_service=mask_service
|
33 |
)
|
poetry.lock
CHANGED
@@ -526,6 +526,30 @@ gitdb = ">=4.0.1,<5"
|
|
526 |
doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
|
527 |
test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
[[package]]
|
530 |
name = "google-auth"
|
531 |
version = "2.38.0"
|
@@ -550,6 +574,104 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
|
|
550 |
reauth = ["pyu2f (>=0.1.5)"]
|
551 |
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
[[package]]
|
554 |
name = "googleapis-common-protos"
|
555 |
version = "1.67.0"
|
@@ -1633,6 +1755,23 @@ tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "ole
|
|
1633 |
typing = ["typing-extensions"]
|
1634 |
xmp = ["defusedxml"]
|
1635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1636 |
[[package]]
|
1637 |
name = "protobuf"
|
1638 |
version = "5.29.3"
|
@@ -2911,4 +3050,4 @@ type = ["pytest-mypy"]
|
|
2911 |
[metadata]
|
2912 |
lock-version = "2.0"
|
2913 |
python-versions = "3.10"
|
2914 |
-
content-hash = "
|
|
|
526 |
doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
|
527 |
test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
|
528 |
|
529 |
+
[[package]]
|
530 |
+
name = "google-api-core"
|
531 |
+
version = "2.24.1"
|
532 |
+
description = "Google API client core library"
|
533 |
+
optional = false
|
534 |
+
python-versions = ">=3.7"
|
535 |
+
files = [
|
536 |
+
{file = "google_api_core-2.24.1-py3-none-any.whl", hash = "sha256:bc78d608f5a5bf853b80bd70a795f703294de656c096c0968320830a4bc280f1"},
|
537 |
+
{file = "google_api_core-2.24.1.tar.gz", hash = "sha256:f8b36f5456ab0dd99a1b693a40a31d1e7757beea380ad1b38faaf8941eae9d8a"},
|
538 |
+
]
|
539 |
+
|
540 |
+
[package.dependencies]
|
541 |
+
google-auth = ">=2.14.1,<3.0.dev0"
|
542 |
+
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
543 |
+
proto-plus = ">=1.22.3,<2.0.0dev"
|
544 |
+
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
545 |
+
requests = ">=2.18.0,<3.0.0.dev0"
|
546 |
+
|
547 |
+
[package.extras]
|
548 |
+
async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"]
|
549 |
+
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
|
550 |
+
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
551 |
+
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
552 |
+
|
553 |
[[package]]
|
554 |
name = "google-auth"
|
555 |
version = "2.38.0"
|
|
|
574 |
reauth = ["pyu2f (>=0.1.5)"]
|
575 |
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
576 |
|
577 |
+
[[package]]
|
578 |
+
name = "google-cloud-core"
|
579 |
+
version = "2.4.1"
|
580 |
+
description = "Google Cloud API client core library"
|
581 |
+
optional = false
|
582 |
+
python-versions = ">=3.7"
|
583 |
+
files = [
|
584 |
+
{file = "google-cloud-core-2.4.1.tar.gz", hash = "sha256:9b7749272a812bde58fff28868d0c5e2f585b82f37e09a1f6ed2d4d10f134073"},
|
585 |
+
{file = "google_cloud_core-2.4.1-py2.py3-none-any.whl", hash = "sha256:a9e6a4422b9ac5c29f79a0ede9485473338e2ce78d91f2370c01e730eab22e61"},
|
586 |
+
]
|
587 |
+
|
588 |
+
[package.dependencies]
|
589 |
+
google-api-core = ">=1.31.6,<2.0.dev0 || >2.3.0,<3.0.0dev"
|
590 |
+
google-auth = ">=1.25.0,<3.0dev"
|
591 |
+
|
592 |
+
[package.extras]
|
593 |
+
grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
|
594 |
+
|
595 |
+
[[package]]
|
596 |
+
name = "google-cloud-storage"
|
597 |
+
version = "3.0.0"
|
598 |
+
description = "Google Cloud Storage API client library"
|
599 |
+
optional = false
|
600 |
+
python-versions = ">=3.7"
|
601 |
+
files = [
|
602 |
+
{file = "google_cloud_storage-3.0.0-py2.py3-none-any.whl", hash = "sha256:f85fd059650d2dbb0ac158a9a6b304b66143b35ed2419afec2905ca522eb2c6a"},
|
603 |
+
{file = "google_cloud_storage-3.0.0.tar.gz", hash = "sha256:2accb3e828e584888beff1165e5f3ac61aa9088965eb0165794a82d8c7f95297"},
|
604 |
+
]
|
605 |
+
|
606 |
+
[package.dependencies]
|
607 |
+
google-api-core = ">=2.15.0,<3.0.0dev"
|
608 |
+
google-auth = ">=2.26.1,<3.0dev"
|
609 |
+
google-cloud-core = ">=2.3.0,<3.0dev"
|
610 |
+
google-crc32c = ">=1.0,<2.0dev"
|
611 |
+
google-resumable-media = ">=2.7.2"
|
612 |
+
requests = ">=2.18.0,<3.0.0dev"
|
613 |
+
|
614 |
+
[package.extras]
|
615 |
+
protobuf = ["protobuf (<6.0.0dev)"]
|
616 |
+
tracing = ["opentelemetry-api (>=1.1.0)"]
|
617 |
+
|
618 |
+
[[package]]
|
619 |
+
name = "google-crc32c"
|
620 |
+
version = "1.6.0"
|
621 |
+
description = "A python wrapper of the C library 'Google CRC32C'"
|
622 |
+
optional = false
|
623 |
+
python-versions = ">=3.9"
|
624 |
+
files = [
|
625 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa"},
|
626 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9"},
|
627 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7"},
|
628 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e"},
|
629 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc"},
|
630 |
+
{file = "google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42"},
|
631 |
+
{file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4"},
|
632 |
+
{file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8"},
|
633 |
+
{file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d"},
|
634 |
+
{file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f"},
|
635 |
+
{file = "google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3"},
|
636 |
+
{file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d"},
|
637 |
+
{file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b"},
|
638 |
+
{file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00"},
|
639 |
+
{file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3"},
|
640 |
+
{file = "google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760"},
|
641 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e2806553238cd076f0a55bddab37a532b53580e699ed8e5606d0de1f856b5205"},
|
642 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb0966e1c50d0ef5bc743312cc730b533491d60585a9a08f897274e57c3f70e0"},
|
643 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:386122eeaaa76951a8196310432c5b0ef3b53590ef4c317ec7588ec554fec5d2"},
|
644 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2952396dc604544ea7476b33fe87faedc24d666fb0c2d5ac971a2b9576ab871"},
|
645 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35834855408429cecf495cac67ccbab802de269e948e27478b1e47dfb6465e57"},
|
646 |
+
{file = "google_crc32c-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8797406499f28b5ef791f339594b0b5fdedf54e203b5066675c406ba69d705c"},
|
647 |
+
{file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc"},
|
648 |
+
{file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d"},
|
649 |
+
{file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e2d8c9a2f853ff116db9706b4a27350587f341eda835f46db3c0a8c8ce2f24"},
|
650 |
+
{file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ca8145b060679ec9176e6de4f89b07363d6805bd4760631ef254905503598d"},
|
651 |
+
{file = "google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc"},
|
652 |
+
]
|
653 |
+
|
654 |
+
[package.extras]
|
655 |
+
testing = ["pytest"]
|
656 |
+
|
657 |
+
[[package]]
|
658 |
+
name = "google-resumable-media"
|
659 |
+
version = "2.7.2"
|
660 |
+
description = "Utilities for Google Media Downloads and Resumable Uploads"
|
661 |
+
optional = false
|
662 |
+
python-versions = ">=3.7"
|
663 |
+
files = [
|
664 |
+
{file = "google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa"},
|
665 |
+
{file = "google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0"},
|
666 |
+
]
|
667 |
+
|
668 |
+
[package.dependencies]
|
669 |
+
google-crc32c = ">=1.0,<2.0dev"
|
670 |
+
|
671 |
+
[package.extras]
|
672 |
+
aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "google-auth (>=1.22.0,<2.0dev)"]
|
673 |
+
requests = ["requests (>=2.18.0,<3.0.0dev)"]
|
674 |
+
|
675 |
[[package]]
|
676 |
name = "googleapis-common-protos"
|
677 |
version = "1.67.0"
|
|
|
1755 |
typing = ["typing-extensions"]
|
1756 |
xmp = ["defusedxml"]
|
1757 |
|
1758 |
+
[[package]]
|
1759 |
+
name = "proto-plus"
|
1760 |
+
version = "1.26.0"
|
1761 |
+
description = "Beautiful, Pythonic protocol buffers"
|
1762 |
+
optional = false
|
1763 |
+
python-versions = ">=3.7"
|
1764 |
+
files = [
|
1765 |
+
{file = "proto_plus-1.26.0-py3-none-any.whl", hash = "sha256:bf2dfaa3da281fc3187d12d224c707cb57214fb2c22ba854eb0c105a3fb2d4d7"},
|
1766 |
+
{file = "proto_plus-1.26.0.tar.gz", hash = "sha256:6e93d5f5ca267b54300880fff156b6a3386b3fa3f43b1da62e680fc0c586ef22"},
|
1767 |
+
]
|
1768 |
+
|
1769 |
+
[package.dependencies]
|
1770 |
+
protobuf = ">=3.19.0,<6.0.0dev"
|
1771 |
+
|
1772 |
+
[package.extras]
|
1773 |
+
testing = ["google-api-core (>=1.31.5)"]
|
1774 |
+
|
1775 |
[[package]]
|
1776 |
name = "protobuf"
|
1777 |
version = "5.29.3"
|
|
|
3050 |
[metadata]
|
3051 |
lock-version = "2.0"
|
3052 |
python-versions = "3.10"
|
3053 |
+
content-hash = "3cfa3697a8b8f9c2ebbcd6e6fe3a7230ed78cafe541cd534df5bbb3b0cac9654"
|
pyproject.toml
CHANGED
@@ -14,6 +14,7 @@ pydantic-ai = "^0.0.24"
|
|
14 |
python-dotenv = "^1.0.1"
|
15 |
logfire = "^3.5.3"
|
16 |
gradio = "^5.16.1"
|
|
|
17 |
|
18 |
|
19 |
[build-system]
|
|
|
14 |
python-dotenv = "^1.0.1"
|
15 |
logfire = "^3.5.3"
|
16 |
gradio = "^5.16.1"
|
17 |
+
google-cloud-storage = "^3.0.0"
|
18 |
|
19 |
|
20 |
[build-system]
|
src/agents/mask_generation_agent.py
CHANGED
@@ -9,7 +9,10 @@ import logfire
|
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
11 |
from src.services.image_uploader import ImageUploader
|
12 |
-
from src.utils import image_path_to_uri
|
|
|
|
|
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
@@ -49,6 +52,16 @@ mask_generation_agent = Agent(
|
|
49 |
deps_type=ImageEditDeps
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
@mask_generation_agent.tool
|
53 |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
54 |
"""
|
@@ -65,16 +78,17 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
65 |
mask_service = ctx.deps.mask_service
|
66 |
hopter_client = ctx.deps.hopter_client
|
67 |
|
|
|
|
|
68 |
# Generate mask
|
69 |
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
70 |
-
mask = mask_service.generate_mask(mask_instruction,
|
71 |
|
72 |
# Magic replace
|
73 |
-
input = MagicReplaceInput(image=
|
74 |
result = hopter_client.magic_replace(input)
|
75 |
-
|
76 |
-
|
77 |
-
return EditImageResult(edited_image_url=uploaded_image.data.url)
|
78 |
|
79 |
@mask_generation_agent.tool
|
80 |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
@@ -86,9 +100,8 @@ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
86 |
|
87 |
input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
|
88 |
result = hopter_client.super_resolution(input)
|
89 |
-
|
90 |
-
|
91 |
-
return EditImageResult(edited_image_url=uploaded_image.data.url)
|
92 |
|
93 |
async def main():
|
94 |
image_file_path = "./assets/lakeview.jpg"
|
|
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
11 |
from src.services.image_uploader import ImageUploader
|
12 |
+
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
13 |
+
import base64
|
14 |
+
import tempfile
|
15 |
+
from PIL import Image
|
16 |
|
17 |
load_dotenv()
|
18 |
|
|
|
52 |
deps_type=ImageEditDeps
|
53 |
)
|
54 |
|
55 |
+
def upload_image_from_base64(base64_image: str) -> str:
|
56 |
+
image_format = base64_image.split(",")[0]
|
57 |
+
image_data = base64.b64decode(base64_image.split(",")[1])
|
58 |
+
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
|
59 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
60 |
+
temp_filename = temp_file.name
|
61 |
+
with open(temp_filename, "wb") as f:
|
62 |
+
f.write(image_data)
|
63 |
+
return upload_image(temp_filename)
|
64 |
+
|
65 |
@mask_generation_agent.tool
|
66 |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
67 |
"""
|
|
|
78 |
mask_service = ctx.deps.mask_service
|
79 |
hopter_client = ctx.deps.hopter_client
|
80 |
|
81 |
+
image_uri = download_image_to_data_uri(image_url)
|
82 |
+
|
83 |
# Generate mask
|
84 |
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
85 |
+
mask = mask_service.generate_mask(mask_instruction, image_uri)
|
86 |
|
87 |
# Magic replace
|
88 |
+
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
|
89 |
result = hopter_client.magic_replace(input)
|
90 |
+
uploaded_image = upload_image_from_base64(result.base64_image)
|
91 |
+
return EditImageResult(edited_image_url=uploaded_image)
|
|
|
92 |
|
93 |
@mask_generation_agent.tool
|
94 |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
|
100 |
|
101 |
input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
|
102 |
result = hopter_client.super_resolution(input)
|
103 |
+
uploaded_image = upload_image_from_base64(result.scaled_image)
|
104 |
+
return EditImageResult(edited_image_url=uploaded_image)
|
|
|
105 |
|
106 |
async def main():
|
107 |
image_file_path = "./assets/lakeview.jpg"
|
src/services/generate_mask.py
CHANGED
@@ -7,7 +7,7 @@ import asyncio
|
|
7 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
8 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
9 |
from src.services.openai_file_upload import OpenAIFileUpload
|
10 |
-
|
11 |
load_dotenv()
|
12 |
|
13 |
system_prompt = """
|
@@ -87,9 +87,10 @@ class GenerateMaskService:
|
|
87 |
Returns:
|
88 |
str: The mask image in base64 format.
|
89 |
"""
|
|
|
90 |
input = RamGroundedSamInput(
|
91 |
text_prompt=mask_instruction.subject,
|
92 |
-
image_b64=
|
93 |
)
|
94 |
generate_mask_result = self.hopter.generate_mask(input)
|
95 |
return generate_mask_result.mask_b64
|
|
|
7 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
8 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
9 |
from src.services.openai_file_upload import OpenAIFileUpload
|
10 |
+
from src.utils import download_image_to_data_uri
|
11 |
load_dotenv()
|
12 |
|
13 |
system_prompt = """
|
|
|
87 |
Returns:
|
88 |
str: The mask image in base64 format.
|
89 |
"""
|
90 |
+
image_uri = download_image_to_data_uri(image_url)
|
91 |
input = RamGroundedSamInput(
|
92 |
text_prompt=mask_instruction.subject,
|
93 |
+
image_b64=image_uri
|
94 |
)
|
95 |
generate_mask_result = self.hopter.generate_mask(input)
|
96 |
return generate_mask_result.mask_b64
|
src/services/google_cloud_image_upload.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.cloud import storage
|
2 |
+
from PIL import Image
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
import uuid
|
6 |
+
import tempfile
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
class GoogleCloudImageUploadService:
|
11 |
+
BUCKET_NAME = "picchat-assets"
|
12 |
+
MAX_DIMENSION = 1024
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
# Using API key here as per your original code. Note that for production,
|
16 |
+
# service account credentials are generally recommended.
|
17 |
+
self.storage_client = storage.Client(client_options={"api_key": os.environ.get("GOOGLE_API_KEY")})
|
18 |
+
|
19 |
+
def upload_image_to_gcs(self, source_file_name):
|
20 |
+
"""
|
21 |
+
Uploads an image to the specified Google Cloud Storage bucket.
|
22 |
+
Supports both JPEG and PNG formats.
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
bucket = self.storage_client.bucket(self.BUCKET_NAME)
|
26 |
+
blob_name = str(uuid.uuid4())
|
27 |
+
blob = bucket.blob(blob_name)
|
28 |
+
|
29 |
+
# Open and optionally resize the image, then save to a temporary file.
|
30 |
+
with Image.open(source_file_name) as image:
|
31 |
+
# Determine the original format. If it's not JPEG or PNG, default to JPEG.
|
32 |
+
original_format = image.format.upper() if image.format in ['JPEG', 'PNG'] else "JPEG"
|
33 |
+
|
34 |
+
# Resize if needed.
|
35 |
+
if image.width > self.MAX_DIMENSION or image.height > self.MAX_DIMENSION:
|
36 |
+
image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
|
37 |
+
|
38 |
+
# Choose the file extension based on the image format.
|
39 |
+
suffix = ".jpg" if original_format == "JPEG" else ".png"
|
40 |
+
|
41 |
+
# Create a temporary file with the appropriate suffix.
|
42 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
43 |
+
temp_filename = temp_file.name
|
44 |
+
image.save(temp_filename, format=original_format)
|
45 |
+
|
46 |
+
try:
|
47 |
+
# Set content type based on the image format.
|
48 |
+
content_type = "image/jpeg" if original_format == "JPEG" else "image/png"
|
49 |
+
blob.upload_from_filename(temp_filename, content_type=content_type)
|
50 |
+
blob.make_public()
|
51 |
+
finally:
|
52 |
+
# Remove the temporary file.
|
53 |
+
os.remove(temp_filename)
|
54 |
+
|
55 |
+
print(f"File {source_file_name} uploaded to {blob_name} in bucket {self.BUCKET_NAME}.")
|
56 |
+
return blob.public_url
|
57 |
+
except Exception as e:
|
58 |
+
print(f"An error occurred: {e}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
|
63 |
+
upload_service = GoogleCloudImageUploadService()
|
64 |
+
url = upload_service.upload_image_to_gcs(image)
|
65 |
+
print(url)
|
src/utils.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import base64
|
2 |
from fastapi import UploadFile
|
|
|
|
|
|
|
|
|
3 |
def image_path_to_base64(image_path: str) -> str:
|
4 |
with open(image_path, "rb") as image_file:
|
5 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
@@ -8,4 +12,38 @@ def upload_file_to_base64(file: UploadFile) -> str:
|
|
8 |
return base64.b64encode(file.file.read()).decode("utf-8")
|
9 |
|
10 |
def image_path_to_uri(image_path: str) -> str:
|
11 |
-
return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
from fastapi import UploadFile
|
3 |
+
from src.services.google_cloud_image_upload import GoogleCloudImageUploadService
|
4 |
+
from PIL import Image
|
5 |
+
from urllib.request import urlopen
|
6 |
+
import io
|
7 |
def image_path_to_base64(image_path: str) -> str:
|
8 |
with open(image_path, "rb") as image_file:
|
9 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
|
12 |
return base64.b64encode(file.file.read()).decode("utf-8")
|
13 |
|
14 |
def image_path_to_uri(image_path: str) -> str:
|
15 |
+
return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
|
16 |
+
|
17 |
+
def upload_image(image_path: str) -> str:
|
18 |
+
"""
|
19 |
+
Upload an image to Google Cloud Storage and return the public URL.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
image (str): The path to the image file.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
str: The public URL of the uploaded image.
|
26 |
+
"""
|
27 |
+
upload_service = GoogleCloudImageUploadService()
|
28 |
+
return upload_service.upload_image_to_gcs(image_path)
|
29 |
+
|
30 |
+
def download_image_to_data_uri(image_url: str) -> str:
|
31 |
+
# Open the image from the URL
|
32 |
+
response = urlopen(image_url)
|
33 |
+
img = Image.open(response)
|
34 |
+
|
35 |
+
# Determine the image format; default to 'JPEG' if not found
|
36 |
+
image_format = img.format if img.format is not None else "JPEG"
|
37 |
+
|
38 |
+
# Build the MIME type; for 'JPEG', use 'image/jpeg'
|
39 |
+
mime_type = "image/jpeg" if image_format.upper() == "JPEG" else f"image/{image_format.lower()}"
|
40 |
+
|
41 |
+
# Save the image to an in-memory buffer using the detected format
|
42 |
+
buffered = io.BytesIO()
|
43 |
+
img.save(buffered, format=image_format)
|
44 |
+
|
45 |
+
# Encode the image bytes to base64
|
46 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
47 |
+
|
48 |
+
# Return the data URI with the correct MIME type
|
49 |
+
return f"data:{mime_type};base64,{img_base64}"
|