Spaces:
Runtime error
Runtime error
File size: 5,826 Bytes
17d12d8 29e6656 7edc5be e1b0f65 f5e679e 2d6909b f716a54 03fd59b 708f094 03fd59b 708f094 03fd59b 59fbf6a 03fd59b 708f094 03fd59b 4b92a71 03fd59b 134b51f 03fd59b e1b0f65 f716a54 5f853f6 03fd59b 5f853f6 f716a54 2d6909b f716a54 5f853f6 59fbf6a b72ef7f 5f853f6 59fbf6a 708f094 4b92a71 708f094 134b51f 708f094 43d4e83 59fbf6a 03fd59b 708f094 59fbf6a 708f094 43d4e83 03fd59b 708f094 59fbf6a 708f094 03fd59b 708f094 f716a54 708f094 4b92a71 708f094 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq
import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, FinishReason
import vertexai.preview.generative_models as generative_models
import google.generativeai as genai
import anthropic
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.schema import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableMap
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_anthropic import ChatAnthropic
from dotenv import load_dotenv
load_dotenv()
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
groq_client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# give access to all APIs for GCP instance
# gcloud auth application-default login
genai.configure(api_key=os.environ.get("GENAI_API_KEY"))
vertexai.init(project="proprietary-info-detection", location="us-central1")
gemini_client = GenerativeModel("gemini-1.5-pro-001")
claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
# LLM params
# For GPT-4 1 word is about 1.3 tokens.
temperature = 1.0
max_tokens = 2048
# RAG params
CHUNK_SIZE = 1024
CHUNK_OVERLAP = CHUNK_SIZE // 8
K = 10
FETCH_K = 20
llm_model_translation = {
"LLaMA 3": "llama3-70b-8192",
"OpenAI GPT 4o Mini": "gpt-4o-mini",
"OpenAI GPT 4o": "gpt-4o",
"OpenAI GPT 4": "gpt-4-turbo",
"Gemini 1.5 Pro": "gemini-1.5-pro",
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
}
llm_classes = {
"llama3-70b-8192": ChatGroq,
"gpt-4o-mini": ChatOpenAI,
"gpt-4o": ChatOpenAI,
"gpt-4-turbo": ChatOpenAI,
"gemini-1.5-pro": ChatGoogleGenerativeAI,
"claude-3-5-sonnet-20240620": ChatAnthropic,
}
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
model_name = llm_model_translation.get(model)
llm_class = llm_classes.get(model_name)
if not llm_class:
raise ValueError(f"Model {model} not supported.")
try:
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
except Exception as e:
print(f"An error occurred: {e}")
llm = None
return llm
def create_db_with_langchain(path: list[str], url_content: dict):
all_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
if path:
for file in path:
loader = PyMuPDFLoader(file)
data = loader.load()
# split it into chunks
docs = text_splitter.split_documents(data)
all_docs.extend(docs)
if url_content:
for url, content in url_content.items():
doc = Document(page_content=content, metadata={"source": url})
# split it into chunks
docs = text_splitter.split_documents([doc])
all_docs.extend(docs)
# print docs
for idx, doc in enumerate(all_docs):
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
db = Chroma.from_documents(all_docs, embedding_function)
return db
def generate_rag(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
db = create_db_with_langchain(path, url_content)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
rag_prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
docs = retriever.get_relevant_documents(topic)
formatted_docs = format_docs(docs)
rag_chain = (
{"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
)
return rag_chain.invoke(prompt)
def generate_base(
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
try:
output = llm.invoke(prompt).content
return output
except Exception as e:
print(f"An error occurred while running the model: {e}")
return None
def generate(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
if path or url_content:
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
else:
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|