pratham0011 commited on
Commit
d89ceaa
·
verified ·
1 Parent(s): d3b6224

Upload 8 files

Browse files
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import gradio as gr
4
+
5
+ from services.qwen import respond
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Track conversation state
11
+ conversation_history = []
12
+
13
+ def clear_conversation():
14
+ global conversation_history
15
+ conversation_history = []
16
+ return [],None
17
+
18
+ def sync_respond(audio, text_input, do_search, history):
19
+ if not audio and not text_input:
20
+ return None, history
21
+
22
+ logger.info(f"Processing request with search enabled: {do_search}")
23
+ result = asyncio.run(respond(audio, text_input, do_search, history))
24
+ audio_path, response_text = result
25
+
26
+ if audio:
27
+ user_message = {"role": "user", "content": "Voice message"}
28
+ else:
29
+ user_message = {"role": "user", "content": text_input}
30
+
31
+ assistant_message = {"role": "assistant", "content": response_text}
32
+ history.extend([user_message, assistant_message])
33
+
34
+ return audio_path, history
35
+
36
+ # Build Gradio interface
37
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
38
+ gr.Markdown(
39
+ """
40
+ <div style="text-align: center; margin-bottom: 1rem;">
41
+ <h1 style="font-weight: bold;">ConversAI: AI Voice & Chat Assistant</h1>
42
+ </div>
43
+ """,
44
+ show_label=False
45
+ )
46
+
47
+ # Input components (left column)
48
+ with gr.Row():
49
+ with gr.Column(scale=2):
50
+ audio_input = gr.Audio(
51
+ label="Your Voice Input",
52
+ type="filepath",
53
+ sources=["microphone"]
54
+ )
55
+ text_input = gr.Textbox(
56
+ label="Or Type Your Message",
57
+ placeholder="Type here..."
58
+ )
59
+ search_checkbox = gr.Checkbox(
60
+ label="Enable web search",
61
+ value=False
62
+ )
63
+ clear_btn = gr.Button("Clear Chat")
64
+
65
+ # Output components (right column)
66
+ with gr.Column(scale=3):
67
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
68
+ audio_output = gr.Audio(
69
+ label="AI Voice Response",
70
+ type="filepath",
71
+ autoplay=True
72
+ )
73
+
74
+ # Define input event handlers
75
+ input_events = [
76
+ audio_input.change(
77
+ fn=sync_respond,
78
+ inputs=[audio_input, text_input,search_checkbox, chatbot],
79
+ outputs=[audio_output, chatbot]
80
+ ),
81
+ text_input.submit(
82
+ fn=sync_respond,
83
+ inputs=[audio_input, text_input, search_checkbox, chatbot],
84
+ outputs=[audio_output, chatbot]
85
+ )
86
+ ]
87
+
88
+ # Clear chat button handler
89
+ clear_btn.click(
90
+ fn=clear_conversation,
91
+ outputs=[chatbot, audio_output]
92
+ )
93
+
94
+ # Start server
95
+ if __name__ == "__main__":
96
+ interface.launch(
97
+ server_name="0.0.0.0",
98
+ server_port=7860,
99
+ debug=True
100
+ )
config/__init__.py ADDED
File without changes
config/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from dotenv import load_dotenv
4
+
5
+ # Configure logging
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+ token = os.getenv("hf_key")
12
+
13
+ # Set compute device (cpu/cuda)
14
+ device = "cpu"
15
+ logger.info(f"Device set to use {device}")
16
+
17
+ # AI Assistant Configuration
18
+ SYSTEM_PROMPT = """You are ConversAI, a helpful AI assistant who remembers conversation history. Keep responses clear, friendly and natural. Always refer to previous context when responding."""
19
+
20
+ # Text-to-Speech Voice Settings (primary/backup)
21
+ VOICE = "en-US-JennyNeural"
22
+ FALLBACK_VOICES = ["en-US-ChristopherNeural", "en-US-EricNeural"]
23
+
24
+ # Audio Output Configuration
25
+ OUTPUT_FORMAT = "audio-24khz-48kbit-mono-mp3"
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ edge-tts
3
+ numpy
4
+ soxr
5
+ pydub
6
+ torch
7
+ sentencepiece
8
+ onnxruntime
9
+ huggingface-hub
10
+ python-dotenv
11
+ asyncio
services/__init__.py ADDED
File without changes
services/qwen.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Optional, Tuple
3
+
4
+ import torch
5
+ from transformers import pipeline
6
+ from transformers import pipeline
7
+
8
+ from config.config import token, device, SYSTEM_PROMPT
9
+ from services.whisper import generate_speech, transcribe
10
+ from services.search import WebSearcher
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ searcher = WebSearcher()
15
+
16
+ # Qwen Configuration
17
+ model_kwargs = {
18
+ "low_cpu_mem_usage": True,
19
+ "torch_dtype": torch.float32,
20
+ 'use_cache': True
21
+ }
22
+ client = pipeline(
23
+ "text-generation",
24
+ model="Qwen/Qwen2.5-0.5B-Instruct",
25
+ token=token,
26
+ trust_remote_code=True,
27
+ device=device,
28
+ model_kwargs=model_kwargs
29
+ )
30
+
31
+ async def respond(
32
+ audio: Optional[str] = None,
33
+ text: Optional[str] = None,
34
+ do_search: bool = False,
35
+ history: List[Dict] = None
36
+ ) -> Tuple[Optional[str], str]:
37
+ try:
38
+ if text:
39
+ user_text = text.strip()
40
+ elif audio:
41
+ user_text = await transcribe(audio)
42
+ else:
43
+ return None, "No input provided"
44
+
45
+ # Build conversation context
46
+ messages = []
47
+ messages.append({"role": "system", "content": SYSTEM_PROMPT})
48
+
49
+ if history:
50
+ messages.extend(history)
51
+
52
+ # Format message history for Qwen
53
+ prompt = ""
54
+ for msg in messages:
55
+ role = msg["role"]
56
+ content = msg["content"]
57
+ prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
58
+
59
+ # Add current user message
60
+ prompt += f"<|im_start|>user\n{user_text}<|im_end|>\n<|im_start|>assistant\n"
61
+
62
+ # Add web-search context if enabled
63
+ if do_search:
64
+ results = searcher.search(user_text)
65
+ if results:
66
+ search_context = "Based on search results:\n"
67
+ for result in results:
68
+ snippet = result['content'][:500].strip()
69
+ search_context += f"{snippet}\n"
70
+ prompt = prompt.replace(SYSTEM_PROMPT, f"{SYSTEM_PROMPT}\n{search_context}")
71
+
72
+ # Generate response
73
+ reply = client(
74
+ prompt,
75
+ max_new_tokens=400,
76
+ do_sample=True,
77
+ temperature=0.7,
78
+ top_p=0.9,
79
+ num_return_sequences=1
80
+ )
81
+
82
+ # Extract and clean assistant response
83
+ assistant_response = reply[0]['generated_text']
84
+ assistant_response = assistant_response.split("<|im_start|>assistant\n")[-1]
85
+ assistant_response = assistant_response.split("<|im_end|>")[0].strip()
86
+
87
+ # Convert response to speech
88
+ audio_path = await generate_speech(assistant_response)
89
+ return audio_path, assistant_response
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error in respond: {str(e)}")
93
+ return None, "Sorry, I encountered an error"
services/search.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict
3
+
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from urllib3.exceptions import InsecureRequestWarning
7
+
8
+ # Disable SSL warnings for requests
9
+ requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class WebSearcher:
14
+ def __init__(self):
15
+ self.headers = {
16
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
17
+ }
18
+
19
+ def extract_text(self, html_content: str) -> str:
20
+ soup = BeautifulSoup(html_content, 'html.parser')
21
+ # Remove unwanted elements
22
+ for element in soup(['script', 'style', 'nav', 'header', 'footer', 'iframe']):
23
+ element.decompose()
24
+ text = ' '.join(soup.stripped_strings)
25
+ return text[:8000] # Limit text length
26
+
27
+ def search(self, query: str, max_results: int = 3) -> List[Dict]:
28
+ results = []
29
+ try:
30
+ with requests.Session() as session:
31
+ # Google search parameters
32
+ search_url = "https://www.google.com/search"
33
+ params = {
34
+ "q": query,
35
+ "num": max_results,
36
+ "hl": "en"
37
+ }
38
+
39
+ response = session.get(
40
+ search_url,
41
+ headers=self.headers,
42
+ params=params,
43
+ timeout=10,
44
+ verify=False
45
+ )
46
+ response.raise_for_status()
47
+
48
+ # Parse search results
49
+ soup = BeautifulSoup(response.text, 'html.parser')
50
+ search_results = soup.select('div.g')
51
+
52
+ for result in search_results[:max_results]:
53
+ link = result.find('a')
54
+ if not link:
55
+ continue
56
+
57
+ url = link.get('href', '')
58
+ if not url.startswith('http'):
59
+ continue
60
+
61
+ try:
62
+ # Fetch webpage content
63
+ page_response = session.get(
64
+ url,
65
+ headers=self.headers,
66
+ timeout=5,
67
+ verify=False
68
+ )
69
+ page_response.raise_for_status()
70
+
71
+ content = self.extract_text(page_response.text)
72
+ results.append({
73
+ "url": url,
74
+ "content": content
75
+ })
76
+ logger.info(f"Successfully fetched content from {url}")
77
+
78
+ except Exception as e:
79
+ logger.warning(f"Failed to fetch {url}: {str(e)}")
80
+ continue
81
+
82
+ except Exception as e:
83
+ logger.error(f"Search failed: {str(e)}")
84
+
85
+ return results[:max_results]
services/whisper.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import librosa
8
+ import edge_tts
9
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
+
11
+ from config.config import VOICE, FALLBACK_VOICES
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Whisper model for speech to text
17
+ processor = WhisperProcessor.from_pretrained(
18
+ "openai/whisper-tiny",
19
+ local_files_only=False
20
+ )
21
+ model = WhisperForConditionalGeneration.from_pretrained(
22
+ "openai/whisper-tiny",
23
+ local_files_only=False,
24
+ low_cpu_mem_usage=True,
25
+ torch_dtype=torch.float32,
26
+ ).to("cpu")
27
+
28
+ # Voice selection handling
29
+ async def get_valid_voice() -> str:
30
+ available_voices = await edge_tts.list_voices()
31
+ voice_names = [VOICE] + FALLBACK_VOICES
32
+
33
+ available_voice_names = {v["ShortName"] for v in available_voices}
34
+ for voice in voice_names:
35
+ if voice in available_voice_names:
36
+ return voice
37
+
38
+ raise RuntimeError("No valid voice found")
39
+
40
+ # Text-to-speech conversion using Edge TTS
41
+ async def generate_speech(text: str) -> Optional[str]:
42
+ if not text or not isinstance(text, str):
43
+ raise ValueError("Invalid text input")
44
+
45
+ voice = await get_valid_voice()
46
+ logger.info(f"Using voice: {voice}")
47
+
48
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
49
+ tmp_path = tmp_file.name
50
+
51
+ communicate = edge_tts.Communicate(text, voice)
52
+ await communicate.save(tmp_path)
53
+
54
+ if not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0:
55
+ raise RuntimeError("Speech file empty or not created")
56
+
57
+ logger.info(f"Speech generated successfully: {tmp_path}")
58
+ return tmp_path
59
+
60
+ # Speech-to-text using Whisper
61
+ async def transcribe(audio_file: str) -> str:
62
+ audio, sr = librosa.load(
63
+ audio_file,
64
+ sr=16000,
65
+ mono=True,
66
+ duration=30
67
+ )
68
+
69
+ inputs = processor(
70
+ audio,
71
+ sampling_rate=sr,
72
+ return_tensors="pt",
73
+ return_attention_mask=True
74
+ ).to(model.device)
75
+
76
+ with torch.no_grad():
77
+ generated_ids = model.generate(
78
+ input_features=inputs.input_features,
79
+ attention_mask=inputs.attention_mask,
80
+ language="en",
81
+ task="transcribe",
82
+ max_length=448,
83
+ temperature=0.0
84
+ )
85
+
86
+ transcription = processor.batch_decode(
87
+ generated_ids,
88
+ skip_special_tokens=True
89
+ )[0].strip()
90
+
91
+ logger.info(f"Transcribed text: {transcription}")
92
+ return transcription