awacke1 commited on
Commit
4b3ee30
ยท
verified ยท
1 Parent(s): 55c19e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -170
app.py CHANGED
@@ -1,178 +1,191 @@
1
- # app.py
2
- import gradio as gr
3
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
4
- from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
- import torch
6
- from PIL import Image
7
- import numpy as np
8
  import os
9
- import tempfile
10
- import moviepy.editor as mpe
11
- import nltk
12
- from pydub import AudioSegment
13
- import warnings
14
- import asyncio
15
- import edge_tts
16
- import random
17
- from openai import OpenAI
18
-
19
- warnings.filterwarnings("ignore", category=UserWarning)
20
-
21
- # Ensure NLTK data is downloaded
22
- nltk.download('punkt')
23
-
24
- # LLM Inference Class
25
- class LLMInferenceNode:
26
- def __init__(self):
27
- self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
28
- self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
29
-
30
- self.huggingface_client = OpenAI(
31
- base_url="https://api-inference.huggingface.co/v1/",
32
- api_key=self.huggingface_token,
33
- )
34
- self.sambanova_client = OpenAI(
35
- api_key=self.sambanova_api_key,
36
- base_url="https://api.sambanova.ai/v1",
37
- )
38
-
39
- def generate(self, input_text, long_talk=True, compress=False,
40
- compression_level="medium", poster=False, prompt_type="Short",
41
- provider="Hugging Face", model=None):
42
- try:
43
- # Define system message
44
- system_message = "You are a helpful assistant. Try your best to give the best response possible to the user."
45
-
46
- # Define base prompts based on type
47
- prompts = {
48
- "Short": """Create a brief, straightforward caption for this description, suitable for a text-to-image AI system.
49
- Focus on the main elements, key characters, and overall scene without elaborate details.""",
50
- "Long": """Create a detailed visually descriptive caption of this description for a text-to-image AI system.
51
- Include detailed visual descriptions, cinematography, and lighting setup."""
52
- }
53
-
54
- base_prompt = prompts.get(prompt_type, prompts["Short"])
55
- user_message = f"{base_prompt}\nDescription: {input_text}"
56
-
57
- # Generate with selected provider
58
- if provider == "Hugging Face":
59
- client = self.huggingface_client
60
- else:
61
- client = self.sambanova_client
62
-
63
- response = client.chat.completions.create(
64
- model=model or "meta-llama/Meta-Llama-3.1-70B-Instruct",
65
- max_tokens=1024,
66
- temperature=1.0,
67
- messages=[
68
- {"role": "system", "content": system_message},
69
- {"role": "user", "content": user_message},
70
- ]
71
- )
72
-
73
- return response.choices[0].message.content.strip()
74
-
75
- except Exception as e:
76
- print(f"An error occurred: {e}")
77
- return f"Error occurred while processing the request: {str(e)}"
78
-
79
- # Initialize models
80
- device = "cuda" if torch.cuda.is_available() else "cpu"
81
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
82
-
83
- # Story generator
84
- story_generator = pipeline(
85
- 'text-generation',
86
- model='gpt2-large',
87
- device=0 if device == 'cuda' else -1
88
- )
89
-
90
- # Stable Diffusion model
91
- sd_pipe = StableDiffusionPipeline.from_pretrained(
92
- "runwayml/stable-diffusion-v1-5",
93
- torch_dtype=torch_dtype
94
- ).to(device)
95
-
96
- # Text-to-Speech function using edge_tts
97
- async def _text2speech_async(text):
98
- communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural")
99
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
100
- tmp_path = tmp_file.name
101
- await communicate.save(tmp_path)
102
- return tmp_path
103
-
104
- def text2speech(text):
105
  try:
106
- output_path = asyncio.run(_text2speech_async(text))
107
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- print(f"Error in text2speech: {str(e)}")
110
- raise
111
-
112
- def generate_story(prompt):
113
- generated = story_generator(prompt, max_length=500, num_return_sequences=1)
114
- story = generated[0]['generated_text']
115
- return story
116
-
117
- def split_story_into_sentences(story):
118
- sentences = nltk.sent_tokenize(story)
119
- return sentences
120
-
121
- def generate_images(sentences):
122
- images = []
123
- for idx, sentence in enumerate(sentences):
124
- image = sd_pipe(sentence).images[0]
125
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png")
126
- image.save(temp_file.name)
127
- images.append(temp_file.name)
128
- return images
129
-
130
- def generate_audio(story_text):
131
- audio_path = text2speech(story_text)
132
- audio = AudioSegment.from_file(audio_path)
133
- total_duration = len(audio) / 1000
134
- return audio_path, total_duration
135
 
