awacke1 commited on
Commit
fef6f0f
ยท
verified ยท
1 Parent(s): 4b3ee30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -169
app.py CHANGED
@@ -1,191 +1,204 @@
1
- import anthropic
2
- import base64
3
- import json
4
- import os
5
- import pandas as pd
6
- import pytz
7
- import re
8
- import streamlit as st
9
  from datetime import datetime
 
 
 
10
  from gradio_client import Client
11
- from azure.cosmos import CosmosClient, exceptions
12
-
13
- # App Configuration
14
- title = "๐Ÿค– ArXiv and Claude AI Assistant"
15
- st.set_page_config(page_title=title, layout="wide")
16
-
17
- # Cosmos DB configuration
18
- ENDPOINT = "https://acae-afd.documents.azure.com:443/"
19
- Key = os.environ.get("Key")
20
- DATABASE_NAME = os.environ.get("COSMOS_DATABASE_NAME")
21
- CONTAINER_NAME = os.environ.get("COSMOS_CONTAINER_NAME")
22
 
23
- # Initialize Anthropic client
24
- anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
25
 
26
- # Initialize session state
27
- if "chat_history" not in st.session_state:
28
- st.session_state.chat_history = []
29
 
30
- def generate_filename(prompt, file_type):
31
- """Generate a filename with timestamp and sanitized prompt"""
32
- central = pytz.timezone('US/Central')
33
- safe_date_time = datetime.now(central).strftime("%m%d_%H%M")
34
- safe_prompt = re.sub(r'\W+', '', prompt)[:90]
35
- return f"{safe_date_time}{safe_prompt}.{file_type}"
 
 
 
 
36
 
37
- def create_file(filename, prompt, response, should_save=True):
38
- """Create and save a file with prompt and response"""
39
- if not should_save:
40
- return
41
- with open(filename, 'w', encoding='utf-8') as file:
42
- file.write(f"Prompt:\n{prompt}\n\nResponse:\n{response}")
43
 
44
- def save_to_cosmos_db(container, query, response1, response2):
45
- """Save interaction to Cosmos DB"""
46
  try:
47
- if container:
48
- timestamp = datetime.utcnow().strftime('%Y%m%d%H%M%S%f')
49
- record = {
50
- "id": timestamp,
51
- "name": timestamp,
52
- "query": query,
53
- "response1": response1,
54
- "response2": response2,
55
- "timestamp": datetime.utcnow().isoformat(),
56
- "type": "ai_response",
57
- "version": "1.0"
58
- }
59
- container.create_item(body=record)
60
- st.success(f"Record saved to Cosmos DB with ID: {record['id']}")
61
  except Exception as e:
62
- st.error(f"Error saving to Cosmos DB: {str(e)}")
 
63
 
64
- def search_arxiv(query):
65
- """Search ArXiv using Gradio client"""
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
- client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
68
-
69
- # Get response from Mixtral model
70
- result_mixtral = client.predict(
71
- query,
72
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
73
  True,
74
  api_name="/ask_llm"
75
  )
76
-
77
- # Get response from Mistral model
78
- result_mistral = client.predict(
79
- query,
80
- "mistralai/Mistral-7B-Instruct-v0.2",
81
- True,
82
- api_name="/ask_llm"
83
- )
84
-
85
- # Get RAG-enhanced response
86
- result_rag = client.predict(
87
- query,
88
- 10, # llm_results_use
89
- "Semantic Search",
90
- "mistralai/Mistral-7B-Instruct-v0.2",
91
- api_name="/update_with_rag_md"
92
- )
93
-
94
- return result_mixtral, result_mistral, result_rag
95
  except Exception as e:
96
- st.error(f"Error searching ArXiv: {str(e)}")
97
- return None, None, None
98
-
99
- def main():
100
- st.title(title)
101
 
102
- # Initialize Cosmos DB client if key is available
103
- if Key:
104
- cosmos_client = CosmosClient(ENDPOINT, credential=Key)
105
- try:
106
- database = cosmos_client.get_database_client(DATABASE_NAME)
107
- container = database.get_container_client(CONTAINER_NAME)
108
- except Exception as e:
109
- st.error(f"Error connecting to Cosmos DB: {str(e)}")
110
- container = None
111
- else:
112
- st.warning("Cosmos DB Key not found in environment variables")
113
- container = None
 
 
114
 
115
- # Create tabs for different functionalities
116
- arxiv_tab, claude_tab, history_tab = st.tabs(["ArXiv Search", "Chat with Claude", "History"])
 
 
 
 
 
117
 
