import requests from PIL import Image import gradio as gr from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline import torch import tempfile import os os.system("pip uninstall -y gradio") os.system("pip install gradio==3.50") # Initialize the device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the image captioning model and processor image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning") image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device) # Load the story generation model and tokenizer story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen") story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen") generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer) # Define a function to generate caption from an image def generate_caption(image): try: # If image is a file object, extract the file path if isinstance(image, tempfile._TemporaryFileWrapper): image_path = image.name else: image_path = image # Load and preprocess the image image = Image.open(image_path) inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True) # Generate caption with torch.no_grad(): caption_ids = image_model.generate(**inputs.to(device)) # Decode the caption caption = image_processor.decode(caption_ids[0], skip_special_tokens=True) return caption except Exception as e: return f"An error occurred: {str(e)}" # Define a function to generate a story based on a prompt def generate_story(prompt): try: input_prompt = prompt story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text'] return story except Exception as e: return f"An error occurred: {str(e)}" # Create Gradio interfaces image_caption_interface = gr.Interface( fn=generate_caption, inputs=gr.inputs.Image(type="filepath", label="Upload Image"), outputs="text", title="Image Captioning", description="Generate a caption for the provided image." ) story_generation_interface = gr.Interface( fn=generate_story, inputs="text", outputs="text", title="Story Generation", description="Generate a story based on the provided prompt." ) # Create Gradio interfaces with gr.Blocks(css="footer{display:none !important}") as combined_interface: gr.Markdown( """ # The Seer's Legacy Crafting tales from the fabric of imagination. """) with gr.Row(): with gr.Column(): image_input = gr.inputs.Image(type="filepath", label="Upload Image") image_output = gr.Text(label="Image Caption") image_btn = gr.Button("Generate Image Caption") with gr.Column(): text_input = gr.inputs.Textbox(label="Story Prompt") story_output = gr.Text(label="Generated Story") story_btn = gr.Button("Generate Story") image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output]) story_btn.click(generate_story, inputs=[text_input], outputs=[story_output]) # Launch the combined interface combined_interface.launch()