136
- def compute_sentence_durations(sentences, total_duration):
137
- total_words = sum(len(sentence.split()) for sentence in sentences)
138
- return [total_duration * (len(sentence.split()) / total_words) for sentence in sentences]
139
-
140
- def create_video(images, durations, audio_path):
141
- clips = [mpe.ImageClip(img).set_duration(dur) for img, dur in zip(images, durations)]
142
- video = mpe.concatenate_videoclips(clips, method='compose')
143
- audio = mpe.AudioFileClip(audio_path)
144
- video = video.set_audio(audio)
145
- output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
146
- video.write_videofile(output_path, fps=1, codec='libx264')
147
- return output_path
148
-
149
- def process_pipeline(prompt, progress=gr.Progress()):
150
  try:
151
- story = generate_story(prompt)
152
- sentences = split_story_into_sentences(story)
153
- images = generate_images(sentences)
154
- audio_path, total_duration = generate_audio(story)
155
- durations = compute_sentence_durations(sentences, total_duration)
156
- video_path = create_video(images, durations, audio_path)
157
- return video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  except Exception as e:
159
- print(f"Error in process_pipeline: {str(e)}")
160
- raise gr.Error(f"An error occurred: {str(e)}")
161
-
162
- # Gradio Interface
163
- title = """<h1 align="center">AI Story Video Generator ๐ŸŽฅ</h1>
164
- <p align="center">Generate a story from a prompt, create images for each sentence, and produce a video with narration!</p>"""
165
 
166
- with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
167
- gr.HTML(title)
168
-
169
- with gr.Row():
170
- with gr.Column():
171
- prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
172
- generate_button = gr.Button("Generate Video")
173
- with gr.Column():
174
- video_output = gr.Video(label="Generated Video")
175
-
176
- generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
177
 
178
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ import base64
3
+ import json
 
 
 
 
4
  import os
