Spaces:
Sleeping
Sleeping
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import logging | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.StreamHandler()) | |
class LocalModel: | |
def __init__(self, model_name: str, max_tokens: int, temperature: float): | |
self.max_tokens = max_tokens | |
self.temperature = temperature | |
# Load the model locally. For a demo, you may choose a lighter model if needed. | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, torch_dtype=torch.float16 | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.pipeline = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
) | |
def __call__(self, prompt: str, **kwargs) -> str: | |
# Adjust the call signature as needed by your agent | |
result = self.pipeline( | |
prompt, | |
max_new_tokens=self.max_tokens, | |
temperature=self.temperature, | |
**kwargs, | |
) | |
output = result[0]["generated_text"] | |
logger.info(f"Model output: {output}") | |
# Assuming the result is a list with one dict containing the generated text: | |
return result[0]["generated_text"] | |
if __name__ == "__main__": | |
local_model = LocalModel("Qwen/Qwen2.5-1.5B", max_tokens=100, temperature=0.5) | |
output = local_model("A big foot") | |
print(output) | |