Somekindofa
commited on
Commit
·
f7f7d8c
1
Parent(s):
23caf6d
Feat/ Implemented CoT
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import accelerate
|
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
6 |
import os
|
7 |
import torch
|
8 |
-
from typing import Optional, Iterator
|
9 |
from threading import Thread
|
10 |
from types import NoneType
|
11 |
|
@@ -827,6 +827,18 @@ if torch.cuda.is_available():
|
|
827 |
device_map="auto")
|
828 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
830 |
@spaces.GPU
|
831 |
def generate(
|
832 |
message: str,
|
@@ -834,57 +846,120 @@ def generate(
|
|
834 |
knowledge: str, # added knowledge parameter
|
835 |
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
836 |
max_new_tokens: int = 1024,
|
837 |
-
temperature: float = 0.
|
838 |
-
top_p: float = 0.
|
839 |
top_k: int = 50,
|
840 |
-
repetition_penalty: float = 1.
|
841 |
-
) -> Iterator[str]:
|
842 |
try:
|
843 |
-
|
844 |
if system_prompt:
|
845 |
-
|
|
|
846 |
if knowledge:
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
|
|
|
|
|
858 |
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
|
|
|
|
863 |
|
864 |
-
|
865 |
-
|
866 |
-
streamer=streamer,
|
867 |
max_new_tokens=max_new_tokens,
|
868 |
-
|
869 |
top_p=top_p,
|
870 |
top_k=top_k,
|
871 |
-
|
872 |
-
num_beams=1,
|
873 |
-
repetition_penalty=repetition_penalty,
|
874 |
-
pad_token_id=tokenizer.eos_token_id,
|
875 |
)
|
876 |
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
for text in streamer:
|
883 |
-
outputs.append(text)
|
884 |
-
yield "".join(outputs)
|
885 |
-
|
886 |
except Exception as e:
|
887 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
888 |
|
889 |
def append_text_knowledge(file_path: str) -> str:
|
890 |
"""
|
@@ -910,6 +985,7 @@ knowledge_textbox = gr.Textbox(
|
|
910 |
lines= 20,
|
911 |
visible=False
|
912 |
)
|
|
|
913 |
chat_interface = gr.ChatInterface(
|
914 |
fn=generate,
|
915 |
type="messages",
|
@@ -929,14 +1005,14 @@ chat_interface = gr.ChatInterface(
|
|
929 |
minimum=0.1,
|
930 |
maximum=4.0,
|
931 |
step=0.1,
|
932 |
-
value=0.
|
933 |
),
|
934 |
gr.Slider(
|
935 |
label="Top-p (nucleus sampling)",
|
936 |
minimum=0.05,
|
937 |
maximum=1.0,
|
938 |
step=0.05,
|
939 |
-
value=0.
|
940 |
),
|
941 |
gr.Slider(
|
942 |
label="Top-k",
|
@@ -950,13 +1026,10 @@ chat_interface = gr.ChatInterface(
|
|
950 |
minimum=1.0,
|
951 |
maximum=2.0,
|
952 |
step=0.05,
|
953 |
-
value=1.
|
954 |
),
|
955 |
],
|
956 |
stop_btn=True,
|
957 |
-
examples=[
|
958 |
-
["In bullet-points, give me the classes from that Turtle ontology :"]
|
959 |
-
],
|
960 |
cache_examples=False,
|
961 |
show_progress="full",
|
962 |
run_examples_on_click=False
|
|
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
6 |
import os
|
7 |
import torch
|
8 |
+
from typing import Optional, Iterator, Dict, Any, List
|
9 |
from threading import Thread
|
10 |
from types import NoneType
|
11 |
|
|
|
827 |
device_map="auto")
|
828 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
829 |
|
830 |
+
# New helper function to create a thinking message
|
831 |
+
def create_thinking_message(content: str, status: str = None) -> Dict[str, Any]:
|
832 |
+
"""Creates a thinking message with metadata for display in the chatbot."""
|
833 |
+
return {
|
834 |
+
"role": "assistant",
|
835 |
+
"content": content,
|
836 |
+
"metadata": {
|
837 |
+
"title": "🧠 Réflexion",
|
838 |
+
"status": status
|
839 |
+
}
|
840 |
+
}
|
841 |
+
|
842 |
@spaces.GPU
|
843 |
def generate(
|
844 |
message: str,
|
|
|
846 |
knowledge: str, # added knowledge parameter
|
847 |
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
848 |
max_new_tokens: int = 1024,
|
849 |
+
temperature: float = 0.2,
|
850 |
+
top_p: float = 0.8,
|
851 |
top_k: int = 50,
|
852 |
+
repetition_penalty: float = 1.0
|
853 |
+
) -> Iterator[Dict[str, Any]]:
|
854 |
try:
|
855 |
+
thinking_conversation = []
|
856 |
if system_prompt:
|
857 |
+
thinking_conversation.append({"role": "system",
|
858 |
+
"content": system_prompt})
|
859 |
if knowledge:
|
860 |
+
thinking_conversation.append({"role": "assistant",
|
861 |
+
"content": f"Voici l'ontologie existante que je dois comprendre: {knowledge}\n\nJe vais l'analyser étape par étape."})
|
862 |
+
|
863 |
+
thinking_conversation += chat_history
|
864 |
+
# Thinking prompt
|
865 |
+
thinking_prompt = message + "\n\nRéfléchis étape par étape. Identifie d'abord les entités, puis les relations, puis organise hiérarchiquement avant de formaliser."
|
866 |
+
|
867 |
+
thinking_conversation.append({"role": "user",
|
868 |
+
"content": thinking_prompt})
|
869 |
+
thinking_message = create_thinking_message(content="Réflexion en cours...",
|
870 |
+
status="pending")
|
871 |
+
yield thinking_message
|
872 |
+
|
873 |
+
thinking_result = generate_llm_response(
|
874 |
+
thinking_conversation,
|
875 |
+
max_new_tokens=max_new_tokens * 2,
|
876 |
+
temperature=temperature,
|
877 |
+
top_p=top_p,
|
878 |
+
top_k=top_k,
|
879 |
+
repetition_penalty=repetition_penalty
|
880 |
+
)
|
881 |
|
882 |
+
thinking_message = create_thinking_message(thinking_result, status="done")
|
883 |
+
yield thinking_message
|
884 |
|
885 |
+
# Final Answer
|
886 |
+
final_conversation = []
|
887 |
+
final_conversation.append({"role": "system", "content": system_prompt})
|
888 |
+
if knowledge:
|
889 |
+
final_conversation.append({"role": "assistant", "content": f"J'ai analysé ce texte: {knowledge}"})
|
890 |
+
final_conversation += chat_history
|
891 |
|
892 |
+
final_answer = generate_llm_response(
|
893 |
+
final_conversation,
|
|
|
894 |
max_new_tokens=max_new_tokens,
|
895 |
+
temperature=temperature * 0.8, # Even lower temperature for final answer
|
896 |
top_p=top_p,
|
897 |
top_k=top_k,
|
898 |
+
repetition_penalty=repetition_penalty
|
|
|
|
|
|
|
899 |
)
|
900 |
|
901 |
+
# Yield the final answer
|
902 |
+
yield {
|
903 |
+
"role": "assistant",
|
904 |
+
"content": final_answer
|
905 |
+
}
|
|
|
|
|
|
|
|
|
906 |
except Exception as e:
|
907 |
+
yield {
|
908 |
+
"role": "assistant",
|
909 |
+
"content": f"An error occurred: {str(e)}"
|
910 |
+
}
|
911 |
+
|
912 |
+
# Helper function to generate responses from the LLM
|
913 |
+
def generate_llm_response(
|
914 |
+
conversation: List[Dict[str, str]],
|
915 |
+
max_new_tokens: int,
|
916 |
+
temperature: float,
|
917 |
+
top_p: float,
|
918 |
+
top_k: int,
|
919 |
+
repetition_penalty: float
|
920 |
+
) -> str:
|
921 |
+
"""Generate a response from the LLM based on the conversation."""
|
922 |
+
input_ids = tokenizer.apply_chat_template(
|
923 |
+
conversation,
|
924 |
+
return_tensors="pt",
|
925 |
+
add_generation_prompt=True
|
926 |
+
)
|
927 |
+
|
928 |
+
input_ids = input_ids.to(model.device)
|
929 |
+
|
930 |
+
streamer = TextIteratorStreamer(
|
931 |
+
tokenizer,
|
932 |
+
timeout=2*60.0,
|
933 |
+
skip_prompt=True,
|
934 |
+
skip_special_tokens=True
|
935 |
+
)
|
936 |
+
|
937 |
+
generate_kwargs = dict(
|
938 |
+
{"input_ids": input_ids},
|
939 |
+
streamer=streamer,
|
940 |
+
max_new_tokens=max_new_tokens,
|
941 |
+
do_sample=True,
|
942 |
+
top_p=top_p,
|
943 |
+
top_k=top_k,
|
944 |
+
temperature=temperature,
|
945 |
+
num_beams=1,
|
946 |
+
repetition_penalty=repetition_penalty,
|
947 |
+
pad_token_id=tokenizer.eos_token_id,
|
948 |
+
)
|
949 |
+
|
950 |
+
t = Thread(
|
951 |
+
target=model.generate,
|
952 |
+
kwargs=generate_kwargs
|
953 |
+
)
|
954 |
+
t.start()
|
955 |
+
|
956 |
+
# Collect the output
|
957 |
+
outputs = []
|
958 |
+
for text in streamer:
|
959 |
+
outputs.append(text)
|
960 |
+
|
961 |
+
return "".join(outputs)
|
962 |
+
|
963 |
|
964 |
def append_text_knowledge(file_path: str) -> str:
|
965 |
"""
|
|
|
985 |
lines= 20,
|
986 |
visible=False
|
987 |
)
|
988 |
+
|
989 |
chat_interface = gr.ChatInterface(
|
990 |
fn=generate,
|
991 |
type="messages",
|
|
|
1005 |
minimum=0.1,
|
1006 |
maximum=4.0,
|
1007 |
step=0.1,
|
1008 |
+
value=0.2,
|
1009 |
),
|
1010 |
gr.Slider(
|
1011 |
label="Top-p (nucleus sampling)",
|
1012 |
minimum=0.05,
|
1013 |
maximum=1.0,
|
1014 |
step=0.05,
|
1015 |
+
value=0.8,
|
1016 |
),
|
1017 |
gr.Slider(
|
1018 |
label="Top-k",
|
|
|
1026 |
minimum=1.0,
|
1027 |
maximum=2.0,
|
1028 |
step=0.05,
|
1029 |
+
value=1.0,
|
1030 |
),
|
1031 |
],
|
1032 |
stop_btn=True,
|
|
|
|
|
|
|
1033 |
cache_examples=False,
|
1034 |
show_progress="full",
|
1035 |
run_examples_on_click=False
|