5
+ import pandas as pd
6
+ import pytz
7
+ import re
8
+ import streamlit as st
9
+ from datetime import datetime
10
+ from gradio_client import Client
11
+ from azure.cosmos import CosmosClient, exceptions
12
+
13
+ # App Configuration
14
+ title = "๐Ÿค– ArXiv and Claude AI Assistant"
15
+ st.set_page_config(page_title=title, layout="wide")
16
+
17
+ # Cosmos DB configuration
18
+ ENDPOINT = "https://acae-afd.documents.azure.com:443/"
19
+ Key = os.environ.get("Key")
20
+ DATABASE_NAME = os.environ.get("COSMOS_DATABASE_NAME")
21
+ CONTAINER_NAME = os.environ.get("COSMOS_CONTAINER_NAME")
22
+
23
+ # Initialize Anthropic client
24
+ anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
25
+
26
+ # Initialize session state
27
+ if "chat_history" not in st.session_state:
28
+ st.session_state.chat_history = []
29
+
30
+ def generate_filename(prompt, file_type):
31
+ """Generate a filename with timestamp and sanitized prompt"""
32
+ central = pytz.timezone('US/Central')
33
+ safe_date_time = datetime.now(central).strftime("%m%d_%H%M")
34
+ safe_prompt = re.sub(r'\W+', '', prompt)[:90]
35
+ return f"{safe_date_time}{safe_prompt}.{file_type}"
36
+
37
+ def create_file(filename, prompt, response, should_save=True):
38
+ """Create and save a file with prompt and response"""
39
+ if not should_save:
40
+ return
41
+ with open(filename, 'w', encoding='utf-8') as file:
42
+ file.write(f"Prompt:\n{prompt}\n\nResponse:\n{response}")
43
+
44
+ def save_to_cosmos_db(container, query, response1, response2):
45
+ """Save interaction to Cosmos DB"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
+ if container:
48
+ timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
49
+ record = {
50
+ "id": timestamp,
51
+ "name": timestamp,
52
+ "query": query,
53
+ "response1": response1,
54
+ "response2": response2,
55
+ "timestamp": datetime.utcnow().isoformat(),
56
+ "type": "ai_response",
57
+ "version": "1.0"
58
+ }
59
+ container.create_item(body=record)
60
+ st.success(f"Record saved to Cosmos DB with ID: {record['id']}")
61
  except Exception as e:
62
+ st.error(f"Error saving to Cosmos DB: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def search_arxiv(query):
65
+ """Search ArXiv using Gradio client"""
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
68
+
69
+ # Get response from Mixtral model
70
+ result_mixtral = client.predict(
71
+ query,
72
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
73
+ True,
74
+ api_name="/ask_llm"
75
+ )
76
+
77
+ # Get response from Mistral model
78
+ result_mistral = client.predict(
79
+ query,
80
+ "mistralai/Mistral-7B-Instruct-v0.2",
81
+ True,
82
+ api_name="/ask_llm"
83
+ )
84
+
85
+ # Get RAG-enhanced response
86
+ result_rag = client.predict(
87
+ query,
88
+ 10, # llm_results_use
89
+ "Semantic Search",
90
+ "mistralai/Mistral-7B-Instruct-v0.2",
91
+ api_name="/update_with_rag_md"
92
+ )
93
+
94
+ return result_mixtral, result_mistral, result_rag
95
  except Exception as e:
96
+ st.error(f"Error searching ArXiv: {str(e)}")
97
+ return None, None, None
 
 
 
 
98
 
99
+ def main():
100
+ st.title(title)
 
 
 
 
 
 
 
 
 
101
 
102
+ # Initialize Cosmos DB client if key is available
103
+ if Key:
104
+ cosmos_client = CosmosClient(ENDPOINT, credential=Key)
105
+ try:
106
+ database = cosmos_client.get_database_client(DATABASE_NAME)
107
+ container = database.get_container_client(CONTAINER_NAME)
108
+ except Exception as e:
109
+ st.error(f"Error connecting to Cosmos DB: {str(e)}")
110
+ container = None
111
+ else:
112
+ st.warning("Cosmos DB Key not found in environment variables")
113
+ container = None
114
+
115
+ # Create tabs for different functionalities
116
+ arxiv_tab, claude_tab, history_tab = st.tabs(["ArXiv Search", "Chat with Claude", "History"])
117
+
118
+ with arxiv_tab:
119
+ st.header("๐Ÿ” ArXiv Search")
120
+ arxiv_query = st.text_area("Enter your research query:", height=100)
121
+ if st.button("Search ArXiv"):
122
+ if arxiv_query:
123
+ with st.spinner("Searching ArXiv..."):
124
+ result_mixtral, result_mistral, result_rag = search_arxiv(arxiv_query)
125
+
126
+ if result_mixtral:
127
+ st.subheader("Mixtral Model Response")
128
+ st.markdown(result_mixtral)
129
+
130
+ st.subheader("Mistral Model Response")
131
+ st.markdown(result_mistral)
132
+
133
+ st.subheader("RAG-Enhanced Response")
134
+ if isinstance(result_rag, (list, tuple)) and len(result_rag) > 0:
135
+ st.markdown(result_rag[0])
136
+ if len(result_rag) > 1:
137
+ st.markdown(result_rag[1])
138
+
139
+ # Save results
140
+ filename = generate_filename(arxiv_query, "md")
141
+ create_file(filename, arxiv_query, f"{result_mixtral}\n\n{result_mistral}")
142
+
143
+ if container:
144
+ save_to_cosmos_db(container, arxiv_query, result_mixtral, result_mistral)
145
+
146
+ with claude_tab:
147
+ st.header("๐Ÿ’ฌ Chat with Claude")
148
+ user_input = st.text_area("Your message:", height=100)
149
+ if st.button("Send"):
150
+ if user_input:
151
+ with st.spinner("Claude is thinking..."):
152
+ try:
153
+ response = anthropic_client.messages.create(
154
+ model="claude-3-sonnet-20240229",
155
+ max_tokens=1000,
156
+ messages=[{"role": "user", "content": user_input}]
157
+ )
158
+
159
+ claude_response = response.content[0].text
160
+ st.markdown("### Claude's Response:")
161
+ st.markdown(claude_response)
162
+
163
+ # Save chat history
164
+ st.session_state.chat_history.append({
165
+ "user": user_input,
166
+ "claude": claude_response,
167
+ "timestamp": datetime.now().isoformat()
168
+ })
169
+
170
+ # Save to file
171
+ filename = generate_filename(user_input, "md")
172
+ create_file(filename, user_input, claude_response)
173
+
174
+ # Save to Cosmos DB
175
+ if container:
176
+ save_to_cosmos_db(container, user_input, claude_response, "")
177
+
178
+ except Exception as e:
179
+ st.error(f"Error communicating with Claude: {str(e)}")
180
+
181
+ with history_tab:
182
+ st.header("๐Ÿ“œ Chat History")
183
+ for chat in reversed(st.session_state.chat_history):
184
+ with st.expander(f"Conversation from {chat.get('timestamp', 'Unknown time')}"):
185
+ st.markdown("**Your message:**")
186
+ st.markdown(chat["user"])
187
+ st.markdown("**Claude's response:**")
188
+ st.markdown(chat["claude"])
189
+
190
+ if __name__ == "__main__":
191
+ main()