Chris4K commited on
Commit
bc53764
·
1 Parent(s): a7a5d1a
model/conversation_chain_singleton.py CHANGED
@@ -1,35 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain.memory import ConversationBufferMemory
2
  from langchain.chains import ConversationChain
3
  from langchain.llms import HuggingFaceHub
4
 
5
  class ConversationChainSingleton:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  _instance = None
7
 
8
  def __new__(cls, *args, **kwargs):
 
 
 
 
 
 
9
  if not cls._instance:
10
  cls._instance = super(ConversationChainSingleton, cls).__new__(cls)
11
  # Initialize your conversation chain here
12
- cls._instance.conversation_chain = get_conversation_chain()
13
  return cls._instance
14
 
15
  def get_conversation_chain(self):
16
- return self.conversation_chain
17
-
18
-
19
- def get_conversation_chain( ):
20
  """
21
- Create a conversational retrieval chain and a language model.
22
 
 
 
23
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- llm = HuggingFaceHub(
26
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
27
- model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
28
- )
29
- # llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
30
-
31
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
32
- conversation_chain = ConversationChain(
33
- llm=llm, verbose=True, memory=memory
34
- )
35
- return conversation_chain
 
1
+ """
2
+ Module: conversation_chain_singleton
3
+
4
+ This module provides a singleton class, ConversationChainSingleton, for managing a conversation chain instance.
5
+
6
+ Dependencies:
7
+ - langchain.memory: Module providing memory functionalities for conversation chains.
8
+ - langchain.chains: Module providing conversation chain functionalities.
9
+ - langchain.llms: Module providing language model functionalities, particularly from HuggingFaceHub.
10
+
11
+ Classes:
12
+ - ConversationChainSingleton: A singleton class for managing a conversation chain instance.
13
+ """
14
+
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain.chains import ConversationChain
17
  from langchain.llms import HuggingFaceHub
18
 
19
  class ConversationChainSingleton:
20
+ """
21
+ A singleton class for managing a conversation chain instance.
22
+
23
+ Attributes:
24
+ - _instance: Private attribute holding the singleton instance.
25
+ - conversation_chain: The conversation chain instance.
26
+
27
+ Methods:
28
+ - __new__(cls, *args, **kwargs): Creates a new instance of the ConversationChainSingleton class.
29
+ - get_conversation_chain(self): Returns the conversation chain instance.
30
+
31
+ Static Methods:
32
+ - get_conversation_chain(): Creates and returns a conversational retrieval chain and a language model.
33
+ """
34
+
35
  _instance = None
36
 
37
  def __new__(cls, *args, **kwargs):
38
+ """
39
+ Create a new instance of the ConversationChainSingleton class if it doesn't exist.
40
+
41
+ Returns:
42
+ - ConversationChainSingleton: The singleton instance.
43
+ """
44
  if not cls._instance:
45
  cls._instance = super(ConversationChainSingleton, cls).__new__(cls)
46
  # Initialize your conversation chain here
47
+ cls._instance.conversation_chain = cls.get_conversation_chain()
48
  return cls._instance
49
 
50
  def get_conversation_chain(self):
 
 
 
 
51
  """
52
+ Get the conversation chain instance.
53
 
