diff --git a/py/db.py b/py/db.py index 5a15957..a773552 100644 --- a/py/db.py +++ b/py/db.py @@ -9,11 +9,11 @@ class DB: self.database = {} def ensure_username(self, data): - if 'username' in data: - return data.get('username') - elif 'email' in data: - for index, entry in self.database: - if entry.get('email') == data.get('email'): + if "username" in data: + return data.get("username") + elif "email" in data: + for index, entry in self.database.items(): + if entry.get("email") == data.get("email"): return index @staticmethod @@ -23,11 +23,16 @@ class DB: return hashed_password def add_user(self, data): - username = data.get('username') - password = data.get('password') - email = data.get('email') + username = data.get("username") + password = data.get("password") + email = data.get("email") hashed_password = self.hash_password(password) - user_data = {"hashed_password": hashed_password, "email": email, "settings": None, "history": None} + user_data = { + "hashed_password": hashed_password, + "email": email, + "settings": None, + "history": None, + } if username not in self.database: self.database[username] = user_data self.save_database() @@ -45,7 +50,7 @@ class DB: def update_password(self, data): username = self.ensure_username(data) - new_password = data.get('new_password') + new_password = data.get("new_password") if not self.check_credentials(data): return False @@ -56,7 +61,7 @@ class DB: def check_credentials(self, data): username = self.ensure_username(data) - password = data.get('password') + password = data.get("password") if username not in self.database: return False @@ -70,7 +75,7 @@ class DB: if not self.check_credentials(data): return False - self.database[username]['settings'] = data.get('data') + self.database[username]["settings"] = data.get("data") self.save_database() return True @@ -79,7 +84,7 @@ class DB: if not self.check_credentials(data): return None - send_back = self.database[username].get('settings') + send_back = self.database[username].get("settings") return send_back def change_history(self, data): @@ -87,7 +92,7 @@ class DB: if not self.check_credentials(data): return False - self.database[username]['history'] = data.get('data') + self.database[username]["history"] = data.get("data") self.save_database() return True @@ -96,7 +101,7 @@ class DB: if not self.check_credentials(data): return None - send_back = self.database[username].get('history') + send_back = self.database[username].get("history") return send_back def get_email(self, data): @@ -104,11 +109,10 @@ class DB: if not self.check_credentials(data): return None - send_back = self.database[username].get('email') + send_back = self.database[username].get("email") return send_back def get_name(self, data): - username = self.ensure_username(data) if not self.check_credentials(data): return None @@ -116,18 +120,18 @@ class DB: return send_back def save_database(self): - if os.environ.get('PRODUCTION') == "YES": + if os.environ.get("PRODUCTION") == "YES": server = pycouchdb.Server("http://admin:admin@localhost:5984/") db = server.database("interstellar_ai") db.save(self.database) else: - with open("database.json", 'w') as file: + with open("database.json", "w") as file: print("saving") json.dump(self.database, file) def load_database(self): - if os.environ.get('PRODUCTION') == "YES": + if os.environ.get("PRODUCTION") == "YES": server = pycouchdb.Server("http://admin:admin@localhost:5984/") db = server.database("interstellar_ai") if db: @@ -138,7 +142,7 @@ class DB: db.save(self.database) else: try: - with open("database.json", 'r') as file: + with open("database.json", "r") as file: self.database = json.load(file) except FileNotFoundError: pass