Samarth991 commited on
Commit
1cba1c6
·
verified ·
1 Parent(s): f579cb9

Update QnA.py

Browse files
Files changed (1) hide show
  1. QnA.py +15 -4
QnA.py CHANGED
@@ -54,17 +54,28 @@ def summarize(documents,llm):
54
  return results['output_text']
55
 
56
 
 
 
 
 
 
 
 
 
57
  def Q_A(vectorstore,question,API_KEY):
58
- os.environ["GROQ_API_KEY"] = API_KEY
59
- llm_groq = ChatGroq(model="llama3-8b-8192")
 
 
 
60
 
61
  # Create a retriever
62
  retriever = vectorstore.as_retriever(search_type = 'similarity',search_kwargs = {'k':2},)
63
  if 'reliable' in question.lower() or 'relaibility' in question.lower():
64
- question_answer_chain = create_stuff_documents_chain(llm_groq, prompt_template_for_relaibility())
65
 
66
  else:
67
- question_answer_chain = create_stuff_documents_chain(llm_groq, prompt_template_to_analyze_resume())
68
 
69
  chain = create_retrieval_chain(retriever, question_answer_chain)
70
  result = chain.invoke({'input':question})
 
54
  return results['output_text']
55
 
56
 
57
+ def get_hugging_face_model(model_id='mistralai/Mistral-7B-Instruct-v0.3',temperature=0.01,max_tokens=2048,api_key=None):
58
+ llm = HuggingFaceHub(
59
+ huggingfacehub_api_token =api_key ,
60
+ repo_id=model_id,
61
+ model_kwargs={"temperature":temperature, "max_new_tokens":max_tokens}
62
+ )
63
+ return llm
64
+
65
  def Q_A(vectorstore,question,API_KEY):
66
+ if API_KEY.startswith('gsk'):
67
+ os.environ["GROQ_API_KEY"] = API_KEY
68
+ chat_llm = ChatGroq(model="llama3-8b-8192")
69
+ elif API.startswith('hk'):
70
+ chat_llm = get_hugging_face_model(API_key=API_KEY)
71
 
72
  # Create a retriever
73
  retriever = vectorstore.as_retriever(search_type = 'similarity',search_kwargs = {'k':2},)
74
  if 'reliable' in question.lower() or 'relaibility' in question.lower():
75
+ question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_for_relaibility())
76
 
77
  else:
78
+ question_answer_chain = create_stuff_documents_chain(chat_llm, prompt_template_to_analyze_resume())
79
 
80
  chain = create_retrieval_chain(retriever, question_answer_chain)
81
  result = chain.invoke({'input':question})