change model to llm write
Browse files- app.py +92 -54
- requirements.txt +10 -3
app.py
CHANGED
@@ -1,20 +1,29 @@
|
|
1 |
#load package
|
2 |
-
from fastapi import FastAPI
|
3 |
from pydantic import BaseModel
|
4 |
-
import uvicorn
|
5 |
-
import logging
|
6 |
import torch
|
7 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
# Configurer les répertoires de cache
|
12 |
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
|
13 |
os.environ['HF_HOME'] = '/app/.cache'
|
14 |
# Charger le modèle et le tokenizer
|
15 |
-
|
16 |
-
tokenizer =
|
17 |
-
|
18 |
|
19 |
#Additional information
|
20 |
|
@@ -34,57 +43,86 @@ app =FastAPI(
|
|
34 |
logging.basicConfig(level=logging.INFO)
|
35 |
logger =logging.getLogger(__name__)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
default_prompt = """Bonjour,
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
2) Rajoutes les informations relatives au Client pour être précis sur la connaissance de ce dernier.
|
50 |
-
3) Rajoutes des éléments de dates (remontée, transfert, prise en charge, résolution, clôture, etc…) ainsi que les délais (par exemple de réponse des différents acteurs ou experts de la chaine de traitement) pour mieux apprécier l'efficacité du traitement de la plainte.
|
51 |
-
4) Rajoutes à la fin une recommandation importante afin d'éviter le mécontentement du Client par exemple pour éviter qu’une Plainte ne soit clôturée sans solution pour le Client notamment et à titre illustratif seulement dans certains cas pour un Client qui a payé pour un service et ne l'a pas obtenu, On ne peut décemment pas clôturer sa plainte sans solution en lui disant d’être plus vigilant, il faut recommander à l’équipe en charge de la plainte de le rembourser ou de trouver un moyen de donner au Client le service pour lequel il a payé (à défaut de le rembourser).
|
52 |
-
5) N’hésites pas à innover sur le ton à utiliser car n’oublies pas que tu dois faire comme si tu parlais à un humain. Ce ton peut être adapté et ne pas toujours être le même en fonction des cas.
|
53 |
"""
|
54 |
-
class
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if __name__ == "__main__":
|
86 |
uvicorn.run("app:app",reload=True)
|
87 |
|
88 |
|
89 |
-
|
90 |
|
|
|
1 |
#load package
|
2 |
+
from fastapi import FastAPI
|
3 |
from pydantic import BaseModel
|
|
|
|
|
4 |
import torch
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
AutoTokenizer,
|
8 |
+
StoppingCriteria,
|
9 |
+
StoppingCriteriaList,
|
10 |
+
TextIteratorStreamer
|
11 |
+
)
|
12 |
+
from typing import List, Tuple
|
13 |
+
from threading import Thread
|
14 |
import os
|
15 |
+
from pydantic import BaseModel
|
16 |
+
import logging
|
17 |
+
import uvicorn
|
18 |
+
|
19 |
|
20 |
# Configurer les répertoires de cache
|
21 |
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
|
22 |
os.environ['HF_HOME'] = '/app/.cache'
|
23 |
# Charger le modèle et le tokenizer
|
24 |
+
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
|
26 |
+
|
27 |
|
28 |
#Additional information
|
29 |
|
|
|
43 |
logging.basicConfig(level=logging.INFO)
|
44 |
logger =logging.getLogger(__name__)
|
45 |
|
46 |
+
class StopOnTokens(StoppingCriteria):
|
47 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
48 |
+
stop_ids = model.config.eos_token_id
|
49 |
+
for stop_id in stop_ids:
|
50 |
+
if input_ids[0][-1] == stop_id:
|
51 |
+
return True
|
52 |
+
return False
|
53 |
|
54 |
default_prompt = """Bonjour,
|
55 |
+
|
56 |
+
En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
|
57 |
+
1. **Informations Client** : Indique des détails pertinents sur le client.
|
58 |
+
2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
|
59 |
+
3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).
|
60 |
+
|
61 |
+
Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.
|
62 |
+
|
63 |
+
Merci !
|
64 |
+
|
|
|
|
|
|
|
|
|
65 |
"""
|
66 |
+
class PredictionRequest(BaseModel):
|
67 |
+
history: List[Tuple[str, str]] = []
|
68 |
+
prompt: str = default_prompt
|
69 |
+
max_length: int = 128000
|
70 |
+
top_p: float = 0.8
|
71 |
+
temperature: float = 0.6
|
72 |
+
@app.post("/generate/")
|
73 |
+
async def predict(request: PredictionRequest):
|
74 |
+
history = request.history
|
75 |
+
prompt = request.prompt
|
76 |
+
max_length = request.max_length
|
77 |
+
top_p = request.top_p
|
78 |
+
temperature = request.temperature
|
79 |
+
|
80 |
+
stop = StopOnTokens()
|
81 |
+
messages = []
|
82 |
+
if prompt:
|
83 |
+
messages.append({"role": "system", "content": prompt})
|
84 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
85 |
+
if prompt and idx == 0:
|
86 |
+
continue
|
87 |
+
if idx == len(history) - 1 and not model_msg:
|
88 |
+
query = user_msg
|
89 |
+
break
|
90 |
+
if user_msg:
|
91 |
+
messages.append({"role": "user", "content": user_msg})
|
92 |
+
if model_msg:
|
93 |
+
messages.append({"role": "assistant", "content": model_msg})
|
94 |
+
|
95 |
+
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
|
96 |
+
next(model.parameters()).device)
|
97 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
|
98 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
99 |
+
tokenizer.get_command("<|observation|>")]
|
100 |
+
generate_kwargs = {
|
101 |
+
"input_ids": model_inputs,
|
102 |
+
"streamer": streamer,
|
103 |
+
"max_new_tokens": max_length,
|
104 |
+
"do_sample": True,
|
105 |
+
"top_p": top_p,
|
106 |
+
"temperature": temperature,
|
107 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
108 |
+
"repetition_penalty": 1,
|
109 |
+
"eos_token_id": eos_token_id,
|
110 |
+
}
|
111 |
+
|
112 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
113 |
+
t.start()
|
114 |
+
|
115 |
+
generated_text = ""
|
116 |
+
for new_token in streamer:
|
117 |
+
if new_token and '<|user|>' in new_token:
|
118 |
+
new_token = new_token.split('<|user|>')[0]
|
119 |
+
if new_token:
|
120 |
+
generated_text += new_token
|
121 |
+
history[-1][1] = generated_text
|
122 |
+
|
123 |
+
return {"history": history}
|
124 |
if __name__ == "__main__":
|
125 |
uvicorn.run("app:app",reload=True)
|
126 |
|
127 |
|
|
|
128 |
|
requirements.txt
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
fastapi==0.111.0
|
2 |
-
torch==2.
|
3 |
-
transformers==4.44.
|
4 |
uvicorn==0.30.1
|
5 |
pydantic==2.7.4
|
6 |
pillow==10.3.0
|
7 |
numpy
|
8 |
scipy==1.11.3
|
9 |
-
sentencepiece==0.2.0
|
10 |
pytesseract==0.3.10
|
11 |
Pillow==10.3.0
|
12 |
BeautifulSoup4==4.12.3
|
13 |
protobuf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
fastapi==0.111.0
|
2 |
+
torch==2.2.0
|
3 |
+
transformers==4.44.0
|
4 |
uvicorn==0.30.1
|
5 |
pydantic==2.7.4
|
6 |
pillow==10.3.0
|
7 |
numpy
|
8 |
scipy==1.11.3
|
|
|
9 |
pytesseract==0.3.10
|
10 |
Pillow==10.3.0
|
11 |
BeautifulSoup4==4.12.3
|
12 |
protobuf
|
13 |
+
spaces==0.29.2
|
14 |
+
accelerate==0.33.0
|
15 |
+
sentencepiece==0.2.0
|
16 |
+
huggingface-hub==0.24.5
|
17 |
+
jinja2==3.1.4
|
18 |
+
sentence_transformers==3.0.1
|
19 |
+
tiktoken==0.7.0
|
20 |
+
einops==0.8.0
|