tdurbor commited on
Commit
f50f18c
1 Parent(s): d19c70c

Synch with datasets when starting

Browse files
Files changed (2) hide show
  1. app.py +17 -11
  2. db.py +31 -3
app.py CHANGED
@@ -1,27 +1,33 @@
1
  import os
 
 
 
2
  import logging
 
 
 
3
  from typing import Tuple
4
- from dotenv import load_dotenv
5
- import gradio as gr
6
  import numpy as np
 
7
  from PIL import Image
8
- import random
9
- from db import compute_elo_scores, get_all_votes, add_vote, is_running_in_space
10
- import json
11
- from pathlib import Path
12
- from uuid import uuid4
13
- import logging
14
- import threading
15
- import time
16
  from datasets import load_dataset
17
  from huggingface_hub import CommitScheduler
18
 
19
-
 
 
 
 
 
 
20
 
21
  token = os.getenv("HUGGINGFACE_HUB_TOKEN")
22
 
23
  # Load datasets
24
  dataset = load_dataset("bgsys/background-removal-arena-green", split='train')
 
25
 
26
  # Configure logging
27
  logging.basicConfig(level=logging.INFO)
 
1
  import os
2
+ import json
3
+ import time
4
+ import random
5
  import logging
6
+ import threading
7
+ from pathlib import Path
8
+ from uuid import uuid4
9
  from typing import Tuple
10
+
 
11
  import numpy as np
12
+ import gradio as gr
13
  from PIL import Image
14
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
15
  from datasets import load_dataset
16
  from huggingface_hub import CommitScheduler
17
 
18
+ from db import (
19
+ compute_elo_scores,
20
+ get_all_votes,
21
+ add_vote,
22
+ is_running_in_space,
23
+ fill_database_once
24
+ )
25
 
26
  token = os.getenv("HUGGINGFACE_HUB_TOKEN")
27
 
28
  # Load datasets
29
  dataset = load_dataset("bgsys/background-removal-arena-green", split='train')
30
+ fill_database_once()
31
 
32
  # Configure logging
33
  logging.basicConfig(level=logging.INFO)
db.py CHANGED
@@ -5,16 +5,16 @@ from sqlalchemy.ext.declarative import declarative_base
5
  from sqlalchemy.orm import sessionmaker, Session
6
  from datetime import datetime
7
  import pandas as pd
8
- import uuid
9
  from rating_systems import compute_elo
10
 
11
  def is_running_in_space():
12
  return "SPACE_ID" in os.environ
13
 
14
  if is_running_in_space():
15
- DATABASE_URL = "sqlite:///./data/newvotes.db"
16
  else:
17
- DATABASE_URL = "sqlite:///./data/local.db"
18
 
19
  engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
20
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -43,6 +43,34 @@ def get_db():
43
  finally:
44
  db.close()
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def add_vote(vote_data):
47
  with SessionLocal() as db:
48
  db_vote = Vote(**vote_data)
 
5
  from sqlalchemy.orm import sessionmaker, Session
6
  from datetime import datetime
7
  import pandas as pd
8
+ from datasets import load_dataset
9
  from rating_systems import compute_elo
10
 
11
  def is_running_in_space():
12
  return "SPACE_ID" in os.environ
13
 
14
  if is_running_in_space():
15
+ DATABASE_URL = "sqlite:///./data/hf-votes.db"
16
  else:
17
+ DATABASE_URL = "sqlite:///./data/local2.db"
18
 
19
  engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
20
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 
43
  finally:
44
  db.close()
45
 
46
+ def fill_database_once(dataset_name="bgsys/votes_datasets_test2"):
47
+ with SessionLocal() as db:
48
+ # Check if the database is already filled
49
+ if db.query(Vote).first() is None:
50
+ dataset = load_dataset(dataset_name)
51
+ for record in dataset['train']:
52
+ # Ensure the timestamp is a string
53
+ timestamp_str = record.get("timestamp", datetime.utcnow().isoformat())
54
+ if not isinstance(timestamp_str, str):
55
+ timestamp_str = datetime.utcnow().isoformat()
56
+
57
+ vote_data = {
58
+ "image_id": record.get("image_id", ""),
59
+ "model_a": record.get("model_a", ""),
60
+ "model_b": record.get("model_b", ""),
61
+ "winner": record.get("winner", ""),
62
+ "user_id": record.get("user_id", ""),
63
+ "fpath_a": record.get("fpath_a", ""),
64
+ "fpath_b": record.get("fpath_b", ""),
65
+ "timestamp": datetime.fromisoformat(timestamp_str)
66
+ }
67
+ db_vote = Vote(**vote_data)
68
+ db.add(db_vote)
69
+ db.commit()
70
+ logging.info("Database filled with data from Hugging Face dataset: %s", dataset_name)
71
+ else:
72
+ logging.info("Database already filled, skipping dataset loading.")
73
+
74
  def add_vote(vote_data):
75
  with SessionLocal() as db:
76
  db_vote = Vote(**vote_data)