HazlamiMalek commited on
Commit
de5c1c1
·
verified ·
1 Parent(s): d3a3de7

Update app,py

Browse files
Files changed (1) hide show
  1. app.py +53 -55
app.py CHANGED
@@ -1,59 +1,57 @@
1
-
2
  import streamlit as st
3
  from PIL import Image
4
- from gtts import gTTS
5
- import os
6
-
7
- # Load your LLaVA model and processor here
8
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
 
9
  import torch
10
-
11
- # Load the processor and model
12
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
13
- model = LlavaNextForConditionalGeneration.from_pretrained(
14
- "llava-hf/llava-v1.6-mistral-7b-hf",
15
- torch_dtype=torch.float16,
16
- low_cpu_mem_usage=True
17
- ).to("cuda:0")
18
-
19
- # Streamlit Interface
20
- st.title("Image-to-Audio Description Generator")
21
-
22
- # Upload an image
23
- uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
24
- if uploaded_image:
25
- # Load and preprocess the image
26
- image = Image.open(uploaded_image).convert("RGB")
27
- st.image(image, caption="Uploaded Image", use_column_width=True)
28
-
29
- # Define the conversation template
30
- conversation = [
31
- {
32
- "role": "user",
33
- "content": [
34
- {"type": "text", "text": "What is shown in this image?"},
35
- {"type": "image"},
36
- ],
37
- },
38
- ]
39
-
40
- # Prepare inputs
41
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
42
- inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")
43
-
44
- # Generate the description
45
- output = model.generate(**inputs, max_new_tokens=100, pad_token_id=processor.tokenizer.eos_token_id)
46
- description = processor.decode(output[0], skip_special_tokens=True)
47
-
48
- # Display the description
49
- st.write(f"Generated Description: {description}")
50
-
51
- # Convert description to audio
52
- tts = gTTS(description)
53
- audio_path = "output.mp3"
54
- tts.save(audio_path)
55
-
56
- # Play the audio
57
- audio_file = open(audio_path, "rb")
58
- audio_bytes = audio_file.read()
59
- st.audio(audio_bytes, format="audio/mp3")
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
 
 
 
 
3
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
4
+ #from gtts import gTTS
5
  import torch
6
+ import cProfile
7
+ import pstats
8
+ torch_dtype=torch.float32
9
+
10
+ # Profile your app
11
+ with cProfile.Profile() as pr:
12
+
13
+ st.title("Image-to-Audio Description Generator")
14
+
15
+ # Load the processor and model
16
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
17
+ model = LlavaNextForConditionalGeneration.from_pretrained(
18
+ "llava-hf/llava-v1.6-mistral-7b-hf",
19
+ torch_dtype=torch.float16,
20
+ low_cpu_mem_usage=True
21
+ ).to("cpu") # Use "cpu" instead of "cuda:0"
22
+
23
+ # File uploader
24
+ uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
25
+ if uploaded_image:
26
+ image = Image.open(uploaded_image).convert("RGB")
27
+ image = image.resize((336, 336)) # Ensure compatibility with the model
28
+ st.image(image, caption="Uploaded Image", use_container_width=True)
29
+
30
+ # Generate description
31
+ conversation = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {"type": "text", "text": "What is shown in this image?"},
36
+ {"type": "image"},
37
+ ],
38
+ },
39
+ ]
40
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
41
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to("cpu")
42
+ output = model.generate(**inputs, max_new_tokens=100, pad_token_id=processor.tokenizer.eos_token_id)
43
+ description = processor.decode(output[0], skip_special_tokens=True)
44
+ st.write(f"Generated Description: {description}")
45
+
46
+ # Convert description to audio
47
+ #tts = gTTS(description)
48
+ #audio_path = "output.mp3"
49
+ #tts.save(audio_path)
50
+
51
+ # Play audio
52
+ #st.audio(audio_path, format="audio/mp3")
53
+
54
+ # Print profiling stats
55
+ stats = pstats.Stats(pr)
56
+ stats.sort_stats(pstats.SortKey.TIME)
57
+ stats.print_stats()