import json from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from .configuration_phi import PhiConfig from .modeling_phi import PhiForCausalLM # This works, but is not streaming """ if __name__ == "__main__": device = "cuda" model_config = PhiConfig(**json.load(open("simplified_phi2/config.json"))) model = PhiForCausalLM(model_config).to(device) phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) model.load_state_dict(phi_model.state_dict()) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) text = "Write an essay on sea monkeys: " tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device) outputs = model.generate(**tokens, max_length=200) text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] print(text) """ # This is streaming, but does not work because you can't set trust_remote_code=True """ if __name__ == "__main__": client = InferenceClient(model="microsoft/phi-2") text = "How do you make cheese?" for token in client.text_generation(text, max_new_tokens=500, stream=True): print(token, end="") """ # This is trying the TextIteratorStreamer class if __name__ == "__main__": # make and load tokenizer, use tokenizer to initialize token_streamer tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) token_streamer = TextIteratorStreamer(tokenizer) # make model and run model.generate(streamer=TextIteratorStreamer) on a thread device = "cuda" model_config = PhiConfig(**json.load(open("simplified_phi2/config.json"))) model = PhiForCausalLM(model_config).to(device) phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) model.load_state_dict(phi_model.state_dict()) thread = Thread( target=model.generate, kwargs=dict( tokenizer( # returns a torch dictionary "Here is an essay on sea monkeys: ", return_tensors="pt", return_attention_mask=False, ).to(device), streamer=token_streamer, max_new_tokens=500, eos_token_id=tokenizer.eos_token_id, ), ) thread.start() # generate my_output = "" for new_token in token_streamer: my_output += new_token print(new_token, end="", flush=True) print() # check output expected_output = """Here is an essay on sea monkeys: Sea monkeys are a type of brine shrimp that are often sold as pets in kits. They are easy to care for and can be a fun addition to any aquarium. However, it is important to understand the proper care and feeding of sea monkeys to ensure their health and longevity. Sea monkeys are sensitive to changes in temperature, pH, and salinity. It is important to keep their environment at a consistent temperature of around 75-80 degrees Fahrenheit. The pH level should be between 7.5 and 8.5, and the salinity should be around 1.023. These conditions can be achieved by using a saltwater mix and a thermometer. Feeding sea monkeys is also important. They can be fed a variety of foods, including flakes, pellets, and freeze-dried bloodworms. It is important to feed them regularly, but not too much. Overfeeding can lead to health problems and a shorter lifespan. In addition to proper care and feeding, it is important to understand the different types of sea monkeys. There are several different species, including the brine shrimp, the fairy shrimp, and the tadpole shrimp. Each species has its own unique characteristics and requirements. Sea monkeys are also known for their ability to reproduce quickly. They can lay hundreds of eggs at a time, which can lead to a large population in a short amount of time. However, it is important to keep their environment clean and free of debris to prevent overcrowding and disease. In conclusion, sea monkeys are a fascinating and easy-to-care-for pet. However, it is important to understand their unique requirements and to provide them with proper care and feeding. By doing so, you can ensure their health and longevity and enjoy watching them grow and thrive in their aquarium.<|endoftext|>""" assert my_output == expected_output print("Test passed")