terry-li-hm commited on
Commit
7bb49ef
·
1 Parent(s): 3ae6033

`torch` only

Browse files
Files changed (1) hide show
  1. app.py +172 -171
app.py CHANGED
@@ -1,20 +1,21 @@
1
- import os
2
 
3
- import chainlit as cl
4
- import openai
5
  import torch
6
- from chainlit.input_widget import Select, Slider
7
- from llama_index import (
8
- ServiceContext,
9
- StorageContext,
10
- TrafilaturaWebReader,
11
- VectorStoreIndex,
12
- load_index_from_storage,
13
- )
14
- from llama_index.callbacks.base import CallbackManager
15
- from llama_index.embeddings import HuggingFaceEmbedding
16
- from llama_index.llms import HuggingFaceLLM, LiteLLM, MessageRole, OpenAI
17
- from llama_index.prompts import PromptTemplate
 
18
 
19
  # from transformers import BitsAndBytesConfig
20
 
@@ -22,159 +23,159 @@ print(f"Is CUDA available: {torch.cuda.is_available()}")
22
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
23
 
24
 
25
- def get_api_key():
26
- api_key = os.getenv("OPENAI_API_KEY")
27
- if api_key is None:
28
- print("OPENAI_API_KEY missing from environment variables")
29
- api_key = input("Please enter your OPENAI_API_KEY: ")
30
- return api_key
31
-
32
-
33
- openai.api_key = get_api_key()
34
-
35
-
36
- def load_index():
37
- try:
38
- storage_context = StorageContext.from_defaults(persist_dir="./storage")
39
- index = load_index_from_storage(storage_context)
40
- except FileNotFoundError:
41
- print("Storage file not found. Loading from web.")
42
- documents = TrafilaturaWebReader().load_data(["https://bit.ly/45BncJA"])
43
- index = VectorStoreIndex.from_documents(documents)
44
- index.storage_context.persist()
45
- return index
46
-
47
-
48
- index = load_index()
49
-
50
- welcome_msg = (
51
- "Hi there! I’m your China Life chatbot, specialising in answering "
52
- "[frequently asked questions](https://bit.ly/45BncJA). "
53
- "How may I assist you today? "
54
- "Feel free to ask questions like, "
55
- "“Is there any action required after receiving the policy?” or "
56
- "“Can I settle using a demand draft?”"
57
- )
58
-
59
-
60
- @cl.on_chat_start
61
- async def start():
62
- chat_profile = cl.user_session.get("chat_profile")
63
- msg = cl.Message(content="")
64
- for token in list(welcome_msg):
65
- await cl.sleep(0.01)
66
- await msg.stream_token(token)
67
-
68
- await msg.send()
69
-
70
- settings = await cl.ChatSettings(
71
- [
72
- Select(
73
- id="Model",
74
- label="Model",
75
- values=[
76
- "gpt-3.5-turbo",
77
- "gpt-4",
78
- "zephyr",
79
- "litellm-gpt-3.5-turbo",
80
- "litellm-opt-125m",
81
- ],
82
- initial_index=1,
83
- ),
84
- Slider(
85
- id="Temperature",
86
- label="Temperature",
87
- initial=0.0,
88
- min=0.0,
89
- max=2.0,
90
- step=0.1,
91
- ),
92
- ]
93
- ).send()
94
- await setup_query_engine(settings)
95
-
96
-
97
- @cl.on_settings_update
98
- async def setup_query_engine(settings):
99
- print("on_settings_update", settings)
100
-
101
- # def messages_to_prompt(messages):
102
- # prompt = ""
103
- # for message in messages:
104
- # if message.role == "system":
105
- # prompt += f"<|system|>\n{message.content}</s>\n"
106
- # elif message.role == "user":
107
- # prompt += f"<|user|>\n{message.content}</s>\n"
108
- # elif message.role == "assistant":
109
- # prompt += f"<|assistant|>\n{message.content}</s>\n"
110
- # if not prompt.startswith("<|system|>\n"):
111
- # prompt = "<|system|>\n</s>\n" + prompt
112
- # prompt = prompt + "<|assistant|>\n"
113
- # return prompt
114
-
115
- if settings["Model"] == "zephyr":
116
- # model_name = "HuggingFaceH4/zephyr-7b-beta"
117
- # query_wrapper_prompt = PromptTemplate(
118
- # "<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"
119
- # )
120
- # quantization_config = BitsAndBytesConfig(
121
- # load_in_4bit=True,
122
- # bnb_4bit_compute_dtype=torch.bfloat16,
123
- # bnb_4bit_quant_type="nf4",
124
- # bnb_4bit_use_double_quant=True,
125
- # )
126
- # llm = HuggingFaceLLM(
127
- # model_name=model_name,
128
- # tokenizer_name=model_name,
129
- # query_wrapper_prompt=query_wrapper_prompt,
130
- # context_window=3900,
131
- # max_new_tokens=256,
132
- # model_kwargs={"quantization_config": quantization_config},
133
- # generate_kwargs={
134
- # "do_sample": True,
135
- # "temperature": settings["Temperature"],
136
- # "top_k": 50,
137
- # "top_p": 0.95,
138
- # },
139
- # messages_to_prompt=messages_to_prompt,
140
- # device_map="auto",
141
- # )
142
- llm = LiteLLM("gpt-3.5-turbo")
143
- elif settings["Model"] == "litellm-gpt-3.5-turbo":
144
- llm = LiteLLM("gpt-3.5-turbo")
145
- elif settings["Model"] == "litellm-opt-125m":
146
- llm = LiteLLM("vllm/facebook/opt-125m")
147
- else:
148
- llm = OpenAI(model=settings["Model"], temperature=settings["Temperature"])
149
-
150
- # embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
151
- service_context = ServiceContext.from_defaults(
152
- llm=llm,
153
- # embed_model=embed_model,
154
- callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
155
- )
156
-
157
- query_engine = index.as_query_engine(
158
- service_context=service_context,
159
- streaming=True,
160
- )
161
-
162
- cl.user_session.set("query_engine", query_engine)
163
-
164
-
165
- @cl.on_message
166
- async def main(message: cl.Message):
167
- query_engine = cl.user_session.get("query_engine")
168
-
169
- if query_engine is None:
170
- await start()
171
- query_engine = cl.user_session.get("query_engine")
172
-
173
- if query_engine:
174
- query_result = await cl.make_async(query_engine.query)(message.content)
175
- response_message = cl.Message(content=query_result.response_txt or "")
176
-
177
- for token in query_result.response_gen:
178
- await response_message.stream_token(token=token)
179
-
180
- await response_message.send()
 
