richardkovacs commited on
Commit
4a2ada5
·
1 Parent(s): d368cb8

feat: add basic files

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. auto_train.py +29 -0
  3. requirements.txt +74 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ flagged
auto_train.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ raw_model_name = 'distilbert-base-uncased'
5
+ raw_model = pipeline('sentiment-analysis', model=raw_model_name)
6
+ fine_tuned_model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
7
+ fine_tuned_model = pipeline('sentiment-analysis', model=fine_tuned_model_name)
8
+
9
+ def get_model_output(input_text, model_choice):
10
+ raw_result = raw_model(input_text)
11
+ fine_tuned_result = fine_tuned_model(input_text)
12
+ return format_model_output(raw_result[0]), format_model_output(fine_tuned_result[0])
13
+
14
+ def format_model_output(output):
15
+ return f"I am {output['score']*100:.2f}% sure that the sentiment is {output['label']}"
16
+
17
+ iface = gr.Interface(
18
+ fn=get_model_output,
19
+ title="DistilBERT Sentiment Analysis",
20
+ inputs=[
21
+ gr.Textbox(label="Input Text"),
22
+ ],
23
+ outputs=[
24
+ gr.Textbox(label="Base DistilBERT output (distilbert-base-uncased)"),
25
+ gr.Textbox(label="Fine-tuned DistilBERT output")
26
+ ],
27
+ )
28
+
29
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.1.2
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2023.11.17
7
+ charset-normalizer==3.3.2
8
+ click==8.1.3
9
+ colorama==0.4.6
10
+ contourpy==1.2.0
11
+ cycler==0.12.1
12
+ exceptiongroup==1.2.0
13
+ fastapi==0.104.1
14
+ ffmpy==0.3.1
15
+ filelock==3.13.1
16
+ fonttools==4.45.0
17
+ fsspec==2023.10.0
18
+ gradio==4.5.0
19
+ gradio_client==0.7.0
20
+ h11==0.14.0
21
+ httpcore==1.0.2
22
+ httpx==0.25.1
23
+ huggingface-hub==0.19.4
24
+ idna==3.4
25
+ importlib-resources==6.1.1
26
+ Jinja2==3.1.2
27
+ jsonschema==4.20.0
28
+ jsonschema-specifications==2023.11.1
29
+ kiwisolver==1.4.5
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.3
32
+ matplotlib==3.8.2
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ networkx==3.2.1
36
+ numpy==1.26.2
37
+ orjson==3.9.10
38
+ packaging==23.2
39
+ pandas==2.1.3
40
+ Pillow==10.1.0
41
+ pydantic==2.5.1
42
+ pydantic_core==2.14.3
43
+ pydub==0.25.1
44
+ Pygments==2.17.1
45
+ pyparsing==3.1.1
46
+ pystache==0.6.0
47
+ python-dateutil==2.8.2
48
+ python-multipart==0.0.6
49
+ pytz==2023.3.post1
50
+ PyYAML==6.0.1
51
+ referencing==0.31.0
52
+ regex==2023.10.3
53
+ requests==2.31.0
54
+ rich==13.7.0
55
+ rpds-py==0.13.1
56
+ safetensors==0.4.0
57
+ semantic-version==2.10.0
58
+ shellingham==1.5.4
59
+ six==1.16.0
60
+ sniffio==1.3.0
61
+ starlette==0.27.0
62
+ sympy==1.12
63
+ tokenizers==0.15.0
64
+ tomlkit==0.12.0
65
+ toolz==0.12.0
66
+ torch==2.1.1
67
+ tqdm==4.66.1
68
+ transformers==4.35.2
69
+ typer==0.9.0
70
+ typing_extensions==4.8.0
71
+ tzdata==2023.3
72
+ urllib3==2.1.0
73
+ uvicorn==0.24.0.post1
74
+ websockets==11.0.3