Chris4K commited on
Commit
ba1082d
1 Parent(s): efe65ac

Update text_generator.py

Browse files
Files changed (1) hide show
  1. text_generator.py +18 -9
text_generator.py CHANGED
@@ -1,5 +1,5 @@
1
- import os
2
  import requests
 
3
 
4
  from transformers import Tool
5
  # Import other necessary libraries if needed
@@ -14,19 +14,28 @@ class TextGenerationTool(Tool):
14
  outputs = ["text"]
15
 
16
  def __call__(self, prompt: str):
17
- API_URL = "https://api-inference.huggingface.co/models/gpt2"
18
- headers = {"Authorization": "Bearer " + os.environ['hf']}
19
-
 
 
 
 
 
 
 
 
 
20
  # Define the payload for the request
21
- payload = {
22
- "inputs": prompt # Adjust this based on your model's input format
23
- }
24
 
25
  # Make the request to the API
26
- generated_text = requests.post(API_URL, headers=headers, json=payload).json()
27
 
28
  # Extract and return the generated text
29
- return generated_text["generated_text"]
30
 
31
  # Uncomment and customize the following lines based on your text generation needs
32
  # text_generator = pipeline(model="gpt2")
 
 
1
  import requests
2
+ import os
3
 
4
  from transformers import Tool
5
  # Import other necessary libraries if needed
 
14
  outputs = ["text"]
15
 
16
  def __call__(self, prompt: str):
17
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
18
+ headers = {"Authorization": + os.environ['hf']}
19
+
20
+ def query(payload):
21
+ response = requests.post(API_URL, headers=headers, json=payload)
22
+ print(response)
23
+ return response.json()
24
+
25
+ output = query({
26
+ "inputs": prompt,
27
+ })
28
+
29
  # Define the payload for the request
30
+ #payload = {
31
+ # "inputs": prompt # Adjust this based on your model's input format
32
+ #}
33
 
34
  # Make the request to the API
35
+ #generated_text = requests.post(API_URL, headers=headers, json=payload).json()
36
 
37
  # Extract and return the generated text
38
+ #return generated_text["generated_text"]
39
 
40
  # Uncomment and customize the following lines based on your text generation needs
41
  # text_generator = pipeline(model="gpt2")