Kishorekumar7 commited on
Commit
9c054fd
Β·
verified Β·
1 Parent(s): 7cc571c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -85
app.py CHANGED
@@ -1,98 +1,94 @@
1
  import streamlit as st
2
- import os
3
  import torch
4
- import soundfile as sf
5
- from groq import Groq
6
- from diffusers import AutoPipelineForText2Image
7
- from streamlit_webrtc import webrtc_streamer, AudioRecorder
8
-
9
- # Load API keys
10
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
11
- HF_API_KEY = os.getenv("HF_API_KEY")
12
-
13
- # Initialize Groq client
14
- client = Groq(api_key=GROQ_API_KEY)
15
-
16
- # Load image generation model
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- image_gen = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo").to(device)
19
-
20
- # Function to transcribe audio
21
- def transcribe(audio_path):
22
- with open(audio_path, "rb") as file:
23
- transcription = client.audio.transcriptions.create(
24
- file=(audio_path, file.read()),
25
- model="whisper-large-v3",
26
- language="ta",
27
- response_format="verbose_json"
28
- )
29
- return transcription["text"]
30
-
31
- # Function to translate Tamil to English
32
- def translate_text(tamil_text):
33
- response = client.chat.completions.create(
34
- model="gemma-7b-it",
35
- messages=[{"role": "user", "content": f"Translate this Tamil text to English: {tamil_text}"}]
36
- )
37
- return response.choices[0].message.content
38
 
39
- # Function to generate text
40
- def generate_text(prompt):
41
- response = client.chat.completions.create(
42
- model="deepseek-coder-r1-7b",
43
- messages=[{"role": "user", "content": f"Write a short story about: {prompt}"}]
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
- return response.choices[0].message.content
 
 
 
46
 
47
- # Function to generate an image
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def generate_image(prompt):
49
- img = image_gen(prompt=prompt).images[0]
50
- return img
51
-
52
- # Streamlit UI
53
- st.title("Tamil Speech to Image & Story Generator")
54
-
55
- # Choose input method
56
- input_method = st.radio("Choose Input Method:", ("Record Audio", "Upload Audio"))
57
-
58
- audio_path = None
59
-
60
- if input_method == "Record Audio":
61
- st.subheader("Record your Tamil speech")
62
- recorder = webrtc_streamer(key="record_audio", audio=True)
63
-
64
- if recorder.audio_receiver:
65
- audio_data = recorder.audio_receiver.get_frames() # Get recorded audio
66
- audio_path = "recorded_audio.wav"
67
- sf.write(audio_path, audio_data, 16000) # Save recorded audio
68
- elif input_method == "Upload Audio":
69
- uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"])
70
- if uploaded_file:
71
- audio_path = "uploaded_audio.wav"
72
- with open(audio_path, "wb") as f:
73
- f.write(uploaded_file.getbuffer())
74
-
75
- if st.button("Generate"):
76
- if not audio_path:
77
- st.error("Please provide an audio file.")
78
- st.stop()
79
-
80
- # Process audio
81
- tamil_text = transcribe(audio_path)
82
- english_text = translate_text(tamil_text)
83
- story = generate_text(english_text)
84
- image = generate_image(english_text)
85
-
86
- # Display results
87
- st.subheader("Tamil Transcription")
88
  st.write(tamil_text)
89
-
90
- st.subheader("English Translation")
 
 
 
91
  st.write(english_text)
92
 
93
- st.subheader("Generated Story")
 
 
 
94
  st.write(story)
95
 
96
- st.subheader("Generated Image")
 
 
97
  st.image(image, caption="Generated Image")
98
 
 
 
 
1
  import streamlit as st
2
+ import torchaudio
3
  import torch
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
5
+ from diffusers import StableDiffusionPipeline
6
+ from io import BytesIO
7
+ import tempfile
8
+ import os
9
+
10
+ st.set_page_config(page_title="Tamil Voice to Story & Image Generator", layout="wide")
11
+ st.title("🎀 Tamil Voice to Story & Image Generator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Load models only once
14
+ @st.cache_resource
15
+ def load_models():
16
+ # 1. Whisper small for speech recognition
17
+ whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
18
+
19
+ # 2. NLLB for Tamil to English translation
20
+ tokenizer_trans = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
21
+ model_trans = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
22
+
23
+ # 3. Tiny Story Generator
24
+ story_gen = pipeline("text-generation", model="sshleifer/tiny-gpt2", device=0 if torch.cuda.is_available() else -1)
25
+
26
+ # 4. Image Generator
27
+ image_pipe = StableDiffusionPipeline.from_pretrained(
28
+ "CompVis/stable-diffusion-v1-4",
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
30
  )
31
+ if torch.cuda.is_available():
32
+ image_pipe.to("cuda")
33
+
34
+ return whisper_pipe, tokenizer_trans, model_trans, story_gen, image_pipe
35
 
36
+ whisper_pipe, tokenizer_trans, model_trans, story_gen, image_pipe = load_models()
37
+
38
+ # Function: Translate Tamil to English
39
+ def translate_ta_to_en(text):
40
+ inputs = tokenizer_trans(text, return_tensors="pt", padding=True)
41
+ translated = model_trans.generate(**inputs, forced_bos_token_id=tokenizer_trans.lang_code_to_id["eng_Latn"])
42
+ return tokenizer_trans.batch_decode(translated, skip_special_tokens=True)[0]
43
+
44
+ # Function: Generate story
45
+ def generate_story(prompt):
46
+ story = story_gen(prompt, max_length=100, num_return_sequences=1)
47
+ return story[0]['generated_text']
48
+
49
+ # Function: Generate image
50
  def generate_image(prompt):
51
+ image = image_pipe(prompt).images[0]
52
+ return image
53
+
54
+ # Upload or Record
55
+ input_method = st.radio("Select Input Method", ["Upload Audio", "Record Live"])
56
+
57
+ if input_method == "Upload Audio":
58
+ audio_file = st.file_uploader("Upload Tamil Audio", type=["wav", "mp3", "m4a"])
59
+ else:
60
+ audio_bytes = st.audio("Record or Upload Audio Below", format='audio/wav')
61
+ audio_file = None
62
+ if audio_bytes:
63
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
64
+ tmpfile.write(audio_bytes.read())
65
+ audio_file = tmpfile.name
66
+
67
+ # Process Button
68
+ if st.button("Generate from Audio") and audio_file:
69
+ with st.spinner("πŸ”„ Transcribing Tamil audio..."):
70
+ result = whisper_pipe(audio_file)
71
+ tamil_text = result['text']
72
+
73
+ st.success("βœ… Tamil Transcription")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  st.write(tamil_text)
75
+
76
+ with st.spinner("🌐 Translating to English..."):
77
+ english_text = translate_ta_to_en(tamil_text)
78
+
79
+ st.success("βœ… English Translation")
80
  st.write(english_text)
81
 
82
+ with st.spinner("✍️ Generating Story..."):
83
+ story = generate_story(english_text)
84
+
85
+ st.success("βœ… Story Generated")
86
  st.write(story)
87
 
88
+ with st.spinner("🎨 Generating Image..."):
89
+ image = generate_image(english_text)
90
+
91
  st.image(image, caption="Generated Image")
92
 
93
+ elif st.button("Generate from Audio") and not audio_file:
94
+ st.warning("Please upload or record an audio file.")