quelmap commited on
Commit
8a1e016
·
1 Parent(s): 13a1d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -1,7 +1,26 @@
1
  import gradio
 
 
 
 
 
2
 
3
  def my_inference_function(name):
4
- return "Hello " + name + "!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  gradio_interface = gradio.Interface(
7
  fn = my_inference_function,
 
1
  import gradio
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-small", device_map="auto", torch_dtype=torch.float16)
6
+ tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-small")
7
 
8
  def my_inference_function(name):
9
+ mail="お世話になっております。商品開発部の森本です。先日は貴重な"
10
+ inputs = tokenizer(f"(上司へのメール敬語で){mail}", return_tensors="pt").to(model.device)
11
+ with torch.no_grad():
12
+ tokens = model.generate(
13
+ **inputs,
14
+ max_new_tokens=4,
15
+ do_sample=True,
16
+ temperature=0.7,
17
+ top_p=0.9,
18
+ repetition_penalty=1.05,
19
+ pad_token_id=tokenizer.pad_token_id,
20
+ )
21
+ output = tokenizer.decode(tokens[0], skip_special_tokens=True)
22
+ output=output.replace(f"(上司へのメール敬語で){mail}","")
23
+ return output.replace("\n", "")
24
 
25
  gradio_interface = gradio.Interface(
26
  fn = my_inference_function,