sagar007 commited on
Commit
c39dc38
·
verified ·
1 Parent(s): 2796a5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  from threading import Thread
6
  from PIL import Image
7
  import subprocess
 
8
 
9
  # Install flash-attention
10
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -46,6 +47,7 @@ vision_model = AutoModelForCausalLM.from_pretrained(
46
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
47
 
48
  # Helper functions
 
49
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
50
  conversation = [{"role": "system", "content": system_prompt}]
51
  for prompt, answer in history:
@@ -78,6 +80,7 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
78
  buffer += new_text
79
  yield buffer
80
 
 
81
  def process_vision_query(image, text_input):
82
  prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
83
  image = Image.fromarray(image).convert("RGB")
 
5
  from threading import Thread
6
  from PIL import Image
7
  import subprocess
8
+ import spaces # Add this import
9
 
10
  # Install flash-attention
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
47
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
48
 
49
  # Helper functions
50
+ @spaces.GPU # Add this decorator
51
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
52
  conversation = [{"role": "system", "content": system_prompt}]
53
  for prompt, answer in history:
 
80
  buffer += new_text
81
  yield buffer
82
 
83
+ @spaces.GPU # Add this decorator
84
  def process_vision_query(image, text_input):
85
  prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
86
  image = Image.fromarray(image).convert("RGB")