Chris4K commited on
Commit
2c2fb5c
·
verified ·
1 Parent(s): eceeded

Update text_generator.py

Browse files
Files changed (1) hide show
  1. text_generator.py +106 -58
text_generator.py CHANGED
@@ -1,68 +1,116 @@
1
- import requests
2
  import os
3
- from transformers import pipeline
4
-
5
-
6
- from transformers import Tool
7
- # Import other necessary libraries if needed
8
 
9
  class TextGenerationTool(Tool):
10
  name = "text_generator"
11
- description = (
12
- "This is a tool for text generation. It takes a prompt as input and returns the generated text."
13
- )
14
-
15
  inputs = ["text"]
16
  outputs = ["text"]
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def __call__(self, prompt: str):
19
- #API_URL = "https://api-inference.huggingface.co/models/openchat/openchat_3.5"
20
- #headers = {"Authorization": "Bearer " + os.environ['hf']}
21
- token=os.environ['HF_token']
22
- #payload = {
23
- # "inputs": prompt # Adjust this based on your model's input format
24
- #}
25
-
26
- #payload = {
27
- # "inputs": "Can you please let us know more details about your ",
28
- # }
29
 
30
- #def query(payload):
31
- #generated_text = requests.post(API_URL, headers=headers, json=payload).json()
32
- #print(generated_text)
33
- #return generated_text["text"]
34
-
35
- # Replace the following line with your text generation logic
36
- #generated_text = f"Generated text based on the prompt: '{prompt}'"
37
-
38
- # Initialize the text generation pipeline
39
- #text_generator = pipeline(model="lgaalves/gpt2-dolly", token=token)
40
- text_generator = pipeline(model="microsoft/Orca-2-13b", token=token)
41
-
42
- # Generate text based on a prompt
43
- generated_text = text_generator(prompt, max_length=500, num_return_sequences=1, temperature=0.7)
44
-
45
- # Print the generated text
46
- print(generated_text)
47
-
48
-
49
-
50
- return generated_text
 
 
 
 
 
51
 
52
- # Define the payload for the request
53
- #payload = {
54
- # "inputs": prompt # Adjust this based on your model's input format
55
- #}
56
-
57
- # Make the request to the API
58
- #generated_text = requests.post(API_URL, headers=headers, json=payload).json()
59
-
60
  # Extract and return the generated text
61
- #return generated_text["generated_text"]
62
-
63
- # Uncomment and customize the following lines based on your text generation needs
64
- # text_generator = pipeline(model="gpt2")
65
- # generated_text = text_generator(prompt, max_length=500, num_return_sequences=1, temperature=0.7)
66
-
67
- # Print the generated text if needed
68
- # print(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import requests
3
+ import gradio as gr
4
+ from transformers import pipeline, Tool
 
 
5
 
6
  class TextGenerationTool(Tool):
7
  name = "text_generator"
8
+ description = "This is a tool for text generation. It takes a prompt as input and returns the generated text."
9
+
 
 
10
  inputs = ["text"]
11
  outputs = ["text"]
12
+
13
+ # Available text generation models
14
+ models = {
15
+ "orca": "microsoft/Orca-2-13b",
16
+ "gpt2-dolly": "lgaalves/gpt2-dolly",
17
+ "gpt2": "gpt2",
18
+ "bloom": "bigscience/bloom-560m",
19
+ "openchat": "openchat/openchat_3.5"
20
+ }
21
+
22
+ def __init__(self, default_model="orca", use_api=False):
23
+ """Initialize with a default model and API preference."""
24
+ self.default_model = default_model
25
+ self.use_api = use_api
26
+ self._pipelines = {}
27
+
28
+ # Check for API token
29
+ self.token = os.environ.get('HF_token')
30
+ if self.token is None and use_api:
31
+ print("Warning: HF_token environment variable not set. API calls will fail.")
32
+
33
  def __call__(self, prompt: str):
34
+ """Process the input prompt and generate text."""
35
+ return self.generate_text(prompt)
36
+
37
+ def generate_text(self, prompt, model_key=None, max_length=500, temperature=0.7):
38
+ """Generate text based on the prompt using the specified or default model."""
39
+ # Determine which model to use
40
+ model_key = model_key or self.default_model
41
+ model_name = self.models.get(model_key, self.models[self.default_model])
 
 
42
 
43
+ # Generate using API if specified
44
+ if self.use_api and model_key == "openchat":
45
+ return self._generate_via_api(prompt, model_name)
46
+
47
+ # Otherwise use local pipeline
48
+ return self._generate_via_pipeline(prompt, model_name, max_length, temperature)
49
+
50
+ def _generate_via_pipeline(self, prompt, model_name, max_length, temperature):
51
+ """Generate text using a local pipeline."""
52
+ # Get or create the pipeline
53
+ if model_name not in self._pipelines:
54
+ self._pipelines[model_name] = pipeline(
55
+ "text-generation",
56
+ model=model_name,
57
+ token=self.token
58
+ )
59
+
60
+ generator = self._pipelines[model_name]
61
+
62
+ # Generate text
63
+ result = generator(
64
+ prompt,
65
+ max_length=max_length,
66
+ num_return_sequences=1,
67
+ temperature=temperature
68
+ )
69
 
 
 
 
 
 
 
 
 
70
  # Extract and return the generated text
71
+ if isinstance(result, list) and len(result) > 0:
72
+ if isinstance(result[0], dict) and 'generated_text' in result[0]:
73
+ return result[0]['generated_text']
74
+ return result[0]
75
+
76
+ return str(result)
77
+
78
+ def _generate_via_api(self, prompt, model_name):
79
+ """Generate text by calling the Hugging Face API."""
80
+ if not self.token:
81
+ return "Error: HF_token not set. Cannot use API."
82
+
83
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
84
+ headers = {"Authorization": f"Bearer {self.token}"}
85
+ payload = {"inputs": prompt}
86
+
87
+ try:
88
+ response = requests.post(api_url, headers=headers, json=payload)
89
+ response.raise_for_status() # Raise exception for HTTP errors
90
+
91
+ result = response.json()
92
+
93
+ # Handle different response formats
94
+ if isinstance(result, list) and len(result) > 0:
95
+ if isinstance(result[0], dict) and 'generated_text' in result[0]:
96
+ return result[0]['generated_text']
97
+ elif isinstance(result, dict) and 'generated_text' in result:
98
+ return result['generated_text']
99
+
100
+ # Fall back to returning the raw response
101
+ return str(result)
102
+
103
+ except Exception as e:
104
+ return f"Error generating text: {str(e)}"
105
+
106
+ # For standalone testing
107
+ if __name__ == "__main__":
108
+ # Create an instance of the TextGenerationTool
109
+ text_generator = TextGenerationTool(default_model="gpt2")
110
+
111
+ # Test with a simple prompt
112
+ test_prompt = "Once upon a time in a digital world,"
113
+ result = text_generator(test_prompt)
114
+
115
+ print(f"Prompt: {test_prompt}")
116
+ print(f"Generated text:\n{result}")