Spaces:
Sleeping
Sleeping
File size: 3,720 Bytes
05140a3 7eaefa4 5f079a9 7eaefa4 69889de 7eaefa4 5f079a9 1baddb8 5f079a9 1baddb8 5f079a9 1baddb8 8f687fb de2ebc5 5f079a9 cf17b13 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 8f687fb 7eaefa4 de2ebc5 7eaefa4 7b9061c 7eaefa4 5f079a9 3c17d63 41d1888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
import json
from pyvis.network import Network
from transformers import pipeline
# Load the pipeline
tourModel = pipeline(model="manhan/GPT-Tour")
def vizTour(activityList):
net = Network(directed=True,notebook=False)
net.add_node(activityList[0])
lastActivity = activityList[0]
for activity in activityList[1:]:
if activity not in net.get_nodes():
net.add_node(activity)
net.add_edge(lastActivity, activity)
lastActivity = activity
net.save_graph('tour.html')
#tour_html = open('tour.html', 'r')
tour_html = gr.HTML("<iframe src=\"tour.html\"></iframe>")
return tour_html
def getTour(income,size,years,sex,edu,wrk):
# person = a dict with person-level and hh-level attributes:
person = {}
hh_income = int(income)
if hh_income<0:
_skip = 1 # nvm, bad data
elif hh_income<4: # $25,000
person['hh_inc'] = 'poor'
elif hh_income<6: # $50,000
person['hh_inc'] = 'low'
elif hh_income<7: # $75,000
person['hh_inc'] = 'medium'
elif hh_income<9: # $125,000
person['hh_inc'] = 'high'
else: # over
person['hh_inc'] = 'affluent'
hh_size = int(size)
if hh_size == 1:
person['hh_size'] = 'single'
elif hh_size == 2:
person['hh_size'] = 'couple'
elif hh_size <= 4:
person['hh_size'] = 'small'
else: # more than four people
person['hh_size'] = 'large'
age = int(years)
if age < 18:
person['age_grp'] = 'child'
elif age < 45:
person['age_grp'] = 'younger'
elif age < 65:
person['age_grp'] = 'older'
else:
person['age_grp'] = 'senior'
person['sex'] = sex
person['edu'] = edu
person['wrk'] = wrk
activity_list = []
prompt = json.dumps(person)[:-1] + ", pattern: "
print(person)
while not activity_list:
generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text']
#print(f"{generated}")
start_pos = generated.find('[')
end_pos = generated.find(']')+1
activity_list_str = generated[start_pos:end_pos]
print(f"Extracted: '{activity_list_str}'")
# this check doesn't appear to work anyways
#if person['wrk']=='yes' and activity_list_str.find('Work')==-1:
# continue # try again
#if person['wrk']=='no' and activity_list_str.find('Work')>0:
# continue # try again
if activity_list_str:
try:
activity_list = json.loads(activity_list_str)
if activity_list[-1]!='Home':
activity_list=[]
continue
break
except Exception as e:
print("Error parsing activity list")
print(e)
else:
print("Nothing extracted!")
tour_html = vizTour(activity_list)
return activity_list, tour_html
with gr.Interface(fn=getTour, inputs=[
gr.inputs.Textbox(label="Annual Household Income (in dollars)"),
gr.inputs.Textbox(label="Household Size (number of people)"),
gr.inputs.Textbox(label="Traveler Age (years)"),
gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"),
gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"),
gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status")],
outputs=["json",gr.HTML(value=open("tour.html",'r'))], title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface:
iface.launch() |