Rohan Kataria
trying to delete
c6e70f1
raw
history blame
1.74 kB
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock
def load_chain_openai(api_key: str):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
os.environ["OPENAI_API_KEY"] = ""
return chain
def load_chain_falcon(api_key: str):
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
chain = ConversationChain(llm=llm)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
return chain
class ChatWrapper:
def __init__(self, chain_type: str):
self.chain_type = chain_type
self.history = []
self.lock = Lock()
self.chain = None
def __call__(self, inp: str, api_key: str = ''):
self.lock.acquire()
try:
if api_key:
if self.chain_type == 'openai':
self.chain = load_chain_openai(api_key)
elif self.chain_type == 'falcon':
self.chain = load_chain_falcon(api_key)
else:
raise ValueError(f'Invalid chain_type: {self.chain_type}')
if self.chain is None:
self.history.append((inp, "Please add your API key to proceed."))
return self.history
output = self.chain.run(input=inp)
self.history.append((inp, output))
except Exception as e:
self.history.append((inp, f"An error occurred: {e}"))
finally:
self.lock.release()
return self.history