File size: 3,720 Bytes
05140a3
7eaefa4
5f079a9
 
7eaefa4
 
69889de
7eaefa4
5f079a9
1baddb8
5f079a9
1baddb8
5f079a9
 
 
1baddb8
 
 
8f687fb
 
de2ebc5
5f079a9
cf17b13
7eaefa4
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
 
 
 
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f687fb
 
 
 
7eaefa4
 
 
 
 
 
 
 
 
 
 
 
de2ebc5
 
7eaefa4
7b9061c
7eaefa4
 
 
 
 
5f079a9
3c17d63
41d1888
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import json
from pyvis.network import Network

from transformers import pipeline
# Load the pipeline
tourModel = pipeline(model="manhan/GPT-Tour")

def vizTour(activityList):
    net = Network(directed=True,notebook=False)
    net.add_node(activityList[0])
    lastActivity = activityList[0]
    for activity in activityList[1:]:
        if activity not in net.get_nodes():
            net.add_node(activity)
        net.add_edge(lastActivity, activity)
        lastActivity = activity
    net.save_graph('tour.html')
    #tour_html = open('tour.html', 'r')
    tour_html = gr.HTML("<iframe src=\"tour.html\"></iframe>")
    return tour_html

def getTour(income,size,years,sex,edu,wrk): 
    # person = a dict with person-level and hh-level attributes:
    person = {}
    hh_income = int(income)
    if hh_income<0:
        _skip = 1 # nvm, bad data
    elif hh_income<4: # $25,000
        person['hh_inc'] = 'poor'
    elif hh_income<6: # $50,000
        person['hh_inc'] = 'low'
    elif hh_income<7: # $75,000
        person['hh_inc'] = 'medium'
    elif hh_income<9: # $125,000
        person['hh_inc'] = 'high'
    else: # over
        person['hh_inc'] = 'affluent'
    hh_size = int(size)
    if hh_size == 1:
        person['hh_size'] = 'single'
    elif hh_size == 2:
        person['hh_size'] = 'couple'
    elif hh_size <= 4:
        person['hh_size'] = 'small'
    else: # more than four people
        person['hh_size'] = 'large'
    age = int(years)
    if age < 18:
        person['age_grp'] = 'child'
    elif age < 45:
        person['age_grp'] = 'younger'
    elif age < 65:
        person['age_grp'] = 'older'
    else:
        person['age_grp'] = 'senior'
    person['sex'] = sex
    person['edu'] = edu
    person['wrk'] = wrk
    activity_list = []
    prompt = json.dumps(person)[:-1] + ", pattern: "
    print(person)
    while not activity_list:
        generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text']
        #print(f"{generated}")
        start_pos = generated.find('[')
        end_pos = generated.find(']')+1
        activity_list_str = generated[start_pos:end_pos]
        print(f"Extracted: '{activity_list_str}'")
        # this check doesn't appear to work anyways
        #if person['wrk']=='yes' and activity_list_str.find('Work')==-1:
        #    continue # try again
        #if person['wrk']=='no' and activity_list_str.find('Work')>0:
        #    continue # try again        
        if activity_list_str:
            try:
                activity_list = json.loads(activity_list_str)
                if activity_list[-1]!='Home':
                    activity_list=[]
                    continue
                break
            except Exception as e:
                print("Error parsing activity list")
                print(e)
        else:
            print("Nothing extracted!")
    tour_html = vizTour(activity_list)
    return activity_list, tour_html

with gr.Interface(fn=getTour, inputs=[
    gr.inputs.Textbox(label="Annual Household Income (in dollars)"),
    gr.inputs.Textbox(label="Household Size (number of people)"),
    gr.inputs.Textbox(label="Traveler Age (years)"),
    gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"),
    gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"),
    gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status")], 
    outputs=["json",gr.HTML(value=open("tour.html",'r'))], title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface:
    iface.launch()