Somekindofa commited on
Commit
f7f7d8c
·
1 Parent(s): 23caf6d

Feat/ Implemented CoT

Browse files
Files changed (1) hide show
  1. app.py +118 -45
app.py CHANGED
@@ -5,7 +5,7 @@ import accelerate
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  import torch
8
- from typing import Optional, Iterator
9
  from threading import Thread
10
  from types import NoneType
11
 
@@ -827,6 +827,18 @@ if torch.cuda.is_available():
827
  device_map="auto")
828
  tokenizer = AutoTokenizer.from_pretrained(model_id)
829
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  @spaces.GPU
831
  def generate(
832
  message: str,
@@ -834,57 +846,120 @@ def generate(
834
  knowledge: str, # added knowledge parameter
835
  system_prompt: str = DEFAULT_SYSTEM_PROMPT,
836
  max_new_tokens: int = 1024,
837
- temperature: float = 0.6,
838
- top_p: float = 0.9,
839
  top_k: int = 50,
840
- repetition_penalty: float = 1.2
841
- ) -> Iterator[str]:
842
  try:
843
- conversation = []
844
  if system_prompt:
845
- conversation.append({"role": "system", "content": system_prompt})
 
846
  if knowledge:
847
- conversation.append({"role": "assistant", "content": f"This is your knowledge: {knowledge}"})
848
-
849
- conversation += chat_history
850
- conversation.append({"role": "user", "content": message})
851
-
852
- input_ids = tokenizer.apply_chat_template(conversation,
853
- return_tensors="pt",
854
- add_generation_prompt=True)
855
-
856
- input_ids = input_ids.to(model.device)
 
 
 
 
 
 
 
 
 
 
 
857
 
 
 
858
 
859
- streamer = TextIteratorStreamer(tokenizer,
860
- timeout=2*60.0,
861
- skip_prompt=True,
862
- skip_special_tokens=True)
 
 
863
 
864
- generate_kwargs = dict(
865
- {"input_ids": input_ids},
866
- streamer=streamer,
867
  max_new_tokens=max_new_tokens,
868
- do_sample=True,
869
  top_p=top_p,
870
  top_k=top_k,
871
- temperature=temperature,
872
- num_beams=1,
873
- repetition_penalty=repetition_penalty,
874
- pad_token_id=tokenizer.eos_token_id,
875
  )
876
 
877
- t = Thread(target=model.generate,
878
- kwargs=generate_kwargs)
879
- t.start()
880
-
881
- outputs = []
882
- for text in streamer:
883
- outputs.append(text)
884
- yield "".join(outputs)
885
-
886
  except Exception as e:
887
- yield f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
  def append_text_knowledge(file_path: str) -> str:
890
  """
@@ -910,6 +985,7 @@ knowledge_textbox = gr.Textbox(
910
  lines= 20,
911
  visible=False
912
  )
 
913
  chat_interface = gr.ChatInterface(
914
  fn=generate,
915
  type="messages",
@@ -929,14 +1005,14 @@ chat_interface = gr.ChatInterface(
929
  minimum=0.1,
930
  maximum=4.0,
931
  step=0.1,
932
- value=0.6,
933
  ),
934
  gr.Slider(
935
  label="Top-p (nucleus sampling)",
936
  minimum=0.05,
937
  maximum=1.0,
938
  step=0.05,
939
- value=0.9,
940
  ),
941
  gr.Slider(
942
  label="Top-k",
@@ -950,13 +1026,10 @@ chat_interface = gr.ChatInterface(
950
  minimum=1.0,
951
  maximum=2.0,
952
  step=0.05,
953
- value=1.2,
954
  ),
955
  ],
956
  stop_btn=True,
957
- examples=[
958
- ["In bullet-points, give me the classes from that Turtle ontology :"]
959
- ],
960
  cache_examples=False,
961
  show_progress="full",
962
  run_examples_on_click=False
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  import torch
8
+ from typing import Optional, Iterator, Dict, Any, List
9
  from threading import Thread
10
  from types import NoneType
11
 
 
827
  device_map="auto")
828
  tokenizer = AutoTokenizer.from_pretrained(model_id)
829
 
830
+ # New helper function to create a thinking message
831
+ def create_thinking_message(content: str, status: str = None) -> Dict[str, Any]:
832
+ """Creates a thinking message with metadata for display in the chatbot."""
833
+ return {
834
+ "role": "assistant",
835
+ "content": content,
836
+ "metadata": {
837
+ "title": "🧠 Réflexion",
838
+ "status": status
839
+ }
840
+ }
841
+
842
  @spaces.GPU
843
  def generate(
844
  message: str,
 
846
  knowledge: str, # added knowledge parameter
847
  system_prompt: str = DEFAULT_SYSTEM_PROMPT,
848
  max_new_tokens: int = 1024,
849
+ temperature: float = 0.2,
850
+ top_p: float = 0.8,
851
  top_k: int = 50,
852
+ repetition_penalty: float = 1.0
853
+ ) -> Iterator[Dict[str, Any]]:
854
  try:
855
+ thinking_conversation = []
856
  if system_prompt:
857
+ thinking_conversation.append({"role": "system",
858
+ "content": system_prompt})
859
  if knowledge:
860
+ thinking_conversation.append({"role": "assistant",
861
+ "content": f"Voici l'ontologie existante que je dois comprendre: {knowledge}\n\nJe vais l'analyser étape par étape."})
862
+
863
+ thinking_conversation += chat_history
864
+ # Thinking prompt
865
+ thinking_prompt = message + "\n\nRéfléchis étape par étape. Identifie d'abord les entités, puis les relations, puis organise hiérarchiquement avant de formaliser."
866
+
867
+ thinking_conversation.append({"role": "user",
868
+ "content": thinking_prompt})
869
+ thinking_message = create_thinking_message(content="Réflexion en cours...",
870
+ status="pending")
871
+ yield thinking_message
872
+
873
+ thinking_result = generate_llm_response(
874
+ thinking_conversation,
875
+ max_new_tokens=max_new_tokens * 2,
876
+ temperature=temperature,
877
+ top_p=top_p,
878
+ top_k=top_k,
879
+ repetition_penalty=repetition_penalty
880
+ )
881
 
882
+ thinking_message = create_thinking_message(thinking_result, status="done")
883
+ yield thinking_message
884
 
885
+ # Final Answer
886
+ final_conversation = []
887
+ final_conversation.append({"role": "system", "content": system_prompt})
888
+ if knowledge:
889
+ final_conversation.append({"role": "assistant", "content": f"J'ai analysé ce texte: {knowledge}"})
890
+ final_conversation += chat_history
891
 
892
+ final_answer = generate_llm_response(
893
+ final_conversation,
 
894
  max_new_tokens=max_new_tokens,
895
+ temperature=temperature * 0.8, # Even lower temperature for final answer
896
  top_p=top_p,
897
  top_k=top_k,
898
+ repetition_penalty=repetition_penalty
 
 
 
899
  )
900
 
901
+ # Yield the final answer
902
+ yield {
903
+ "role": "assistant",
904
+ "content": final_answer
905
+ }
 
 
 
 
906
  except Exception as e:
907
+ yield {
908
+ "role": "assistant",
909
+ "content": f"An error occurred: {str(e)}"
910
+ }
911
+
912
+ # Helper function to generate responses from the LLM
913
+ def generate_llm_response(
914
+ conversation: List[Dict[str, str]],
915
+ max_new_tokens: int,
916
+ temperature: float,
917
+ top_p: float,
918
+ top_k: int,
919
+ repetition_penalty: float
920
+ ) -> str:
921
+ """Generate a response from the LLM based on the conversation."""
922
+ input_ids = tokenizer.apply_chat_template(
923
+ conversation,
924
+ return_tensors="pt",
925
+ add_generation_prompt=True
926
+ )
927
+
928
+ input_ids = input_ids.to(model.device)
929
+
930
+ streamer = TextIteratorStreamer(
931
+ tokenizer,
932
+ timeout=2*60.0,
933
+ skip_prompt=True,
934
+ skip_special_tokens=True
935
+ )
936
+
937
+ generate_kwargs = dict(
938
+ {"input_ids": input_ids},
939
+ streamer=streamer,
940
+ max_new_tokens=max_new_tokens,
941
+ do_sample=True,
942
+ top_p=top_p,
943
+ top_k=top_k,
944
+ temperature=temperature,
945
+ num_beams=1,
946
+ repetition_penalty=repetition_penalty,
947
+ pad_token_id=tokenizer.eos_token_id,
948
+ )
949
+
950
+ t = Thread(
951
+ target=model.generate,
952
+ kwargs=generate_kwargs
953
+ )
954
+ t.start()
955
+
956
+ # Collect the output
957
+ outputs = []
958
+ for text in streamer:
959
+ outputs.append(text)
960
+
961
+ return "".join(outputs)
962
+
963
 
964
  def append_text_knowledge(file_path: str) -> str:
965
  """
 
985
  lines= 20,
986
  visible=False
987
  )
988
+
989
  chat_interface = gr.ChatInterface(
990
  fn=generate,
991
  type="messages",
 
1005
  minimum=0.1,
1006
  maximum=4.0,
1007
  step=0.1,
1008
+ value=0.2,
1009
  ),
1010
  gr.Slider(
1011
  label="Top-p (nucleus sampling)",
1012
  minimum=0.05,
1013
  maximum=1.0,
1014
  step=0.05,
1015
+ value=0.8,
1016
  ),
1017
  gr.Slider(
1018
  label="Top-k",
 
1026
  minimum=1.0,
1027
  maximum=2.0,
1028
  step=0.05,
1029
+ value=1.0,
1030
  ),
1031
  ],
1032
  stop_btn=True,
 
 
 
1033
  cache_examples=False,
1034
  show_progress="full",
1035
  run_examples_on_click=False