pratham0011 commited on
Commit
dc4a1e0
·
verified ·
1 Parent(s): e1834c5

Upload 8 files

Browse files
Files changed (3) hide show
  1. services/qwen.py +18 -16
  2. services/search.py +1 -1
  3. services/whisper.py +20 -44
services/qwen.py CHANGED
@@ -2,10 +2,10 @@ import logging
2
  from typing import List, Dict, Optional, Tuple
3
 
4
  import torch
5
- from transformers import pipeline
6
- from transformers import pipeline
7
 
8
- from config.config import token, device, SYSTEM_PROMPT
9
  from services.whisper import generate_speech, transcribe
10
  from services.search import WebSearcher
11
 
@@ -19,13 +19,12 @@ model_kwargs = {
19
  "torch_dtype": torch.float32,
20
  'use_cache': True
21
  }
22
- client = pipeline(
23
- "text-generation",
24
  model="Qwen/Qwen2.5-0.5B-Instruct",
25
- token=token,
26
- trust_remote_code=True,
27
- device=device,
28
- model_kwargs=model_kwargs
29
  )
30
 
31
  async def respond(
@@ -65,24 +64,27 @@ async def respond(
65
  if results:
66
  search_context = "Based on search results:\n"
67
  for result in results:
68
- snippet = result['content'][:500].strip()
69
  search_context += f"{snippet}\n"
70
  prompt = prompt.replace(SYSTEM_PROMPT, f"{SYSTEM_PROMPT}\n{search_context}")
71
 
72
  # Generate response
73
- reply = client(
74
  prompt,
75
- max_new_tokens=400,
76
  do_sample=True,
77
  temperature=0.7,
78
  top_p=0.9,
79
- num_return_sequences=1
80
  )
81
 
82
  # Extract and clean assistant response
83
- assistant_response = reply[0]['generated_text']
84
- assistant_response = assistant_response.split("<|im_start|>assistant\n")[-1]
85
- assistant_response = assistant_response.split("<|im_end|>")[0].strip()
 
 
 
86
 
87
  # Convert response to speech
88
  audio_path = await generate_speech(assistant_response)
 
2
  from typing import List, Dict, Optional, Tuple
3
 
4
  import torch
5
+ # from transformers import pipeline
6
+ from huggingface_hub import InferenceClient
7
 
8
+ from config.config import token, SYSTEM_PROMPT
9
  from services.whisper import generate_speech, transcribe
10
  from services.search import WebSearcher
11
 
 
19
  "torch_dtype": torch.float32,
20
  'use_cache': True
21
  }
22
+ client = InferenceClient(
 
23
  model="Qwen/Qwen2.5-0.5B-Instruct",
24
+ token=token
25
+ # trust_remote_code=True,
26
+ # device=device,
27
+ # model_kwargs=model_kwargs
28
  )
29
 
30
  async def respond(
 
64
  if results:
65
  search_context = "Based on search results:\n"
66
  for result in results:
67
+ snippet = result['content'][:5000].strip()
68
  search_context += f"{snippet}\n"
69
  prompt = prompt.replace(SYSTEM_PROMPT, f"{SYSTEM_PROMPT}\n{search_context}")
70
 
71
  # Generate response
72
+ reply = client.text_generation(
73
  prompt,
74
+ max_new_tokens=300,
75
  do_sample=True,
76
  temperature=0.7,
77
  top_p=0.9,
78
+ return_full_text=False
79
  )
80
 
81
  # Extract and clean assistant response
82
+ assistant_response = reply # Reply is already the generated text string
83
+ if "<|im_start|>assistant\n" in assistant_response:
84
+ assistant_response = assistant_response.split("<|im_start|>assistant\n")[-1]
85
+ if "<|im_end|>" in assistant_response:
86
+ assistant_response = assistant_response.split("<|im_end|>")[0]
87
+ assistant_response = assistant_response.strip()
88
 
89
  # Convert response to speech
90
  audio_path = await generate_speech(assistant_response)
services/search.py CHANGED
@@ -40,7 +40,7 @@ class WebSearcher:
40
  search_url,
41
  headers=self.headers,
42
  params=params,
43
- timeout=10,
44
  verify=False
45
  )
46
  response.raise_for_status()
 
40
  search_url,
41
  headers=self.headers,
42
  params=params,
43
+ timeout=3,
44
  verify=False
45
  )
46
  response.raise_for_status()
services/whisper.py CHANGED
@@ -1,29 +1,19 @@
1
  import os
2
  import tempfile
3
  import logging
 
4
  from typing import Optional
5
 
6
- import torch
7
- import librosa
8
  import edge_tts
9
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
 
11
- from config.config import VOICE, FALLBACK_VOICES
12
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
  # Whisper model for speech to text
17
- processor = WhisperProcessor.from_pretrained(
18
- "openai/whisper-tiny",
19
- local_files_only=False
20
- )
21
- model = WhisperForConditionalGeneration.from_pretrained(
22
- "openai/whisper-tiny",
23
- local_files_only=False,
24
- low_cpu_mem_usage=True,
25
- torch_dtype=torch.float32,
26
- ).to("cpu")
27
 
28
  # Voice selection handling
29
  async def get_valid_voice() -> str:
@@ -59,34 +49,20 @@ async def generate_speech(text: str) -> Optional[str]:
59
 
60
  # Speech-to-text using Whisper
61
  async def transcribe(audio_file: str) -> str:
62
- audio, sr = librosa.load(
63
- audio_file,
64
- sr=16000,
65
- mono=True,
66
- duration=30
67
- )
68
-
69
- inputs = processor(
70
- audio,
71
- sampling_rate=sr,
72
- return_tensors="pt",
73
- return_attention_mask=True
74
- ).to(model.device)
75
-
76
- with torch.no_grad():
77
- generated_ids = model.generate(
78
- input_features=inputs.input_features,
79
- attention_mask=inputs.attention_mask,
80
- language="en",
81
- task="transcribe",
82
- max_length=448,
83
- temperature=0.0
84
- )
85
 
86
- transcription = processor.batch_decode(
87
- generated_ids,
88
- skip_special_tokens=True
89
- )[0].strip()
90
-
91
- logger.info(f"Transcribed text: {transcription}")
92
- return transcription
 
 
 
 
 
 
 
1
  import os
2
  import tempfile
3
  import logging
4
+ import requests
5
  from typing import Optional
6
 
 
 
7
  import edge_tts
 
8
 
9
+ from config.config import VOICE, FALLBACK_VOICES, token
10
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  # Whisper model for speech to text
15
+ API_URL = "https://api-inference.huggingface.co/models/openai/whisper-tiny"
16
+ headers = {"Authorization": f"Bearer {token}"}
 
 
 
 
 
 
 
 
17
 
18
  # Voice selection handling
19
  async def get_valid_voice() -> str:
 
49
 
50
  # Speech-to-text using Whisper
51
  async def transcribe(audio_file: str) -> str:
52
+ try:
53
+ with open(audio_file, "rb") as f:
54
+ data = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ response = requests.post(API_URL, headers=headers, data=data)
57
+ result = response.json()
58
+
59
+ if "text" in result:
60
+ transcription = result["text"].strip()
61
+ logger.info(f"Transcribed text: {transcription}")
62
+ return transcription
63
+ else:
64
+ raise ValueError("No transcription in response")
65
+
66
+ except Exception as e:
67
+ logger.error(f"Transcription error: {str(e)}")
68
+ raise RuntimeError(f"Failed to transcribe audio: {str(e)}")