article_writer / ai_generate.py
minko186's picture
merge main + multi pdfs + updated html cleaning + better references
43d4e83
raw
history blame
6.73 kB
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq
import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, FinishReason
import vertexai.preview.generative_models as generative_models
import google.generativeai as genai
import anthropic
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_anthropic import ChatAnthropic
from dotenv import load_dotenv
load_dotenv()
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
groq_client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# give access to all APIs for GCP instance
# gcloud auth application-default login
genai.configure(api_key=os.environ.get("GENAI_API_KEY"))
vertexai.init(project="proprietary-info-detection", location="us-central1")
gemini_client = GenerativeModel("gemini-1.5-pro-001")
claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
# For GPT-4 1 word is about 1.3 tokens.
temperature = 1.0
max_tokens = 2048
rag_llms = {
"LLaMA 3": ChatGroq(
temperature=temperature,
max_tokens=max_tokens,
model_name="llama3-70b-8192",
),
"OpenAI GPT 4o Mini": ChatOpenAI(
temperature=temperature,
max_tokens=max_tokens,
model_name="gpt-4o-mini",
),
"OpenAI GPT 4o": ChatOpenAI(
temperature=temperature,
max_tokens=max_tokens,
model_name="gpt-4o",
),
"OpenAI GPT 4": ChatOpenAI(
temperature=temperature,
max_tokens=max_tokens,
model_name="gpt-4-turbo",
),
"Gemini 1.5 Pro": ChatGoogleGenerativeAI(temperature=temperature, max_tokens=max_tokens, model="gemini-1.5-pro"),
"Claude Sonnet 3.5": ChatAnthropic(
temperature=temperature,
max_tokens=max_tokens,
model_name="claude-3-5-sonnet-20240620",
),
}
def create_db_with_langchain(path):
all_docs = []
for file in path:
loader = PyMuPDFLoader(file)
data = loader.load()
# split it into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(data)
all_docs.extend(docs)
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# load it into Chroma
db = Chroma.from_documents(all_docs, embedding_function)
return db
def generate_rag(text, model, path):
print(f"Generating text using RAG for {model}...")
llm = rag_llms[model]
db = create_db_with_langchain(path)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 4, "fetch_k": 20})
prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm
return rag_chain.invoke(text).content
def generate_groq(text, model):
completion = groq_client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": text},
{
"role": "assistant",
"content": "Please follow the instruction and write about the given topic in approximately the given number of words",
},
],
temperature=temperature,
max_tokens=max_tokens,
stream=True,
stop=None,
)
response = ""
for i, chunk in enumerate(completion):
if i != 0:
response += chunk.choices[0].delta.content or ""
return response
def generate_openai(text, model, openai_client):
message = [{"role": "user", "content": text}]
response = openai_client.chat.completions.create(
model=model,
messages=message,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content
def generate_gemini(text, model, gemini_client):
safety_settings = {
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
generation_config = {
"max_output_tokens": max_tokens,
"temperature": temperature,
}
response = gemini_client.generate_content(
[text],
generation_config=generation_config,
safety_settings=safety_settings,
stream=False,
)
return response.text
def generate_claude(text, model, claude_client):
response = claude_client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system="You are helpful assistant.",
messages=[{"role": "user", "content": [{"type": "text", "text": text}]}],
)
return response.content[0].text.strip()
def generate(text, model, path, api=None):
if path:
result = generate_rag(text, model, path)
return result
else:
print(f"Generating text for {model}...")
if model == "LLaMA 3":
return generate_groq(text, "llama3-70b-8192")
elif model == "OpenAI GPT 4o Mini":
return generate_openai(text, "gpt-4o-mini", openai_client)
elif model == "OpenAI GPT 4o":
return generate_openai(text, "gpt-4o", openai_client)
elif model == "OpenAI GPT 4":
return generate_openai(text, "gpt-4-turbo", openai_client)
elif model == "Gemini 1.5 Pro":
return generate_gemini(text, "", gemini_client)
elif model == "Claude Sonnet 3.5":
return generate_claude(text, "claude-3-5-sonnet-20240620", claude_client)