Spaces:
Sleeping
Sleeping
nursulu
commited on
Commit
·
ceec8fc
1
Parent(s):
6c696fb
add functions
Browse files- app.py +85 -8
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/image_utils.cpython-311.pyc +0 -0
- utils/__pycache__/model_utils.cpython-311.pyc +0 -0
- utils/image_utils.py +148 -0
- utils/model_utils.py +79 -0
app.py
CHANGED
@@ -1,17 +1,94 @@
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import io
|
4 |
|
5 |
-
|
6 |
-
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
st.title("Image Upload and Processing App")
|
10 |
|
11 |
-
# Upload the image
|
12 |
-
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
+
import base64
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import torch
|
9 |
+
from peft import PeftModel, PeftConfig
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
import argparse
|
12 |
import io
|
13 |
|
14 |
+
from utils.model_utils import get_model_caption
|
15 |
+
from utils.image_utils import overlay_caption
|
16 |
|
17 |
+
@st.cache_resource
|
18 |
+
def load_models():
|
19 |
+
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
21 |
+
model_angry = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_angry")
|
22 |
+
model_happy = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_happy")
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
base_model.to(device)
|
26 |
+
model_happy.to(device)
|
27 |
+
model_angry.to(device)
|
28 |
+
|
29 |
+
# Load the adapters for specific moods
|
30 |
+
base_model.load_adapter("NursNurs/outputs_gemma2b_happy", "happy")
|
31 |
+
base_model.load_adapter("NursNurs/outputs_gemma2b_angry", "angry")
|
32 |
+
|
33 |
+
return base_model, tokenizer, model_happy, model_angry, device
|
34 |
+
|
35 |
+
# x = st.slider('Select a value')
|
36 |
+
# st.write(x, 'squared is', x * x)
|
37 |
+
def generate_meme_from_image(img_path, base_model, tokenizer, hf_token, output_dir, device='cuda'):
|
38 |
+
caption = get_model_caption(img_path, base_model, tokenizer, hf_token)
|
39 |
+
image = overlay_caption(caption, img_path, output_dir)
|
40 |
+
return image, caption
|
41 |
|
42 |
st.title("Image Upload and Processing App")
|
43 |
|
|
|
|
|
44 |
|
45 |
+
def main():
|
46 |
+
st.title("Meme Generator with Mood")
|
47 |
+
|
48 |
+
base_model, tokenizer, model_happy, model_angry, device = load_models()
|
49 |
+
|
50 |
+
# Input widget to upload an image
|
51 |
+
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
|
52 |
+
|
53 |
+
# Input widget to add Hugging Face token
|
54 |
+
hf_token = st.text_input("Enter your Hugging Face Token", type="password")
|
55 |
+
|
56 |
+
# Dropdown to select mood
|
57 |
+
# mood = st.selectbox("Select Mood", options=["happy", "angry"])
|
58 |
+
|
59 |
+
# Directory for saving the meme (optional, but you can let users set this if needed)
|
60 |
+
output_dir = "results"
|
61 |
+
|
62 |
+
if uploaded_image is not None and hf_token:
|
63 |
+
# Convert uploaded image to a PIL image
|
64 |
+
img = Image.open(uploaded_image)
|
65 |
+
|
66 |
+
# Generate meme when button is pressed
|
67 |
+
if st.button("Generate Meme"):
|
68 |
+
with st.spinner('Generating meme...'):
|
69 |
+
image, caption = generate_meme_from_image(img, base_model, tokenizer, hf_token, device)
|
70 |
+
|
71 |
+
# Display the output
|
72 |
+
st.image(image, caption=f"Generated Meme: {caption}")
|
73 |
+
|
74 |
+
# Optionally allow downloading the meme
|
75 |
+
buf = io.BytesIO()
|
76 |
+
image.save(buf, format="PNG")
|
77 |
+
byte_im = buf.getvalue()
|
78 |
+
|
79 |
+
st.download_button(
|
80 |
+
label="Download Meme",
|
81 |
+
data=byte_im,
|
82 |
+
file_name="generated_meme.png",
|
83 |
+
mime="image/png"
|
84 |
+
)
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
main()
|
88 |
+
# # Upload the image
|
89 |
+
# uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
|
90 |
+
|
91 |
+
# # Process and display if image is uploaded
|
92 |
+
# if uploaded_image is not None:
|
93 |
+
# image = Image.open(uploaded_image)
|
94 |
+
# st.image(image, caption="Uploaded Image", use_column_width=True)
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (236 Bytes). View file
|
|
utils/__pycache__/image_utils.cpython-311.pyc
ADDED
Binary file (6.98 kB). View file
|
|
utils/__pycache__/model_utils.cpython-311.pyc
ADDED
Binary file (5.05 kB). View file
|
|
utils/image_utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
import textwrap
|
5 |
+
|
6 |
+
|
7 |
+
def get_unique_filename(filename):
|
8 |
+
"""
|
9 |
+
Generate a unique filename by appending a number if a file with the same name already exists.
|
10 |
+
"""
|
11 |
+
if not os.path.exists(filename):
|
12 |
+
return filename
|
13 |
+
|
14 |
+
base, ext = os.path.splitext(filename)
|
15 |
+
counter = 1
|
16 |
+
new_filename = f"{base}_{counter}{ext}"
|
17 |
+
|
18 |
+
while os.path.exists(new_filename):
|
19 |
+
counter += 1
|
20 |
+
new_filename = f"{base}_{counter}{ext}"
|
21 |
+
|
22 |
+
return new_filename
|
23 |
+
|
24 |
+
|
25 |
+
def save_image_with_unique_name(image, path):
|
26 |
+
unique_path = get_unique_filename(path)
|
27 |
+
image.save(unique_path)
|
28 |
+
print(f"Image saved as: {unique_path}")
|
29 |
+
|
30 |
+
def find_text_in_answer(text):
|
31 |
+
print("Full caption:", text)
|
32 |
+
text = text.split("Caption:")[1]
|
33 |
+
text = text.replace("\n", "")
|
34 |
+
text = text.replace("model", "")
|
35 |
+
# Remove everything that lookslike <>
|
36 |
+
text = re.sub(r'<[^>]*>', '', text)
|
37 |
+
|
38 |
+
# Remove non-alphanumeric characters (keeping spaces)
|
39 |
+
text = re.sub(r'[^a-zA-Z0-9\?\!\s]', '', text)
|
40 |
+
print("Filtered caption:", text)
|
41 |
+
if text:
|
42 |
+
return text
|
43 |
+
else:
|
44 |
+
return "Me when I couldn't parse the model's answer but I still want you to smile :)"
|
45 |
+
|
46 |
+
|
47 |
+
def draw_text(draw, text, position, font, max_width, outline_color="black", text_color="white", outline_width=2):
|
48 |
+
"""
|
49 |
+
Draw text on the image with an outline, splitting it into lines if necessary and returning the total height used by the text.
|
50 |
+
The text is horizontally centered in the specified max_width.
|
51 |
+
"""
|
52 |
+
print("Adding the caption on the image...")
|
53 |
+
|
54 |
+
# Split the text into multiple lines based on the max width
|
55 |
+
lines = []
|
56 |
+
words = text.split()
|
57 |
+
line = ''
|
58 |
+
for word in words:
|
59 |
+
test_line = f'{line} {word}'.strip()
|
60 |
+
bbox = draw.textbbox((0, 0), test_line, font=font)
|
61 |
+
width = bbox[2] - bbox[0] # Width of the text
|
62 |
+
if width <= max_width:
|
63 |
+
line = test_line
|
64 |
+
else:
|
65 |
+
if line: # Avoid appending empty lines
|
66 |
+
lines.append(line)
|
67 |
+
line = word
|
68 |
+
if line:
|
69 |
+
lines.append(line)
|
70 |
+
|
71 |
+
y = position[1]
|
72 |
+
|
73 |
+
# Draw the text with an outline (black) first, centered horizontally
|
74 |
+
for line in lines:
|
75 |
+
# Calculate the width of the line and adjust the x position to center it
|
76 |
+
bbox = draw.textbbox((0, 0), line, font=font)
|
77 |
+
line_width = bbox[2] - bbox[0]
|
78 |
+
x = (max_width - line_width) // 2 + position[0]
|
79 |
+
|
80 |
+
# Draw the outline by drawing the text multiple times around the original position
|
81 |
+
for offset_x in [-outline_width, 0, outline_width]:
|
82 |
+
for offset_y in [-outline_width, 0, outline_width]:
|
83 |
+
if offset_x != 0 or offset_y != 0:
|
84 |
+
draw.text((x + offset_x, y + offset_y), line, font=font, fill=outline_color)
|
85 |
+
|
86 |
+
# Draw the main text (white) on top of the outline
|
87 |
+
draw.text((x, y), line, font=font, fill=text_color)
|
88 |
+
y += bbox[3] - bbox[1] # Update y position based on line height
|
89 |
+
|
90 |
+
return y - position[1] # Return the total height used by the text
|
91 |
+
|
92 |
+
def calculate_text_height(caption, font, max_width):
|
93 |
+
"""
|
94 |
+
Calculate the height of the text when drawn, given the caption, font, and maximum width.
|
95 |
+
"""
|
96 |
+
image = Image.new('RGB', (max_width, 1))
|
97 |
+
draw = ImageDraw.Draw(image)
|
98 |
+
return draw_text(draw, caption, (0, 0), font, max_width)
|
99 |
+
|
100 |
+
def add_caption(image_path, caption, output_path, top_margin=10, bottom_margin=10, max_caption_length=10, min_distance_from_bottom_mm=10):
|
101 |
+
image = Image.open(image_path)
|
102 |
+
draw = ImageDraw.Draw(image)
|
103 |
+
width, height = image.size
|
104 |
+
|
105 |
+
# Convert mm to pixels (assuming 96 DPI)
|
106 |
+
dpi = 96
|
107 |
+
min_distance_from_bottom_px = min_distance_from_bottom_mm * dpi / 25.4
|
108 |
+
|
109 |
+
# Split the caption into two parts if it is too long
|
110 |
+
if len(caption.split()) > max_caption_length:
|
111 |
+
font_size=20
|
112 |
+
total_len = len(caption.split())
|
113 |
+
mid = int(total_len / 2)
|
114 |
+
|
115 |
+
top_caption = caption.split()[:mid]
|
116 |
+
bottom_caption = caption.split()[mid:]
|
117 |
+
|
118 |
+
top_caption = " ".join(top_caption)
|
119 |
+
bottom_caption = " ".join(bottom_caption)
|
120 |
+
else:
|
121 |
+
top_caption = ""
|
122 |
+
bottom_caption = caption
|
123 |
+
font_size=30
|
124 |
+
|
125 |
+
# Load a font
|
126 |
+
font = ImageFont.truetype(r"fonts/Anton/Anton-Regular.ttf", font_size)
|
127 |
+
|
128 |
+
# Top caption
|
129 |
+
top_caption_position = (width // 10, top_margin)
|
130 |
+
draw_text(draw, top_caption, top_caption_position, font, width - 2 * (width // 10))
|
131 |
+
|
132 |
+
# Bottom caption
|
133 |
+
if bottom_caption: # Draw bottom caption only if it's not empty
|
134 |
+
# Calculate the height of the bottom caption
|
135 |
+
bottom_caption_height = calculate_text_height(bottom_caption, font, width - 2 * (width // 10))
|
136 |
+
bottom_caption_position = (width // 10, height - min_distance_from_bottom_px - bottom_caption_height)
|
137 |
+
draw_text(draw, bottom_caption, bottom_caption_position, font, width - 2 * (width // 10))
|
138 |
+
|
139 |
+
save_image_with_unique_name(image, output_path)
|
140 |
+
return image
|
141 |
+
|
142 |
+
|
143 |
+
def overlay_caption(text, img_path, output_dir):
|
144 |
+
img_name = img_path.split("/")[-1]
|
145 |
+
text = find_text_in_answer(text)
|
146 |
+
text = text.strip(".")
|
147 |
+
image = add_caption(img_path, text, output_dir+"/"+img_name)
|
148 |
+
return image
|
utils/model_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
+
from tqdm import tqdm
|
7 |
+
import re
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def query_clip(data, hf_token):
|
13 |
+
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32"
|
14 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
15 |
+
with open(data["image_path"], "rb") as f:
|
16 |
+
img = f.read()
|
17 |
+
payload={
|
18 |
+
"parameters": data["parameters"],
|
19 |
+
"inputs": base64.b64encode(img).decode("utf-8")
|
20 |
+
}
|
21 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
22 |
+
return response.json()
|
23 |
+
|
24 |
+
|
25 |
+
def get_sentiment(img_path, hf_token):
|
26 |
+
print("Getting the sentiment of the image...")
|
27 |
+
output = query_clip({
|
28 |
+
"image_path": img_path,
|
29 |
+
"parameters": {"candidate_labels": ["angry", "happy"]},
|
30 |
+
}, hf_token)
|
31 |
+
try:
|
32 |
+
print("Sentiment:", output[0]['label'])
|
33 |
+
return output[0]['label']
|
34 |
+
except:
|
35 |
+
print(output)
|
36 |
+
print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.")
|
37 |
+
|
38 |
+
|
39 |
+
def query_blip(filename, hf_token):
|
40 |
+
API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
|
41 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
42 |
+
with open(filename, "rb") as f:
|
43 |
+
file = f.read()
|
44 |
+
response = requests.post(API_URL, headers=headers, data=file)
|
45 |
+
return response.json()
|
46 |
+
|
47 |
+
|
48 |
+
def get_description(img_path, hf_token):
|
49 |
+
print("Getting the context of the image...")
|
50 |
+
output = query_blip(img_path, hf_token)
|
51 |
+
|
52 |
+
try:
|
53 |
+
print("Context:", output[0]['generated_text'])
|
54 |
+
return output[0]['generated_text']
|
55 |
+
except:
|
56 |
+
print(output)
|
57 |
+
print("The model is not available right now due to query limits. Try running again now or within the next hour")
|
58 |
+
|
59 |
+
|
60 |
+
def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'):
|
61 |
+
sentiment = get_sentiment(img_path, hf_token)
|
62 |
+
description = get_description(img_path, hf_token)
|
63 |
+
|
64 |
+
prompt_template = """
|
65 |
+
Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n
|
66 |
+
You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more.
|
67 |
+
Topic: {query}
|
68 |
+
<end_of_turn>\\n<start_of_turn>model Caption:
|
69 |
+
"""
|
70 |
+
prompt = prompt_template.format(query=description)
|
71 |
+
|
72 |
+
print("Generating captions...")
|
73 |
+
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
74 |
+
model_inputs = encodeds.to(device)
|
75 |
+
base_model.set_adapter(sentiment)
|
76 |
+
base_model.to(device)
|
77 |
+
generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id)
|
78 |
+
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
79 |
+
return (decoded)
|