diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..17e07f7 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,302 @@ +use std::convert::Infallible; +use std::net::{IpAddr, SocketAddr, Ipv4Addr}; +use std::str::FromStr; +use std::sync::mpsc::{Sender, channel}; +use std::{thread, time::Duration}; +use std::sync::Arc; +use std::fs; + +use serde_derive::{Serialize, Deserialize}; +use toml; + +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Request, Response, Server, StatusCode, Method, header::HeaderValue}; + +use rusqlite::{Connection, Transaction}; + +use blake3; + +enum VustQuery { + Get, + Like, + Commit, +} + +struct VustMessage { + req_type: VustQuery, + ip: SocketAddr, + path: String, + res: Sender<Response<Body>>, +} + +fn respond(sc :StatusCode, mes : String) -> Response<Body> { + Response::builder() + .status(sc) + .body(Body::from(mes)) + .unwrap() +} + +fn wrapCORS(mut res: Response<Body>, config :Arc<Config>) -> Response<Body> { + res.headers_mut().insert("Access-Control-Allow-Origin", HeaderValue::from_str(&config.cors_hosts).unwrap() ); + res.headers_mut().insert("Access-Control-Allow-Methods", HeaderValue::from_static("*")); + res.headers_mut().insert("Access-Control-Allow-Headers", HeaderValue::from_static("*")); + res +} + +fn handle(req: Request<Body>, addr: SocketAddr, tx: Sender<VustMessage>, config: Arc<Config>) -> Response<Body> { + const PREFIX_PATH: &str = "/like/"; + + let path: String = req.uri().path().to_string(); + + let query_type = match *req.method() { + Method::POST => VustQuery::Like, + Method::PUT => VustQuery::Commit, + Method::GET => VustQuery::Get, + Method::OPTIONS => { + return respond(StatusCode::OK, "".to_string()); + }, + _ => { + return respond(StatusCode::METHOD_NOT_ALLOWED, "Tu t'attends à quoi au juste ?".to_string()); + }, + }; + + if !path.starts_with(PREFIX_PATH) { + return respond(StatusCode::BAD_REQUEST, format!("Le point d'entrée est {}", PREFIX_PATH).to_string()); + } + + let path : String = path.chars().skip(PREFIX_PATH.len()).collect(); + + if path.contains("/") || path.len() > 128 || path.len() == 0 || !config.list_articles.contains(&path) { + return respond(StatusCode::BAD_REQUEST, "Tu t'attends à quoi au juste ?".to_string()); + } + + let (ttx , rx) = channel(); + let message = VustMessage{ + req_type: query_type, + ip: addr, + path, + res: ttx, + }; + + tx.send(message).unwrap(); + + match *req.method() { + Method::PUT => respond(StatusCode::OK, "Fait !".to_string()), + _ => rx.recv().unwrap(), + } +} + +fn do_get(tr: &Transaction, ip : SocketAddr, path : String) -> Response<Body> { + + for _ in [1..3] { + let mut req_prepared = tr.prepare("SELECT cast(SUM(number) as text) FROM likes WHERE path = ?").unwrap(); + let mut rows = req_prepared.query(rusqlite::params![path.as_str()]).unwrap(); + + let first_row = rows.next(); + + // Error while fetching + // By example busy database + if first_row.is_err() { + continue; + } + let first_row = first_row.unwrap(); + + // Empty row ! Nobody like what I do. 😭 + if first_row.is_none() { + return respond(StatusCode::OK, "0".to_string()) + } + + let first_row = first_row.unwrap(); + + match first_row.get(0) { + Ok(nb_likes) => return respond(StatusCode::OK, nb_likes), + // In case of NULL or not a Integer value: + Err(_) => break, + } + } + + respond(StatusCode::OK, "❓".to_string()) +} + +fn do_like(tr: &Transaction, ip : SocketAddr, path : String) -> Response<Body> { + let hash_ip = match ip.ip() { + IpAddr::V4(ip) => blake3::hash(&ip.octets()).to_hex(), + IpAddr::V6(ip) => blake3::hash(&ip.octets()).to_hex(), + }; + + for _ in [1..7] { + let mut req_prepared = tr.prepare("SELECT number, cast(lastMod as UNSIGNED INT), cast(unixepoch() as UNSIGNED INT) FROM likes WHERE ip_hash = ? and path = ?").unwrap(); // , path.as_str()]).unwrap(); + let mut rows = req_prepared.query(rusqlite::params![hash_ip.as_str(), path.as_str()]).unwrap(); + + let first_row = rows.next(); + + match first_row { + Ok(None) => { + let res = tr.execute("INSERT OR IGNORE INTO likes VALUES (?, ?, unixepoch(), 1)", [hash_ip.as_str(), path.as_str()]); + if res.is_err() { + println!("Error doing the request {:?}", res.err()); + + continue; + } + }, + Ok(Some(t)) => { + let number : u64 = t.get(0).unwrap(); + let time : u64 = t.get(1).unwrap(); + let now : u64= t.get(2).unwrap(); + + if number > 31 { + return respond(StatusCode::RANGE_NOT_SATISFIABLE, format!("Trop de cœurs ! 💕 x ({})", number).to_string()); + } + + let limite = (1 << number) / 10; // 2^likes / Cst + let dtime = now - time; + + if dtime < limite { + let time_remaining = limite - dtime; + return respond(StatusCode::TOO_MANY_REQUESTS, format!("Attendez {}s avant de pouvoir envoyer un autre cœur.", time_remaining)); + } + + let res = tr.execute("UPDATE likes SET number = number + 1, lastMod = unixepoch() WHERE ip_hash = ? and path = ?", [hash_ip.as_str(), path.as_str()]); + + if res.is_err() { + println!("Error doing the request {:?}", res.err()); + + continue + } + return respond(StatusCode::OK, format!("Merci ! 💕 x {}", number + 1).to_string()) + }, + Err(_) => { + continue + }, + }; + } + + respond(StatusCode::INTERNAL_SERVER_ERROR, "💕 Erreur, il y a un soucis. (>﹏<)".to_string()) +} + +#[derive(Serialize, Deserialize)] +struct Config { + ip: String, + port: u16, + cors_hosts: String, + list_articles: Vec<String>, +} + +fn get_config() -> String { + let path = if std::path::Path::new("vust.conf").exists() { + "vust.conf" + } else if std::path::Path::new("/etc/vust.conf").exists() { + "/etc/vust.conf" + } else { + return r#" + ip = "127.0.0.1" + port = 3000 + + # A comma sseparated list of hosts + cors_hosts = '*' + list_articles = [ + 'bizarreries-du-langage-c', + 'retour-sur-laoc-2021-semaine-1', + '2FA-discord-sur-pc', + 'duckduckgo-google-en-mieux', + ] + "#.to_string(); + }; + + fs::read_to_string(path).expect("Unable to read config file") +} + +#[tokio::main] +async fn main() { + let config: Arc<Config> = Arc::new(toml::from_str(get_config().as_str()).unwrap()); + let ip = IpAddr::from_str(config.ip.as_str()).expect("Invalid IP address"); + let addr = SocketAddr::new(ip, config.port); + eprintln!("Listening on {}", addr); + + let (tx , rx) = channel(); + + + let ttx = tx.clone(); + thread::spawn(move || { + let tx = ttx.clone(); + loop { + let (txx , _) = channel(); + let res = tx.send(VustMessage{ + req_type: VustQuery::Commit, + ip: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), + path: "".to_string(), + res: txx, + }); + + if res.is_ok() || res.is_err() { + thread::sleep(Duration::from_secs(10)); + } + } + }); + + // This thread handle the sqlite connection + thread::spawn(move || { + let mut conn = Connection::open("likes.db").unwrap(); + loop { + + let tr = conn.transaction().unwrap(); + let mut should_commit = false; + + loop { + let recv : VustMessage = rx.recv().unwrap(); + match recv.req_type { + VustQuery::Like => { + let res = do_like(&tr, recv.ip, recv.path); + let res = recv.res.send(res); + if res.is_ok() { + should_commit = true; + } + + continue; + }, + VustQuery::Get => { + let res = recv.res.send(do_get(&tr, recv.ip, recv.path)); + if res.is_ok() || res.is_err() { + continue; + } + }, + VustQuery::Commit => { + if should_commit { + break; + } + } + } + } + tr.commit().unwrap(); + } + }); + + // The closure passed to `make_service_fn` is executed each time a new + // connection is established and returns a future that resolves to a + // service. + let make_service = make_service_fn(|conn: &AddrStream| { + // The closure passed to `service_fn` is executed each time a request + // arrives on the connection and returns a future that resolves + // to a response. + let remote_addr = conn.remote_addr(); + let tx = tx.clone(); + let config = config.clone(); + + async move { + Ok::<_, Infallible>(service_fn( move |req| { + let tx = tx.clone(); + let config = config.clone(); + async move { + Ok::<_, Infallible>(wrapCORS(handle(req, remote_addr, tx, config.clone()), config.clone())) + } + })) + } + }); + + // Start the server. + if let Err(e) = Server::bind(&addr).serve(make_service).await { + eprintln!("Error: {:#}", e); + std::process::exit(1); + } +} |