terry-li-hm commited on
Commit
2990b6a
·
1 Parent(s): e2c69ce

Add `zephyr`

Browse files
Files changed (2) hide show
  1. app.py +56 -4
  2. requirements.txt +5 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import chainlit as cl
4
  import openai
 
5
  from chainlit.input_widget import Select, Slider, Switch
6
  from langchain.chat_models import ChatOpenAI
7
  from llama_index import (
@@ -13,7 +14,13 @@ from llama_index import (
13
  load_index_from_storage,
14
  )
15
  from llama_index.callbacks.base import CallbackManager
 
 
16
  from llama_index.llms import ChatMessage, HuggingFaceLLM, MessageRole, OpenAI
 
 
 
 
17
 
18
 
19
  def get_api_key():
@@ -66,7 +73,7 @@ async def start():
66
  Select(
67
  id="Model",
68
  label="Model",
69
- values=["gpt-3.5-turbo", "gpt-4"],
70
  initial_index=1,
71
  ),
72
  Slider(
@@ -86,10 +93,55 @@ async def start():
86
  async def setup_query_engine(settings):
87
  print("on_settings_update", settings)
88
 
89
- llm = OpenAI(model=settings["Model"], temperature=settings["Temperature"])
90
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  service_context = ServiceContext.from_defaults(
92
- llm=llm, callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()])
 
 
93
  )
94
 
95
  query_engine = index.as_query_engine(
 
2
 
3
  import chainlit as cl
4
  import openai
5
+ import torch
6
  from chainlit.input_widget import Select, Slider, Switch
7
  from langchain.chat_models import ChatOpenAI
8
  from llama_index import (
 
14
  load_index_from_storage,
15
  )
16
  from llama_index.callbacks.base import CallbackManager
17
+ from llama_index.chat_engine import CondenseQuestionChatEngine
18
+ from llama_index.embeddings import HuggingFaceEmbedding
19
  from llama_index.llms import ChatMessage, HuggingFaceLLM, MessageRole, OpenAI
20
+ from llama_index.prompts import PromptTemplate
21
+ from llama_index.query_engine import SubQuestionQueryEngine
22
+ from llama_index.tools import QueryEngineTool, ToolMetadata
23
+ from transformers import BitsAndBytesConfig
24
 
25
 
26
  def get_api_key():
 
73
  Select(
74
  id="Model",
75
  label="Model",
76
+ values=["gpt-3.5-turbo", "gpt-4", "zephyr"],
77
  initial_index=1,
78
  ),
79
  Slider(
 
93
  async def setup_query_engine(settings):
94
  print("on_settings_update", settings)
95
 
96
+ def messages_to_prompt(messages):
97
+ prompt = ""
98
+ for message in messages:
99
+ if message.role == "system":
100
+ prompt += f"<|system|>\n{message.content}</s>\n"
101
+ elif message.role == "user":
102
+ prompt += f"<|user|>\n{message.content}</s>\n"
103
+ elif message.role == "assistant":
104
+ prompt += f"<|assistant|>\n{message.content}</s>\n"
105
+ if not prompt.startswith("<|system|>\n"):
106
+ prompt = "<|system|>\n</s>\n" + prompt
107
+ prompt = prompt + "<|assistant|>\n"
108
+ return prompt
109
+
110
+ if settings["Model"] == "zephyr":
111
+ model_name = "HuggingFaceH4/zephyr-7b-beta"
112
+ query_wrapper_prompt = PromptTemplate(
113
+ "<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"
114
+ )
115
+ quantization_config = BitsAndBytesConfig(
116
+ load_in_4bit=True,
117
+ bnb_4bit_compute_dtype=torch.float16,
118
+ bnb_4bit_quant_type="nf4",
119
+ bnb_4bit_use_double_quant=True,
120
+ )
121
+ llm = HuggingFaceLLM(
122
+ model_name=model_name,
123
+ tokenizer_name=model_name,
124
+ query_wrapper_prompt=query_wrapper_prompt,
125
+ context_window=3900,
126
+ max_new_tokens=256,
127
+ model_kwargs={"quantization_config": quantization_config},
128
+ generate_kwargs={
129
+ "do_sample": True,
130
+ "temperature": settings["Temperature"],
131
+ "top_k": 50,
132
+ "top_p": 0.95,
133
+ },
134
+ messages_to_prompt=messages_to_prompt,
135
+ device_map="auto",
136
+ )
137
+ else:
138
+ llm = OpenAI(model=settings["Model"], temperature=settings["Temperature"])
139
+
140
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
141
  service_context = ServiceContext.from_defaults(
142
+ llm=llm,
143
+ embed_model=embed_model,
144
+ callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
145
  )
146
 
147
  query_engine = index.as_query_engine(
requirements.txt CHANGED
@@ -2,3 +2,8 @@ chainlit
2
  llama-index
3
  trafilatura
4
  openai
 
 
 
 
 
 
2
  llama-index
3
  trafilatura
4
  openai
5
+ torch
6
+ transformers
7
+ accelerate
8
+ scipy
9
+ bitsandbytes