yash-srivastava19 commited on
Commit
c9946ac
·
1 Parent(s): c626d36

Create custom_class.py

Browse files

Created `custom_class.py`

Files changed (1) hide show
  1. custom_class.py +50 -0
custom_class.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wrap our custom fine-tuned model into a Langchain LLM class, and make it available for chainlit factory.
2
+ import chainlit as cl
3
+ import cohere
4
+ from langchain.llms.base import LLM
5
+ from metaphor_python import Metaphor
6
+
7
+ ### Caution : This should not kept in open, and I'll fix this in upcoming commits.
8
+
9
+ metaphor = Metaphor("56d9ca28-e84b-43ce-8f68-75e1a8bb4dd3") ## This is metaphor API Key
10
+ co = cohere.Client('P21OYiQsKcOVS9XeQsacEhMVodPbbbYccpX2XWsz') # This is COHERE_API_KEY
11
+
12
+ class SourcesString:
13
+ """[ Custom Class to beautify methaphor.search.result(currently, only top 3 results are shown)] """
14
+ def __init__(self, results_list) -> None:
15
+ self.res_list = results_list
16
+
17
+ def get_parsed_string(self):
18
+
19
+ return "### Additional Sources:\n"+"\n".join(f"{i+1}. [{self.res_list[i].title}]({self.res_list[i].url})" for i in range(len(self.res_list)))
20
+
21
+
22
+ class CustomLLM(LLM):
23
+ """[Wrapper for Cohere Chat Model, which is then appended with Metaphor Search Results to mimic RAG. ] """
24
+
25
+ @property
26
+ def _llm_type(self) -> str:
27
+ return "custom"
28
+
29
+ def _call(self,prompt: str,stop = None,run_manager = None,) -> str:
30
+ """ Cohere Chat Model is not supported as of now, so we create a custom wrapper it. """
31
+
32
+ if stop is not None:
33
+ raise ValueError("stop kwargs are not permitted.")
34
+ response = co.chat(f'{prompt}', model = 'command')
35
+
36
+ results = metaphor.search(prompt, use_autoprompt=True, num_results=3)
37
+ # print(results)
38
+
39
+ print(response)
40
+
41
+ sources_text = SourcesString(results.results).get_parsed_string().strip("```") # pretty hacky, but we need to find a good fix for this.
42
+ print(sources_text)
43
+
44
+ return response.text + "\n\n\n" + sources_text
45
+
46
+ @property
47
+ def _identifying_params(self) :
48
+ """Get the identifying parameters."""
49
+
50
+ return {"model_type": f'COHERE_CHAT_MODEL'}