simonlee-cb commited on
Commit
ad3aed5
·
1 Parent(s): 24c3cc8

feat: fix chat

Browse files
gradio_chat.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
- from src.utils import image_path_to_uri
8
  from pydantic_ai.messages import (
9
  ToolCallPart,
10
  ToolReturnPart
@@ -46,9 +46,7 @@ def build_user_message(chat_input):
46
 
47
  def build_messages_for_agent(chat_input, past_messages):
48
  # filter out image messages from past messages to save on tokens
49
- messages = [msg for msg in past_messages
50
- if not (isinstance(msg, dict)
51
- and msg.get("type") == "image_url")]
52
 
53
  # add the user's text message
54
  if chat_input["text"]:
@@ -59,7 +57,7 @@ def build_messages_for_agent(chat_input, past_messages):
59
 
60
  # add the user's image message
61
  files = chat_input.get("files", [])
62
- image_url = image_path_to_uri(files[0]) if files else None
63
  if image_url:
64
  messages.append({
65
  "type": "image_url",
@@ -77,7 +75,7 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
77
  # Prepare messages for the agent
78
  text = chat_input["text"]
79
  files = chat_input.get("files", [])
80
- image_url = image_path_to_uri(files[0]) if files else None
81
  messages = [
82
  {
83
  "type": "text",
@@ -99,10 +97,12 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
99
  hopter_client=hopter,
100
  mask_service=mask_service
101
  )
 
102
  # Run the agent
103
  async with mask_generation_agent.run_stream(
104
  messages,
105
  deps=deps,
 
106
  ) as result:
107
  for message in result.new_messages():
108
  for call in message.parts:
 
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
+ from src.utils import image_path_to_uri, upload_image
8
  from pydantic_ai.messages import (
9
  ToolCallPart,
10
  ToolReturnPart
 
46
 
47
  def build_messages_for_agent(chat_input, past_messages):
48
  # filter out image messages from past messages to save on tokens
49
+ messages = past_messages
 
 
50
 
51
  # add the user's text message
52
  if chat_input["text"]:
 
57
 
58
  # add the user's image message
59
  files = chat_input.get("files", [])
60
+ image_url = upload_image(files[0]) if files else None
61
  if image_url:
62
  messages.append({
63
  "type": "image_url",
 
75
  # Prepare messages for the agent
76
  text = chat_input["text"]
77
  files = chat_input.get("files", [])
78
+ image_url = upload_image(files[0]) if files else None
79
  messages = [
80
  {
81
  "type": "text",
 
97
  hopter_client=hopter,
98
  mask_service=mask_service
99
  )
100
+
101
  # Run the agent
102
  async with mask_generation_agent.run_stream(
103
  messages,
104
  deps=deps,
105
+ message_history=past_messages
106
  ) as result:
107
  for message in result.new_messages():
108
  for call in message.parts:
gradio_demo.py CHANGED
@@ -4,20 +4,17 @@ import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
- from src.utils import image_path_to_uri
8
  from pydantic_ai.messages import (
9
- ToolCallPart,
10
  ToolReturnPart
11
  )
12
  from src.agents.mask_generation_agent import EditImageResult
13
- from pydantic_ai.agent import Agent
14
- from pydantic_ai.models.openai import OpenAIModel
15
-
16
  load_dotenv()
17
 
18
  async def process_edit(image, instruction):
19
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
20
  mask_service = GenerateMaskService(hopter=hopter)
 
21
  messages = [
22
  {
23
  "type": "text",
@@ -26,11 +23,11 @@ async def process_edit(image, instruction):
26
  ]
27
  if image:
28
  messages.append(
29
- {"type": "image_url", "image_url": {"url": image_path_to_uri(image)}}
30
  )
31
  deps = ImageEditDeps(
32
  edit_instruction=instruction,
33
- image_url=image_path_to_uri(image),
34
  hopter_client=hopter,
35
  mask_service=mask_service
36
  )
 
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
 
7
  from pydantic_ai.messages import (
 
8
  ToolReturnPart
9
  )
10
  from src.agents.mask_generation_agent import EditImageResult
11
+ from src.utils import upload_image
 
 
12
  load_dotenv()
13
 
14
  async def process_edit(image, instruction):
15
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
16
  mask_service = GenerateMaskService(hopter=hopter)
17
+ image_url = upload_image(image)
18
  messages = [
19
  {
20
  "type": "text",
 
23
  ]
24
  if image:
25
  messages.append(
26
+ {"type": "image_url", "image_url": {"url": image_url}}
27
  )
28
  deps = ImageEditDeps(
29
  edit_instruction=instruction,
30
+ image_url=image_url,
31
  hopter_client=hopter,
32
  mask_service=mask_service
33
  )
poetry.lock CHANGED
@@ -526,6 +526,30 @@ gitdb = ">=4.0.1,<5"
526
  doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
527
  test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  [[package]]
530
  name = "google-auth"
531
  version = "2.38.0"
@@ -550,6 +574,104 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
550
  reauth = ["pyu2f (>=0.1.5)"]
551
  requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  [[package]]
554
  name = "googleapis-common-protos"
555
  version = "1.67.0"
@@ -1633,6 +1755,23 @@ tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "ole
1633
  typing = ["typing-extensions"]
1634
  xmp = ["defusedxml"]
1635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1636
  [[package]]
1637
  name = "protobuf"
1638
  version = "5.29.3"
@@ -2911,4 +3050,4 @@ type = ["pytest-mypy"]
2911
  [metadata]
2912
  lock-version = "2.0"
2913
  python-versions = "3.10"
2914
- content-hash = "83c3d0c47f98107284fb95080f8f9a6ba3c17d9a89f13ccbd66b73ca99863b4d"
 
526
  doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
527
  test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
528
 
529
+ [[package]]
530
+ name = "google-api-core"
531
+ version = "2.24.1"
532
+ description = "Google API client core library"
533
+ optional = false
534
+ python-versions = ">=3.7"
535
+ files = [
536
+ {file = "google_api_core-2.24.1-py3-none-any.whl", hash = "sha256:bc78d608f5a5bf853b80bd70a795f703294de656c096c0968320830a4bc280f1"},
537
+ {file = "google_api_core-2.24.1.tar.gz", hash = "sha256:f8b36f5456ab0dd99a1b693a40a31d1e7757beea380ad1b38faaf8941eae9d8a"},
538
+ ]
539
+
540
+ [package.dependencies]
541
+ google-auth = ">=2.14.1,<3.0.dev0"
542
+ googleapis-common-protos = ">=1.56.2,<2.0.dev0"
543
+ proto-plus = ">=1.22.3,<2.0.0dev"
544
+ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
545
+ requests = ">=2.18.0,<3.0.0.dev0"
546
+
547
+ [package.extras]
548
+ async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"]
549
+ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
550
+ grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
551
+ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
552
+
553
  [[package]]
554
  name = "google-auth"
555
  version = "2.38.0"
 
574
  reauth = ["pyu2f (>=0.1.5)"]
575
  requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
576
 
577
+ [[package]]
578
+ name = "google-cloud-core"
579
+ version = "2.4.1"
580
+ description = "Google Cloud API client core library"
581
+ optional = false
582
+ python-versions = ">=3.7"
583
+ files = [
584
+ {file = "google-cloud-core-2.4.1.tar.gz", hash = "sha256:9b7749272a812bde58fff28868d0c5e2f585b82f37e09a1f6ed2d4d10f134073"},
585
+ {file = "google_cloud_core-2.4.1-py2.py3-none-any.whl", hash = "sha256:a9e6a4422b9ac5c29f79a0ede9485473338e2ce78d91f2370c01e730eab22e61"},
586
+ ]
587
+
588
+ [package.dependencies]
589
+ google-api-core = ">=1.31.6,<2.0.dev0 || >2.3.0,<3.0.0dev"
590
+ google-auth = ">=1.25.0,<3.0dev"
591
+
592
+ [package.extras]
593
+ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
594
+
595
+ [[package]]
596
+ name = "google-cloud-storage"
597
+ version = "3.0.0"
598
+ description = "Google Cloud Storage API client library"
599
+ optional = false
600
+ python-versions = ">=3.7"
601
+ files = [
602
+ {file = "google_cloud_storage-3.0.0-py2.py3-none-any.whl", hash = "sha256:f85fd059650d2dbb0ac158a9a6b304b66143b35ed2419afec2905ca522eb2c6a"},
603
+ {file = "google_cloud_storage-3.0.0.tar.gz", hash = "sha256:2accb3e828e584888beff1165e5f3ac61aa9088965eb0165794a82d8c7f95297"},
604
+ ]
605
+
606
+ [package.dependencies]
607
+ google-api-core = ">=2.15.0,<3.0.0dev"
608
+ google-auth = ">=2.26.1,<3.0dev"
609
+ google-cloud-core = ">=2.3.0,<3.0dev"
610
+ google-crc32c = ">=1.0,<2.0dev"
611
+ google-resumable-media = ">=2.7.2"
612
+ requests = ">=2.18.0,<3.0.0dev"
613
+
614
+ [package.extras]
615
+ protobuf = ["protobuf (<6.0.0dev)"]
616
+ tracing = ["opentelemetry-api (>=1.1.0)"]
617
+
618
+ [[package]]
619
+ name = "google-crc32c"
620
+ version = "1.6.0"
621
+ description = "A python wrapper of the C library 'Google CRC32C'"
622
+ optional = false
623
+ python-versions = ">=3.9"
624
+ files = [
625
+ {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa"},
626
+ {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9"},
627
+ {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7"},
628
+ {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e"},
629
+ {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc"},
630
+ {file = "google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42"},
631
+ {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4"},
632
+ {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8"},
633
+ {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d"},
634
+ {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f"},
635
+ {file = "google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3"},
636
+ {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d"},
637
+ {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b"},
638
+ {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00"},
639
+ {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3"},
640
+ {file = "google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760"},
641
+ {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e2806553238cd076f0a55bddab37a532b53580e699ed8e5606d0de1f856b5205"},
642
+ {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb0966e1c50d0ef5bc743312cc730b533491d60585a9a08f897274e57c3f70e0"},
643
+ {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:386122eeaaa76951a8196310432c5b0ef3b53590ef4c317ec7588ec554fec5d2"},
644
+ {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2952396dc604544ea7476b33fe87faedc24d666fb0c2d5ac971a2b9576ab871"},
645
+ {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35834855408429cecf495cac67ccbab802de269e948e27478b1e47dfb6465e57"},
646
+ {file = "google_crc32c-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8797406499f28b5ef791f339594b0b5fdedf54e203b5066675c406ba69d705c"},
647
+ {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc"},
648
+ {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d"},
649
+ {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e2d8c9a2f853ff116db9706b4a27350587f341eda835f46db3c0a8c8ce2f24"},
650
+ {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ca8145b060679ec9176e6de4f89b07363d6805bd4760631ef254905503598d"},
651
+ {file = "google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc"},
652
+ ]
653
+
654
+ [package.extras]
655
+ testing = ["pytest"]
656
+
657
+ [[package]]
658
+ name = "google-resumable-media"
659
+ version = "2.7.2"
660
+ description = "Utilities for Google Media Downloads and Resumable Uploads"
661
+ optional = false
662
+ python-versions = ">=3.7"
663
+ files = [
664
+ {file = "google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa"},
665
+ {file = "google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0"},
666
+ ]
667
+
668
+ [package.dependencies]
669
+ google-crc32c = ">=1.0,<2.0dev"
670
+
671
+ [package.extras]
672
+ aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "google-auth (>=1.22.0,<2.0dev)"]
673
+ requests = ["requests (>=2.18.0,<3.0.0dev)"]
674
+
675
  [[package]]
676
  name = "googleapis-common-protos"
677
  version = "1.67.0"
 
1755
  typing = ["typing-extensions"]
1756
  xmp = ["defusedxml"]
1757
 
1758
+ [[package]]
1759
+ name = "proto-plus"
1760
+ version = "1.26.0"
1761
+ description = "Beautiful, Pythonic protocol buffers"
1762
+ optional = false
1763
+ python-versions = ">=3.7"
1764
+ files = [
1765
+ {file = "proto_plus-1.26.0-py3-none-any.whl", hash = "sha256:bf2dfaa3da281fc3187d12d224c707cb57214fb2c22ba854eb0c105a3fb2d4d7"},
1766
+ {file = "proto_plus-1.26.0.tar.gz", hash = "sha256:6e93d5f5ca267b54300880fff156b6a3386b3fa3f43b1da62e680fc0c586ef22"},
1767
+ ]
1768
+
1769
+ [package.dependencies]
1770
+ protobuf = ">=3.19.0,<6.0.0dev"
1771
+
1772
+ [package.extras]
1773
+ testing = ["google-api-core (>=1.31.5)"]
1774
+
1775
  [[package]]
1776
  name = "protobuf"
1777
  version = "5.29.3"
 
3050
  [metadata]
3051
  lock-version = "2.0"
3052
  python-versions = "3.10"
3053
+ content-hash = "3cfa3697a8b8f9c2ebbcd6e6fe3a7230ed78cafe541cd534df5bbb3b0cac9654"
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ pydantic-ai = "^0.0.24"
14
  python-dotenv = "^1.0.1"
15
  logfire = "^3.5.3"
16
  gradio = "^5.16.1"
 
17
 
18
 
19
  [build-system]
 
14
  python-dotenv = "^1.0.1"
15
  logfire = "^3.5.3"
16
  gradio = "^5.16.1"
17
+ google-cloud-storage = "^3.0.0"
18
 
19
 
20
  [build-system]
src/agents/mask_generation_agent.py CHANGED
@@ -9,7 +9,10 @@ import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
11
  from src.services.image_uploader import ImageUploader
12
- from src.utils import image_path_to_uri
 
 
 
13
 
14
  load_dotenv()
15
 
@@ -49,6 +52,16 @@ mask_generation_agent = Agent(
49
  deps_type=ImageEditDeps
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
52
  @mask_generation_agent.tool
53
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
54
  """
@@ -65,16 +78,17 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
65
  mask_service = ctx.deps.mask_service
66
  hopter_client = ctx.deps.hopter_client
67
 
 
 
68
  # Generate mask
69
  mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
70
- mask = mask_service.generate_mask(mask_instruction, image_url)
71
 
72
  # Magic replace
73
- input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
74
  result = hopter_client.magic_replace(input)
75
- uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
76
- uploaded_image = uploader.upload_url(result.base64_image)
77
- return EditImageResult(edited_image_url=uploaded_image.data.url)
78
 
79
  @mask_generation_agent.tool
80
  async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
@@ -86,9 +100,8 @@ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
86
 
87
  input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
88
  result = hopter_client.super_resolution(input)
89
- uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
90
- uploaded_image = uploader.upload_url(result.scaled_image)
91
- return EditImageResult(edited_image_url=uploaded_image.data.url)
92
 
93
  async def main():
94
  image_file_path = "./assets/lakeview.jpg"
 
9
  from src.services.generate_mask import GenerateMaskService
10
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
11
  from src.services.image_uploader import ImageUploader
12
+ from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
13
+ import base64
14
+ import tempfile
15
+ from PIL import Image
16
 
17
  load_dotenv()
18
 
 
52
  deps_type=ImageEditDeps
53
  )
54
 
55
+ def upload_image_from_base64(base64_image: str) -> str:
56
+ image_format = base64_image.split(",")[0]
57
+ image_data = base64.b64decode(base64_image.split(",")[1])
58
+ suffix = ".jpg" if image_format == "image/jpeg" else ".png"
59
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
60
+ temp_filename = temp_file.name
61
+ with open(temp_filename, "wb") as f:
62
+ f.write(image_data)
63
+ return upload_image(temp_filename)
64
+
65
  @mask_generation_agent.tool
66
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
67
  """
 
78
  mask_service = ctx.deps.mask_service
79
  hopter_client = ctx.deps.hopter_client
80
 
81
+ image_uri = download_image_to_data_uri(image_url)
82
+
83
  # Generate mask
84
  mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
85
+ mask = mask_service.generate_mask(mask_instruction, image_uri)
86
 
87
  # Magic replace
88
+ input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
89
  result = hopter_client.magic_replace(input)
90
+ uploaded_image = upload_image_from_base64(result.base64_image)
91
+ return EditImageResult(edited_image_url=uploaded_image)
 
92
 
93
  @mask_generation_agent.tool
94
  async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
 
100
 
101
  input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
102
  result = hopter_client.super_resolution(input)
103
+ uploaded_image = upload_image_from_base64(result.scaled_image)
104
+ return EditImageResult(edited_image_url=uploaded_image)
 
105
 
106
  async def main():
107
  image_file_path = "./assets/lakeview.jpg"
src/services/generate_mask.py CHANGED
@@ -7,7 +7,7 @@ import asyncio
7
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
8
  from src.models.generate_mask_instruction import GenerateMaskInstruction
9
  from src.services.openai_file_upload import OpenAIFileUpload
10
-
11
  load_dotenv()
12
 
13
  system_prompt = """
@@ -87,9 +87,10 @@ class GenerateMaskService:
87
  Returns:
88
  str: The mask image in base64 format.
89
  """
 
90
  input = RamGroundedSamInput(
91
  text_prompt=mask_instruction.subject,
92
- image_b64=image_url
93
  )
94
  generate_mask_result = self.hopter.generate_mask(input)
95
  return generate_mask_result.mask_b64
 
7
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
8
  from src.models.generate_mask_instruction import GenerateMaskInstruction
9
  from src.services.openai_file_upload import OpenAIFileUpload
10
+ from src.utils import download_image_to_data_uri
11
  load_dotenv()
12
 
13
  system_prompt = """
 
87
  Returns:
88
  str: The mask image in base64 format.
89
  """
90
+ image_uri = download_image_to_data_uri(image_url)
91
  input = RamGroundedSamInput(
92
  text_prompt=mask_instruction.subject,
93
+ image_b64=image_uri
94
  )
95
  generate_mask_result = self.hopter.generate_mask(input)
96
  return generate_mask_result.mask_b64
src/services/google_cloud_image_upload.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.cloud import storage
2
+ from PIL import Image
3
+ from dotenv import load_dotenv
4
+ import os
5
+ import uuid
6
+ import tempfile
7
+
8
+ load_dotenv()
9
+
10
+ class GoogleCloudImageUploadService:
11
+ BUCKET_NAME = "picchat-assets"
12
+ MAX_DIMENSION = 1024
13
+
14
+ def __init__(self):
15
+ # Using API key here as per your original code. Note that for production,
16
+ # service account credentials are generally recommended.
17
+ self.storage_client = storage.Client(client_options={"api_key": os.environ.get("GOOGLE_API_KEY")})
18
+
19
+ def upload_image_to_gcs(self, source_file_name):
20
+ """
21
+ Uploads an image to the specified Google Cloud Storage bucket.
22
+ Supports both JPEG and PNG formats.
23
+ """
24
+ try:
25
+ bucket = self.storage_client.bucket(self.BUCKET_NAME)
26
+ blob_name = str(uuid.uuid4())
27
+ blob = bucket.blob(blob_name)
28
+
29
+ # Open and optionally resize the image, then save to a temporary file.
30
+ with Image.open(source_file_name) as image:
31
+ # Determine the original format. If it's not JPEG or PNG, default to JPEG.
32
+ original_format = image.format.upper() if image.format in ['JPEG', 'PNG'] else "JPEG"
33
+
34
+ # Resize if needed.
35
+ if image.width > self.MAX_DIMENSION or image.height > self.MAX_DIMENSION:
36
+ image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
37
+
38
+ # Choose the file extension based on the image format.
39
+ suffix = ".jpg" if original_format == "JPEG" else ".png"
40
+
41
+ # Create a temporary file with the appropriate suffix.
42
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
43
+ temp_filename = temp_file.name
44
+ image.save(temp_filename, format=original_format)
45
+
46
+ try:
47
+ # Set content type based on the image format.
48
+ content_type = "image/jpeg" if original_format == "JPEG" else "image/png"
49
+ blob.upload_from_filename(temp_filename, content_type=content_type)
50
+ blob.make_public()
51
+ finally:
52
+ # Remove the temporary file.
53
+ os.remove(temp_filename)
54
+
55
+ print(f"File {source_file_name} uploaded to {blob_name} in bucket {self.BUCKET_NAME}.")
56
+ return blob.public_url
57
+ except Exception as e:
58
+ print(f"An error occurred: {e}")
59
+ return None
60
+
61
+ if __name__ == "__main__":
62
+ image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
63
+ upload_service = GoogleCloudImageUploadService()
64
+ url = upload_service.upload_image_to_gcs(image)
65
+ print(url)
src/utils.py CHANGED
@@ -1,5 +1,9 @@
1
  import base64
2
  from fastapi import UploadFile
 
 
 
 
3
  def image_path_to_base64(image_path: str) -> str:
4
  with open(image_path, "rb") as image_file:
5
  return base64.b64encode(image_file.read()).decode("utf-8")
@@ -8,4 +12,38 @@ def upload_file_to_base64(file: UploadFile) -> str:
8
  return base64.b64encode(file.file.read()).decode("utf-8")
9
 
10
  def image_path_to_uri(image_path: str) -> str:
11
- return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  from fastapi import UploadFile
3
+ from src.services.google_cloud_image_upload import GoogleCloudImageUploadService
4
+ from PIL import Image
5
+ from urllib.request import urlopen
6
+ import io
7
  def image_path_to_base64(image_path: str) -> str:
8
  with open(image_path, "rb") as image_file:
9
  return base64.b64encode(image_file.read()).decode("utf-8")
 
12
  return base64.b64encode(file.file.read()).decode("utf-8")
13
 
14
  def image_path_to_uri(image_path: str) -> str:
15
+ return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
16
+
17
+ def upload_image(image_path: str) -> str:
18
+ """
19
+ Upload an image to Google Cloud Storage and return the public URL.
20
+
21
+ Args:
22
+ image (str): The path to the image file.
23
+
24
+ Returns:
25
+ str: The public URL of the uploaded image.
26
+ """
27
+ upload_service = GoogleCloudImageUploadService()
28
+ return upload_service.upload_image_to_gcs(image_path)
29
+
30
+ def download_image_to_data_uri(image_url: str) -> str:
31
+ # Open the image from the URL
32
+ response = urlopen(image_url)
33
+ img = Image.open(response)
34
+
35
+ # Determine the image format; default to 'JPEG' if not found
36
+ image_format = img.format if img.format is not None else "JPEG"
37
+
38
+ # Build the MIME type; for 'JPEG', use 'image/jpeg'
39
+ mime_type = "image/jpeg" if image_format.upper() == "JPEG" else f"image/{image_format.lower()}"
40
+
41
+ # Save the image to an in-memory buffer using the detected format
42
+ buffered = io.BytesIO()
43
+ img.save(buffered, format=image_format)
44
+
45
+ # Encode the image bytes to base64
46
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
47
+
48
+ # Return the data URI with the correct MIME type
49
+ return f"data:{mime_type};base64,{img_base64}"