|
import torch |
|
import gradio as gr |
|
from transformers import pipeline |
|
from typing import List, Dict, Any, Tuple |
|
import csv |
|
from io import StringIO |
|
from PIL import Image, ImageDraw, ImageFont |
|
import requests |
|
from io import BytesIO |
|
import os |
|
from pathlib import Path |
|
import logging |
|
|
|
|
|
FONT_CACHE_DIR = Path("./font_cache") |
|
FONT_CACHE_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
FONT_SOURCES = { |
|
"Arial": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Arial.ttf", |
|
"Arial Bold": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Arial_Bold.ttf", |
|
"Arial Bold Italic": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Arial_Bold_Italic.ttf", |
|
"Arial Italic": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Arial_Italic.ttf", |
|
"Courier New": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Courier_New.ttf", |
|
"Verdana": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Verdana.ttf", |
|
"Verdana Bold": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Verdana_Bold.ttf", |
|
"Verdana Bold Italic": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Verdana_Bold_Italic.ttf", |
|
"Verdana Italic": "https://github.com/matomo-org/travis-scripts/raw/master/fonts/Verdana_Italic.ttf", |
|
} |
|
|
|
|
|
font_cache = {} |
|
|
|
|
|
def load_and_cache_fonts(): |
|
"""Load and cache fonts from known sources.""" |
|
for font_name, url in FONT_SOURCES.items(): |
|
font_path = FONT_CACHE_DIR / f"{font_name}.ttf" |
|
|
|
|
|
if font_path.exists(): |
|
try: |
|
font_cache[font_name] = str(font_path) |
|
logging.info(f"Loaded cached font: {font_name}") |
|
except Exception as e: |
|
logging.error(f"Error loading cached font {font_name}: {e}") |
|
continue |
|
|
|
|
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
|
|
with open(font_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
font_cache[font_name] = str(font_path) |
|
logging.info(f"Downloaded and cached font: {font_name}") |
|
except Exception as e: |
|
logging.error(f"Error downloading font {font_name}: {e}") |
|
|
|
|
|
|
|
load_and_cache_fonts() |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model="alpindale/Llama-3.2-3B-Instruct", |
|
torch_dtype=torch.bfloat16, |
|
device="cuda", |
|
) |
|
|
|
|
|
def read_feed_data(feed_text: str) -> List[Dict[str, str]]: |
|
"""Read all rows of feed data and return as list of dictionaries. |
|
Automatically detects the delimiter from common options (|, ,, ;, \t).""" |
|
feed_io = StringIO(feed_text) |
|
|
|
first_line = feed_io.readline().strip() |
|
|
|
|
|
delimiters = ["|", ",", ";", "\t"] |
|
delimiter = "|" |
|
max_count = 0 |
|
|
|
|
|
for d in delimiters: |
|
count = len(first_line.split(d)) |
|
if count > max_count: |
|
max_count = count |
|
delimiter = d |
|
|
|
|
|
feed_io.seek(0) |
|
reader = csv.reader(feed_io, delimiter=delimiter) |
|
headers = next(reader) |
|
return [dict(zip(headers, row)) for row in reader] |
|
|
|
|
|
def overlay_text_on_image( |
|
image_url: str, |
|
text: str, |
|
position: Tuple[int, int], |
|
font_size: int, |
|
font_color: str, |
|
font_family: str, |
|
) -> Image.Image: |
|
"""Add text overlay to image with specified properties.""" |
|
|
|
response = requests.get(image_url) |
|
img = Image.open(BytesIO(response.content)) |
|
|
|
|
|
draw = ImageDraw.Draw(img) |
|
|
|
try: |
|
|
|
if font_family in font_cache: |
|
font = ImageFont.truetype(font_cache[font_family], font_size) |
|
else: |
|
|
|
font = ImageFont.truetype(font_family, font_size) |
|
except OSError: |
|
|
|
font = ImageFont.load_default() |
|
logging.warning(f"Failed to load font {font_family}, using default") |
|
|
|
|
|
if font_color.startswith("rgba"): |
|
try: |
|
|
|
rgba = font_color.strip("rgba()").split(",") |
|
r = int(float(rgba[0])) |
|
g = int(float(rgba[1])) |
|
b = int(float(rgba[2])) |
|
a = int(float(rgba[3]) * 255) |
|
font_color = f"#{r:02x}{g:02x}{b:02x}" |
|
except (ValueError, IndexError): |
|
logging.warning( |
|
f"Invalid RGBA color format: {font_color}, falling back to white" |
|
) |
|
font_color = "#FFFFFF" |
|
|
|
|
|
draw.text(position, text, font=font, fill=font_color) |
|
|
|
return img |
|
|
|
|
|
def generate_response( |
|
prompt: str, |
|
feed_text: str, |
|
text_x: int = 10, |
|
text_y: int = 10, |
|
font_size: int = 24, |
|
font_color: str = "#FFFFFF", |
|
font_family: str = "Arial", |
|
max_new_tokens: int = 256, |
|
temperature: float = 0.7, |
|
) -> List[Image.Image]: |
|
|
|
feed_data_list = read_feed_data(feed_text) |
|
images = [] |
|
|
|
for feed_data in feed_data_list: |
|
|
|
formatted_prompt = prompt.format(**feed_data) |
|
system_prompt = "You are a helpful assistant that processes Meta Product Feeds." |
|
|
|
print(formatted_prompt) |
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": formatted_prompt}, |
|
] |
|
|
|
|
|
outputs = pipe( |
|
messages, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
) |
|
|
|
response = outputs[0]["generated_text"] |
|
|
|
|
|
generated_text = str(response[-1]["content"]) if response else "" |
|
|
|
|
|
image_with_text = overlay_text_on_image( |
|
image_url=feed_data.get("image_link", ""), |
|
text=generated_text, |
|
position=(text_x, text_y), |
|
font_size=font_size, |
|
font_color=font_color, |
|
font_family=font_family, |
|
) |
|
images.append(image_with_text) |
|
|
|
return images |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
title="Meta Product Feed Chat", |
|
description="Chat with Llama 3.2 model using feed data. Use {field_name} in your prompt to include feed data. The feed should be in CSV format with headers in the first row.", |
|
fn=generate_response, |
|
inputs=[ |
|
gr.Textbox(label="Enter your prompt (use {field_name} for feed data)", lines=3), |
|
gr.Textbox( |
|
label="Feed data (CSV with auto-detected delimiter)", lines=10, value="" |
|
), |
|
gr.Number(label="Text X Position", value=10), |
|
gr.Number(label="Text Y Position", value=10), |
|
gr.Number(label="Font Size", value=24), |
|
gr.ColorPicker(label="Font Color", value="#FFFFFF"), |
|
gr.Dropdown( |
|
label="Font Family", |
|
choices=list(FONT_SOURCES.keys()), |
|
value="Arial", |
|
), |
|
gr.Slider(minimum=1, maximum=512, value=256, step=1, label="Max New Tokens"), |
|
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), |
|
], |
|
outputs=[ |
|
gr.Gallery(label="Product Images with Text", columns=2), |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|