|
|
|
pip install transformers gradio Pillow requests |
|
import os |
|
import requests |
|
from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer |
|
from PIL import Image, ImageDraw |
|
import io |
|
import gradio as gr |
|
import torch |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
model_name = "Helsinki-NLP/opus-mt-mul-en" |
|
translation_model = MarianMTModel.from_pretrained(model_name).to(device) |
|
translation_tokenizer = MarianTokenizer.from_pretrained(model_name) |
|
|
|
|
|
text_generation_model_name = "EleutherAI/gpt-neo-1.3B" |
|
text_generation_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device) |
|
text_generation_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name) |
|
|
|
|
|
if text_generation_tokenizer.pad_token is None: |
|
text_generation_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
|
|
|
os.environ['HF_API_KEY'] = 'Your_HF_TOKEN' |
|
api_key = os.getenv('HF_API_KEY') |
|
if api_key is None: |
|
raise ValueError("Hugging Face API key is not set. Please set it in your environment.") |
|
|
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" |
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
if response.status_code != 200: |
|
print(f"Error: Received status code {response.status_code}") |
|
print(f"Response: {response.text}") |
|
return None |
|
return response.content |
|
|
|
|
|
def translate_text(tamil_text): |
|
inputs = translation_tokenizer(tamil_text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
translated_tokens = translation_model.generate(**inputs) |
|
translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) |
|
return translation |
|
|
|
|
|
def generate_image(prompt): |
|
image_bytes = query({"inputs": prompt}) |
|
|
|
if image_bytes is None: |
|
|
|
error_img = Image.new('RGB', (300, 300), color=(255, 0, 0)) |
|
d = ImageDraw.Draw(error_img) |
|
d.text((10, 150), "Image Generation Failed", fill=(255, 255, 255)) |
|
return error_img |
|
|
|
try: |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
return image |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
error_img = Image.new('RGB', (300, 300), color=(255, 0, 0)) |
|
d = ImageDraw.Draw(error_img) |
|
d.text((10, 150), "Invalid Image Data", fill=(255, 255, 255)) |
|
return error_img |
|
|
|
|
|
def generate_creative_text(translated_text): |
|
inputs = text_generation_tokenizer(translated_text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
generated_tokens = text_generation_model.generate(**inputs, max_length=100) |
|
creative_text = text_generation_tokenizer.decode(generated_tokens[0], skip_special_tokens=True) |
|
return creative_text |
|
|
|
|
|
def translate_generate_image_and_text(tamil_text): |
|
|
|
translated_text = translate_text(tamil_text) |
|
|
|
|
|
image = generate_image(translated_text) |
|
|
|
|
|
creative_text = generate_creative_text(translated_text) |
|
|
|
return translated_text, creative_text, image |
|
|
|
|
|
css = """ |
|
#transart-title { |
|
font-size: 2.5em; |
|
font-weight: bold; |
|
color: #4CAF50; |
|
text-align: center; |
|
margin-bottom: 10px; |
|
} |
|
#transart-subtitle { |
|
font-size: 1.25em; |
|
text-align: center; |
|
color: #555555; |
|
margin-bottom: 20px; |
|
} |
|
body { |
|
background-color: #f0f0f5; |
|
} |
|
.gradio-container { |
|
font-family: 'Arial', sans-serif; |
|
} |
|
""" |
|
|
|
|
|
title_markdown = """ |
|
# <div id="transart-title">TransArt</div> |
|
### <div id="transart-subtitle">Tamil to English Translation, Creative Text & Image Generation</div> |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as interface: |
|
gr.Markdown(title_markdown) |
|
with gr.Row(): |
|
with gr.Column(): |
|
tamil_input = gr.Textbox(label="Enter Tamil Text", placeholder="Type Tamil text here...", lines=3) |
|
with gr.Column(): |
|
translated_output = gr.Textbox(label="Translated Text", interactive=False) |
|
creative_text_output = gr.Textbox(label="Creative Generated Text", interactive=False) |
|
generated_image_output = gr.Image(label="Generated Image") |
|
|
|
gr.Button("Generate").click(fn=translate_generate_image_and_text, inputs=tamil_input, outputs=[translated_output, creative_text_output, generated_image_output]) |
|
|
|
|
|
interface.launch(debug=True, server_name="0.0.0.0") |