Oracle_Wikipage / Classes /Owiki_Class.py
Sujithanumala's picture
Update Classes/Owiki_Class.py
375537c verified
import os
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import PromptTemplate
import json
import re
from Classes.Helper_Class import DB_Retriever
from typing import Optional
class OWiki:
def __init__(self,**kwargs):
temperature = kwargs['temperature']
self.summary = kwargs['summary_length']
model = kwargs["model"]
self.db_loc = kwargs["db_loc"]
self.api_key = kwargs["api_key"]
os.environ["GOOGLE_API_KEY"] = self.api_key
genai.configure(api_key=self.api_key)
self.llm = ChatGoogleGenerativeAI(model=model,
temperature=temperature)
self.model_embedding = kwargs['model_embeddings']
def get_summary_template(self):
prompt = """Generate a summary for the following conversational data in less than {summary} lines.\nText:\n{text}\n\nSummary:"""
prompt_template = PromptTemplate(template = prompt,input_variables=['summary','text'])
return prompt_template
def create_sql_prompt_template(self,schemas):
prompt = """You are an expert in writing SQL commands as best as you can.
Here are the rules you must follow to complete your task.
1. If you need more context about question don't hesitate to ask but don't keep on asking all the time.
2. Please make sure you answer correctly so that you will get a very high reward.
For the below schemas write an SQL query. \nSQL Schema:"""
for table_name,table_schema in schemas.items():
prompt+= f"Table Name: {table_name} Schema:"
for key,value in table_schema.items():
prompt+= f"{key} {value}"
prompt+= """\n\nQuestion:{question}\n\nAnswer:"""
# print("Prompt",prompt,'\n\n')
prompt_template = PromptTemplate(template = prompt,input_variables=['question'])
# print("Prompt template",prompt_template)
return prompt_template
def create_prompt_for_OIC_bot(self):
template = """You are expert OIC(Oracle Integration Cloud) Bot. You will be given a task to solve as best you can.
Here are the rules you should always follow to solve your task:
1. Always provide a Question Explanation with **Question Explanation:** Heading and Potential Solution with **Potential Solution:** Headings.
2. Take care of not being too long typically not exceeding 5000 tokens.
3. Response must contain all possible **Error Scenarios:** if applicable along with a **Summary:** Heading containing breif summary of the task solution at the end.
4. If you don't know the answer or if it is not in the context, please answer as **I am not trained on this topics due to limited resources.**
Now Begin!
Context:
{context}
Question: {question}
"""
prompt = PromptTemplate.from_template(template)
return prompt
def create_sql_agent(self,question,schemas):
prompt_template = self.create_sql_prompt_template(schemas)
chain = prompt_template | self.llm | StrOutputParser()
response = chain.invoke({"question":question})
response = self.format_llm_response(response)
return response
def generate_summary(self,text):
prompt_template = self.get_summary_template()
chain = prompt_template | self.llm | StrOutputParser()
response = chain.invoke({"text":text,"summary":self.summary})
return response
def format_llm_response(self,text):
bold_pattern = r"\*\*(.*?)\*\*"
italic_pattern = r"\*(.*?)\*"
code_pattern = r"```(.*?)```"
text = text.replace('\n', '<br>')
formatted_text = re.sub(code_pattern,"<pre><code>\\1</code></pre>",text)
formatted_text = re.sub(bold_pattern, "<b>\\1</b>", formatted_text)
formatted_text = re.sub(italic_pattern, "<i>\\1</i>", formatted_text)
return formatted_text
def search_from_db(self, query : str, chat_history : Optional[str] ) -> str :
db = DB_Retriever(self.db_loc,self.model_embedding)
retriever = db.retrieve(query)
prompt = self.create_prompt_for_OIC_bot()
chat_history = self.generate_summary(chat_history)
retrieval_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
)
response = retrieval_chain.invoke(query)
response = self.format_llm_response(response)
return response