ybelkada commited on
Commit
c6df388
·
1 Parent(s): 1f0c4d6
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -1,7 +1,30 @@
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import transformers
4
  import gradio as gr
5
 
6
+ from src.client import DistributedBloomForCausalLM
 
7
 
8
+ INITIAL_PEERS = ['/ip4/193.106.95.184/tcp/31337/p2p/QmUigSxrVz9x5FR9ZYr4iRfEX2vDxihL2YZtDd7sp2eKnM']
9
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
10
+ model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
11
+
12
+ def inference(text, seq_length=1):
13
+ input_ids = tokenizer(text, return_tensors='pt')['input_ids']
14
+ with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
15
+ for i in range(seq_length):
16
+ h = model.transformer.word_embeddings(input_ids)
17
+ h = model.transformer.word_embeddings_layernorm(h)
18
+
19
+ h = remote_transformer.step(h) # note [yozh]: this line currently freezes for 10 seconds first time only, its gonna be fixed in the nearest PR
20
+
21
+ h = model.transformer.ln_f(h)
22
+ h = F.linear(h, weight=model.transformer.word_embeddings.weight) # note: this line takes a while, will also be fixed
23
+ next_token_ix = torch.multinomial((h[0, -1] / 0.8).softmax(-1), 1)
24
+
25
+ # print(end=tokenizer.decode(next_token_ix.item()))
26
+ input_ids = next_token_ix.view(1, 1)
27
+ return tokenizer.decode(input_ids.item())
28
+
29
+ iface = gr.Interface(fn=inference, inputs="text", outputs="text")
30
  iface.launch()