from transformers import AutoModelForCausalLM, AutoTokenizer import transformers import torch import gradio as gr #Setting device to cuda torch.set_default_device("cuda") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = "deepapaikar/katzbot-phi2" # pipeline = transformers.pipeline( # "text-generation", # model=model, # torch_dtype=torch.float16, # ) tokenizer = AutoTokenizer.from_pretrained(model) # def predict_answer(question, token=25): # messages = [{"role": "user", "content": f"{question}"}] # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).to(device) # outputs = pipeline(prompt, max_new_tokens=token, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) # return outputs[0]["generated_text"] def predict_answer(question, token=25): messages = [{"role": "user", "content": f"{question}"}] # Generate prompt text using the chat template prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Tokenize the prompt text to create input IDs suitable for the model inputs = tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True) # Move the tensor to the specified device inputs = {k: v.to(device) for k, v in inputs.items()} # Use the model directly for inference model.eval() # Ensure the model is in evaluation mode model.to(device) # Ensure the model is on the correct device # Generate outputs output_sequences = model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=token + inputs['input_ids'].shape[-1], # Adjust max_length accordingly do_sample=True, temperature=0.7, top_k=50, top_p=0.95 ) # Decode the output sequences to text output_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) return output_text def gradio_predict(question, token): answer = predict_answer(question, token) return answer # Define the Gradio interface iface = gr.Interface( fn=gradio_predict, inputs=[gr.Textbox(label="Question", placeholder="e.g. Where is Yeshiva University located?", scale=4), gr.Slider(2, 100, value=25, label="Token Count", info="Choose between 2 and 100")], outputs=gr.TextArea(label="Answer"), title="KatzBot", description="Phi2-trial1", ) # Launch the app iface.queue().launch(debug=True)