Spaces:
Running
Running
from document_to_gloss import DocumentToASLConverter | |
from document_parsing import DocumentParser | |
from vectorizer import Vectorizer | |
from video_gen import create_multi_stitched_video | |
import gradio as gr | |
import asyncio | |
import re | |
import boto3 | |
import os | |
from botocore.config import Config | |
from dotenv import load_dotenv | |
import requests | |
import tempfile | |
import uuid | |
import base64 | |
# Load environment variables from .env file | |
load_dotenv() | |
# Load R2/S3 environment secrets | |
R2_ASL_VIDEOS_URL = os.environ.get("R2_ASL_VIDEOS_URL") | |
R2_ENDPOINT = os.environ.get("R2_ENDPOINT") | |
R2_ACCESS_KEY_ID = os.environ.get("R2_ACCESS_KEY_ID") | |
R2_SECRET_ACCESS_KEY = os.environ.get("R2_SECRET_ACCESS_KEY") | |
# Validate that required environment variables are set | |
if not all([R2_ASL_VIDEOS_URL, R2_ENDPOINT, R2_ACCESS_KEY_ID, | |
R2_SECRET_ACCESS_KEY]): | |
raise ValueError( | |
"Missing required R2 environment variables. " | |
"Please check your .env file." | |
) | |
title = "AI-SL" | |
description = "Convert text to ASL!" | |
article = ("<p style='text-align: center'><a href='https://github.com/deenasun' " | |
"target='_blank'>Deena Sun on Github</a></p>") | |
inputs = gr.File(label="Upload Document (pdf, txt, docx, or epub)") | |
outputs = [ | |
gr.JSON(label="Processing Results"), | |
gr.Video(label="ASL Video Output"), | |
gr.HTML(label="Download Link") | |
] | |
parser = DocumentParser() | |
asl_converter = DocumentToASLConverter() | |
vectorizer = Vectorizer() | |
session = boto3.session.Session() | |
s3 = session.client( | |
service_name='s3', | |
region_name='auto', | |
endpoint_url=R2_ENDPOINT, | |
aws_access_key_id=R2_ACCESS_KEY_ID, | |
aws_secret_access_key=R2_SECRET_ACCESS_KEY, | |
config=Config(signature_version='s3v4') | |
) | |
def clean_gloss_token(token): | |
"""Clean a single gloss token""" | |
if not token: | |
return None | |
# Remove punctuation and convert to lowercase | |
cleaned = re.sub(r'[^\w\s]', '', token).lower().strip() | |
# Remove extra whitespace | |
cleaned = re.sub(r'\s+', ' ', cleaned).strip() | |
return cleaned if cleaned else None | |
def verify_video_format(video_path): | |
""" | |
Verify that a video file is in a browser-compatible format (H.264 MP4) | |
""" | |
try: | |
import cv2 | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return False, "Could not open video file" | |
# Get video properties | |
fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) | |
codec = "".join([chr((fourcc >> 8 * i) & 0xFF) for i in range(4)]) | |
cap.release() | |
# Check if it's H.264 | |
if codec in ['avc1', 'H264', 'h264']: | |
return True, f"Video is H.264 encoded ({codec})" | |
else: | |
return False, f"Video codec {codec} may not be browser compatible" | |
except Exception as e: | |
return False, f"Error checking video format: {e}" | |
def upload_video_to_r2(video_path, bucket_name="asl-videos"): | |
""" | |
Upload a video file to R2 and return a public URL | |
""" | |
try: | |
# Verify video format for browser compatibility | |
is_compatible, message = verify_video_format(video_path) | |
print(f"Video format check: {message}") | |
# Generate a unique filename | |
file_extension = os.path.splitext(video_path)[1] | |
unique_filename = f"{uuid.uuid4()}{file_extension}" | |
# Upload to R2 | |
with open(video_path, 'rb') as video_file: | |
s3.upload_fileobj( | |
video_file, | |
bucket_name, | |
unique_filename, | |
ExtraArgs={ | |
'ACL': 'public-read', | |
'ContentType': 'video/mp4; codecs="avc1.42E01E"', # H.264 | |
'CacheControl': 'max-age=86400', # Cache for 24 hours | |
'ContentDisposition': 'inline' # Force inline display | |
}) | |
# Replace the endpoint with the domain for uploading | |
if R2_ENDPOINT: | |
public_domain = (R2_ENDPOINT.replace('https://', '') | |
.split('.')[0]) | |
video_url = (f"https://{public_domain}.r2.cloudflarestorage.com/" | |
f"{bucket_name}/{unique_filename}") | |
print(f"Video uploaded to R2: {video_url}") | |
public_video_url = f"{R2_ASL_VIDEOS_URL}/{unique_filename}" | |
print(f"Public video url: {public_video_url}") | |
return public_video_url | |
else: | |
print("R2_ENDPOINT is not configured") | |
return None | |
except Exception as e: | |
print(f"Error uploading video to R2: {e}") | |
return None | |
def video_to_base64(video_path): | |
""" | |
Convert a video file to base64 string for direct download | |
""" | |
try: | |
with open(video_path, 'rb') as video_file: | |
video_data = video_file.read() | |
base64_data = base64.b64encode(video_data).decode('utf-8') | |
return f"data:video/mp4;base64,{base64_data}" | |
except Exception as e: | |
print(f"Error converting video to base64: {e}") | |
return None | |
def download_video_from_url(video_url): | |
""" | |
Download a video from a public R2 URL | |
Returns the local file path where the video is saved | |
""" | |
try: | |
# Create a temporary file with .mp4 extension | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
temp_path = temp_file.name | |
temp_file.close() | |
# Download the video | |
print(f"Downloading video from: {video_url}") | |
response = requests.get(video_url, stream=True) | |
response.raise_for_status() | |
# Save to temporary file | |
with open(temp_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Video downloaded to: {temp_path}") | |
return temp_path | |
except Exception as e: | |
print(f"Error downloading video: {e}") | |
return None | |
def cleanup_temp_video(file_path): | |
""" | |
Clean up temporary video file | |
""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.unlink(file_path) | |
print(f"Cleaned up: {file_path}") | |
except Exception as e: | |
print(f"Error cleaning up file: {e}") | |
def determine_input_type(input_data): | |
""" | |
Determine the type of input data and return a standardized format. | |
Returns: (input_type, processed_data) where input_type is 'text', | |
'file_path', or 'file_object' | |
""" | |
if isinstance(input_data, str): | |
# Check if it's a file path (contains file extension) | |
if any(ext in input_data.lower() for ext in ['.pdf', '.txt', '.docx', '.doc', '.epub']): | |
return 'file_path', input_data | |
# Check if it's a string representation of a gradio.FileData dict | |
elif input_data.startswith('{') and 'gradio.FileData' in input_data: | |
try: | |
import ast | |
import json | |
# Try to parse as JSON first | |
try: | |
file_data = json.loads(input_data) | |
except json.JSONDecodeError: | |
# Fall back to ast.literal_eval for safer parsing | |
file_data = ast.literal_eval(input_data) | |
if isinstance(file_data, dict) and 'path' in file_data: | |
print(f"Parsed FileData: {file_data}") | |
return 'file_path', file_data['path'] | |
except (ValueError, SyntaxError, json.JSONDecodeError) as e: | |
print(f"Error parsing FileData string: {e}") | |
print(f"Input data: {input_data}") | |
pass | |
else: | |
return 'text', input_data.strip() | |
elif isinstance(input_data, dict) and 'path' in input_data: | |
# This is a gradio.FileData object from API calls | |
return 'file_path', input_data['path'] | |
elif hasattr(input_data, 'name'): | |
# This is a regular file object | |
return 'file_path', input_data.name | |
else: | |
return 'unknown', None | |
def process_input(input_data): | |
""" | |
Extract text content from various input types. | |
Returns the text content ready for ASL conversion. | |
""" | |
input_type, processed_data = determine_input_type(input_data) | |
if input_type == 'text': | |
return processed_data | |
elif input_type == 'file_path': | |
try: | |
print(f"Processing file: {processed_data}") | |
# Use document converter for all file types | |
gloss = asl_converter.convert_document(processed_data) | |
print(f"Converted gloss: {gloss[:100]}...") | |
return gloss | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
return None | |
else: | |
print(f"Unsupported input type: {type(input_data)}") | |
return None | |
async def parse_vectorize_and_search_unified(input_data): | |
""" | |
Unified function that handles both text and file inputs | |
""" | |
# Process the input to get gloss | |
gloss = process_input(input_data) | |
if not gloss: | |
return { | |
"status": "error", | |
"message": "Failed to process input" | |
}, None | |
print("ASL", gloss) | |
# Split by spaces and clean each token | |
gloss_tokens = gloss.split() | |
cleaned_tokens = [] | |
for token in gloss_tokens: | |
cleaned = clean_gloss_token(token) | |
if cleaned: # Only add non-empty tokens | |
cleaned_tokens.append(cleaned) | |
print("Cleaned tokens:", cleaned_tokens) | |
videos = [] | |
video_files = [] # Store local file paths for stitching | |
for g in cleaned_tokens: | |
print(f"Processing {g}") | |
try: | |
result = await vectorizer.vector_query_from_supabase(query=g) | |
print("result", result) | |
if result.get("match", False): | |
video_url = result["video_url"] | |
videos.append(video_url) | |
# Download the video | |
local_path = download_video_from_url(video_url) | |
if local_path: | |
video_files.append(local_path) | |
except Exception as e: | |
print(f"Error processing {g}: {e}") | |
continue | |
# Create stitched video if we have multiple videos | |
stitched_video_path = None | |
if len(video_files) > 1: | |
try: | |
print(f"Creating stitched video from {len(video_files)} videos...") | |
stitched_video_path = tempfile.NamedTemporaryFile( | |
delete=False, suffix='.mp4' | |
).name | |
create_multi_stitched_video(video_files, stitched_video_path) | |
print(f"Stitched video created: {stitched_video_path}") | |
except Exception as e: | |
print(f"Error creating stitched video: {e}") | |
stitched_video_path = None | |
elif len(video_files) == 1: | |
# If only one video, just use it directly | |
stitched_video_path = video_files[0] | |
# Upload final video to R2 and get public URL | |
video_download_url = None | |
if stitched_video_path: | |
video_download_url = upload_video_to_r2(stitched_video_path) | |
# Don't clean up the local file yet - let frontend use it first | |
# Clean up individual video files after stitching | |
for video_file in video_files: | |
if video_file != stitched_video_path: # Don't delete the final output | |
cleanup_temp_video(video_file) | |
video64 = video_to_base64(stitched_video_path) | |
# Return simplified results | |
return { | |
"status": "success", | |
"videos": videos, | |
"video_count": len(videos), | |
"gloss": gloss, | |
"cleaned_tokens": cleaned_tokens, | |
"video_download_url": video_download_url, | |
"video_as_base_64": video64 | |
}, stitched_video_path | |
def parse_vectorize_and_search_unified_sync(input_data): | |
return asyncio.run(parse_vectorize_and_search_unified(input_data)) | |
def predict_unified(input_data): | |
""" | |
Unified prediction function that handles both text and file inputs | |
""" | |
try: | |
if input_data is None: | |
return { | |
"status": "error", | |
"message": "Please provide text or upload a document" | |
}, None | |
# Use the unified processing function | |
result = parse_vectorize_and_search_unified_sync(input_data) | |
# Get the results | |
json_data, local_video_path = result | |
# If we have a local video path, use it directly for Gradio | |
if local_video_path and json_data.get("status") == "success": | |
# Schedule cleanup of the video file after a delay | |
# This gives Gradio time to load and display the video | |
import threading | |
import time | |
def delayed_cleanup(video_path): | |
time.sleep(30) # Wait 30 seconds before cleanup | |
cleanup_temp_video(video_path) | |
# Start cleanup thread | |
cleanup_thread = threading.Thread( | |
target=delayed_cleanup, | |
args=(local_video_path,) | |
) | |
cleanup_thread.daemon = True | |
cleanup_thread.start() | |
return json_data, local_video_path | |
return result | |
except Exception as e: | |
print(f"Error in predict_unified function: {e}") | |
return { | |
"status": "error", | |
"message": f"An error occurred: {str(e)}" | |
}, None | |
# Create the Gradio interface | |
def create_interface(): | |
"""Create and configure the Gradio interface""" | |
# Create the interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox( | |
label="Enter text to convert to ASL", | |
placeholder="Type or paste your text here...", | |
lines=5 | |
), | |
gr.File( | |
label="Upload Document (pdf, txt, docx, or epub)", | |
file_types=[".pdf", ".txt", ".docx", ".epub"] | |
) | |
], | |
outputs=[ | |
gr.JSON(label="Results"), | |
gr.Video(label="ASL Video") | |
], | |
title=title, | |
description=description, | |
article=article | |
) | |
return interface | |
# Add a predict function for Hugging Face API access | |
def predict(text, file): | |
""" | |
Predict function for Hugging Face API access. | |
This function will be available as the /predict endpoint. | |
""" | |
# Determine which input to use | |
if text and text.strip(): | |
# Use text input | |
input_data = text.strip() | |
elif file is not None: | |
# Use file input - let the centralized processor handle the type | |
input_data = file | |
else: | |
# No input provided | |
return { | |
"status": "error", | |
"message": "Please provide either text or upload a file" | |
}, None | |
print("Input to the prediction function", input_data) | |
print("Input type:", type(input)) | |
# Process using the unified function | |
return predict_unified(input_data) | |
# For Hugging Face Spaces, use the Interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True # Set to True for local testing with public URL | |
) | |