LaserOverrider commited on
Commit
52fe558
·
verified ·
1 Parent(s): fb9f031

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -3
README.md CHANGED
@@ -41,13 +41,12 @@ Use the code below to get started with the model.
41
  from transformers import AutoModelForCausalLM, AutoTokenizer
42
  import torch
43
 
44
- model = AutoModelForCausalLM.from_pretrained("DIAG-PSSeng/cicero_v2-phi1.5", torch_dtype=torch.float16).to("cuda")
45
 
46
  tokenizer = AutoTokenizer.from_pretrained("DIAG-PSSeng/cicero_v2-phi1.5", trust_remote_code=True)
47
 
48
  def generate_text(model, tokenizer, prompt, length=50, do_sample=True):
49
- device="cuda"
50
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
51
  gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=length,max_length=length)
52
  generated_text = tokenizer.batch_decode(gen_tokens)
53
  return generated_text
 
41
  from transformers import AutoModelForCausalLM, AutoTokenizer
42
  import torch
43
 
44
+ model = AutoModelForCausalLM.from_pretrained("DIAG-PSSeng/cicero_v2-phi1.5", trust_remote_code=True ,torch_dtype=torch.float16).to("cuda")
45
 
46
  tokenizer = AutoTokenizer.from_pretrained("DIAG-PSSeng/cicero_v2-phi1.5", trust_remote_code=True)
47
 
48
  def generate_text(model, tokenizer, prompt, length=50, do_sample=True):
49
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
 
50
  gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=length,max_length=length)
51
  generated_text = tokenizer.batch_decode(gen_tokens)
52
  return generated_text