1
+ # import os
2
 
3
+ # import chainlit as cl
4
+ # import openai
5
  import torch
6
+
7
+ # from chainlit.input_widget import Select, Slider
8
+ # from llama_index import (
9
+ # ServiceContext,
10
+ # StorageContext,
11
+ # TrafilaturaWebReader,
12
+ # VectorStoreIndex,
13
+ # load_index_from_storage,
14
+ # )
15
+ # from llama_index.callbacks.base import CallbackManager
16
+ # from llama_index.embeddings import HuggingFaceEmbedding
17
+ # from llama_index.llms import HuggingFaceLLM, LiteLLM, MessageRole, OpenAI
18
+ # from llama_index.prompts import PromptTemplate
19
 
20
  # from transformers import BitsAndBytesConfig
21
 
 
23
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
24
 
25
 
26
+ # def get_api_key():
27
+ # api_key = os.getenv("OPENAI_API_KEY")
28
+ # if api_key is None:
29
+ # print("OPENAI_API_KEY missing from environment variables")
30
+ # api_key = input("Please enter your OPENAI_API_KEY: ")
31
+ # return api_key
32
+
33
+
34
+ # openai.api_key = get_api_key()
35
+
36
+
37
+ # def load_index():
38
+ # try:
39
+ # storage_context = StorageContext.from_defaults(persist_dir="./storage")
40
+ # index = load_index_from_storage(storage_context)
41
+ # except FileNotFoundError:
42
+ # print("Storage file not found. Loading from web.")
43
+ # documents = TrafilaturaWebReader().load_data(["https://bit.ly/45BncJA"])
44
+ # index = VectorStoreIndex.from_documents(documents)
45
+ # index.storage_context.persist()
46
+ # return index
47
+
48
+
49
+ # index = load_index()
50
+
51
+ # welcome_msg = (
52
+ # "Hi there! I’m your China Life chatbot, specialising in answering "
53
+ # "[frequently asked questions](https://bit.ly/45BncJA). "
54
+ # "How may I assist you today? "
55
+ # "Feel free to ask questions like, "
56
+ # "“Is there any action required after receiving the policy?” or "
57
+ # "“Can I settle using a demand draft?”"
58
+ # )
59
+
60
+
61
+ # @cl.on_chat_start
62
+ # async def start():
63
+ # chat_profile = cl.user_session.get("chat_profile")
64
+ # msg = cl.Message(content="")
65
+ # for token in list(welcome_msg):
66
+ # await cl.sleep(0.01)
67
+ # await msg.stream_token(token)
68
+
69
+ # await msg.send()
70
+
71
+ # settings = await cl.ChatSettings(
72
+ # [
73
+ # Select(
74
+ # id="Model",
75
+ # label="Model",
76
+ # values=[
77
+ # "gpt-3.5-turbo",
78
+ # "gpt-4",
79
+ # "zephyr",
80
+ # "litellm-gpt-3.5-turbo",
81
+ # "litellm-opt-125m",
82
+ # ],
83
+ # initial_index=1,
84
+ # ),
85
+ # Slider(
86
+ # id="Temperature",
87
+ # label="Temperature",
88
+ # initial=0.0,
89
+ # min=0.0,
90
+ # max=2.0,
91
+ # step=0.1,
92
+ # ),
93
+ # ]
94
+ # ).send()
95
+ # await setup_query_engine(settings)
96
+
97
+
98
+ # @cl.on_settings_update
99
+ # async def setup_query_engine(settings):
100
+ # print("on_settings_update", settings)
101
+
102
+ # # def messages_to_prompt(messages):
103
+ # # prompt = ""
104
+ # # for message in messages:
105
+ # # if message.role == "system":
106
+ # # prompt += f"<|system|>\n{message.content}</s>\n"
107
+ # # elif message.role == "user":
108
+ # # prompt += f"<|user|>\n{message.content}</s>\n"
109
+ # # elif message.role == "assistant":
110
+ # # prompt += f"<|assistant|>\n{message.content}</s>\n"
111
+ # # if not prompt.startswith("<|system|>\n"):
112
+ # # prompt = "<|system|>\n</s>\n" + prompt
113
+ # # prompt = prompt + "<|assistant|>\n"
114
+ # # return prompt
115
+
116
+ # if settings["Model"] == "zephyr":
117
+ # # model_name = "HuggingFaceH4/zephyr-7b-beta"
118
+ # # query_wrapper_prompt = PromptTemplate(
119
+ # # "<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"
120
+ # # )
121
+ # # quantization_config = BitsAndBytesConfig(
122
+ # # load_in_4bit=True,
123
+ # # bnb_4bit_compute_dtype=torch.bfloat16,
124
+ # # bnb_4bit_quant_type="nf4",
125
+ # # bnb_4bit_use_double_quant=True,
126
+ # # )
127
+ # # llm = HuggingFaceLLM(
128
+ # # model_name=model_name,
129
+ # # tokenizer_name=model_name,
130
+ # # query_wrapper_prompt=query_wrapper_prompt,
131
+ # # context_window=3900,
132
+ # # max_new_tokens=256,
133
+ # # model_kwargs={"quantization_config": quantization_config},
134
+ # # generate_kwargs={
135
+ # # "do_sample": True,
136
+ # # "temperature": settings["Temperature"],
137
+ # # "top_k": 50,
138
+ # # "top_p": 0.95,
139
+ # # },
140
+ # # messages_to_prompt=messages_to_prompt,
141
+ # # device_map="auto",
142
+ # # )
143
+ # llm = LiteLLM("gpt-3.5-turbo")
144
+ # elif settings["Model"] == "litellm-gpt-3.5-turbo":
145
+ # llm = LiteLLM("gpt-3.5-turbo")
146
+ # elif settings["Model"] == "litellm-opt-125m":
147
+ # llm = LiteLLM("vllm/facebook/opt-125m")
148
+ # else:
149
+ # llm = OpenAI(model=settings["Model"], temperature=settings["Temperature"])
150
+
151
+ # # embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
152
+ # service_context = ServiceContext.from_defaults(
153
+ # llm=llm,
154
+ # # embed_model=embed_model,
155
+ # callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
156
+ # )
157
+
158
+ # query_engine = index.as_query_engine(
159
+ # service_context=service_context,
160
+ # streaming=True,
161
+ # )
162
+
163
+ # cl.user_session.set("query_engine", query_engine)
164
+
165
+
166
+ # @cl.on_message
167
+ # async def main(message: cl.Message):
168
+ # query_engine = cl.user_session.get("query_engine")
169
+
170
+ # if query_engine is None:
171
+ # await start()
172
+ # query_engine = cl.user_session.get("query_engine")
173
+
174
+ # if query_engine:
175
+ # query_result = await cl.make_async(query_engine.query)(message.content)
176
+ # response_message = cl.Message(content=query_result.response_txt or "")
177
+
178
+ # for token in query_result.response_gen:
179
+ # await response_message.stream_token(token=token)
180
+
181
+ # await response_message.send()