Spaces:
Runtime error
Runtime error
Commit
·
1839808
1
Parent(s):
b6eb5c9
make temperature and max_length parameters
Browse files
main.py
CHANGED
@@ -35,18 +35,19 @@ token_size_limit = 512
|
|
35 |
app = FastAPI()
|
36 |
|
37 |
|
|
|
38 |
@app.post('/reply')
|
39 |
async def Reply(req: Request):
|
40 |
request = await req.json()
|
41 |
-
msg = request
|
42 |
print(f'MSG: {msg}')
|
43 |
|
44 |
input_ids = tokenizer(msg, return_tensors='pt').input_ids # .to('cuda')
|
45 |
output = model.generate(
|
46 |
input_ids[:, -token_size_limit:],
|
47 |
do_sample=True,
|
48 |
-
temperature=0.9,
|
49 |
-
max_length=100,
|
50 |
)
|
51 |
reply = tokenizer.batch_decode(output)[0]
|
52 |
print(f'REPLY: {reply}')
|
|
|
35 |
app = FastAPI()
|
36 |
|
37 |
|
38 |
+
# { msg: string, temperature: float, max_length: number }
|
39 |
@app.post('/reply')
|
40 |
async def Reply(req: Request):
|
41 |
request = await req.json()
|
42 |
+
msg = request.get('msg')
|
43 |
print(f'MSG: {msg}')
|
44 |
|
45 |
input_ids = tokenizer(msg, return_tensors='pt').input_ids # .to('cuda')
|
46 |
output = model.generate(
|
47 |
input_ids[:, -token_size_limit:],
|
48 |
do_sample=True,
|
49 |
+
temperature=request.get('temperature', 0.9),
|
50 |
+
max_length=request.get('max_length', 100),
|
51 |
)
|
52 |
reply = tokenizer.batch_decode(output)[0]
|
53 |
print(f'REPLY: {reply}')
|