AdhamEhab commited on
Commit
e265f46
·
verified ·
1 Parent(s): f103a37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -93
app.py CHANGED
@@ -1,93 +1,97 @@
1
- import requests
2
- from PIL import Image
3
- import gradio as gr
4
- from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
5
- import torch
6
- import tempfile
7
-
8
- # Initialize the device
9
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
-
11
- # Load the image captioning model and processor
12
- image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning")
13
- image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device)
14
-
15
- # Load the story generation model and tokenizer
16
- story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen")
17
- story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen")
18
-
19
- generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer)
20
-
21
- # Define a function to generate caption from an image
22
- def generate_caption(image):
23
- try:
24
- # If image is a file object, extract the file path
25
- if isinstance(image, tempfile._TemporaryFileWrapper):
26
- image_path = image.name
27
- else:
28
- image_path = image
29
-
30
- # Load and preprocess the image
31
- image = Image.open(image_path)
32
- inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True)
33
-
34
- # Generate caption
35
- with torch.no_grad():
36
- caption_ids = image_model.generate(**inputs.to(device))
37
-
38
- # Decode the caption
39
- caption = image_processor.decode(caption_ids[0], skip_special_tokens=True)
40
- return caption
41
- except Exception as e:
42
- return f"An error occurred: {str(e)}"
43
-
44
- # Define a function to generate a story based on a prompt
45
- def generate_story(prompt):
46
- try:
47
- input_prompt = prompt
48
- story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text']
49
- return story
50
- except Exception as e:
51
- return f"An error occurred: {str(e)}"
52
-
53
- # Create Gradio interfaces
54
- image_caption_interface = gr.Interface(
55
- fn=generate_caption,
56
- inputs=gr.inputs.Image(type="filepath", label="Upload Image"),
57
- outputs="text",
58
- title="Image Captioning",
59
- description="Generate a caption for the provided image."
60
- )
61
-
62
- story_generation_interface = gr.Interface(
63
- fn=generate_story,
64
- inputs="text",
65
- outputs="text",
66
- title="Story Generation",
67
- description="Generate a story based on the provided prompt."
68
- )
69
-
70
- # Create Gradio interfaces
71
- with gr.Blocks(css="footer{display:none !important}") as combined_interface:
72
- gr.Markdown(
73
- """
74
- # The Seer's Legacy
75
- Crafting tales from the fabric of imagination.
76
- """)
77
- with gr.Row():
78
- with gr.Column():
79
- image_input = gr.inputs.Image(type="filepath", label="Upload Image")
80
- image_output = gr.Text(label="Image Caption")
81
- image_btn = gr.Button("Generate Image Caption")
82
-
83
- with gr.Column():
84
- text_input = gr.inputs.Textbox(label="Story Prompt")
85
- story_output = gr.Text(label="Generated Story")
86
- story_btn = gr.Button("Generate Story")
87
-
88
- image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output])
89
- story_btn.click(generate_story, inputs=[text_input], outputs=[story_output])
90
-
91
- # Launch the combined interface
92
- combined_interface.launch()
93
-
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
5
+ import torch
6
+ import tempfile
7
+ import os
8
+
9
+ os.system("pip uninstall -y gradio")
10
+ os.system("pip install gradio==2.6.4")
11
+
12
+ # Initialize the device
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ # Load the image captioning model and processor
16
+ image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning")
17
+ image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device)
18
+
19
+ # Load the story generation model and tokenizer
20
+ story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen")
21
+ story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen")
22
+
23
+ generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer)
24
+
25
+ # Define a function to generate caption from an image
26
+ def generate_caption(image):
27
+ try:
28
+ # If image is a file object, extract the file path
29
+ if isinstance(image, tempfile._TemporaryFileWrapper):
30
+ image_path = image.name
31
+ else:
32
+ image_path = image
33
+
34
+ # Load and preprocess the image
35
+ image = Image.open(image_path)
36
+ inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True)
37
+
38
+ # Generate caption
39
+ with torch.no_grad():
40
+ caption_ids = image_model.generate(**inputs.to(device))
41
+
42
+ # Decode the caption
43
+ caption = image_processor.decode(caption_ids[0], skip_special_tokens=True)
44
+ return caption
45
+ except Exception as e:
46
+ return f"An error occurred: {str(e)}"
47
+
48
+ # Define a function to generate a story based on a prompt
49
+ def generate_story(prompt):
50
+ try:
51
+ input_prompt = prompt
52
+ story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text']
53
+ return story
54
+ except Exception as e:
55
+ return f"An error occurred: {str(e)}"
56
+
57
+ # Create Gradio interfaces
58
+ image_caption_interface = gr.Interface(
59
+ fn=generate_caption,
60
+ inputs=gr.inputs.Image(type="filepath", label="Upload Image"),
61
+ outputs="text",
62
+ title="Image Captioning",
63
+ description="Generate a caption for the provided image."
64
+ )
65
+
66
+ story_generation_interface = gr.Interface(
67
+ fn=generate_story,
68
+ inputs="text",
69
+ outputs="text",
70
+ title="Story Generation",
71
+ description="Generate a story based on the provided prompt."
72
+ )
73
+
74
+ # Create Gradio interfaces
75
+ with gr.Blocks(css="footer{display:none !important}") as combined_interface:
76
+ gr.Markdown(
77
+ """
78
+ # The Seer's Legacy
79
+ Crafting tales from the fabric of imagination.
80
+ """)
81
+ with gr.Row():
82
+ with gr.Column():
83
+ image_input = gr.inputs.Image(type="filepath", label="Upload Image")
84
+ image_output = gr.Text(label="Image Caption")
85
+ image_btn = gr.Button("Generate Image Caption")
86
+
87
+ with gr.Column():
88
+ text_input = gr.inputs.Textbox(label="Story Prompt")
89
+ story_output = gr.Text(label="Generated Story")
90
+ story_btn = gr.Button("Generate Story")
91
+
92
+ image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output])
93
+ story_btn.click(generate_story, inputs=[text_input], outputs=[story_output])
94
+
95
+ # Launch the combined interface
96
+ combined_interface.launch()
97
+