File size: 3,442 Bytes
e265f46
 
 
 
 
 
 
 
 
8209b0c
e265f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import requests
from PIL import Image
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
import torch
import tempfile
import os

os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50")

# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the image captioning model and processor
image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning")
image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device)

# Load the story generation model and tokenizer
story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen")
story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen")

generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer)

# Define a function to generate caption from an image
def generate_caption(image):
    try:
        # If image is a file object, extract the file path
        if isinstance(image, tempfile._TemporaryFileWrapper):
            image_path = image.name
        else:
            image_path = image

        # Load and preprocess the image
        image = Image.open(image_path)
        inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True)

        # Generate caption
        with torch.no_grad():
            caption_ids = image_model.generate(**inputs.to(device))

        # Decode the caption
        caption = image_processor.decode(caption_ids[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Define a function to generate a story based on a prompt
def generate_story(prompt):
    try:
        input_prompt = prompt
        story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text']
        return story
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create Gradio interfaces
image_caption_interface = gr.Interface(
    fn=generate_caption,
    inputs=gr.inputs.Image(type="filepath", label="Upload Image"),
    outputs="text",
    title="Image Captioning",
    description="Generate a caption for the provided image."
)

story_generation_interface = gr.Interface(
    fn=generate_story,
    inputs="text",
    outputs="text",
    title="Story Generation",
    description="Generate a story based on the provided prompt."
)

# Create Gradio interfaces
with gr.Blocks(css="footer{display:none !important}") as combined_interface:
    gr.Markdown(
    """
    # The Seer's Legacy
    Crafting tales from the fabric of imagination.
    """)
    with gr.Row():
        with gr.Column():
            image_input = gr.inputs.Image(type="filepath", label="Upload Image")
            image_output = gr.Text(label="Image Caption")
            image_btn = gr.Button("Generate Image Caption")

        with gr.Column():
            text_input = gr.inputs.Textbox(label="Story Prompt")
            story_output = gr.Text(label="Generated Story")
            story_btn = gr.Button("Generate Story")

    image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output])
    story_btn.click(generate_story, inputs=[text_input], outputs=[story_output])

# Launch the combined interface
combined_interface.launch()