54
+ Returns:
55
+ - ConversationChain: The conversation chain instance.
56
  """
57
+ return self.conversation_chain
58
+
59
+ @staticmethod
60
+ def get_conversation_chain():
61
+ """
62
+ Create a conversational retrieval chain and a language model.
63
+
64
+ Returns:
65
+ - ConversationChain: The conversation chain instance.
66
+ """
67
+ llm = HuggingFaceHub(
68
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
69
+ model_kwargs={"max_length": 1048, "temperature": 0.2, "max_new_tokens": 256, "top_p": 0.95, "repetition_penalty": 1.0},
70
+ )
71
+ # llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
72
 
73
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
74
+ conversation_chain = ConversationChain(
75
+ llm=llm, verbose=True, memory=memory
76
+ )
77
+ return conversation_chain
 
 
 
 
 
 
model/custom_agent.py CHANGED
@@ -1,17 +1,41 @@
1
- # custom_agent.py
2
- import os
3
- import base64
4
- import io
5
- import requests
 
 
 
 
 
 
 
 
 
 
6
  import time
 
7
  from transformers import Agent
8
  from utils.logger import log_response
9
 
10
- import time
11
- import torch
12
-
13
  class CustomHfAgent(Agent):
 
 
14
  def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  super().__init__(
16
  chat_prompt_template=chat_prompt_template,
17
  run_prompt_template=run_prompt_template,
@@ -22,6 +46,16 @@ class CustomHfAgent(Agent):
22
  self.input_params = input_params
23
 
24
  def generate_one(self, prompt, stop):
 
 
 
 
 
 
 
 
 
 
25
  headers = {"Authorization": self.token}
26
  max_new_tokens = self.input_params.get("max_new_tokens", 192)
27
  parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
@@ -29,12 +63,17 @@ class CustomHfAgent(Agent):
29
  "inputs": prompt,
30
  "parameters": parameters,
31
  }
32
- response = requests.post(self.url_endpoint, json=inputs, headers=headers)
33
-
 
 
 
 
 
34
  if response.status_code == 429:
35
  log_response("Getting rate-limited, waiting a tiny bit before trying again.")
36
  time.sleep(1)
37
- return self._generate_one(prompt)
38
  elif response.status_code != 200:
39
  raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
40
  log_response(response)
 
1
+ """
2
+ Module: custom_agent
3
+
4
+ This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API.
5
+
6
+ Dependencies:
7
+ - time: Standard Python time module for time-related operations.
8
+ - requests: HTTP library for making requests.
9
+ - transformers: Hugging Face's transformers library for NLP tasks.
10
+ - utils.logger: Custom logger module for logging responses.
11
+
12
+ Classes:
13
+ - CustomHfAgent: A custom class for interacting with the Hugging Face model API.
14
+ """
15
+
16
  import time
17
+ import requests
18
  from transformers import Agent
19
  from utils.logger import log_response
20
 
 
 
 
21
  class CustomHfAgent(Agent):
22
+ """A custom class for interacting with the Hugging Face model API."""
23
+
24
  def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
25
+ """
26
+ Initialize the CustomHfAgent.
27
+
28
+ Args:
29
+ - url_endpoint (str): The URL endpoint for the Hugging Face model API.
30
+ - token (str): The authentication token required to access the API.
31
+ - chat_prompt_template (str): Template for chat prompts.
32
+ - run_prompt_template (str): Template for run prompts.
33
+ - additional_tools (list): Additional tools for the agent.
34
+ - input_params (dict): Additional parameters for input.
35
+
36
+ Returns:
37
+ - None
38
+ """
39
  super().__init__(
40
  chat_prompt_template=chat_prompt_template,
41
  run_prompt_template=run_prompt_template,
 
46
  self.input_params = input_params
47
 
48
  def generate_one(self, prompt, stop):
49
+ """
50
+ Generate one response from the Hugging Face model.
51
+
52
+ Args:
53
+ - prompt (str): The prompt to generate a response for.
54
+ - stop (list): A list of strings indicating where to stop generating text.
55
+
56
+ Returns:
57
+ - str: The generated response.
58
+ """
59
  headers = {"Authorization": self.token}
60
  max_new_tokens = self.input_params.get("max_new_tokens", 192)
61
  parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
 
63
  "inputs": prompt,
64
  "parameters": parameters,
65
  }
66
+ print(inputs)
67
+ try:
68
+ response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300)
69
+ except requests.Timeout:
70
+ pass
71
+ except requests.ConnectionError:
72
+ pass
73
  if response.status_code == 429:
74
  log_response("Getting rate-limited, waiting a tiny bit before trying again.")
75
  time.sleep(1)
76
+ return self.generate_one(prompt, stop)
77
  elif response.status_code != 200:
78
  raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
79
  log_response(response)