OmkarThawakar commited on
Commit
582ec29
·
1 Parent(s): 3cf14d0

added zerogpu support

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -15,6 +15,7 @@ from src.data.transforms import transform_test
15
 
16
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
17
  import gradio as gr
 
18
 
19
  from langchain_core.output_parsers import StrOutputParser
20
  from langchain_core.prompts import ChatPromptTemplate
@@ -26,7 +27,7 @@ GROQ_API_KEY = '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'
26
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
27
 
28
  # Initialize LLM
29
- llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2)
30
 
31
  # QA system prompt and chain
32
  qa_system_prompt = """
@@ -179,9 +180,10 @@ custom_css = """
179
  }
180
  """
181
 
182
-
183
  def respond_to_user(image, message):
184
  # Process the image and message here
 
185
  chat = Chat(model,transform,df,tar_img_feats, device)
186
  chat.encode_image(image)
187
  data = chat.ask()
 
15
 
16
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
17
  import gradio as gr
18
+ import spaces
19
 
20
  from langchain_core.output_parsers import StrOutputParser
21
  from langchain_core.prompts import ChatPromptTemplate
 
27
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
28
 
29
  # Initialize LLM
30
+ llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2)
31
 
32
  # QA system prompt and chain
33
  qa_system_prompt = """
 
180
  }
181
  """
182
 
183
+ @spaces.GPU
184
  def respond_to_user(image, message):
185
  # Process the image and message here
186
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
  chat = Chat(model,transform,df,tar_img_feats, device)
188
  chat.encode_image(image)
189
  data = chat.ask()