RyanMullins commited on
Commit
2a04008
β€’
1 Parent(s): ff97c38

Adding generation with GPT-2 as a mock model

Browse files
Files changed (2) hide show
  1. app.py +80 -6
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,23 +1,94 @@
1
  from collections.abc import Sequence
2
  import random
 
3
 
4
  import gradio as gr
 
 
 
5
 
6
  # If the watewrmark is not detected, consider the use case. Could be because of
7
  # the nature of the task (e.g., fatcual responses are lower entropy) or it could
8
  # be another
9
 
10
- _GEMMA_2B = 'google/gemma-2b'
11
 
12
  _PROMPTS: tuple[str] = (
13
  'prompt 1',
14
  'prompt 2',
15
  'prompt 3',
16
- 'prompt 4',
17
  )
18
 
19
  _CORRECT_ANSWERS: dict[str, bool] = {}
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  with gr.Blocks() as demo:
22
  prompt_inputs = [
23
  gr.Textbox(value=prompt, lines=4, label='Prompt')
@@ -43,14 +114,17 @@ with gr.Blocks() as demo:
43
  detect_btn = gr.Button('Detect', visible=False)
44
 
45
  def generate(*prompts):
46
- standard = [f'{prompt} response' for prompt in prompts]
47
- watermarked = [f'{prompt} watermarked response' for prompt in prompts]
 
 
 
48
  responses = standard + watermarked
49
  random.shuffle(responses)
50
 
51
  _CORRECT_ANSWERS.update({
52
- response: response in watermarked
53
- for response in responses
54
  })
55
 
56
  # Load model
 
1
  from collections.abc import Sequence
2
  import random
3
+ from typing import Optional
4
 
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ import transformers
9
 
10
  # If the watewrmark is not detected, consider the use case. Could be because of
11
  # the nature of the task (e.g., fatcual responses are lower entropy) or it could
12
  # be another
13
 
14
+ _MODEL_IDENTIFIER = 'hf-internal-testing/tiny-random-gpt2'
15
 
16
  _PROMPTS: tuple[str] = (
17
  'prompt 1',
18
  'prompt 2',
19
  'prompt 3',
 
20
  )
21
 
22
  _CORRECT_ANSWERS: dict[str, bool] = {}
23
 
24
+ _TORCH_DEVICE = (
25
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
26
+ )
27
+
28
+ _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig(
29
+ ngram_len=5,
30
+ keys=[
31
+ 654,
32
+ 400,
33
+ 836,
34
+ 123,
35
+ 340,
36
+ 443,
37
+ 597,
38
+ 160,
39
+ 57,
40
+ 29,
41
+ 590,
42
+ 639,
43
+ 13,
44
+ 715,
45
+ 468,
46
+ 990,
47
+ 966,
48
+ 226,
49
+ 324,
50
+ 585,
51
+ 118,
52
+ 504,
53
+ 421,
54
+ 521,
55
+ 129,
56
+ 669,
57
+ 732,
58
+ 225,
59
+ 90,
60
+ 960,
61
+ ],
62
+ sampling_table_size=2**16,
63
+ sampling_table_seed=0,
64
+ context_history_size=1024,
65
+ )
66
+
67
+
68
+ tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER)
69
+ tokenizer.pad_token_id = tokenizer.eos_token_id
70
+ model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER)
71
+ model.to(_TORCH_DEVICE)
72
+
73
+
74
+ @spaces.GPU
75
+ def generate_outputs(
76
+ prompts: Sequence[str],
77
+ watermarking_config: Optional[
78
+ transformers.generation.SynthIDTextWatermarkingConfig
79
+ ] = None,
80
+ ) -> Sequence[str]:
81
+ tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE)
82
+ output_sequences = model.generate(
83
+ **tokenized_prompts,
84
+ watermarking_config=watermarking_config,
85
+ do_sample=True,
86
+ max_length=500,
87
+ top_k=40,
88
+ )
89
+ return tokenizer.batch_decode(output_sequences)
90
+
91
+
92
  with gr.Blocks() as demo:
93
  prompt_inputs = [
94
  gr.Textbox(value=prompt, lines=4, label='Prompt')
 
114
  detect_btn = gr.Button('Detect', visible=False)
115
 
116
  def generate(*prompts):
117
+ standard = generate_outputs(prompts=prompts)
118
+ watermarked = generate_outputs(
119
+ prompts=prompts,
120
+ watermarking_config=_WATERMARK_CONFIG,
121
+ )
122
  responses = standard + watermarked
123
  random.shuffle(responses)
124
 
125
  _CORRECT_ANSWERS.update({
126
+ response: response in watermarked
127
+ for response in responses
128
  })
129
 
130
  # Load model
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  gradio
2
  spaces
3
  transformers @ git+https://github.com/sumedhghaisas2/transformers_private@synthid_text
 
 
 
 
1
  gradio
2
  spaces
3
  transformers @ git+https://github.com/sumedhghaisas2/transformers_private@synthid_text
4
+
5
+ --extra-index-url https://download.pytorch.org/whl/cu113
6
+ torch