118
- with arxiv_tab:
119
- st.header("๐Ÿ” ArXiv Search")
120
- arxiv_query = st.text_area("Enter your research query:", height=100)
121
- if st.button("Search ArXiv"):
122
- if arxiv_query:
123
- with st.spinner("Searching ArXiv..."):
124
- result_mixtral, result_mistral, result_rag = search_arxiv(arxiv_query)
125
-
126
- if result_mixtral:
127
- st.subheader("Mixtral Model Response")
128
- st.markdown(result_mixtral)
129
-
130
- st.subheader("Mistral Model Response")
131
- st.markdown(result_mistral)
132
-
133
- st.subheader("RAG-Enhanced Response")
134
- if isinstance(result_rag, (list, tuple)) and len(result_rag) > 0:
135
- st.markdown(result_rag[0])
136
- if len(result_rag) > 1:
137
- st.markdown(result_rag[1])
138
-
139
- # Save results
140
- filename = generate_filename(arxiv_query, "md")
141
- create_file(filename, arxiv_query, f"{result_mixtral}\n\n{result_mistral}")
142
-
143
- if container:
144
- save_to_cosmos_db(container, arxiv_query, result_mixtral, result_mistral)
145
 
146
- with claude_tab:
147
- st.header("๐Ÿ’ฌ Chat with Claude")
148
- user_input = st.text_area("Your message:", height=100)
149
- if st.button("Send"):
150
- if user_input:
151
- with st.spinner("Claude is thinking..."):
152
- try:
153
- response = anthropic_client.messages.create(
154
- model="claude-3-sonnet-20240229",
155
- max_tokens=1000,
156
- messages=[{"role": "user", "content": user_input}]
157
- )
158
-
159
- claude_response = response.content[0].text
160
- st.markdown("### Claude's Response:")
161
- st.markdown(claude_response)
162
-
163
- # Save chat history
164
- st.session_state.chat_history.append({
165
- "user": user_input,
166
- "claude": claude_response,
167
- "timestamp": datetime.now().isoformat()
168
- })
169
-
170
- # Save to file
171
- filename = generate_filename(user_input, "md")
172
- create_file(filename, user_input, claude_response)
173
-
174
- # Save to Cosmos DB
175
- if container:
176
- save_to_cosmos_db(container, user_input, claude_response, "")
177
-
178
- except Exception as e:
179
- st.error(f"Error communicating with Claude: {str(e)}")
180
 
