gufett0 commited on
Commit
baf000f
·
1 Parent(s): ed51056

added async

Browse files
Files changed (1) hide show
  1. backend.py +21 -32
backend.py CHANGED
@@ -24,11 +24,14 @@ model_id = "google/gemma-2-2b-it"
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
- device_map="auto",
28
  torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
29
- token=True
30
- )
31
  model.eval()
 
 
 
 
32
  # what models will be used by LlamaIndex:
33
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
34
  Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
@@ -54,8 +57,7 @@ def build_index():
54
 
55
 
56
  @spaces.GPU(duration=20)
57
- def handle_query(query_str, chathistory):
58
-
59
  index = build_index()
60
 
61
  qa_prompt_str = (
@@ -71,45 +73,32 @@ def handle_query(query_str, chathistory):
71
  chat_text_qa_msgs = [
72
  (
73
  "system",
74
- "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
75
  ),
76
  ("user", qa_prompt_str),
77
  ]
78
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
79
 
80
  try:
81
- # Create a streaming query engine
82
- """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
83
-
84
- # Execute the query
85
- streaming_response = query_engine.query(query_str)
86
-
87
- r = streaming_response.response
88
- cleaned_result = r.replace("<end_of_turn>", "").strip()
89
- yield cleaned_result"""
90
-
91
- # Stream the response
92
- """outputs = []
93
- for text in streaming_response.response_gen:
94
-
95
- outputs.append(str(text))
96
- yield "".join(outputs)"""
97
-
98
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
99
  chat_engine = index.as_chat_engine(
100
- chat_mode="context",
101
- memory=memory,
102
- system_prompt=(
103
- "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. "
104
- ),
105
  )
106
 
 
107
  response = chat_engine.stream_chat(query_str)
108
- #response = chat_engine.chat(query_str)
109
- for token in response.response_gen:
110
- yield token
111
-
 
112
 
 
 
113
  except Exception as e:
114
  yield f"Error processing query: {str(e)}"
115
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
+ device_map="auto", ## change this back to auto!!!
28
  torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
29
+ token=True)
 
30
  model.eval()
31
+
32
+ #from accelerate import disk_offload
33
+ #disk_offload(model=model, offload_dir="offload")
34
+
35
  # what models will be used by LlamaIndex:
36
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
  Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
 
57
 
58
 
59
  @spaces.GPU(duration=20)
60
+ async def handle_query(query_str, chathistory):
 
61
  index = build_index()
62
 
63
  qa_prompt_str = (
 
73
  chat_text_qa_msgs = [
74
  (
75
  "system",
76
+ "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti.",
77
  ),
78
  ("user", qa_prompt_str),
79
  ]
80
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
81
 
82
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
84
  chat_engine = index.as_chat_engine(
85
+ chat_mode="context",
86
+ memory=memory,
87
+ system_prompt=(
88
+ "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti."
89
+ ),
90
  )
91
 
92
+ # Stream the response
93
  response = chat_engine.stream_chat(query_str)
94
+ outputs = []
95
+
96
+ async for token in response.response_gen:
97
+ outputs.append(token)
98
+ yield "".join(outputs)
99
 
100
+ except StopAsyncIteration:
101
+ yield "No more responses to stream."
102
  except Exception as e:
103
  yield f"Error processing query: {str(e)}"
104