MatteoScript commited on
Commit
a644f61
·
verified ·
1 Parent(s): 347ae61

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -1
main.py CHANGED
@@ -11,6 +11,7 @@ import socket
11
  import time
12
  from enum import Enum
13
  import random
 
14
 
15
  #--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
16
  app = FastAPI()
@@ -35,7 +36,8 @@ class PostSpazio(BaseModel):
35
  nomeSpazio: str
36
  input: str = ''
37
  api_name: str = "/chat"
38
-
 
39
  #--------------------------------------------------- Generazione TESTO ------------------------------------------------------
40
  @app.post("/Genera")
41
  def read_root(request: Request, input_data: InputData):
@@ -74,6 +76,36 @@ def format_prompt(message, history):
74
  prompt += f"[{now}] [INST] {message} [/INST]"
75
  return prompt
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  #--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
79
  style_image = {
 
11
  import time
12
  from enum import Enum
13
  import random
14
+ import asyncio
15
 
16
  #--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
17
  app = FastAPI()
 
36
  nomeSpazio: str
37
  input: str = ''
38
  api_name: str = "/chat"
39
+
40
+
41
  #--------------------------------------------------- Generazione TESTO ------------------------------------------------------
42
  @app.post("/Genera")
43
  def read_root(request: Request, input_data: InputData):
 
76
  prompt += f"[{now}] [INST] {message} [/INST]"
77
  return prompt
78
 
79
+ #--------------------------------------------------- Generazione TESTO Asincrono ------------------------------------------------------
80
+ @app.post("/GeneraAsincrono")
81
+ def read_root_async(request: Request, input_data: InputData):
82
+ input_text = input_data.input
83
+ temperature = input_data.temperature
84
+ max_new_tokens = input_data.max_new_tokens
85
+ top_p = input_data.top_p
86
+ repetition_penalty = input_data.repetition_penalty
87
+ history = []
88
+ async with aiohttp.ClientSession() as session:
89
+ tasks = [generate_async(input_data.input, history, input_data.temperature, input_data.max_new_tokens, input_data.top_p, input_data.repetition_penalty) for _ in range(10)]
90
+ responses = await asyncio.gather(*tasks)
91
+ return {"responses": responses}
92
+
93
+ async def generate_async(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
94
+ temperature = float(temperature)
95
+ if temperature < 1e-2:
96
+ temperature = 1e-2
97
+ top_p = float(top_p)
98
+ generate_kwargs = dict(
99
+ temperature=temperature,
100
+ max_new_tokens=max_new_tokens,
101
+ top_p=top_p,
102
+ repetition_penalty=repetition_penalty,
103
+ do_sample=True,
104
+ seed=42,
105
+ )
106
+ formatted_prompt = format_prompt(prompt, history)
107
+ output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
108
+ return output
109
 
110
  #--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
111
  style_image = {