181
- with history_tab:
182
- st.header("๐Ÿ“œ Chat History")
183
- for chat in reversed(st.session_state.chat_history):
184
- with st.expander(f"Conversation from {chat.get('timestamp', 'Unknown time')}"):
185
- st.markdown("**Your message:**")
186
- st.markdown(chat["user"])
187
- st.markdown("**Claude's response:**")
188
- st.markdown(chat["claude"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  if __name__ == "__main__":
191
- main()
 
1
+ import gradio as gr
2
+ import random
3
+ import time
 
 
 
 
 
4
  from datetime import datetime
5
+ import tempfile
6
+ import os
7
+ from moviepy.editor import ImageClip, concatenate_videoclips
8
  from gradio_client import Client
9
+ from PIL import Image
10
+ import edge_tts
11
+ import asyncio
12
+ import warnings
13
+ import numpy as np
 
 
 
 
 
 
14
 
15
+ warnings.filterwarnings('ignore')
 
16
 
17
+ # Initialize the Gradio client for model access
18
+ client = Client("stabilityai/stable-diffusion-xl-base-1.0")
19
+ arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
20
 
21
+ STORY_GENRES = [
22
+ "Science Fiction",
23
+ "Fantasy",
24
+ "Mystery",
25
+ "Romance",
26
+ "Horror",
27
+ "Adventure",
28
+ "Historical Fiction",
29
+ "Comedy"
30
+ ]
31
 
32
+ STORY_STRUCTURES = {
33
+ "Three Act": "Setup (Introduction, Inciting Incident) -> Confrontation (Rising Action, Climax) -> Resolution (Falling Action, Conclusion)",
34
+ "Hero's Journey": "Ordinary World -> Call to Adventure -> Trials -> Transformation -> Return",
35
+ "Five Act": "Exposition -> Rising Action -> Climax -> Falling Action -> Resolution",
36
+ "Seven Point": "Hook -> Plot Turn 1 -> Pinch Point 1 -> Midpoint -> Pinch Point 2 -> Plot Turn 2 -> Resolution"
37
+ }
38
 
39
+ async def generate_speech(text, voice="en-US-AriaNeural"):
40
+ """Generate speech from text using edge-tts"""
41
  try:
42
+ communicate = edge_tts.Communicate(text, voice)
43
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
44
+ tmp_path = tmp_file.name
45
+ await communicate.save(tmp_path)
46
+ return tmp_path
 
 
 
 
 
 
 
 
 
47
  except Exception as e:
48
+ print(f"Error in text2speech: {str(e)}")
49
+ raise
50
 
51
+ def generate_story_prompt(base_prompt, genre, structure):
52
+ """Generate an expanded story prompt based on genre and structure"""
53
+ prompt = f"""Create a {genre} story using this concept: '{base_prompt}'
54
+ Follow this structure: {STORY_STRUCTURES[structure]}
55
+ Include vivid descriptions and sensory details.
56
+ Make it engaging and suitable for visualization.
57
+ Keep each scene description clear and detailed enough for image generation.
58
+ Limit the story to 5-7 key scenes.
59
+ """
60
+ return prompt
61
+
62
+ def generate_story(prompt, model_choice):
63
+ """Generate story using specified model"""
64
  try:
65
+ result = arxiv_client.predict(
66
+ prompt,
67
+ model_choice,
 
 
 
68
  True,
69
  api_name="/ask_llm"
70
  )
71
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
+ return f"Error generating story: {str(e)}"
 
 
 
 
74
 
75
+ def generate_image_from_text(text_prompt):
76
+ """Generate an image from text description"""
77
+ try:
78
+ result = client.predict(
79
+ text_prompt,
80
+ num_inference_steps=30,
81
+ guidance_scale=7.5,
82
+ width=768,
83
+ height=512,
84
+ api_name="/text2image"
85
+ )
86
+ return result
87
+ except Exception as e:
88
+ return None
89
 
90
+ def create_video_from_images(image_paths, durations):
91
+ """Create video from a series of images"""
92
+ clips = [ImageClip(img_path).set_duration(dur) for img_path, dur in zip(image_paths, durations)]
93
+ final_clip = concatenate_videoclips(clips, method="compose")
94
+ output_path = tempfile.mktemp(suffix=".mp4")
95
+ final_clip.write_videofile(output_path, fps=24)
96
+ return output_path
97
 
98
+ def process_story(story_text, num_scenes=5):
99
+ """Break story into scenes for visualization"""
100
+ sentences = story_text.split('.')
101
+ scenes = []
102
+ scene_length = max(1, len(sentences) // num_scenes)
103
+
104
+ for i in range(0, len(sentences), scene_length):
105
+ scene = '. '.join(sentences[i:i+scene_length]).strip()
106
+ if scene:
107
+ scenes.append(scene)
108
+
109
+ return scenes[:num_scenes]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def story_generator_interface(prompt, genre, structure, model_choice, num_scenes, words_per_scene):
112
+ """Main story generation and multimedia creation function"""
113
+
114
+ # Generate expanded prompt
115
+ story_prompt = generate_story_prompt(prompt, genre, structure)
116
+
117
+ # Generate story
118
+ story = generate_story(story_prompt, model_choice)
119
+
120
+ # Process story into scenes
121
+ scenes = process_story(story, num_scenes)
122
+
123
+ # Generate images for each scene
124
+ image_paths = []
125
+ for scene in scenes:
126
+ image = generate_image_from_text(scene)
127
+ if image is not None:
128
+ temp_path = tempfile.mktemp(suffix=".png")
129
+ Image.fromarray(image).save(temp_path)
130
+ image_paths.append(temp_path)
131
+
132
+ # Generate speech
133
+ audio_path = asyncio.run(generate_speech(story))
134
+
135
+ # Create video
136
+ scene_durations = [5.0] * len(image_paths) # 5 seconds per scene
137
+ video_path = create_video_from_images(image_paths, scene_durations)
138
+
139
+ return story, image_paths, audio_path, video_path
 
 
 
 
 
140
 
141
+ # Create Gradio interface
142
+ with gr.Blocks(title="AI Story Generator & Visualizer") as demo:
143
+ gr.Markdown("# ๐ŸŽญ AI Story Generator & Visualizer")
144
+
145
+ with gr.Row():
146
+ with gr.Column():
147
+ prompt_input = gr.Textbox(
148
+ label="Story Concept",
149
+ placeholder="Enter your story idea...",
150
+ lines=3
151
+ )
152
+ genre_input = gr.Dropdown(
153
+ label="Genre",
154
+ choices=STORY_GENRES,
155
+ value="Fantasy"
156
+ )
157
+ structure_input = gr.Dropdown(
158
+ label="Story Structure",
159
+ choices=list(STORY_STRUCTURES.keys()),
160
+ value="Three Act"
161
+ )
162
+ model_choice = gr.Dropdown(
163
+ label="Model",
164
+ choices=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"],
165
+ value="mistralai/Mixtral-8x7B-Instruct-v0.1"
166
+ )
167
+ num_scenes = gr.Slider(
168
+ label="Number of Scenes",
169
+ minimum=3,
170
+ maximum=7,
171
+ value=5,
172
+ step=1
173
+ )
174
+ words_per_scene = gr.Slider(
175
+ label="Words per Scene",
176
+ minimum=20,
177
+ maximum=100,
178
+ value=50,
179
+ step=10
180
+ )
181
+ generate_btn = gr.Button("Generate Story & Media")
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ story_output = gr.Textbox(
186
+ label="Generated Story",
187
+ lines=10,
188
+ readonly=True
189
+ )
190
+ with gr.Column():
191
+ gallery = gr.Gallery(label="Scene Visualizations")
192
+
193
+ with gr.Row():
194
+ audio_output = gr.Audio(label="Story Narration")
195
+ video_output = gr.Video(label="Story Video")
196
+
197
+ generate_btn.click(
198
+ fn=story_generator_interface,
199
+ inputs=[prompt_input, genre_input, structure_input, model_choice, num_scenes, words_per_scene],
200
+ outputs=[story_output, gallery, audio_output, video_output]
201
+ )
202
 
203
  if __name__ == "__main__":
204
+ demo.launch(reload=True)