Spaces:
Sleeping
Sleeping
File size: 1,738 Bytes
0365501 c157cd5 0365501 c6e70f1 0365501 c6e70f1 0365501 c6e70f1 0365501 c6e70f1 0365501 c6e70f1 0365501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock
def load_chain_openai(api_key: str):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
os.environ["OPENAI_API_KEY"] = ""
return chain
def load_chain_falcon(api_key: str):
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
chain = ConversationChain(llm=llm)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
return chain
class ChatWrapper:
def __init__(self, chain_type: str):
self.chain_type = chain_type
self.history = []
self.lock = Lock()
self.chain = None
def __call__(self, inp: str, api_key: str = ''):
self.lock.acquire()
try:
if api_key:
if self.chain_type == 'openai':
self.chain = load_chain_openai(api_key)
elif self.chain_type == 'falcon':
self.chain = load_chain_falcon(api_key)
else:
raise ValueError(f'Invalid chain_type: {self.chain_type}')
if self.chain is None:
self.history.append((inp, "Please add your API key to proceed."))
return self.history
output = self.chain.run(input=inp)
self.history.append((inp, output))
except Exception as e:
self.history.append((inp, f"An error occurred: {e}"))
finally:
self.lock.release()
return self.history |