sunbv56 commited on
Commit
3b810d1
·
verified ·
1 Parent(s): dd04ca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import gradio as gr
2
  import asyncio
3
  import os
4
- import torch
5
- import torchvision.transforms as transforms
6
- from torchvision.utils import save_image
7
  from google import genai
8
  from google.genai import types
9
  from PIL import Image
10
  from io import BytesIO
 
11
 
12
  # Cấu hình API Key
13
  api_key = os.getenv("GEMINI_API_KEY")
@@ -16,28 +14,20 @@ if not api_key:
16
 
17
  client = genai.Client(api_key=api_key)
18
 
19
- # Load SRCNN từ Torch Hub
20
- model = torch.hub.load('pytorch/vision:v0.10.0', 'srcnn', pretrained=True)
21
- model.eval()
22
 
23
- def upscale_image(image, target_resolution=(2560, 1440)):
24
- """Nâng cấp độ phân giải ảnh bằng nội suy trước khi qua SRCNN"""
25
- image = image.resize(target_resolution, Image.BICUBIC) # Nội suy trước khi SRCNN xử lý
26
- transform = transforms.Compose([
27
- transforms.ToTensor(),
28
- transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension
29
- ])
30
- img_tensor = transform(image)
31
- with torch.no_grad():
32
- upscaled_tensor = model(img_tensor)
33
- upscaled_image = transforms.ToPILImage()(upscaled_tensor.squeeze(0))
34
- return upscaled_image
35
 
36
  def load_image_as_bytes(image_path):
37
  """Chuyển ảnh thành dữ liệu nhị phân"""
38
  with Image.open(image_path) as img:
39
  img = img.convert("RGB") # Đảm bảo ảnh là RGB
40
- img = upscale_image(img) # SRCNN trước khi gửi đi
41
  img_bytes = BytesIO()
42
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
43
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
@@ -59,7 +49,7 @@ async def generate_image(image_bytes, text_input):
59
  for part in response.candidates[0].content.parts:
60
  if part.inline_data is not None:
61
  img = Image.open(BytesIO(part.inline_data.data))
62
- img = upscale_image(img) # SRCNN sau khi nhận ảnh từ Gemini
63
  images.append(img)
64
  return images
65
 
@@ -86,7 +76,7 @@ demo = gr.Interface(
86
  gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Số lượng ảnh cần tạo")
87
  ],
88
  outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
89
- title="Chỉnh sửa ảnh bằng Gemini AI + SRCNN",
90
  description="Upload ảnh và nhập yêu cầu chỉnh sửa. Ảnh được nâng cấp độ phân giải trước và sau khi xử lý.",
91
  )
92
 
 
1
  import gradio as gr
2
  import asyncio
3
  import os
 
 
 
4
  from google import genai
5
  from google.genai import types
6
  from PIL import Image
7
  from io import BytesIO
8
+ from super_image import RcanModel, ImageLoader
9
 
10
  # Cấu hình API Key
11
  api_key = os.getenv("GEMINI_API_KEY")
 
14
 
15
  client = genai.Client(api_key=api_key)
16
 
17
+ # Load RCAN-BAM model
18
+ model = RcanModel.from_pretrained('eugenesiow/rcan-bam', scale=2)
 
19
 
20
+ def upscale_image(image):
21
+ """Nâng cấp độ phân giải ảnh bằng RCAN-BAM"""
22
+ inputs = ImageLoader.load_image(image)
23
+ preds = model(inputs)
24
+ return ImageLoader.to_pil_image(preds)
 
 
 
 
 
 
 
25
 
26
  def load_image_as_bytes(image_path):
27
  """Chuyển ảnh thành dữ liệu nhị phân"""
28
  with Image.open(image_path) as img:
29
  img = img.convert("RGB") # Đảm bảo ảnh là RGB
30
+ img = upscale_image(img) # RCAN-BAM xử
31
  img_bytes = BytesIO()
32
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
33
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
 
49
  for part in response.candidates[0].content.parts:
50
  if part.inline_data is not None:
51
  img = Image.open(BytesIO(part.inline_data.data))
52
+ img = upscale_image(img) # RCAN-BAM sau khi nhận ảnh từ Gemini
53
  images.append(img)
54
  return images
55
 
 
76
  gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Số lượng ảnh cần tạo")
77
  ],
78
  outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
79
+ title="Chỉnh sửa ảnh bằng Gemini AI + RCAN-BAM",
80
  description="Upload ảnh và nhập yêu cầu chỉnh sửa. Ảnh được nâng cấp độ phân giải trước và sau khi xử lý.",
81
  )
82