huckiyang commited on
Commit
a024afa
·
1 Parent(s): 941b6eb

[test] demo

Browse files
Files changed (3) hide show
  1. app.py +202 -142
  2. requirements.txt +5 -4
  3. translation_model.py +158 -0
app.py CHANGED
@@ -1,154 +1,214 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ]
59
 
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
  )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
  )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
 
 
 
 
 
 
 
 
149
  ],
150
- outputs=[result, seed],
 
 
151
  )
152
-
 
 
 
 
 
 
 
 
 
 
 
 
153
  if __name__ == "__main__":
154
  demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from trl import AutoModelForCausalLMWithValueHead
5
+
6
+ # Set device and dtype
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ torch_dtype = torch.bfloat16
9
+
10
+ # Load models only once at startup
11
+ print("Loading models...")
12
+ model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ lm_model = AutoModelForCausalLM.from_pretrained(
15
+ model_id,
16
+ torch_dtype=torch_dtype,
17
+ device_map="auto"
18
+ )
19
+
20
+ # Load the reward model
21
+ RM = AutoModelForCausalLMWithValueHead.from_pretrained(
22
+ 'ray24724919/plan2align_rm',
23
+ torch_dtype=torch_dtype,
24
+ device_map="auto"
25
+ )
26
+ RM.eval()
27
+ print("Models loaded successfully!")
28
+
29
+ # Self-contained translation and evaluation functions
30
+ def translate(source_text, target_language="English"):
31
+ """
32
+ Translate text from Chinese to the specified target language.
33
+
34
+ Args:
35
+ source_text (str): The Chinese text to translate
36
+ target_language (str): The target language for translation
37
+
38
+ Returns:
39
+ str: The translated text
40
+ """
41
+ # Format the input as per the system prompt
42
+ messages = [
43
+ {"role": "system", "content": "You are a helpful translator and only output the result."},
44
+ {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
45
+ ]
46
+
47
+ # Format messages for the model
48
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
49
+
50
+ # Tokenize the input
51
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
52
+
53
+ # Generate translation
54
+ with torch.no_grad():
55
+ outputs = lm_model.generate(
56
+ **inputs,
57
+ max_new_tokens=512,
58
+ temperature=0.7,
59
+ do_sample=True,
60
+ pad_token_id=tokenizer.eos_token_id
61
+ )
62
+
63
+ # Decode the generated text
64
+ translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
65
+ return translation
66
+
67
+ def evaluate_translation(source_text, translation, target_language="English"):
68
+ """
69
+ Evaluate the quality of a translation using the reward model.
70
+
71
+ Args:
72
+ source_text (str): The original Chinese text
73
+ translation (str): The translated text
74
+ target_language (str): The target language of the translation
75
+
76
+ Returns:
77
+ float: The reward score
78
+ """
79
+ messages = [
80
+ {"role": "system", "content": "You are a helpful translator and only output the result."},
81
+ {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
82
+ {"role": "assistant", "content": translation}
83
+ ]
84
+
85
+ # Format messages for the reward model
86
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False)
87
+
88
+ # Tokenize the input
89
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
90
+
91
+ # Get reward score
92
+ with torch.no_grad():
93
+ outputs = RM(input_ids=inputs.input_ids)
94
+ reward_score = outputs.value.item()
95
+
96
+ return reward_score
97
+
98
+ # Combined function for the Gradio interface
99
+ def translate_text(source_text, target_language):
100
+ """
101
+ Translate text and get reward score
102
+
103
+ Args:
104
+ source_text (str): The Chinese text to translate
105
+ target_language (str): The target language for translation
106
+
107
+ Returns:
108
+ tuple: (translation, reward_score)
109
+ """
110
+ if not source_text.strip():
111
+ return "Please enter some text to translate.", 0.0
112
+
113
+ try:
114
+ translation = translate(source_text, target_language)
115
+ reward_score = evaluate_translation(source_text, translation, target_language)
116
+ return translation, float(reward_score)
117
+ except Exception as e:
118
+ return f"Error: {str(e)}", 0.0
119
+
120
+ # Define available target languages
121
+ target_languages = [
122
+ "English", "French", "Spanish", "German", "Italian",
123
+ "Portuguese", "Russian", "Japanese", "Korean", "Arabic"
124
  ]
125
 
126
+ # Create the Gradio interface
127
+ with gr.Blocks(title="Chinese Translation with Reward Scoring") as demo:
128
+ gr.Markdown("# Chinese to Any Language Translation")
129
+ gr.Markdown("This demo translates Chinese text to your chosen language and provides a quality score from our reward model.")
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ source_text = gr.Textbox(
134
+ label="Chinese Text",
135
+ placeholder="Enter Chinese text here...",
136
+ lines=5
 
 
 
 
 
 
 
137
  )
138
+ target_language = gr.Dropdown(
139
+ choices=target_languages,
140
+ value="English",
141
+ label="Target Language"
 
 
 
 
 
 
 
142
  )
143
+ translate_button = gr.Button("Translate")
144
+
145
+ with gr.Column():
146
+ translation_output = gr.Textbox(
147
+ label="Translation",
148
+ lines=5,
149
+ interactive=False
150
  )
151
+ reward_score = gr.Number(
152
+ label="Translation Quality Score (higher is better)",
153
+ precision=4,
154
+ interactive=False
155
+ )
156
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  with gr.Row():
158
+ score_indicator = gr.Label(label="Quality Rating")
159
+
160
+ # Function to update the quality rating based on score
161
+ def update_quality_rating(score):
162
+ if score >= 0.8:
163
+ return "Excellent"
164
+ elif score >= 0.6:
165
+ return "Good"
166
+ elif score >= 0.4:
167
+ return "Average"
168
+ elif score >= 0.2:
169
+ return "Poor"
170
+ else:
171
+ return "Very Poor"
172
+
173
+ # Set up the translation flow
174
+ translate_outputs = translate_button.click(
175
+ fn=translate_text,
176
+ inputs=[source_text, target_language],
177
+ outputs=[translation_output, reward_score]
178
+ )
179
+
180
+ # Update the quality rating whenever the reward score changes
181
+ reward_score.change(
182
+ fn=update_quality_rating,
183
+ inputs=[reward_score],
184
+ outputs=[score_indicator]
185
+ )
186
+
187
+ # Examples
188
+ gr.Examples(
189
+ examples=[
190
+ ["你好,世界!", "English"],
191
+ ["我喜欢学习新的语言。", "Spanish"],
192
+ ["北京烤鴨很好吃。", "French"],
193
+ ["人工智能正在改变世界。", "German"],
194
+ ["今天天气真好。", "Japanese"]
195
  ],
196
+ inputs=[source_text, target_language],
197
+ outputs=[translation_output, reward_score],
198
+ fn=translate_text
199
  )
200
+
201
+ gr.Markdown("## How It Works")
202
+ gr.Markdown("""
203
+ 1. Enter Chinese text in the input box
204
+ 2. Select your desired target language
205
+ 3. Click 'Translate' to get the translation
206
+ 4. The system will display the translation and a quality score
207
+
208
+ The quality score is generated by a reward model trained to evaluate translation quality.
209
+ Higher scores indicate better translations.
210
+ """)
211
+
212
+ # Launch the app
213
  if __name__ == "__main__":
214
  demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
 
6
  xformers
 
1
  accelerate
2
+ gradio
3
+ safetensors
4
+ torch>=2.0.0
5
+ transformers>=4.30.0
6
+ trl>=0.7.1
7
  xformers
translation_model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import safetensors.torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from trl import AutoModelForCausalLMWithValueHead
5
+
6
+ # Set device and dtype
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ torch_dtype = torch.bfloat16
9
+
10
+ # Load the base LLaMa 3.1 8B model for translation
11
+ model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
13
+ lm_model = AutoModelForCausalLM.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch_dtype,
16
+ device_map="auto"
17
+ )
18
+
19
+ # Load the reward model
20
+ RM = AutoModelForCausalLMWithValueHead.from_pretrained(
21
+ 'ray24724919/plan2align_rm',
22
+ torch_dtype=torch_dtype,
23
+ device_map="auto"
24
+ )
25
+ RM.eval()
26
+ RM.gradient_checkpointing_enable() # if needed for memory efficiency
27
+
28
+ # Define the load_file function
29
+ def load_file(file_path):
30
+ return safetensors.torch.load_file(file_path)
31
+
32
+ # Load value head weights if you have the file
33
+ # If you don't have the specific file, you might need to download it or use the model as is
34
+ try:
35
+ value_head_weights = load_file("value_head.safetensors") # Replace with actual path
36
+ new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
37
+ RM.v_head.load_state_dict(new_state_dict)
38
+ except FileNotFoundError:
39
+ print("Value head weights file not found. Using default weights.")
40
+
41
+ # Define translation function with more flexibility
42
+ def translate(source_text, target_language="English", model=lm_model):
43
+ """
44
+ Translate text from Chinese to the specified target language.
45
+
46
+ Args:
47
+ source_text (str): The Chinese text to translate
48
+ target_language (str): The target language for translation
49
+ model: The model to use for translation
50
+
51
+ Returns:
52
+ str: The translated text
53
+ """
54
+ # Format the input as per the system prompt
55
+ messages = [
56
+ {"role": "system", "content": "You are a helpful translator and only output the result."},
57
+ {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
58
+ ]
59
+
60
+ # Format messages for the model
61
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
+
63
+ # Tokenize the input
64
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
65
+
66
+ # Generate translation
67
+ with torch.no_grad():
68
+ outputs = model.generate(
69
+ **inputs,
70
+ max_new_tokens=512,
71
+ temperature=0.7,
72
+ do_sample=True,
73
+ pad_token_id=tokenizer.eos_token_id
74
+ )
75
+
76
+ # Decode the generated text
77
+ translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
78
+ return translation
79
+
80
+ # Evaluate the translation using the reward model
81
+ def evaluate_translation(source_text, translation, target_language="English"):
82
+ """
83
+ Evaluate the quality of a translation using the reward model.
84
+
85
+ Args:
86
+ source_text (str): The original Chinese text
87
+ translation (str): The translated text
88
+ target_language (str): The target language of the translation
89
+
90
+ Returns:
91
+ float: The reward score
92
+ """
93
+ messages = [
94
+ {"role": "system", "content": "You are a helpful translator and only output the result."},
95
+ {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
96
+ {"role": "assistant", "content": translation}
97
+ ]
98
+
99
+ # Format messages for the reward model
100
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False)
101
+
102
+ # Tokenize the input
103
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
104
+
105
+ # Get reward score
106
+ with torch.no_grad():
107
+ outputs = RM(input_ids=inputs.input_ids)
108
+ reward_score = outputs.value.item()
109
+
110
+ return reward_score
111
+
112
+ # Function to translate and evaluate in one step
113
+ def translate_and_evaluate(source_text, target_language="English"):
114
+ """
115
+ Translate text and evaluate the translation quality in one step.
116
+
117
+ Args:
118
+ source_text (str): The Chinese text to translate
119
+ target_language (str): The target language for translation
120
+
121
+ Returns:
122
+ tuple: (translation, reward_score)
123
+ """
124
+ translation = translate(source_text, target_language)
125
+ reward_score = evaluate_translation(source_text, translation, target_language)
126
+ return translation, reward_score
127
+
128
+ # Example usage
129
+ if __name__ == "__main__":
130
+ # Example with default target language (English)
131
+ source = "你好世界"
132
+ translation, reward_score = translate_and_evaluate(source)
133
+ print(f"Source: {source}")
134
+ print(f"Translation to English: {translation}")
135
+ print(f"Reward Score: {reward_score}")
136
+
137
+ # Example with custom target language
138
+ target_language = "French"
139
+ translation, reward_score = translate_and_evaluate(source, target_language)
140
+ print(f"\nSource: {source}")
141
+ print(f"Translation to {target_language}: {translation}")
142
+ print(f"Reward Score: {reward_score}")
143
+
144
+ # Interactive mode
145
+ print("\n=== Interactive Translation Mode ===")
146
+ print("Enter 'quit' to exit")
147
+ while True:
148
+ user_input = input("\nEnter Chinese text to translate: ")
149
+ if user_input.lower() == 'quit':
150
+ break
151
+
152
+ target = input("Enter target language (default: English): ").strip()
153
+ if not target:
154
+ target = "English"
155
+
156
+ translation, reward_score = translate_and_evaluate(user_input, target)
157
+ print(f"Translation to {target}: {translation}")
158
+ print(f"Reward Score: {reward_score}")