Shanat commited on
Commit
e807162
1 Parent(s): bc15fbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -1
app.py CHANGED
@@ -13,7 +13,35 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
13
  import torch
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  chatbot = pipeline(model="microsoft/Phi-3.5-mini-instruct")
 
 
17
  #token = os.getenv("HF_TOKEN")
18
  #login(token = os.getenv('HF_TOKEN'))
19
  #chatbot = pipeline(model="meta-llama/Llama-3.2-1B")
@@ -27,16 +55,38 @@ chatbot = pipeline(model="microsoft/Phi-3.5-mini-instruct")
27
 
28
  #chatbot = pipeline(model="facebook/blenderbot-400M-distill")
29
 
 
 
 
 
 
 
 
30
  message_list = []
31
  response_list = []
32
 
33
 
34
  def vanilla_chatbot(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  #inputs = tokenizer(message, return_tensors="pt").to("cpu")
36
  #with torch.no_grad():
37
  # outputs = model.generate(inputs.input_ids, max_length=100)
38
  #return tokenizer.decode(outputs[0], skip_special_tokens=True)
39
- conversation = chatbot(message)
40
 
41
  return conversation[0]['generated_text']
42
 
 
13
  import torch
14
 
15
 
16
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
17
+ from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
18
+ from llama_index.core.retrievers import VectorIndexRetriever
19
+ from llama_index.core.query_engine import RetrieverQueryEngine
20
+ from llama_index.core.postprocessor import SimilarityPostprocessor
21
+
22
+ Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
23
+ Settings.llm = None
24
+ Settings.chunk_size = 256
25
+ Settings.chunk_overlap = 25
26
+ documents = SimpleDirectoryReader("/test").load_data()
27
+ index = VectorStoreIndex.from_documents(documents)
28
+
29
+ top_k = 6
30
+
31
+ # configure retriever
32
+ retriever = VectorIndexRetriever(
33
+ index=index,
34
+ similarity_top_k=top_k,
35
+ )
36
+
37
+ query_engine = RetrieverQueryEngine(
38
+ retriever=retriever,
39
+ node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.5)],
40
+ )
41
+
42
  chatbot = pipeline(model="microsoft/Phi-3.5-mini-instruct")
43
+
44
+
45
  #token = os.getenv("HF_TOKEN")
46
  #login(token = os.getenv('HF_TOKEN'))
47
  #chatbot = pipeline(model="meta-llama/Llama-3.2-1B")
 
55
 
56
  #chatbot = pipeline(model="facebook/blenderbot-400M-distill")
57
 
58
+ prompt_template_w_context = lambda context, comment: f"""{context}
59
+ Please respond to the following comment. Use the context above if it is helpful.
60
+ {comment}
61
+ [/INST]
62
+ """
63
+
64
+
65
  message_list = []
66
  response_list = []
67
 
68
 
69
  def vanilla_chatbot(message, history):
70
+ response = query_engine.query(message)
71
+ # reformat response
72
+ context = "Context:\n"
73
+ for i in range(len(response.source_nodes)):
74
+ context = context + response.source_nodes[i].text + "\n\n"
75
+ #print(context)
76
+ prompt = prompt_template_w_context(context, message)
77
+ #inputs = tokenizer(prompt, return_tensors="pt")
78
+ #outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=280)
79
+ #print(tokenizer.batch_decode(outputs)[0])
80
+ #conversation = pipe(message, temperature=0.1)
81
+ #ot=tokenizer.batch_decode(outputs)[0]
82
+ #context_length=len(prompt)
83
+ #new_sentence = ot[context_length+3:]
84
+ #return new_sentence
85
  #inputs = tokenizer(message, return_tensors="pt").to("cpu")
86
  #with torch.no_grad():
87
  # outputs = model.generate(inputs.input_ids, max_length=100)
88
  #return tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+ conversation = chatbot(prompt)
90
 
91
  return conversation[0]['generated_text']
92