Spaces:
Runtime error
Runtime error
asvonavnsnvononaon
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
class MRGenerator:
|
6 |
+
def __init__(self):
|
7 |
+
model_id = "arcee-ai/Llama-3.1-SuperNova-Lite"
|
8 |
+
self.pipeline = transformers.pipeline(
|
9 |
+
"text-generation",
|
10 |
+
model=model_id,
|
11 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
12 |
+
device_map="auto",)
|
13 |
+
self.system_prompt = {"role": "system",
|
14 |
+
"content": '''# CONTEXT #You are an expert in traffic rules and scene analysis.
|
15 |
+
#Key Concepts# 1. traffic rule: Define how the ego-vehicle should maneuver in the specific driving scenario. The ontology elements in driving scenario are classified into road_network and object_environment.
|
16 |
+
2. maneuver: A specific action or movement that an ego-vehicle performs.
|
17 |
+
3. road_network: Road elements are specified in the traffic rule, such as lanes, lines and crosswalks.
|
18 |
+
4. object_environment: One object or environment is specified in the traffic rule.'''}
|
19 |
+
|
20 |
+
self.terminators = [
|
21 |
+
self.pipeline.tokenizer.eos_token_id,
|
22 |
+
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
23 |
+
]
|
24 |
+
self.max_new_tokens = 1024
|
25 |
+
self.do_sample = True
|
26 |
+
self.temperature = 0.1
|
27 |
+
self.top_p = 0.9
|
28 |
+
|
29 |
+
def find_maneuver(self, prompt):
|
30 |
+
user_message = {
|
31 |
+
"role": "user",
|
32 |
+
"content": f"""
|
33 |
+
# OBJECTIVE # Your task is to find the ego-vehicle's most dangerous maneuver in the traffic rules.
|
34 |
+
# STYLE # The supported maneuvers are limited to the following types:
|
35 |
+
slow down, stop, turn, turn left, turn right, keep the same.
|
36 |
+
# NOTICE # Some maneuvers could lead to the listed types, e.g., 'yield' can be interpreted as 'slow down'.
|
37 |
+
Generate answer for maneuver only. Do not generate other output.
|
38 |
+
# EXAMPLE #
|
39 |
+
Example Text: If you are driving on an unpaved road that intersects with a paved road, you must yield the right-of-way to vehicles traveling on the paved road.
|
40 |
+
Example Answer: slow down
|
41 |
+
Example Text: Steady Red Light (Stop) Stop before entering the crosswalk or intersection. You may turn right unless prohibited by law. You may also turn left if both streets are one way, unless prohibited by law. You must yield to all pedestrians and other traffic lawfully using the intersection.
|
42 |
+
Example Answer: stop
|
43 |
+
===== END OF EXAMPLE ======
|
44 |
+
Text: {prompt}
|
45 |
+
Answer: """
|
46 |
+
}
|
47 |
+
answer = self.LLM([self.system_prompt, user_message])
|
48 |
+
return answer
|
49 |
+
|
50 |
+
def find_road_network(self, prompt):
|
51 |
+
user_message = {
|
52 |
+
"role": "user",
|
53 |
+
"content": f"""
|
54 |
+
# OBJECTIVE # Your task is to find the most dangerous road_network where ego-vehicle will violate the traffic rule.
|
55 |
+
# NOTICE # Generate answer for road_network only. Do not generate other output. Before answering, verify the road_network describes road, not traffic participants.
|
56 |
+
road_network only refers to road types, such as: intersection, one-way street, two-way street, roundabout, highway, freeway, residential street, rural road, urban street, bridge, tunnel, parking lot, alley, T-junction, divided highway, bike lane, etc.
|
57 |
+
The answer must be ONE or TWO words only.
|
58 |
+
road_network should not conflict with object_environments. For instance, avoid conflicts between a red arrow light and the lane where a red cross-shaped light or arrow light is active. In case of any conflicts, set road_network to "any road". The answer should not exceed three words.
|
59 |
+
# EXAMPLE #
|
60 |
+
Example Text: If you are driving on an unpaved road that intersects with a paved road, you must yield the right-of-way to vehicles traveling on the paved road.
|
61 |
+
Example Answer: unpaved road
|
62 |
+
Example Text: Steady Red Light (Stop) Stop before entering the crosswalk or intersection. You may turn right unless prohibited by law. You may also turn left if both streets are one way, unless prohibited by law. You must yield to all pedestrians and other traffic lawfully using the intersection.
|
63 |
+
Example Answer: crosswalk
|
64 |
+
===== END OF EXAMPLE ======
|
65 |
+
Text: {prompt}
|
66 |
+
Answer: """
|
67 |
+
}
|
68 |
+
answer = self.LLM([self.system_prompt, user_message])
|
69 |
+
return answer
|
70 |
+
|
71 |
+
def find_object_environment(self, prompt):
|
72 |
+
user_message = {
|
73 |
+
"role": "user",
|
74 |
+
"content": f"""
|
75 |
+
# OBJECTIVE # Your task is to find all suitable object_environment, which will lead ego-vehicle to take the most dangerous maneuver.
|
76 |
+
# NOTICE # Generate answer for object_environment only. Do not generate other output.
|
77 |
+
Each item in the answer must be a single object_environment. Sentences containing "or" should be split into separate content.
|
78 |
+
Before answering, verify the logic: "In the road_network, object_environment cause the ego-vehicle to maneuver". If this logic doesn't hold true, please don't output this object_environment.
|
79 |
+
Every object_environment must be ONE to three words only.
|
80 |
+
# EXAMPLE #
|
81 |
+
Example Text: If you are driving on an unpaved road that intersects with a paved road, you must yield the right-of-way to vehicles traveling on the paved road.
|
82 |
+
Example Answer: "vehicle"
|
83 |
+
Example Text: Steady Red Light (Stop) Stop before entering the crosswalk or intersection. You may turn right unless prohibited by law. You may also turn left if both streets are one way, unless prohibited by law. You must yield to all pedestrians and other traffic lawfully using the intersection.
|
84 |
+
Example Answer: "red light"
|
85 |
+
Example Text: Lane signal lights indicate: (1) When the green arrow light is on, allow vehicles in the lane to pass in the direction indicated; (2) When the red cross-shaped light or arrow light is on, vehicles in the lane are prohibited from passing.
|
86 |
+
Example Answer: "red light"
|
87 |
+
Example Text: Flashing red light: Vehicles and streetcars/trams must stop at the stopping point before proceeding.
|
88 |
+
Example Answer: "flashing red light"
|
89 |
+
Example Text: Slow down on wet road. Do not suddenly turn, speed up, or stop.
|
90 |
+
Example Answer: "Wet"
|
91 |
+
===== END OF EXAMPLE ======
|
92 |
+
Text: {prompt}
|
93 |
+
Answer: """
|
94 |
+
}
|
95 |
+
answer = self.LLM([self.system_prompt, user_message])
|
96 |
+
answer = answer.strip('"')
|
97 |
+
answer = [item.strip().strip('"').strip("'") for item in answer.split(',')]
|
98 |
+
return answer[0] # Just return the first object_environment
|
99 |
+
|
100 |
+
def combine_to_MR(self, maneuver, road_network, object_environment):
|
101 |
+
user_message = {
|
102 |
+
"role": "user",
|
103 |
+
"content": f"""
|
104 |
+
# OBJECTIVE
|
105 |
+
Your task is to combine vehicle maneuver, road_network, object_environments into MR.
|
106 |
+
# EXAMPLE
|
107 |
+
Example Text: maneuver: "slow down", road_network: "unpaved road", object_environment: "vehicle"
|
108 |
+
Example Answer:Given the unpaved road
|
109 |
+
When ITMI add vehicle
|
110 |
+
Then ego-vehicle should slow down
|
111 |
+
Example Text: maneuver: "stop", road_network: "crosswalk", object_environment: "red light"
|
112 |
+
Example Answer:Given the crosswalk
|
113 |
+
When ITMI add red light
|
114 |
+
Then ego-vehicle should stop
|
115 |
+
User: maneuver: "{maneuver}", road_network: "{road_network}", object_environment: "{object_environment}"
|
116 |
+
"""}
|
117 |
+
answer = self.LLM([self.system_prompt, user_message])
|
118 |
+
return answer
|
119 |
+
|
120 |
+
def find_prompt(self, road_network, object_environment):
|
121 |
+
user_message = {
|
122 |
+
"role": "user",
|
123 |
+
"content": f'''road_network:{road_network},objects_environment:{object_environment}'''
|
124 |
+
}
|
125 |
+
answer = self.LLM([
|
126 |
+
{"role": "system", "content": '''Generate a sample diffusion inpainting prompt based on the given traffic rule and scenario.
|
127 |
+
Provide ONLY the prompt, with no additional explanation or content.
|
128 |
+
This prompt should describe the scene from the camera's perspective, focusing on the traffic rule.
|
129 |
+
Ensure the prompt is faithful to the original text and captures the key visual elements.
|
130 |
+
Describe the scene from the camera's perspective mounted on the ego vehicle that must change its state (e.g., yield, stop). Focus on the visual elements of the road network and environment without explicitly mentioning the ego vehicle as a subject.
|
131 |
+
IMPORTANT: Limit the prompt to a maximum of 50 words.
|
132 |
+
Dot not show any independent vehicle contorl wards like slow down, yield, prepare to stop in this prompt!.'''},
|
133 |
+
{"role": "user", "content": """road_network: intersection, objects_environment: a bicycle rider"""},
|
134 |
+
{"role": "assistant", "content": """You are driving approach to intersection, a bicycle rider on the road"""},
|
135 |
+
{"role": "user", "content": """road_network: intersection, objects_environment: turn left sign"""},
|
136 |
+
{"role": "assistant", "content": """You are driving approach to intersection, a turn left sign on the road"""},
|
137 |
+
{"role": "user", "content": """road_network: construction zone, objects_environment: traffic control devices"""},
|
138 |
+
{"role": "assistant", "content": """You are driving approach to construction zone, a traffic control devices on the road, """},
|
139 |
+
user_message
|
140 |
+
])
|
141 |
+
return answer
|
142 |
+
|
143 |
+
def LLM(self, messages):
|
144 |
+
outputs = self.pipeline(
|
145 |
+
messages,
|
146 |
+
max_new_tokens=self.max_new_tokens,
|
147 |
+
eos_token_id=self.terminators,
|
148 |
+
do_sample=self.do_sample,
|
149 |
+
temperature=self.temperature,
|
150 |
+
top_p=self.top_p,
|
151 |
+
pad_token_id=self.pipeline.tokenizer.eos_token_id
|
152 |
+
)
|
153 |
+
answer = outputs[0]["generated_text"][-1]['content']
|
154 |
+
return answer
|
155 |
+
|
156 |
+
def __call__(self, traffic_rule):
|
157 |
+
# Find components
|
158 |
+
maneuver = self.find_maneuver(traffic_rule)
|
159 |
+
road_network = self.find_road_network(traffic_rule)
|
160 |
+
object_environment = self.find_object_environment(traffic_rule)
|
161 |
+
|
162 |
+
# Generate MR and prompt
|
163 |
+
mr = self.combine_to_MR(maneuver, road_network, object_environment)
|
164 |
+
diffusion_prompt = self.find_prompt(road_network, object_environment)
|
165 |
+
|
166 |
+
result = {
|
167 |
+
"MR": mr,
|
168 |
+
"maneuver": maneuver,
|
169 |
+
"road_network": road_network,
|
170 |
+
"object_environment": object_environment,
|
171 |
+
"diffusion_prompt": diffusion_prompt
|
172 |
+
}
|
173 |
+
return result
|
174 |
+
|
175 |
+
def generate_mr(traffic_rule):
|
176 |
+
generator = MRGenerator()
|
177 |
+
results = generator(traffic_rule)
|
178 |
+
# Format the output as a string
|
179 |
+
output = f"""MR:
|
180 |
+
{results['MR']}
|
181 |
+
|
182 |
+
Components:
|
183 |
+
Maneuver: {results['maneuver']}
|
184 |
+
Road Network: {results['road_network']}
|
185 |
+
Object/Environment: {results['object_environment']}
|
186 |
+
|
187 |
+
Diffusion Prompt:
|
188 |
+
{results['diffusion_prompt']}"""
|
189 |
+
return output
|
190 |
+
|
191 |
+
# Create Gradio interface
|
192 |
+
demo = gr.Interface(
|
193 |
+
fn=generate_mr,
|
194 |
+
inputs=gr.Textbox(label="Enter Traffic Rule", lines=3),
|
195 |
+
outputs=gr.Textbox(label="Generated Output", lines=10),
|
196 |
+
title="Traffic Rule MR Generator",
|
197 |
+
description="Enter a traffic rule to generate its corresponding MR and components."
|
198 |
+
)
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
demo.launch()
|