sillon commited on
Commit
01ad8f8
1 Parent(s): bfbbbf7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -37
README.md CHANGED
@@ -1,38 +1,22 @@
1
- ```
2
- Copied
3
- import json
4
- import requests
5
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
6
- API_URL = "https://api-inference.huggingface.co/models/sillon/DialoGPT-small-HospitalBot"
7
- def query(payload):
8
- data = json.dumps(payload)
9
- response = requests.request("POST", API_URL, headers=headers, data=data)
10
- return json.loads(response.content.decode("utf-8"))
11
- data = query(
12
- {
13
- "inputs": {
14
- "past_user_inputs": ["Which movie is the best ?"],
15
- "generated_responses": ["It's Die Hard for sure."],
16
- "text": "Can you explain why ?",
17
- },
18
- }
19
- )
20
- # Response
21
- self.assertEqual(
22
- data,
23
- {
24
- "generated_text": "It's the best movie ever.",
25
- "conversation": {
26
- "past_user_inputs": [
27
- "Which movie is the best ?",
28
- "Can you explain why ?",
29
- ],
30
- "generated_responses": [
31
- "It's Die Hard for sure.",
32
- "It's the best movie ever.",
33
- ],
34
- },
35
- "warnings": ["Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation."],
36
- },
37
- )
38
  ```
 
1
+ ```python
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("sillon/DialoGPT-small-HospitalBot")
7
+ model = AutoModelForCausalLM.from_pretrained("sillon/DialoGPT-small-HospitalBot")
8
+
9
+ # Let's chat for 5 lines
10
+ for step in range(5):
11
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
12
+ new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
13
+
14
+ # append the new user input tokens to the chat history
15
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
16
+
17
+ # generated a response while limiting the total chat history to 1000 tokens,
18
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
19
+
20
+ # pretty print last ouput tokens from bot
21
+ print("HospitalBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ```