kh-CHEUNG commited on
Commit
1ac2a58
·
verified ·
1 Parent(s): 7be6ab0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -7
app.py CHANGED
@@ -11,6 +11,8 @@ import os
11
  from transformers import pipeline
12
  from transformers.pipelines.audio_utils import ffmpeg_read
13
 
 
 
14
  from langchain.prompts import PromptTemplate
15
  from langchain_huggingface import HuggingFaceEmbeddings
16
  from langchain_community.vectorstores import FAISS
@@ -19,10 +21,20 @@ from langchain_text_splitters import SentenceTransformersTokenTextSplitter
19
 
20
  from PIL import Image
21
 
22
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
23
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
24
- model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
25
- model.to("cuda:0")
 
 
 
 
 
 
 
 
 
 
26
 
27
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
28
 
@@ -35,6 +47,7 @@ from huggingface_hub import InferenceClient
35
  """
36
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
37
  """
 
38
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
39
 
40
  device = 0 if torch.cuda.is_available() else "cpu"
@@ -47,7 +60,7 @@ asr_pl = pipeline(
47
  )
48
 
49
  application_title = "Enlight Innovations Limited -- Demo"
50
- application_description = "This demo is desgined to illustrate our basic ideas and feasibility in implementation."
51
 
52
  @spaces.GPU
53
  def respond(
@@ -69,8 +82,38 @@ def respond(
69
 
70
  messages.append({"role": "user", "content": message})
71
 
72
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
74
  for message in client.chat_completion(
75
  messages,
76
  max_tokens=max_tokens,
@@ -82,6 +125,7 @@ def respond(
82
 
83
  response += token
84
  yield response
 
85
 
86
  @spaces.GPU
87
  def transcribe(asr_inputs, task):
@@ -90,7 +134,7 @@ def transcribe(asr_inputs, task):
90
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
91
 
92
  text = asr_pl(asr_inputs, batch_size=ASR_BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
93
- return text
94
 
95
 
96
  """Gradio User Interface"""
 
11
  from transformers import pipeline
12
  from transformers.pipelines.audio_utils import ffmpeg_read
13
 
14
+ from sentence_transformers import SentenceTransformer
15
+
16
  from langchain.prompts import PromptTemplate
17
  from langchain_huggingface import HuggingFaceEmbeddings
18
  from langchain_community.vectorstores import FAISS
 
21
 
22
  from PIL import Image
23
 
24
+ # from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
25
+ # processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
26
+ # model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
28
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
29
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_id,
32
+ device_map="auto",
33
+ ).to("cuda:0")
34
+ terminators = [
35
+ tokenizer.eos_token_id,
36
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
37
+ ]
38
 
39
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
40
 
 
47
  """
48
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
49
  """
50
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
51
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
52
 
53
  device = 0 if torch.cuda.is_available() else "cpu"
 
60
  )
61
 
62
  application_title = "Enlight Innovations Limited -- Demo"
63
+ application_description = "This demo is designed to illustrate our basic ideas and feasibility in implementation."
64
 
65
  @spaces.GPU
66
  def respond(
 
82
 
83
  messages.append({"role": "user", "content": message})
84
 
85
+ input_ids = tokenizer.apply_chat_template(
86
+ messages,
87
+ add_generation_prompt=True,
88
+ return_tensors="pt"
89
+ ).to(model.device)
90
+
91
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
92
+
93
+ generate_kwargs = dict(
94
+ input_ids= input_ids,
95
+ streamer=streamer,
96
+ max_new_tokens=max_new_tokens,
97
+ do_sample=True,
98
+ temperature=temperature,
99
+ eos_token_id=terminators,
100
+ )
101
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
102
+ if temperature == 0:
103
+ generate_kwargs['do_sample'] = False
104
+
105
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
106
+ t.start()
107
+
108
+ outputs = []
109
+ for text in streamer:
110
+ outputs.append(text)
111
+ #print(outputs)
112
+ yield "".join(outputs)
113
 
114
+ """
115
+ response = ""
116
+
117
  for message in client.chat_completion(
118
  messages,
119
  max_tokens=max_tokens,
 
125
 
126
  response += token
127
  yield response
128
+ """
129
 
130
  @spaces.GPU
131
  def transcribe(asr_inputs, task):
 
134
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
135
 
136
  text = asr_pl(asr_inputs, batch_size=ASR_BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
137
+ return text.strip()
138
 
139
 
140
  """Gradio User Interface"""