Nicolai Berk commited on
Commit
9629f65
·
1 Parent(s): a92e9d3

Adjust GPU decorator

Browse files
Files changed (1) hide show
  1. app.py +3 -16
app.py CHANGED
@@ -8,16 +8,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  import os
9
  import spaces
10
 
11
- print("CUDA available:", torch.cuda.is_available())
12
-
13
- @spaces.GPU
14
- def claim_gpu():
15
- # Dummy function to make Spaces detect GPU usage
16
- pass
17
-
18
- claim_gpu()
19
-
20
-
21
  # Login automatically if HF_TOKEN is present
22
  hf_token = os.getenv("HF_TOKEN")
23
  if hf_token:
@@ -52,9 +42,10 @@ index.add(corpus_embeddings_np)
52
 
53
  # Generator (choose one: local HF model or OpenAI)
54
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
55
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="auto", torch_dtype=torch.float16)
56
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150)
57
 
 
58
  def rag_pipeline(query):
59
  # Embed query
60
  query_embedding = embedder.encode([query], convert_to_tensor=True, device='cpu').numpy()
@@ -69,11 +60,7 @@ def rag_pipeline(query):
69
  print("-", repr(doc))
70
 
71
  # # Rerank
72
- # rerank_pairs = [[str(query), str(doc)] for doc in retrieved_docs if isinstance(doc, str) and doc.strip()]
73
- # if not rerank_pairs:
74
- # return "No valid documents found to rerank."
75
- # scores = reranker.predict(rerank_pairs)
76
-
77
  # scores = reranker.predict(rerank_pairs)
78
  # reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
79
 
 
8
  import os
9
  import spaces
10
 
 
 
 
 
 
 
 
 
 
 
11
  # Login automatically if HF_TOKEN is present
12
  hf_token = os.getenv("HF_TOKEN")
13
  if hf_token:
 
42
 
43
  # Generator (choose one: local HF model or OpenAI)
44
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
45
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.float16)
46
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150)
47
 
48
+ @spaces.GPU
49
  def rag_pipeline(query):
50
  # Embed query
51
  query_embedding = embedder.encode([query], convert_to_tensor=True, device='cpu').numpy()
 
60
  print("-", repr(doc))
61
 
62
  # # Rerank
63
+ # rerank_pairs = [[str(query), str(doc)] for doc in retrieved_docs]
 
 
 
 
64
  # scores = reranker.predict(rerank_pairs)
65
  # reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
66