Yoxas commited on
Commit
d1e3096
·
verified ·
1 Parent(s): 4761f6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -18,9 +18,9 @@ embeddings = torch.tensor(df['embedding'].tolist(), device=device)
18
  # Load the Sentence Transformer model
19
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
20
 
21
- # Load the LLaMA model for response generation
22
- llama_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
23
- llama_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(device)
24
 
25
  # Define the function to find the most relevant document
26
  @spaces.GPU(duration=120)
@@ -35,9 +35,9 @@ def retrieve_relevant_doc(query):
35
  def generate_response(query):
36
  relevant_doc = retrieve_relevant_doc(query)
37
  input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
38
- inputs = llama_tokenizer(input_text, return_tensors="pt").to(device)
39
- outputs = llama_model.generate(inputs["input_ids"], max_length=1024)
40
- response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
41
  return response
42
 
43
  # Create a Gradio interface
 
18
  # Load the Sentence Transformer model
19
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
20
 
21
+ # Load the ai model for response generation
22
+ ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
23
+ ai_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(device)
24
 
25
  # Define the function to find the most relevant document
26
  @spaces.GPU(duration=120)
 
35
  def generate_response(query):
36
  relevant_doc = retrieve_relevant_doc(query)
37
  input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
38
+ inputs = ai_tokenizer(input_text, return_tensors="pt").to(device)
39
+ outputs = ai_model.generate(inputs["input_ids"], max_length=1024)
40
+ response = ai_tokenizer.decode(outputs[0], skip_special_tokens=True)
41
  return response
42
 
43
  # Create a Gradio interface