sunbv56 commited on
Commit
e79e752
·
verified ·
1 Parent(s): 73aaa23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import gradio as gr
2
  import asyncio
3
  import os
 
 
 
 
4
  from google import genai
5
  from google.genai import types
6
  from PIL import Image
7
  from io import BytesIO
8
- from super_image import RcanModel, ImageLoader
9
 
10
  # Cấu hình API Key
11
  api_key = os.getenv("GEMINI_API_KEY")
@@ -14,20 +17,43 @@ if not api_key:
14
 
15
  client = genai.Client(api_key=api_key)
16
 
17
- # Load RCAN-BAM model
18
- model = RcanModel.from_pretrained('eugenesiow/rcan-bam', scale=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def upscale_image(image):
21
- """Nâng cấp độ phân giải ảnh bằng RCAN-BAM"""
22
- inputs = ImageLoader.load_image(image)
23
- preds = model(inputs)
24
- return ImageLoader.to_pil_image(preds)
 
 
 
 
 
 
25
 
26
  def load_image_as_bytes(image_path):
27
  """Chuyển ảnh thành dữ liệu nhị phân"""
28
  with Image.open(image_path) as img:
29
  img = img.convert("RGB") # Đảm bảo ảnh là RGB
30
- img = upscale_image(img) # RCAN-BAM xử lý
31
  img_bytes = BytesIO()
32
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
33
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
@@ -49,7 +75,7 @@ async def generate_image(image_bytes, text_input):
49
  for part in response.candidates[0].content.parts:
50
  if part.inline_data is not None:
51
  img = Image.open(BytesIO(part.inline_data.data))
52
- img = upscale_image(img) # RCAN-BAM sau khi nhận ảnh từ Gemini
53
  images.append(img)
54
  return images
55
 
@@ -76,7 +102,7 @@ demo = gr.Interface(
76
  gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Số lượng ảnh cần tạo")
77
  ],
78
  outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
79
- title="Chỉnh sửa ảnh bằng Gemini AI + RCAN-BAM",
80
  description="Upload ảnh và nhập yêu cầu chỉnh sửa. Ảnh được nâng cấp độ phân giải trước và sau khi xử lý.",
81
  )
82
 
 
1
  import gradio as gr
2
  import asyncio
3
  import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torchvision.transforms as transforms
8
  from google import genai
9
  from google.genai import types
10
  from PIL import Image
11
  from io import BytesIO
 
12
 
13
  # Cấu hình API Key
14
  api_key = os.getenv("GEMINI_API_KEY")
 
17
 
18
  client = genai.Client(api_key=api_key)
19
 
20
+ # Định nghĩa mô hình SRCNN
21
+ class SRCNN(nn.Module):
22
+ def __init__(self):
23
+ super(SRCNN, self).__init__()
24
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
25
+ self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
26
+ self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
27
+ self.relu = nn.ReLU()
28
+
29
+ def forward(self, x):
30
+ x = self.relu(self.conv1(x))
31
+ x = self.relu(self.conv2(x))
32
+ x = self.conv3(x)
33
+ return x
34
+
35
+ # Khởi tạo mô hình SRCNN
36
+ model = SRCNN()
37
+ model.load_state_dict(torch.load("srcnn.pth", map_location=torch.device('cpu')))
38
+ model.eval()
39
 
40
  def upscale_image(image):
41
+ """Nâng cấp độ phân giải ảnh bằng SRCNN"""
42
+ transform = transforms.Compose([
43
+ transforms.ToTensor(),
44
+ transforms.Lambda(lambda x: x.unsqueeze(0)) # Thêm batch dimension
45
+ ])
46
+ input_tensor = transform(image)
47
+ with torch.no_grad():
48
+ output_tensor = model(input_tensor)
49
+ output_image = transforms.ToPILImage()(output_tensor.squeeze(0))
50
+ return output_image
51
 
52
  def load_image_as_bytes(image_path):
53
  """Chuyển ảnh thành dữ liệu nhị phân"""
54
  with Image.open(image_path) as img:
55
  img = img.convert("RGB") # Đảm bảo ảnh là RGB
56
+ img = upscale_image(img) # SRCNN xử lý
57
  img_bytes = BytesIO()
58
  img.save(img_bytes, format="JPEG") # Lưu ảnh vào buffer
59
  return img_bytes.getvalue() # Lấy dữ liệu nhị phân
 
75
  for part in response.candidates[0].content.parts:
76
  if part.inline_data is not None:
77
  img = Image.open(BytesIO(part.inline_data.data))
78
+ img = upscale_image(img) # SRCNN sau khi nhận ảnh từ Gemini
79
  images.append(img)
80
  return images
81
 
 
102
  gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Số lượng ảnh cần tạo")
103
  ],
104
  outputs=gr.Gallery(label="Kết quả chỉnh sửa", columns=4),
105
+ title="Chỉnh sửa ảnh bằng Gemini AI + SRCNN",
106
  description="Upload ảnh và nhập yêu cầu chỉnh sửa. Ảnh được nâng cấp độ phân giải trước và sau khi xử lý.",
107
  )
108